refactor(5-1):重构代码以提高可读性和效率
- 重新组织代码结构,使逻辑更清晰 - 使用更有意义的变量名,提高代码可读性 - 移除冗余的中间变量,简化代码- 添加注释以解释关键步骤
This commit is contained in:
parent
036a740505
commit
1d16bebe43
115
5-1.py
115
5-1.py
@ -59,37 +59,28 @@ def idfs(corpus):
|
||||
return idf_rdd
|
||||
|
||||
|
||||
# 分词化完整数据集
|
||||
amazonFullRecToToken = amazon.map(lambda line: (line[0], tokenize(line[1])))
|
||||
googleFullRecToToken = google.map(lambda line: (line[0], tokenize(line[1])))
|
||||
print('Amazon full dataset is {} products, Google full dataset is {} products'.format(
|
||||
amazonFullRecToToken.count(),
|
||||
googleFullRecToToken.count()))
|
||||
|
||||
# 计算完整数据集的 IDF
|
||||
idfs_full = idfs(full_corpus_rdd)
|
||||
fullCorpusRDD = amazonFullRecToToken.union(googleFullRecToToken)
|
||||
idfsFull = idfs(fullCorpusRDD)
|
||||
idfsFullCount = idfsFull.count()
|
||||
print('There are %s unique tokens in the full datasets.' % idfsFullCount)
|
||||
|
||||
# 创建广播变量
|
||||
idfs_full_weights = idfs_full.collectAsMap()
|
||||
idfs_full_broadcast = sc.broadcast(idfs_full_weights)
|
||||
idfsFullWeights = idfsFull.collectAsMap()
|
||||
idfsFullBroadcast = sc.broadcast(idfsFullWeights)
|
||||
|
||||
|
||||
# 计算 TF-IDF
|
||||
def tf(tokens):
|
||||
""" 计算词频 TF """
|
||||
total = len(tokens)
|
||||
counts = {}
|
||||
for token in tokens:
|
||||
counts[token] = counts.get(token, 0) + 1
|
||||
return {k: float(v) / total for k, v in counts.items()}
|
||||
|
||||
|
||||
def tfidf(tokens, idfs):
|
||||
""" 计算 TF-IDF """
|
||||
tfs = tf(tokens)
|
||||
return {k: v * idfs.get(k, 0) for k, v in tfs.items()}
|
||||
|
||||
|
||||
# 计算 Amazon 和 Google 的 TF-IDF
|
||||
amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
|
||||
google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
|
||||
|
||||
# 创建广播变量
|
||||
amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap())
|
||||
google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap())
|
||||
# 计算完整数据集的 TF-IDF
|
||||
amazonWeightsRDD = amazonFullRecToToken.map(lambda x: (x[0], tfidf(x[1], idfsFullBroadcast.value)))
|
||||
googleWeightsRDD = googleFullRecToToken.map(lambda x: (x[0], tfidf(x[1], idfsFullBroadcast.value)))
|
||||
print('There are {} Amazon weights and {} Google weights.'.format(amazonWeightsRDD.count(),
|
||||
googleWeightsRDD.count()))
|
||||
|
||||
# 计算权重范数
|
||||
def norm(weights):
|
||||
@ -98,53 +89,61 @@ def norm(weights):
|
||||
|
||||
|
||||
# 计算 Amazon 和 Google 的权重范数
|
||||
amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
||||
google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1])))
|
||||
|
||||
# 创建广播变量
|
||||
amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap())
|
||||
google_norms_broadcast = sc.broadcast(google_norms.collectAsMap())
|
||||
|
||||
amazonNorms = amazonWeightsRDD.map(lambda x: (x[0], norm(x[1])))
|
||||
amazonNormsBroadcast = sc.broadcast(amazonNorms.collectAsMap())
|
||||
googleNorms = googleWeightsRDD.map(lambda x: (x[0], norm(x[1])))
|
||||
googleNormsBroadcast = sc.broadcast(googleNorms.collectAsMap())
|
||||
|
||||
# 创建反向索引
|
||||
def invert(record):
|
||||
""" 反转 (ID, tokens) 到 (token, ID) """
|
||||
""" Invert (ID, tokens) to a list of (token, ID) """
|
||||
id = record[0]
|
||||
weights = record[1]
|
||||
return [(token, id) for token in weights]
|
||||
pairs = [(token, id) for token in weights.keys()]
|
||||
return pairs
|
||||
|
||||
amazonInvPairsRDD = amazonWeightsRDD.flatMap(lambda x: invert(x)).cache()
|
||||
googleInvPairsRDD = googleWeightsRDD.flatMap(lambda x: invert(x)).cache()
|
||||
print('There are {} Amazon inverted pairs and {} Google inverted pairs.'.format(amazonInvPairsRDD.count(),
|
||||
googleInvPairsRDD.count()))
|
||||
|
||||
# 创建反向索引
|
||||
amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache()
|
||||
google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache()
|
||||
# 识别共有 token
|
||||
def swap(record):
|
||||
""" Swap (token, (ID, URL)) to ((ID, URL), token) """
|
||||
token = record[0]
|
||||
keys = record[1]
|
||||
return (keys, token)
|
||||
|
||||
# 计算共有的 token
|
||||
common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map(
|
||||
lambda x: (x[0], list(x[1]))).cache()
|
||||
commonTokens = (amazonInvPairsRDD
|
||||
.join(googleInvPairsRDD)
|
||||
.map(lambda x: swap(x))
|
||||
.groupByKey()
|
||||
.map(lambda x: (x[0], list(x[1])))
|
||||
.cache())
|
||||
|
||||
print('Found %d common tokens' % commonTokens.count())
|
||||
|
||||
# 计算余弦相似度
|
||||
def fast_cosine_similarity(record):
|
||||
""" 计算余弦相似度 """
|
||||
amazon_id = record[0][0]
|
||||
google_url = record[0][1]
|
||||
amazonWeightsBroadcast = sc.broadcast(amazonWeightsRDD.collectAsMap())
|
||||
googleWeightsBroadcast = sc.broadcast(googleWeightsRDD.collectAsMap())
|
||||
|
||||
def fastCosineSimilarity(record):
|
||||
""" Compute Cosine Similarity using Broadcast variables """
|
||||
amazonRec = record[0][0]
|
||||
googleRec = record[0][1]
|
||||
tokens = record[1]
|
||||
s = sum([(amazonWeightsBroadcast.value[amazonRec].get(token, 0) * googleWeightsBroadcast.value[googleRec].get(token, 0))
|
||||
for token in tokens])
|
||||
value = s / (amazonNormsBroadcast.value[amazonRec] * googleNormsBroadcast.value[googleRec])
|
||||
key = (amazonRec, googleRec)
|
||||
return (key, value)
|
||||
|
||||
# 使用 .get() 方法来安全地访问字典中的元素,避免 KeyError
|
||||
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(token, 0)
|
||||
for token in tokens if token in amazon_weights_broadcast.value[amazon_id] and token in google_weights_broadcast.value[google_url]])
|
||||
similaritiesFullRDD = commonTokens.map(lambda x: fastCosineSimilarity(x)).cache()
|
||||
|
||||
# 使用广播变量计算余弦相似度
|
||||
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
|
||||
|
||||
return ((amazon_id, google_url), value)
|
||||
|
||||
|
||||
# 计算相似度
|
||||
similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache()
|
||||
print(similaritiesFullRDD.count())
|
||||
|
||||
# 查看结果
|
||||
print("Number of similarity records: {}".format(similarities_full_rdd.count()))
|
||||
print("Number of similarity records: {}".format(similaritiesFullRDD.count()))
|
||||
|
||||
# 计算并测试相似度
|
||||
similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][
|
||||
|
Loading…
Reference in New Issue
Block a user