refactor(5-1):重构代码以提高可读性和效率
- 重新组织代码结构,使逻辑更清晰 - 使用更有意义的变量名,提高代码可读性 - 移除冗余的中间变量,简化代码- 添加注释以解释关键步骤
This commit is contained in:
parent
036a740505
commit
1d16bebe43
115
5-1.py
115
5-1.py
@ -59,37 +59,28 @@ def idfs(corpus):
|
|||||||
return idf_rdd
|
return idf_rdd
|
||||||
|
|
||||||
|
|
||||||
|
# 分词化完整数据集
|
||||||
|
amazonFullRecToToken = amazon.map(lambda line: (line[0], tokenize(line[1])))
|
||||||
|
googleFullRecToToken = google.map(lambda line: (line[0], tokenize(line[1])))
|
||||||
|
print('Amazon full dataset is {} products, Google full dataset is {} products'.format(
|
||||||
|
amazonFullRecToToken.count(),
|
||||||
|
googleFullRecToToken.count()))
|
||||||
|
|
||||||
# 计算完整数据集的 IDF
|
# 计算完整数据集的 IDF
|
||||||
idfs_full = idfs(full_corpus_rdd)
|
fullCorpusRDD = amazonFullRecToToken.union(googleFullRecToToken)
|
||||||
|
idfsFull = idfs(fullCorpusRDD)
|
||||||
|
idfsFullCount = idfsFull.count()
|
||||||
|
print('There are %s unique tokens in the full datasets.' % idfsFullCount)
|
||||||
|
|
||||||
# 创建广播变量
|
# 创建广播变量
|
||||||
idfs_full_weights = idfs_full.collectAsMap()
|
idfsFullWeights = idfsFull.collectAsMap()
|
||||||
idfs_full_broadcast = sc.broadcast(idfs_full_weights)
|
idfsFullBroadcast = sc.broadcast(idfsFullWeights)
|
||||||
|
|
||||||
|
# 计算完整数据集的 TF-IDF
|
||||||
# 计算 TF-IDF
|
amazonWeightsRDD = amazonFullRecToToken.map(lambda x: (x[0], tfidf(x[1], idfsFullBroadcast.value)))
|
||||||
def tf(tokens):
|
googleWeightsRDD = googleFullRecToToken.map(lambda x: (x[0], tfidf(x[1], idfsFullBroadcast.value)))
|
||||||
""" 计算词频 TF """
|
print('There are {} Amazon weights and {} Google weights.'.format(amazonWeightsRDD.count(),
|
||||||
total = len(tokens)
|
googleWeightsRDD.count()))
|
||||||
counts = {}
|
|
||||||
for token in tokens:
|
|
||||||
counts[token] = counts.get(token, 0) + 1
|
|
||||||
return {k: float(v) / total for k, v in counts.items()}
|
|
||||||
|
|
||||||
|
|
||||||
def tfidf(tokens, idfs):
|
|
||||||
""" 计算 TF-IDF """
|
|
||||||
tfs = tf(tokens)
|
|
||||||
return {k: v * idfs.get(k, 0) for k, v in tfs.items()}
|
|
||||||
|
|
||||||
|
|
||||||
# 计算 Amazon 和 Google 的 TF-IDF
|
|
||||||
amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
|
|
||||||
google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
|
|
||||||
|
|
||||||
# 创建广播变量
|
|
||||||
amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap())
|
|
||||||
google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap())
|
|
||||||
|
|
||||||
# 计算权重范数
|
# 计算权重范数
|
||||||
def norm(weights):
|
def norm(weights):
|
||||||
@ -98,53 +89,61 @@ def norm(weights):
|
|||||||
|
|
||||||
|
|
||||||
# 计算 Amazon 和 Google 的权重范数
|
# 计算 Amazon 和 Google 的权重范数
|
||||||
amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
amazonNorms = amazonWeightsRDD.map(lambda x: (x[0], norm(x[1])))
|
||||||
google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
amazonNormsBroadcast = sc.broadcast(amazonNorms.collectAsMap())
|
||||||
|
googleNorms = googleWeightsRDD.map(lambda x: (x[0], norm(x[1])))
|
||||||
# 创建广播变量
|
googleNormsBroadcast = sc.broadcast(googleNorms.collectAsMap())
|
||||||
amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap())
|
|
||||||
google_norms_broadcast = sc.broadcast(google_norms.collectAsMap())
|
|
||||||
|
|
||||||
|
|
||||||
# 创建反向索引
|
# 创建反向索引
|
||||||
def invert(record):
|
def invert(record):
|
||||||
""" 反转 (ID, tokens) 到 (token, ID) """
|
""" Invert (ID, tokens) to a list of (token, ID) """
|
||||||
id = record[0]
|
id = record[0]
|
||||||
weights = record[1]
|
weights = record[1]
|
||||||
return [(token, id) for token in weights]
|
pairs = [(token, id) for token in weights.keys()]
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
amazonInvPairsRDD = amazonWeightsRDD.flatMap(lambda x: invert(x)).cache()
|
||||||
|
googleInvPairsRDD = googleWeightsRDD.flatMap(lambda x: invert(x)).cache()
|
||||||
|
print('There are {} Amazon inverted pairs and {} Google inverted pairs.'.format(amazonInvPairsRDD.count(),
|
||||||
|
googleInvPairsRDD.count()))
|
||||||
|
|
||||||
# 创建反向索引
|
# 识别共有 token
|
||||||
amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache()
|
def swap(record):
|
||||||
google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache()
|
""" Swap (token, (ID, URL)) to ((ID, URL), token) """
|
||||||
|
token = record[0]
|
||||||
|
keys = record[1]
|
||||||
|
return (keys, token)
|
||||||
|
|
||||||
# 计算共有的 token
|
commonTokens = (amazonInvPairsRDD
|
||||||
common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map(
|
.join(googleInvPairsRDD)
|
||||||
lambda x: (x[0], list(x[1]))).cache()
|
.map(lambda x: swap(x))
|
||||||
|
.groupByKey()
|
||||||
|
.map(lambda x: (x[0], list(x[1])))
|
||||||
|
.cache())
|
||||||
|
|
||||||
|
print('Found %d common tokens' % commonTokens.count())
|
||||||
|
|
||||||
# 计算余弦相似度
|
# 计算余弦相似度
|
||||||
def fast_cosine_similarity(record):
|
amazonWeightsBroadcast = sc.broadcast(amazonWeightsRDD.collectAsMap())
|
||||||
""" 计算余弦相似度 """
|
googleWeightsBroadcast = sc.broadcast(googleWeightsRDD.collectAsMap())
|
||||||
amazon_id = record[0][0]
|
|
||||||
google_url = record[0][1]
|
def fastCosineSimilarity(record):
|
||||||
|
""" Compute Cosine Similarity using Broadcast variables """
|
||||||
|
amazonRec = record[0][0]
|
||||||
|
googleRec = record[0][1]
|
||||||
tokens = record[1]
|
tokens = record[1]
|
||||||
|
s = sum([(amazonWeightsBroadcast.value[amazonRec].get(token, 0) * googleWeightsBroadcast.value[googleRec].get(token, 0))
|
||||||
|
for token in tokens])
|
||||||
|
value = s / (amazonNormsBroadcast.value[amazonRec] * googleNormsBroadcast.value[googleRec])
|
||||||
|
key = (amazonRec, googleRec)
|
||||||
|
return (key, value)
|
||||||
|
|
||||||
# 使用 .get() 方法来安全地访问字典中的元素,避免 KeyError
|
similaritiesFullRDD = commonTokens.map(lambda x: fastCosineSimilarity(x)).cache()
|
||||||
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(token, 0)
|
|
||||||
for token in tokens if token in amazon_weights_broadcast.value[amazon_id] and token in google_weights_broadcast.value[google_url]])
|
|
||||||
|
|
||||||
# 使用广播变量计算余弦相似度
|
print(similaritiesFullRDD.count())
|
||||||
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
|
|
||||||
|
|
||||||
return ((amazon_id, google_url), value)
|
|
||||||
|
|
||||||
|
|
||||||
# 计算相似度
|
|
||||||
similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache()
|
|
||||||
|
|
||||||
# 查看结果
|
# 查看结果
|
||||||
print("Number of similarity records: {}".format(similarities_full_rdd.count()))
|
print("Number of similarity records: {}".format(similaritiesFullRDD.count()))
|
||||||
|
|
||||||
# 计算并测试相似度
|
# 计算并测试相似度
|
||||||
similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][
|
similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][
|
||||||
|
Loading…
Reference in New Issue
Block a user