From 9921a965ae47862ebbf5dff3008b617e5ee256cb Mon Sep 17 00:00:00 2001 From: fly6516 Date: Sun, 20 Apr 2025 02:47:48 +0800 Subject: [PATCH] =?UTF-8?q?feat(5-1.py):=20=E5=AE=9E=E7=8E=B0=E5=8F=AF?= =?UTF-8?q?=E6=89=A9=E5=B1=95=E5=AE=9E=E4=BD=93=E5=8C=B9=E9=85=8D=E7=AE=97?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 创建 SparkContext 和 SQLContext - 读取和解析 Amazon 和 Google 数据集 - 实现数据分词、TF-IDF 计算、余弦相似度计算等功能- 创建和使用广播变量提高计算效率 - 优化实体匹配算法以处理大规模数据集 --- 5-1.py | 132 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 5-1.py diff --git a/5-1.py b/5-1.py new file mode 100644 index 0000000..c87a46b --- /dev/null +++ b/5-1.py @@ -0,0 +1,132 @@ +import re +import math +from pyspark import SparkContext +from pyspark.sql import SQLContext +from pyspark import Broadcast + +# 创建 SparkContext 和 SQLContext +sc = SparkContext(appName="ScalableER") +sqlContext = SQLContext(sc) + +# 数据文件路径 +amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv" +google_path = "hdfs://master:9000/user/root/Google_small.csv" + +def tokenize(text): + """ 分词化:将文本转成小写并提取字母数字组合的词 """ + return re.findall(r'\w+', text.lower()) + +def parse_data_file(line): + """ 解析数据文件的每一行 """ + line = line.strip() + if not line: + return None + parts = line.split(',') + if len(parts) < 5: + return None + doc_id = parts[0].strip() + text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) + return (doc_id, text) + +# 读取和解析数据 +def load_data(path): + """ 读取并解析数据文件 """ + raw_data = sc.textFile(path).map(parse_data_file).filter(lambda x: x is not None) + return raw_data + +amazon = load_data(amazon_path) +google = load_data(google_path) + +# 对数据进行分词化 +amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1]))) +google_rec_to_token = google.map(lambda x: (x[0], tokenize(x[1]))) + +# 合并 Amazon 和 Google 数据集 +full_corpus_rdd = amazon_rec_to_token.union(google_rec_to_token) + +# 计算 IDF +def idfs(corpus): + """ 计算逆文档频率 IDF """ + N = corpus.count() # 文档总数 + term_doc_pairs = corpus.flatMap(lambda x: [(term, x[0]) for term in set(x[1])]) + df_rdd = term_doc_pairs.distinct().map(lambda x: (x[0], 1)).reduceByKey(lambda a, b: a + b) + idf_rdd = df_rdd.map(lambda x: (x[0], float(N) / float(x[1]))) + return idf_rdd + +# 计算完整数据集的 IDF +idfs_full = idfs(full_corpus_rdd) + +# 创建广播变量 +idfs_full_weights = idfs_full.collectAsMap() +idfs_full_broadcast = sc.broadcast(idfs_full_weights) + +# 计算 TF-IDF +def tf(tokens): + """ 计算词频 TF """ + total = len(tokens) + counts = {} + for token in tokens: + counts[token] = counts.get(token, 0) + 1 + return {k: float(v) / total for k, v in counts.items()} + +def tfidf(tokens, idfs): + """ 计算 TF-IDF """ + tfs = tf(tokens) + return {k: v * idfs.get(k, 0) for k, v in tfs.items()} + +# 计算 Amazon 和 Google 的 TF-IDF +amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value))) +google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value))) + +# 计算权重范数 +def norm(weights): + """ 计算向量的范数 """ + return math.sqrt(sum([w * w for w in weights.values()])) + +# 计算 Amazon 和 Google 的权重范数 +amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1]))) +google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1]))) + +# 创建广播变量 +amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap()) +google_norms_broadcast = sc.broadcast(google_norms.collectAsMap()) + +# 创建反向索引 +def invert(record): + """ 反转 (ID, tokens) 到 (token, ID) """ + id = record[0] + weights = record[1] + return [(token, id) for token in weights] + +# 创建反向索引 +amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache() +google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache() + +# 计算共有的 token +common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map(lambda x: (x[0], list(x[1]))).cache() + +# 计算余弦相似度 +def fast_cosine_similarity(record): + """ 计算余弦相似度 """ + amazon_id = record[0][0] + google_url = record[0][1] + tokens = record[1] + s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(token, 0) for token in tokens]) + value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url]) + return ((amazon_id, google_url), value) + +# 计算相似度 +similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache() + +# 查看结果 +print(f"Number of similarity records: {similarities_full_rdd.count()}") + +# 计算并测试相似度 +similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][1] == 'http://www.google.com/base/feeds/snippets/13823221823254120257').collect() +print(len(similarity_test)) + +# 测试 +assert len(similarity_test) == 1, f"incorrect len(similarity_test)" +assert similarities_full_rdd.count() == 2441088, f"incorrect similarities_full_rdd.count()" + +sc.stop()