BD-exp-9/3-1.py
fly6516 1043551309 feat(3-1.py): 实现 TF-IDF 权重计算并优化代码结构
- 新增辅助函数:tokenize、tf、idfs、tfidf
-优化数据加载与预处理逻辑- 实现全局 IDF 计算并绘制直方图
-完成全局 TF-IDF 计算并保存结果到 HDFS- 增加针对特定 Amazon 记录的 TF-IDF计算示例
- 优化代码注释和结构,提高可读性
2025-04-16 10:08:50 +08:00

194 lines
7.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
实验步骤3利用 TF-IDF 加权提升文本相似性计算准确性
功能说明:
1. 从 HDFS 上读取 Amazon_small.csv 和 Google_small.csv提取文档ID及文本标题、描述、制造商的组合
2. 对文本进行分词,构建语料库,格式为:(doc_id, [token列表])
3. 实现 tf(tokens) 函数,计算 token 的相对词频;
4. 实现 idfs(corpus) 函数,计算每个唯一 token 的逆文档频率使用公式IDF(t) = N / n(t)
5. 实现 tfidf(tokens, idfs) 函数,结合 TF 与 IDF 权重计算每个 token 的 TF-IDF
6. 利用 RDD 操作计算全局 TF-IDF并保存结果
7. 对 Amazon 记录 "b'000hkgj8k" 调用 tfidf() 测试,打印该记录中各 token 的 TF-IDF 权重。
注意:
(1) 代码使用 Python 3.5 语法;
(2) 请保证 HDFS 上已上传 Amazon_small.csv 和 Google_small.csv 文件;
(3) 如需停用词过滤,可在 tokenize() 函数中扩展,本示例未特别去除停用词。
"""
from pyspark import SparkContext
import re
import matplotlib.pyplot as plt
sc = SparkContext(appName="TFIDF_Analysis")
###############################
# 1. 定义辅助函数
###############################
def tokenize(text):
"""
分词:转换成小写后提取所有字母或数字组成的单词
"""
return re.findall(r'\w+', text.lower())
def tf(tokens):
"""
计算 TF词频
Args:
tokens (list of str): 输入的 token 列表
Returns:
dict: 每个 token 映射到其 TF 值(出现次数/总 token 数)
"""
total = len(tokens)
counts = {}
for token in tokens:
if token in counts:
counts[token] = counts[token] + 1
else:
counts[token] = 1
return {k: float(v) / total for k, v in counts.items()}
def idfs(corpus):
"""
计算语料库中每个唯一 token 的 IDF 权重
Args:
corpus (RDD): 每个元素格式为 (doc_id, [token列表])
Returns:
RDD: (token, IDF值) 的 RDD
"""
N = corpus.count()
# 对每个文档取唯一 token 集合(避免在同一文档中重复计数)
uniqueTokens = corpus.map(lambda x: set(x[1]))
# 对每个文档生成 (token, 1) 对,再求和计算包含该 token 的文档数 n(t)
tokenCountPair = uniqueTokens.flatMap(lambda tokens: [(token, 1) for token in tokens])
tokenDocCounts = tokenCountPair.reduceByKey(lambda a, b: a + b)
# 计算 IDF不取对数IDF(t) = N / n(t)
return tokenDocCounts.map(lambda x: (x[0], float(N) / float(x[1])))
def tfidf(tokens, idfs_dict):
"""
计算 TF-IDF 权重
Args:
tokens (list of str): 输入 token 列表
idfs_dict (dict): token 到其 IDF 权重的字典
Returns:
dict: 每个 token 映射到其 TF-IDF 权重
"""
tfs = tf(tokens)
# 对于在 idfs 字典中存在的 token 计算 TF * IDF
tfIdfDict = {token: tfs[token] * idfs_dict[token] for token in tfs if token in idfs_dict}
return tfIdfDict
###############################
# 2. 数据加载与预处理
###############################
# 修改以下 HDFS 路径,根据你的集群配置,如 "hdfs://master:9000"
amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv"
google_path = "hdfs://master:9000/user/root/Google_small.csv"
def parse_csv_line(line):
"""
解析 CSV 行(假设字段使用 '","' 分隔,且首尾有引号)
返回 (doc_id, text),其中 text 为标题、描述、制造商字段拼接后的字符串
"""
line = line.strip()
if not line:
return None
# 去除首尾引号后按 '","' 拆分
parts = line.strip('"').split('","')
if len(parts) < 4:
return None
doc_id = parts[0].strip()
text = "{} {} {}".format(parts[1].strip(), parts[2].strip(), parts[3].strip())
return (doc_id, text)
# 读取数据,并过滤掉表头(假设表头中 doc_id 为 "id"
amazon_rdd = sc.textFile(amazon_path).map(parse_csv_line) \
.filter(lambda x: x is not None and x[0].lower() != "id")
google_rdd = sc.textFile(google_path).map(parse_csv_line) \
.filter(lambda x: x is not None and x[0].lower() != "id")
# 转换为 (doc_id, [token列表])
amazonRecToToken = amazon_rdd.map(lambda x: (x[0], tokenize(x[1])))
googleRecToToken = google_rdd.map(lambda x: (x[0], tokenize(x[1])))
# 创建语料库corpusRDD
corpusRDD = amazonRecToToken.union(googleRecToToken)
print("Corpus document count: {}".format(corpusRDD.count()))
# 注:测试要求 corpusRDD.count() 为 400这里根据实际数据
###############################
# 3. 全局 IDF 计算
###############################
idfsRDD = idfs(corpusRDD)
uniqueTokenCount = idfsRDD.count()
print("There are {} unique tokens in the small datasets.".format(uniqueTokenCount))
# 打印 IDF 值最小的 11 个 token
smallIDFTokens = idfsRDD.takeOrdered(11, key=lambda s: s[1])
print("Smallest 11 IDF tokens:")
for token, idf_value in smallIDFTokens:
print("{}: {}".format(token, idf_value))
# 绘制 IDF 直方图
small_idf_values = idfsRDD.map(lambda s: s[1]).collect()
fig = plt.figure(figsize=(8, 3))
plt.hist(small_idf_values, 50, log=True)
plt.title("IDF Histogram")
plt.xlabel("IDF value")
plt.ylabel("Frequency (log scale)")
plt.show()
###############################
# 4. 全局 TF-IDF 计算RDD实现
###############################
# 计算 TF 部分:为每个文档生成 ((doc_id, term), tf) 对
doc_term_pairs = corpusRDD.flatMap(lambda x: [((x[0], term), 1) for term in x[1]])
doc_term_counts = doc_term_pairs.reduceByKey(lambda a, b: a + b)
doc_lengths = corpusRDD.map(lambda x: (x[0], len(x[1])))
doc_term_counts_mapped = doc_term_counts.map(lambda x: (x[0][0], (x[0][1], x[1])))
tf_joined = doc_term_counts_mapped.join(doc_lengths)
tf_rdd = tf_joined.map(lambda x: ((x[0], x[1][0][0]), float(x[1][0][1]) / float(x[1][1])))
# 将 tf_rdd 转换为以 term 为 key便于 join
tf_rdd_by_term = tf_rdd.map(lambda x: (x[0][1], (x[0][0], x[1])))
# idfsRDD 格式为 (term, idf)
tfidf_joined = tf_rdd_by_term.join(idfsRDD)
tfidf_rdd = tfidf_joined.map(lambda x: ((x[1][0][0], x[0]), x[1][0][1] * x[1][1]))
# 将全局 TF-IDF 结果保存到 HDFS输出目录若已存在请删除或修改
output_path = "hdfs://master:9000/user/root/output/tfidf"
tfidf_rdd.saveAsTextFile(output_path)
print("Global TF-IDF result saved to: {}".format(output_path))
for item in tfidf_rdd.take(5):
print(item)
###############################
# 5. 针对 Amazon 记录 "b'000hkgj8k" 的 TF-IDF 计算(使用封装函数)
###############################
# 提取记录 "b'000hkgj8k" 的 token 列表
# 注意:实际记录的 doc_id 是否包含引号需与数据一致,此处示例与测试文本一致
rec_tokens = amazonRecToToken.filter(lambda x: x[0] == "b'000hkgj8k").collect()
if rec_tokens:
rec_tokens = rec_tokens[0][1]
# 将全局 idfsRDD 转换成 Python 字典
idfsWeights = idfsRDD.collectAsMap()
rec_tf_idf = tfidf(rec_tokens, idfsWeights)
print('Amazon record "b000hkgj8k" has tokens and TF-IDF weights:')
for token, weight in rec_tf_idf.items():
print("{}: {}".format(token, weight))
else:
print('Record "b000hkgj8k" not found in Amazon_small.csv')
sc.stop()