style(5-1): 移除测试断言中的 f-string- 删除了测试断言中不必要的 f-string 表达式
- 简化了断言的错误信息输出格式
This commit is contained in:
parent
8fcedbec41
commit
8bccc2cad7
44
5-1.py
44
5-1.py
@ -12,11 +12,13 @@ sqlContext = SQLContext(sc)
|
|||||||
amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv"
|
amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv"
|
||||||
google_path = "hdfs://master:9000/user/root/Google_small.csv"
|
google_path = "hdfs://master:9000/user/root/Google_small.csv"
|
||||||
|
|
||||||
|
|
||||||
def tokenize(text):
|
def tokenize(text):
|
||||||
""" 分词化:将文本转成小写并提取字母数字组合的词 """
|
""" 分词化:将文本转成小写并提取字母数字组合的词 """
|
||||||
return re.findall(r'\w+', text.lower())
|
return re.findall(r'\w+', text.lower())
|
||||||
|
|
||||||
def parse_data_file(line):
|
|
||||||
|
def parse_data_file(line, data_source='amazon'):
|
||||||
""" 解析数据文件的每一行 """
|
""" 解析数据文件的每一行 """
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
@ -25,17 +27,27 @@ def parse_data_file(line):
|
|||||||
if len(parts) < 5:
|
if len(parts) < 5:
|
||||||
return None
|
return None
|
||||||
doc_id = parts[0].strip()
|
doc_id = parts[0].strip()
|
||||||
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
|
||||||
|
# 对不同数据集进行处理
|
||||||
|
if data_source == 'amazon':
|
||||||
|
# Amazon 文件格式: id, title, description, manufacturer, price
|
||||||
|
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
||||||
|
else:
|
||||||
|
# Google 文件格式: id, name, description, manufacturer, price
|
||||||
|
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
||||||
|
|
||||||
return (doc_id, text)
|
return (doc_id, text)
|
||||||
|
|
||||||
|
|
||||||
# 读取和解析数据
|
# 读取和解析数据
|
||||||
def load_data(path):
|
def load_data(path, data_source='amazon'):
|
||||||
""" 读取并解析数据文件 """
|
""" 读取并解析数据文件 """
|
||||||
raw_data = sc.textFile(path).map(parse_data_file).filter(lambda x: x is not None)
|
raw_data = sc.textFile(path).map(lambda line: parse_data_file(line, data_source)).filter(lambda x: x is not None)
|
||||||
return raw_data
|
return raw_data
|
||||||
|
|
||||||
amazon = load_data(amazon_path)
|
|
||||||
google = load_data(google_path)
|
amazon = load_data(amazon_path, data_source='amazon')
|
||||||
|
google = load_data(google_path, data_source='google')
|
||||||
|
|
||||||
# 对数据进行分词化
|
# 对数据进行分词化
|
||||||
amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1])))
|
amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1])))
|
||||||
@ -44,6 +56,7 @@ google_rec_to_token = google.map(lambda x: (x[0], tokenize(x[1])))
|
|||||||
# 合并 Amazon 和 Google 数据集
|
# 合并 Amazon 和 Google 数据集
|
||||||
full_corpus_rdd = amazon_rec_to_token.union(google_rec_to_token)
|
full_corpus_rdd = amazon_rec_to_token.union(google_rec_to_token)
|
||||||
|
|
||||||
|
|
||||||
# 计算 IDF
|
# 计算 IDF
|
||||||
def idfs(corpus):
|
def idfs(corpus):
|
||||||
""" 计算逆文档频率 IDF """
|
""" 计算逆文档频率 IDF """
|
||||||
@ -53,6 +66,7 @@ def idfs(corpus):
|
|||||||
idf_rdd = df_rdd.map(lambda x: (x[0], float(N) / float(x[1])))
|
idf_rdd = df_rdd.map(lambda x: (x[0], float(N) / float(x[1])))
|
||||||
return idf_rdd
|
return idf_rdd
|
||||||
|
|
||||||
|
|
||||||
# 计算完整数据集的 IDF
|
# 计算完整数据集的 IDF
|
||||||
idfs_full = idfs(full_corpus_rdd)
|
idfs_full = idfs(full_corpus_rdd)
|
||||||
|
|
||||||
@ -60,6 +74,7 @@ idfs_full = idfs(full_corpus_rdd)
|
|||||||
idfs_full_weights = idfs_full.collectAsMap()
|
idfs_full_weights = idfs_full.collectAsMap()
|
||||||
idfs_full_broadcast = sc.broadcast(idfs_full_weights)
|
idfs_full_broadcast = sc.broadcast(idfs_full_weights)
|
||||||
|
|
||||||
|
|
||||||
# 计算 TF-IDF
|
# 计算 TF-IDF
|
||||||
def tf(tokens):
|
def tf(tokens):
|
||||||
""" 计算词频 TF """
|
""" 计算词频 TF """
|
||||||
@ -69,20 +84,24 @@ def tf(tokens):
|
|||||||
counts[token] = counts.get(token, 0) + 1
|
counts[token] = counts.get(token, 0) + 1
|
||||||
return {k: float(v) / total for k, v in counts.items()}
|
return {k: float(v) / total for k, v in counts.items()}
|
||||||
|
|
||||||
|
|
||||||
def tfidf(tokens, idfs):
|
def tfidf(tokens, idfs):
|
||||||
""" 计算 TF-IDF """
|
""" 计算 TF-IDF """
|
||||||
tfs = tf(tokens)
|
tfs = tf(tokens)
|
||||||
return {k: v * idfs.get(k, 0) for k, v in tfs.items()}
|
return {k: v * idfs.get(k, 0) for k, v in tfs.items()}
|
||||||
|
|
||||||
|
|
||||||
# 计算 Amazon 和 Google 的 TF-IDF
|
# 计算 Amazon 和 Google 的 TF-IDF
|
||||||
amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
|
amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
|
||||||
google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
|
google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
|
||||||
|
|
||||||
|
|
||||||
# 计算权重范数
|
# 计算权重范数
|
||||||
def norm(weights):
|
def norm(weights):
|
||||||
""" 计算向量的范数 """
|
""" 计算向量的范数 """
|
||||||
return math.sqrt(sum([w * w for w in weights.values()]))
|
return math.sqrt(sum([w * w for w in weights.values()]))
|
||||||
|
|
||||||
|
|
||||||
# 计算 Amazon 和 Google 的权重范数
|
# 计算 Amazon 和 Google 的权重范数
|
||||||
amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
||||||
google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
||||||
@ -91,6 +110,7 @@ google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
|||||||
amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap())
|
amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap())
|
||||||
google_norms_broadcast = sc.broadcast(google_norms.collectAsMap())
|
google_norms_broadcast = sc.broadcast(google_norms.collectAsMap())
|
||||||
|
|
||||||
|
|
||||||
# 创建反向索引
|
# 创建反向索引
|
||||||
def invert(record):
|
def invert(record):
|
||||||
""" 反转 (ID, tokens) 到 (token, ID) """
|
""" 反转 (ID, tokens) 到 (token, ID) """
|
||||||
@ -98,12 +118,15 @@ def invert(record):
|
|||||||
weights = record[1]
|
weights = record[1]
|
||||||
return [(token, id) for token in weights]
|
return [(token, id) for token in weights]
|
||||||
|
|
||||||
|
|
||||||
# 创建反向索引
|
# 创建反向索引
|
||||||
amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache()
|
amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache()
|
||||||
google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache()
|
google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache()
|
||||||
|
|
||||||
# 计算共有的 token
|
# 计算共有的 token
|
||||||
common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map(lambda x: (x[0], list(x[1]))).cache()
|
common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map(
|
||||||
|
lambda x: (x[0], list(x[1]))).cache()
|
||||||
|
|
||||||
|
|
||||||
# 计算余弦相似度
|
# 计算余弦相似度
|
||||||
def fast_cosine_similarity(record):
|
def fast_cosine_similarity(record):
|
||||||
@ -111,10 +134,12 @@ def fast_cosine_similarity(record):
|
|||||||
amazon_id = record[0][0]
|
amazon_id = record[0][0]
|
||||||
google_url = record[0][1]
|
google_url = record[0][1]
|
||||||
tokens = record[1]
|
tokens = record[1]
|
||||||
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(token, 0) for token in tokens])
|
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(
|
||||||
|
token, 0) for token in tokens])
|
||||||
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
|
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
|
||||||
return ((amazon_id, google_url), value)
|
return ((amazon_id, google_url), value)
|
||||||
|
|
||||||
|
|
||||||
# 计算相似度
|
# 计算相似度
|
||||||
similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache()
|
similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache()
|
||||||
|
|
||||||
@ -122,7 +147,8 @@ similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache()
|
|||||||
print("Number of similarity records: {}".format(similarities_full_rdd.count()))
|
print("Number of similarity records: {}".format(similarities_full_rdd.count()))
|
||||||
|
|
||||||
# 计算并测试相似度
|
# 计算并测试相似度
|
||||||
similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][1] == 'http://www.google.com/base/feeds/snippets/13823221823254120257').collect()
|
similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][
|
||||||
|
1] == 'http://www.google.com/base/feeds/snippets/13823221823254120257').collect()
|
||||||
print(len(similarity_test))
|
print(len(similarity_test))
|
||||||
|
|
||||||
# 测试
|
# 测试
|
||||||
|
Loading…
Reference in New Issue
Block a user