feat(5-1.py): 实现可扩展实体匹配算法

- 创建 SparkContext 和 SQLContext
- 读取和解析 Amazon 和 Google 数据集
- 实现数据分词、TF-IDF 计算、余弦相似度计算等功能- 创建和使用广播变量提高计算效率
- 优化实体匹配算法以处理大规模数据集
This commit is contained in:
fly6516 2025-04-20 02:47:48 +08:00
parent 33687f9fcd
commit 9921a965ae

132
5-1.py Normal file
View File

@ -0,0 +1,132 @@
import re
import math
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark import Broadcast
# 创建 SparkContext 和 SQLContext
sc = SparkContext(appName="ScalableER")
sqlContext = SQLContext(sc)
# 数据文件路径
amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv"
google_path = "hdfs://master:9000/user/root/Google_small.csv"
def tokenize(text):
""" 分词化:将文本转成小写并提取字母数字组合的词 """
return re.findall(r'\w+', text.lower())
def parse_data_file(line):
""" 解析数据文件的每一行 """
line = line.strip()
if not line:
return None
parts = line.split(',')
if len(parts) < 5:
return None
doc_id = parts[0].strip()
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
return (doc_id, text)
# 读取和解析数据
def load_data(path):
""" 读取并解析数据文件 """
raw_data = sc.textFile(path).map(parse_data_file).filter(lambda x: x is not None)
return raw_data
amazon = load_data(amazon_path)
google = load_data(google_path)
# 对数据进行分词化
amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1])))
google_rec_to_token = google.map(lambda x: (x[0], tokenize(x[1])))
# 合并 Amazon 和 Google 数据集
full_corpus_rdd = amazon_rec_to_token.union(google_rec_to_token)
# 计算 IDF
def idfs(corpus):
""" 计算逆文档频率 IDF """
N = corpus.count() # 文档总数
term_doc_pairs = corpus.flatMap(lambda x: [(term, x[0]) for term in set(x[1])])
df_rdd = term_doc_pairs.distinct().map(lambda x: (x[0], 1)).reduceByKey(lambda a, b: a + b)
idf_rdd = df_rdd.map(lambda x: (x[0], float(N) / float(x[1])))
return idf_rdd
# 计算完整数据集的 IDF
idfs_full = idfs(full_corpus_rdd)
# 创建广播变量
idfs_full_weights = idfs_full.collectAsMap()
idfs_full_broadcast = sc.broadcast(idfs_full_weights)
# 计算 TF-IDF
def tf(tokens):
""" 计算词频 TF """
total = len(tokens)
counts = {}
for token in tokens:
counts[token] = counts.get(token, 0) + 1
return {k: float(v) / total for k, v in counts.items()}
def tfidf(tokens, idfs):
""" 计算 TF-IDF """
tfs = tf(tokens)
return {k: v * idfs.get(k, 0) for k, v in tfs.items()}
# 计算 Amazon 和 Google 的 TF-IDF
amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value)))
# 计算权重范数
def norm(weights):
""" 计算向量的范数 """
return math.sqrt(sum([w * w for w in weights.values()]))
# 计算 Amazon 和 Google 的权重范数
amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1])))
google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1])))
# 创建广播变量
amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap())
google_norms_broadcast = sc.broadcast(google_norms.collectAsMap())
# 创建反向索引
def invert(record):
""" 反转 (ID, tokens) 到 (token, ID) """
id = record[0]
weights = record[1]
return [(token, id) for token in weights]
# 创建反向索引
amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache()
google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache()
# 计算共有的 token
common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map(lambda x: (x[0], list(x[1]))).cache()
# 计算余弦相似度
def fast_cosine_similarity(record):
""" 计算余弦相似度 """
amazon_id = record[0][0]
google_url = record[0][1]
tokens = record[1]
s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(token, 0) for token in tokens])
value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url])
return ((amazon_id, google_url), value)
# 计算相似度
similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache()
# 查看结果
print(f"Number of similarity records: {similarities_full_rdd.count()}")
# 计算并测试相似度
similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][1] == 'http://www.google.com/base/feeds/snippets/13823221823254120257').collect()
print(len(similarity_test))
# 测试
assert len(similarity_test) == 1, f"incorrect len(similarity_test)"
assert similarities_full_rdd.count() == 2441088, f"incorrect similarities_full_rdd.count()"
sc.stop()