refactor(5-1):重构代码以提高可读性和性能

- 移除了未使用的 data_source 参数
- 优化了 parse_data_file 和 load_data 函数- 使用 .get() 方法安全访问字典元素
- 改进了 cosine_sim 函数,使用广播变量计算余弦相似度
This commit is contained in:
fly6516 2025-04-20 03:00:37 +08:00
parent dc883eaf72
commit 38917b896f

29
5-1.py
View File

@ -18,7 +18,7 @@ def tokenize(text):
return re.findall(r'\w+', text.lower())
def parse_data_file(line, data_source='amazon'):
def parse_data_file(line):
""" 解析数据文件的每一行 """
line = line.strip()
if not line:
@ -27,27 +27,19 @@ def parse_data_file(line, data_source='amazon'):
if len(parts) < 5:
return None
doc_id = parts[0].strip()
# 对不同数据集进行处理
if data_source == 'amazon':
# Amazon 文件格式: id, title, description, manufacturer, price
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
else:
# Google 文件格式: id, name, description, manufacturer, price
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
return (doc_id, text)
# 读取和解析数据
def load_data(path, data_source='amazon'):
def load_data(path):
""" 读取并解析数据文件 """
raw_data = sc.textFile(path).map(lambda line: parse_data_file(line, data_source)).filter(lambda x: x is not None)
raw_data = sc.textFile(path).map(parse_data_file).filter(lambda x: x is not None)
return raw_data
amazon = load_data(amazon_path, data_source='amazon')
google = load_data(google_path, data_source='google')
amazon = load_data(amazon_path)
google = load_data(google_path)
# 对数据进行分词化
amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1])))
@ -99,7 +91,6 @@ google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_f
amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap())
google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap())
# 计算权重范数
def norm(weights):
""" 计算向量的范数 """
@ -138,9 +129,15 @@ def fast_cosine_similarity(record):
amazon_id = record[0][0]
google_url = record[0][1]
tokens = record[1]
# 使用 .get() 方法来安全地访问字典中的元素,避免 KeyError
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(
token, 0) for token in tokens])
token, 0)
for token in tokens])
# 使用广播变量计算余弦相似度
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
return ((amazon_id, google_url), value)