diff --git a/3-1.py b/3-1.py index 17856a7..bf31c98 100644 --- a/3-1.py +++ b/3-1.py @@ -1,39 +1,66 @@ -# 3-1.py +# coding=utf-8 from pyspark import SparkContext -from collections import defaultdict +import csv +import re -sc = SparkContext() +# Python 3.5 没有 f-string,使用 format +def tokenize(text): + # 分词并保留英文、数字 + return re.findall(r'\w+', text.lower()) -corpus = google_tokens.union(amazon_tokens) -N = corpus.count() +def load_stopwords(sc): + try: + return set(sc.textFile("hdfs:///user/root/stopwords.txt").collect()) + except: + # fallback to local + with open("stopwords.txt", "r") as f: + return set([line.strip() for line in f]) -def compute_tf(record): - doc_id, tokens = record - tf = defaultdict(float) - for token in tokens: - tf[token] += 1.0 - total = float(len(tokens)) - for key in tf: - tf[key] = tf[key] / total - return (doc_id, tf) +def parse_csv_line(line): + # 使用 csv.reader 兼容逗号分隔含引号的数据 + reader = csv.reader([line]) + return next(reader) -tf_rdd = corpus.map(compute_tf) +def extract_info(line, source): + try: + fields = parse_csv_line(line) + if source == "google": + # Google: id, name, description, manufacturer... + pid = fields[0].strip() + text = "{} {} {}".format(fields[1], fields[2], fields[3]) + else: + # Amazon: id, title, description, manufacturer... + pid = fields[0].strip() + text = "{} {} {}".format(fields[1], fields[2], fields[3]) + return (pid, text) + except: + return (None, None) -token_docs = corpus.flatMap(lambda x: [(token, x[0]) for token in set(x[1])]) -doc_freq = token_docs.groupByKey().mapValues(lambda x: len(set(x))) -idf_dict = doc_freq.map(lambda x: (x[0], float(N) / x[1])).collectAsMap() -idf_bcast = sc.broadcast(idf_dict) +if __name__ == "__main__": + sc = SparkContext(appName="InvertedIndex") + stopwords = load_stopwords(sc) -def compute_tfidf(record): - doc_id, tf_map = record - idf_map = idf_bcast.value - tfidf = {} - for token in tf_map: - tfidf[token] = tf_map[token] * idf_map.get(token, 0.0) - return (doc_id, tfidf) + # 加载数据 + google = sc.textFile("hdfs://master:9000/user/root/GoogleProducts.csv") + amazon = sc.textFile("hdfs://master:9000/user/root/AmazonProducts.csv") -tfidf_rdd = tf_rdd.map(compute_tfidf) + # 提取内容 + google_rdd = google.map(lambda line: extract_info(line, "google")) \ + .filter(lambda x: x[0] is not None) + amazon_rdd = amazon.map(lambda line: extract_info(line, "amazon")) \ + .filter(lambda x: x[0] is not None) -print("TF-IDF sample: ", tfidf_rdd.take(1)) + # 合并两数据集 + all_data = google_rdd.union(amazon_rdd) + + # 构建倒排索引 + inverted_index = all_data.flatMap(lambda x: [((word, x[0])) for word in tokenize(x[1]) if word not in stopwords]) \ + .groupByKey() \ + .mapValues(lambda ids: list(set(ids))) + + # 输出(可保存到 HDFS) + inverted_index.saveAsTextFile("hdfs:///user/root/output/inverted_index") + + sc.stop()