refactor(5-1):重构代码以提高可读性和性能
- 移除了未使用的 data_source 参数 - 优化了 parse_data_file 和 load_data 函数- 使用 .get() 方法安全访问字典元素 - 改进了 cosine_sim 函数,使用广播变量计算余弦相似度
This commit is contained in:
parent
dc883eaf72
commit
38917b896f
29
5-1.py
29
5-1.py
@ -18,7 +18,7 @@ def tokenize(text):
|
||||
return re.findall(r'\w+', text.lower())
|
||||
|
||||
|
||||
def parse_data_file(line, data_source='amazon'):
|
||||
def parse_data_file(line):
|
||||
""" 解析数据文件的每一行 """
|
||||
line = line.strip()
|
||||
if not line:
|
||||
@ -27,27 +27,19 @@ def parse_data_file(line, data_source='amazon'):
|
||||
if len(parts) < 5:
|
||||
return None
|
||||
doc_id = parts[0].strip()
|
||||
|
||||
# 对不同数据集进行处理
|
||||
if data_source == 'amazon':
|
||||
# Amazon 文件格式: id, title, description, manufacturer, price
|
||||
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
||||
else:
|
||||
# Google 文件格式: id, name, description, manufacturer, price
|
||||
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
||||
|
||||
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
||||
return (doc_id, text)
|
||||
|
||||
|
||||
# 读取和解析数据
|
||||
def load_data(path, data_source='amazon'):
|
||||
def load_data(path):
|
||||
""" 读取并解析数据文件 """
|
||||
raw_data = sc.textFile(path).map(lambda line: parse_data_file(line, data_source)).filter(lambda x: x is not None)
|
||||
raw_data = sc.textFile(path).map(parse_data_file).filter(lambda x: x is not None)
|
||||
return raw_data
|
||||
|
||||
|
||||
amazon = load_data(amazon_path, data_source='amazon')
|
||||
google = load_data(google_path, data_source='google')
|
||||
amazon = load_data(amazon_path)
|
||||
google = load_data(google_path)
|
||||
|
||||
# 对数据进行分词化
|
||||
amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1])))
|
||||
@ -99,7 +91,6 @@ google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_f
|
||||
amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap())
|
||||
google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap())
|
||||
|
||||
|
||||
# 计算权重范数
|
||||
def norm(weights):
|
||||
""" 计算向量的范数 """
|
||||
@ -138,9 +129,15 @@ def fast_cosine_similarity(record):
|
||||
amazon_id = record[0][0]
|
||||
google_url = record[0][1]
|
||||
tokens = record[1]
|
||||
|
||||
# 使用 .get() 方法来安全地访问字典中的元素,避免 KeyError
|
||||
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(
|
||||
token, 0) for token in tokens])
|
||||
token, 0)
|
||||
for token in tokens])
|
||||
|
||||
# 使用广播变量计算余弦相似度
|
||||
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
|
||||
|
||||
return ((amazon_id, google_url), value)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user