diff --git a/5-1.py b/5-1.py index 3a0e512..3b50023 100644 --- a/5-1.py +++ b/5-1.py @@ -12,11 +12,13 @@ sqlContext = SQLContext(sc) amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv" google_path = "hdfs://master:9000/user/root/Google_small.csv" + def tokenize(text): """ 分词化:将文本转成小写并提取字母数字组合的词 """ return re.findall(r'\w+', text.lower()) -def parse_data_file(line): + +def parse_data_file(line, data_source='amazon'): """ 解析数据文件的每一行 """ line = line.strip() if not line: @@ -25,17 +27,27 @@ def parse_data_file(line): if len(parts) < 5: return None doc_id = parts[0].strip() - text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) + + # 对不同数据集进行处理 + if data_source == 'amazon': + # Amazon 文件格式: id, title, description, manufacturer, price + text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) + else: + # Google 文件格式: id, name, description, manufacturer, price + text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip()) + return (doc_id, text) + # 读取和解析数据 -def load_data(path): +def load_data(path, data_source='amazon'): """ 读取并解析数据文件 """ - raw_data = sc.textFile(path).map(parse_data_file).filter(lambda x: x is not None) + raw_data = sc.textFile(path).map(lambda line: parse_data_file(line, data_source)).filter(lambda x: x is not None) return raw_data -amazon = load_data(amazon_path) -google = load_data(google_path) + +amazon = load_data(amazon_path, data_source='amazon') +google = load_data(google_path, data_source='google') # 对数据进行分词化 amazon_rec_to_token = amazon.map(lambda x: (x[0], tokenize(x[1]))) @@ -44,6 +56,7 @@ google_rec_to_token = google.map(lambda x: (x[0], tokenize(x[1]))) # 合并 Amazon 和 Google 数据集 full_corpus_rdd = amazon_rec_to_token.union(google_rec_to_token) + # 计算 IDF def idfs(corpus): """ 计算逆文档频率 IDF """ @@ -53,6 +66,7 @@ def idfs(corpus): idf_rdd = df_rdd.map(lambda x: (x[0], float(N) / float(x[1]))) return idf_rdd + # 计算完整数据集的 IDF idfs_full = idfs(full_corpus_rdd) @@ -60,6 +74,7 @@ idfs_full = idfs(full_corpus_rdd) idfs_full_weights = idfs_full.collectAsMap() idfs_full_broadcast = sc.broadcast(idfs_full_weights) + # 计算 TF-IDF def tf(tokens): """ 计算词频 TF """ @@ -69,20 +84,24 @@ def tf(tokens): counts[token] = counts.get(token, 0) + 1 return {k: float(v) / total for k, v in counts.items()} + def tfidf(tokens, idfs): """ 计算 TF-IDF """ tfs = tf(tokens) return {k: v * idfs.get(k, 0) for k, v in tfs.items()} + # 计算 Amazon 和 Google 的 TF-IDF amazon_weights_rdd = amazon_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value))) google_weights_rdd = google_rec_to_token.map(lambda x: (x[0], tfidf(x[1], idfs_full_broadcast.value))) + # 计算权重范数 def norm(weights): """ 计算向量的范数 """ return math.sqrt(sum([w * w for w in weights.values()])) + # 计算 Amazon 和 Google 的权重范数 amazon_norms = amazon_weights_rdd.map(lambda x: (x[0], norm(x[1]))) google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1]))) @@ -91,6 +110,7 @@ google_norms = google_weights_rdd.map(lambda x: (x[0], norm(x[1]))) amazon_norms_broadcast = sc.broadcast(amazon_norms.collectAsMap()) google_norms_broadcast = sc.broadcast(google_norms.collectAsMap()) + # 创建反向索引 def invert(record): """ 反转 (ID, tokens) 到 (token, ID) """ @@ -98,12 +118,15 @@ def invert(record): weights = record[1] return [(token, id) for token in weights] + # 创建反向索引 amazon_inv_pairs_rdd = amazon_weights_rdd.flatMap(lambda x: invert(x)).cache() google_inv_pairs_rdd = google_weights_rdd.flatMap(lambda x: invert(x)).cache() # 计算共有的 token -common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map(lambda x: (x[0], list(x[1]))).cache() +common_tokens = amazon_inv_pairs_rdd.join(google_inv_pairs_rdd).map(lambda x: (x[0], x[1])).groupByKey().map( + lambda x: (x[0], list(x[1]))).cache() + # 计算余弦相似度 def fast_cosine_similarity(record): @@ -111,10 +134,12 @@ def fast_cosine_similarity(record): amazon_id = record[0][0] google_url = record[0][1] tokens = record[1] - s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get(token, 0) for token in tokens]) + s = sum([amazon_weights_broadcast.value[amazon_id].get(token, 0) * google_weights_broadcast.value[google_url].get( + token, 0) for token in tokens]) value = s / (amazon_norms_broadcast.value[amazon_id] * google_norms_broadcast.value[google_url]) return ((amazon_id, google_url), value) + # 计算相似度 similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache() @@ -122,7 +147,8 @@ similarities_full_rdd = common_tokens.map(fast_cosine_similarity).cache() print("Number of similarity records: {}".format(similarities_full_rdd.count())) # 计算并测试相似度 -similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][1] == 'http://www.google.com/base/feeds/snippets/13823221823254120257').collect() +similarity_test = similarities_full_rdd.filter(lambda x: x[0][0] == 'b00005lzly' and x[0][ + 1] == 'http://www.google.com/base/feeds/snippets/13823221823254120257').collect() print(len(similarity_test)) # 测试