refactor(5-1):重构代码以提高可读性和性能
- 移除了未使用的 data_source 参数 - 优化了 parse_data_file 和 load_data 函数- 使用 .get() 方法安全访问字典元素 - 改进了 cosine_sim 函数,使用广播变量计算余弦相似度
This commit is contained in:
parent
dc883eaf72
commit
38917b896f
29
5-1.py
29
5-1.py
@ -18,7 +18,7 @@ def tokenize(text):
|
|||||||
return re.findall(r'\w+', text.lower())
|
return re.findall(r'\w+', text.lower())
|
||||||
|
|
||||||
|
|
||||||
def parse_data_file(line, data_source='amazon'):
|
def parse_data_file(line):
|
||||||
""" 解析数据文件的每一行 """
|
""" 解析数据文件的每一行 """
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
if not line:
|
if not line:
|
||||||
@ -27,27 +27,19 @@ def parse_data_file(line, data_source='amazon'):
|
|||||||
if len(parts) < 5:
|
if len(parts) < 5:
|
||||||
return None
|
return None
|
||||||
doc_id = parts[0].strip()
|
doc_id = parts[0].strip()
|
||||||
|
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
||||||
# 对不同数据集进行处理
|
|
||||||
if data_source == 'amazon':
|
|
||||||
# Amazon 文件格式: id, title, description, manufacturer, price
|
|
||||||
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
|
||||||
else:
|
|
||||||
# Google 文件格式: id, name, description, manufacturer, price
|
|
||||||
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
|
|
||||||
|
|
||||||
return (doc_id, text)
|
return (doc_id, text)
|
||||||
|
|
||||||
|
|
||||||
# 读取和解析数据
|
# 读取和解析数据
|
||||||
def load_data(path, data_source='amazon'):
|
def load_data(path):
|
||||||
""" 读取并解析数据文件 """
|
""" 读取并解析数据文件 """
|
||||||
raw_data = sc.textFile(path).map(lambda line: parse_data_file(line, data_source)).filter(lambda x: x is not None)
|
raw_data = sc.textFile(path).map(parse_data_file).filter(lambda x: x is not None)
|
||||||
return raw_data
|
return raw_data
|
||||||
|
|
||||||
|
|
||||||
amazon = load_data(amazon_path, data_source='amazon')
|
amazon = load_data(amazon_path)
|
||||||
google = load_data(google_path, data_source='google')
|
google = load_data(google_path)
|
||||||
|
|
||||||
# 对数据进行分词化
|
# 对数据进行分词化
|
||||||
amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1])))
|
amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1])))
|
||||||
@ -99,7 +91,6 @@ google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_f
|
|||||||
amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap())
|
amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap())
|
||||||
google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap())
|
google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap())
|
||||||
|
|
||||||
|
|
||||||
# 计算权重范数
|
# 计算权重范数
|
||||||
def norm(weights):
|
def norm(weights):
|
||||||
""" 计算向量的范数 """
|
""" 计算向量的范数 """
|
||||||
@ -138,9 +129,15 @@ def fast_cosine_similarity(record):
|
|||||||
amazon_id = record[0][0]
|
amazon_id = record[0][0]
|
||||||
google_url = record[0][1]
|
google_url = record[0][1]
|
||||||
tokens = record[1]
|
tokens = record[1]
|
||||||
|
|
||||||
|
# 使用 .get() 方法来安全地访问字典中的元素,避免 KeyError
|
||||||
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(
|
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(
|
||||||
token, 0) for token in tokens])
|
token, 0)
|
||||||
|
for token in tokens])
|
||||||
|
|
||||||
|
# 使用广播变量计算余弦相似度
|
||||||
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
|
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
|
||||||
|
|
||||||
return ((amazon_id, google_url), value)
|
return ((amazon_id, google_url), value)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user