From 38917b896f7feb41dea506a87c9d1a1e346e49e0 Mon Sep 17 00:00:00 2001 From: fly6516 Date: Sun, 20 Apr 2025 03:00:37 +0800 Subject: [PATCH] =?UTF-8?q?refactor(5-1):=E9=87=8D=E6=9E=84=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E4=BB=A5=E6=8F=90=E9=AB=98=E5=8F=AF=E8=AF=BB=E6=80=A7?= =?UTF-8?q?=E5=92=8C=E6=80=A7=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除了未使用的 data_source 参数 - 优化了 parse_data_file 和 load_data 函数- 使用 .get() 方法安全访问字典元素 - 改进了 cosine_sim 函数,使用广播变量计算余弦相似度 --- 5-1.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/5-1.py b/5-1.py index ac727c0..5b90235 100644 --- a/5-1.py +++ b/5-1.py @@ -18,7 +18,7 @@ def tokenize(text): return re.findall(r'\w+', text.lower()) -def parse_data_file(line, data_source='amazon'): +def parse_data_file(line): """ 解析数据文件的每一行 """ line = line.strip() if not line: @@ -27,27 +27,19 @@ def parse_data_file(line, data_source='amazon'): if len(parts) < 5: return None doc_id = parts[0].strip() - - # 对不同数据集进行处理 - if data_source == 'amazon': - # Amazon 文件格式: id, title, description, manufacturer, price - text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) - else: - # Google 文件格式: id, name, description, manufacturer, price - text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) - + text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) return (doc_id, text) # 读取和解析数据 -def load_data(path, data_source='amazon'): +def load_data(path): """ 读取并解析数据文件 """ - raw_data = sc.textFile(path).map(lambda line: parse_data_file(line, data_source)).filter(lambda x: x is not None) + raw_data = sc.textFile(path).map(parse_data_file).filter(lambda x: x is not None) return raw_data -amazon = load_data(amazon_path, data_source='amazon') -google = load_data(google_path, data_source='google') +amazon = load_data(amazon_path) +google = load_data(google_path) # 对数据进行分词化 amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1]))) @@ -99,7 +91,6 @@ google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_f amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap()) google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap()) - # 计算权重范数 def norm(weights): """ 计算向量的范数 """ @@ -138,9 +129,15 @@ def fast_cosine_similarity(record): amazon_id = record[0][0] google_url = record[0][1] tokens = record[1] + + # 使用 .get() 方法来安全地访问字典中的元素,避免 KeyError s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get( - token, 0) for token in tokens]) + token, 0) + for token in tokens]) + + # 使用广播变量计算余弦相似度 value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url]) + return ((amazon_id, google_url), value)