import re import math from pyspark import SparkContext from pyspark.sql import SQLContext from pyspark import Broadcast # 创建 SparkContext 和 SQLContext sc = SparkContext(appName="ScalableER") sqlContext = SQLContext(sc) # 数据文件路径 amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv" google_path = "hdfs://master:9000/user/root/Google_small.csv" def tokenize(text): """ 分词化:将文本转成小写并提取字母数字组合的词 """ return re.findall(r'\w+', text.lower()) def parse_data_file(line, data_source='amazon'): """ 解析数据文件的每一行 """ line = line.strip() if not line: return None parts = line.split(',') if len(parts) < 5: return None doc_id = parts[0].strip() # 对不同数据集进行处理 if data_source == 'amazon': # Amazon 文件格式: id, title, description, manufacturer, price text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) else: # Google 文件格式: id, name, description, manufacturer, price text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) return (doc_id, text) # 读取和解析数据 def load_data(path, data_source='amazon'): """ 读取并解析数据文件 """ raw_data = sc.textFile(path).map(lambda line: parse_data_file(line, data_source)).filter(lambda x: x is not None) return raw_data amazon = load_data(amazon_path, data_source='amazon') google = load_data(google_path, data_source='google') # 对数据进行分词化 amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1]))) google_rec_to_token = google.map(lambda x: (x[0], tokenize(x[1]))) # 合并 Amazon 和 Google 数据集 full_corpus_rdd = amazon_rec_to_token.union(google_rec_to_token) # 计算 IDF def idfs(corpus): """ 计算逆文档频率 IDF """ N = corpus.count() # 文档总数 term_doc_pairs = corpus.flatMap(lambda x: [(term, x[0]) for term in set(x[1])]) df_rdd = term_doc_pairs.distinct().map(lambda x: (x[0], 1)).reduceByKey(lambda a, b: a + b) idf_rdd = df_rdd.map(lambda x: (x[0], float(N) / float(x[1]))) return idf_rdd # 计算完整数据集的 IDF idfs_full = idfs(full_corpus_rdd) # 创建广播变量 idfs_full_weights = idfs_full.collectAsMap() idfs_full_broadcast = sc.broadcast(idfs_full_weights) # 计算 TF-IDF def tf(tokens): """ 计算词频 TF """ total = len(tokens) counts = {} for token in tokens: counts[token] = counts.get(token, 0) + 1 return {k: float(v) / total for k, v in counts.items()} def tfidf(tokens, idfs): """ 计算 TF-IDF """ tfs = tf(tokens) return {k: v * idfs.get(k, 0) for k, v in tfs.items()} # 计算 Amazon 和 Google 的 TF-IDF amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value))) google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value))) # 创建广播变量 amazon_weights_broadcast = sc.broadcast(amazon_weights_rdd.collectAsMap()) google_weights_broadcast = sc.broadcast(google_weights_rdd.collectAsMap()) # 计算权重范数 def norm(weights): """ 计算向量的范数 """ return math.sqrt(sum([w * w for w in weights.values()])) # 计算 Amazon 和 Google 的权重范数 amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1]))) google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1]))) # 创建广播变量 amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap()) google_norms_broadcast = sc.broadcast(google_norms.collectAsMap()) # 创建反向索引 def invert(record): """ 反转 (ID, tokens) 到 (token, ID) """ id = record[0] weights = record[1] return [(token, id) for token in weights] # 创建反向索引 amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache() google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache() # 计算共有的 token common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map( lambda x: (x[0], list(x[1]))).cache() # 计算余弦相似度 def fast_cosine_similarity(record): """ 计算余弦相似度 """ amazon_id = record[0][0] google_url = record[0][1] tokens = record[1] s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get( token, 0) for token in tokens]) value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url]) return ((amazon_id, google_url), value) # 计算相似度 similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache() # 查看结果 print("Number of similarity records: {}".format(similarities_full_rdd.count())) # 计算并测试相似度 similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][ 1] == 'http://www.google.com/base/feeds/snippets/13823221823254120257').collect() print(len(similarity_test)) # 测试 assert len(similarity_test) == 1, "incorrect len(similarity_test)" assert similarities_full_rdd.count() == 2441088, "incorrect similarities_full_rdd.count()" sc.stop()