refactor(5-1):重构代码以提高可读性和效率

- 重新组织代码结构,使逻辑更清晰
- 使用更有意义的变量名,提高代码可读性
- 移除冗余的中间变量,简化代码- 添加注释以解释关键步骤
This commit is contained in:
fly6516 2025-04-20 03:04:39 +08:00
parent 036a740505
commit 1d16bebe43

115
5-1.py
View File

@ -59,37 +59,28 @@ def idfs(corpus):
return idf_rdd
# 分词化完整数据集
amazonFullRecToToken = amazon.map(lambda line: (line[0], tokenize(line[1])))
googleFullRecToToken = google.map(lambda line: (line[0], tokenize(line[1])))
print('Amazon full dataset is {} products, Google full dataset is {} products'.format(
amazonFullRecToToken.count(),
googleFullRecToToken.count()))
# 计算完整数据集的 IDF
idfs_full = idfs(full_corpus_rdd)
fullCorpusRDD = amazonFullRecToToken.union(googleFullRecToToken)
idfsFull = idfs(fullCorpusRDD)
idfsFullCount = idfsFull.count()
print('There are %s unique tokens in the full datasets.' % idfsFullCount)
# 创建广播变量
idfs_full_weights = idfs_full.collectAsMap()
idfs_full_broadcast = sc.broadcast(idfs_full_weights)
idfsFullWeights = idfsFull.collectAsMap()
idfsFullBroadcast = sc.broadcast(idfsFullWeights)
# 计算 TF-IDF
def tf(tokens):
""" 计算词频 TF """
total = len(tokens)
counts = {}
for token in tokens:
counts[token] = counts.get(token, 0) + 1
return {k: float(v) / total for k, v in counts.items()}
def tfidf(tokens, idfs):
""" 计算 TF-IDF """
tfs = tf(tokens)
return {k: v * idfs.get(k, 0) for k, v in tfs.items()}
# 计算 Amazon 和 Google 的 TF-IDF
amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
# 创建广播变量
amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap())
google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap())
# 计算完整数据集的 TF-IDF
amazonWeightsRDD = amazonFullRecToToken.map(lambda x: (x[0], tfidf(x[1], idfsFullBroadcast.value)))
googleWeightsRDD = googleFullRecToToken.map(lambda x: (x[0], tfidf(x[1], idfsFullBroadcast.value)))
print('There are {} Amazon weights and {} Google weights.'.format(amazonWeightsRDD.count(),
googleWeightsRDD.count()))
# 计算权重范数
def norm(weights):
@ -98,53 +89,61 @@ def norm(weights):
# 计算 Amazon 和 Google 的权重范数
amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1])))
google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1])))
# 创建广播变量
amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap())
google_norms_broadcast = sc.broadcast(google_norms.collectAsMap())
amazonNorms = amazonWeightsRDD.map(lambda x: (x[0], norm(x[1])))
amazonNormsBroadcast = sc.broadcast(amazonNorms.collectAsMap())
googleNorms = googleWeightsRDD.map(lambda x: (x[0], norm(x[1])))
googleNormsBroadcast = sc.broadcast(googleNorms.collectAsMap())
# 创建反向索引
def invert(record):
""" 反转 (ID, tokens) 到 (token, ID) """
""" Invert (ID, tokens) to a list of (token, ID) """
id = record[0]
weights = record[1]
return [(token, id) for token in weights]
pairs = [(token, id) for token in weights.keys()]
return pairs
amazonInvPairsRDD = amazonWeightsRDD.flatMap(lambda x: invert(x)).cache()
googleInvPairsRDD = googleWeightsRDD.flatMap(lambda x: invert(x)).cache()
print('There are {} Amazon inverted pairs and {} Google inverted pairs.'.format(amazonInvPairsRDD.count(),
googleInvPairsRDD.count()))
# 创建反向索引
amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache()
google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache()
# 识别共有 token
def swap(record):
""" Swap (token, (ID, URL)) to ((ID, URL), token) """
token = record[0]
keys = record[1]
return (keys, token)
# 计算共有的 token
common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map(
lambda x: (x[0], list(x[1]))).cache()
commonTokens = (amazonInvPairsRDD
.join(googleInvPairsRDD)
.map(lambda x: swap(x))
.groupByKey()
.map(lambda x: (x[0], list(x[1])))
.cache())
print('Found %d common tokens' % commonTokens.count())
# 计算余弦相似度
def fast_cosine_similarity(record):
""" 计算余弦相似度 """
amazon_id = record[0][0]
google_url = record[0][1]
amazonWeightsBroadcast = sc.broadcast(amazonWeightsRDD.collectAsMap())
googleWeightsBroadcast = sc.broadcast(googleWeightsRDD.collectAsMap())
def fastCosineSimilarity(record):
""" Compute Cosine Similarity using Broadcast variables """
amazonRec = record[0][0]
googleRec = record[0][1]
tokens = record[1]
s = sum([(amazonWeightsBroadcast.value[amazonRec].get(token, 0) * googleWeightsBroadcast.value[googleRec].get(token, 0))
for token in tokens])
value = s / (amazonNormsBroadcast.value[amazonRec] * googleNormsBroadcast.value[googleRec])
key = (amazonRec, googleRec)
return (key, value)
# 使用 .get() 方法来安全地访问字典中的元素,避免 KeyError
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(token, 0)
for token in tokens if token in amazon_weights_broadcast.value[amazon_id] and token in google_weights_broadcast.value[google_url]])
similaritiesFullRDD = commonTokens.map(lambda x: fastCosineSimilarity(x)).cache()
# 使用广播变量计算余弦相似度
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
return ((amazon_id, google_url), value)
# 计算相似度
similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache()
print(similaritiesFullRDD.count())
# 查看结果
print("Number of similarity records: {}".format(similarities_full_rdd.count()))
print("Number of similarity records: {}".format(similaritiesFullRDD.count()))
# 计算并测试相似度
similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][