From 250e1b99e06321d9b57fc06e1999b8d53baf0d73 Mon Sep 17 00:00:00 2001 From: fly6516 Date: Sun, 20 Apr 2025 02:24:10 +0800 Subject: [PATCH] =?UTF-8?q?feat(similarity):=20=E6=B7=BB=E5=8A=A0=E6=96=87?= =?UTF-8?q?=E6=9C=AC=E7=9B=B8=E4=BC=BC=E5=BA=A6=E8=AE=A1=E7=AE=97=E5=8A=9F?= =?UTF-8?q?=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 实现了计算两个文本之间相似度的完整流程 - 包括 TF-IDF 计算、余弦相似度计算等功能 - 使用 Spark 广播变量优化计算效率 -支持从 HDFS 读取数据进行计算 --- 4-1.py | 186 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 186 insertions(+) create mode 100644 4-1.py diff --git a/4-1.py b/4-1.py new file mode 100644 index 0000000..c3e1221 --- /dev/null +++ b/4-1.py @@ -0,0 +1,186 @@ +import math +import re +from pyspark import SparkContext, Broadcast +from pyspark.sql import SparkSession + +# 1. 计算点积 +def dotprod(a, b): + """ Compute dot product + Args: + a (dict): first dictionary of record to value + b (dict): second dictionary of record to value + Returns: + float: result of the dot product with the two input dictionaries + """ + sum = 0.0 + for k1, v1 in a.items(): + for k2, v2 in b.items(): + if k1 == k2: + sum += v1 * v2 + return sum + +# 2. 计算范数 +def norm(a): + """ Compute square root of the dot product + Args: + a (dict): a dictionary of record to value + Returns: + float: norm of the dictionary + """ + return math.sqrt(dotprod(a, a)) + +# 3. 计算余弦相似度 +def cossim(a, b): + """ Compute cosine similarity + Args: + a (dict): first dictionary of record to value + b (dict): second dictionary of record to value + Returns: + float: cosine similarity value + """ + return dotprod(a, b) / (norm(a) * norm(b)) + +# 4. 计算TF-IDF +def tfidf(tokens, idfsDictionary): + """ Calculate TF-IDF values for token list + Args: + tokens (list): list of tokens + idfsDictionary (dict): IDF values + Returns: + dict: dictionary of token -> TF-IDF value + """ + tf = {} + for token in tokens: + tf[token] = tf.get(token, 0) + 1 + total_tokens = len(tokens) + for token in tf: + tf[token] = tf[token] / total_tokens + tfidf = {token: tf[token] * idfsDictionary.get(token, 0) for token in tf} + return tfidf + +# 5. 余弦相似度计算函数 +def cosineSimilarity(string1, string2, idfsDictionary): + """ Compute cosine similarity between two strings using TF-IDF weights + Args: + string1 (str): first string + string2 (str): second string + idfsDictionary (dict): IDF dictionary + Returns: + float: cosine similarity + """ + tokens1 = tokenize(string1) + tokens2 = tokenize(string2) + tfidf1 = tfidf(tokens1, idfsDictionary) + tfidf2 = tfidf(tokens2, idfsDictionary) + return cossim(tfidf1, tfidf2) + +# 6. Tokenize function (split by spaces or punctuation) +def tokenize(text): + """ Tokenizes a string into a list of words + Args: + text (str): input string + Returns: + list: list of tokens (words) + """ + text = re.sub(r'[^\w\s]', '', text.lower()) + return text.split() + +# 7. 计算相似度的RDD处理函数 +def computeSimilarity(record, idfsDictionary): + """ Compute similarity on a combination record + Args: + record (tuple): (google record, amazon record) + idfsDictionary (dict): IDF dictionary + Returns: + tuple: (google URL, amazon ID, cosine similarity value) + """ + googleRec = record[0] + amazonRec = record[1] + googleURL = googleRec[0] + amazonID = amazonRec[0] + googleValue = googleRec[1] + amazonValue = amazonRec[1] + cs = cosineSimilarity(googleValue, amazonValue, idfsDictionary) + return (googleURL, amazonID, cs) + +# 8. 解析黄金标准数据 +def parse_goldfile_line(goldfile_line): + """ Parse a line from the 'golden standard' data file + Args: + goldfile_line (str): a line of data + Returns: + tuple: ((key, 'gold', 1 if successful or else 0)) + """ + GOLDFILE_PATTERN = '^(.+),(.+)' + match = re.search(GOLDFILE_PATTERN, goldfile_line) + if match is None: + print(f'Invalid goldfile line: {goldfile_line}') + return (goldfile_line, -1) + elif match.group(1) == '"idAmazon"': + print(f'Header datafile line: {goldfile_line}') + return (goldfile_line, 0) + else: + key = f'{match.group(1)} {match.group(2)}' + return ((key, 'gold'), 1) + +# 9. 使用广播变量提高效率 +def computeSimilarityBroadcast(record, idfsBroadcast): + """ Compute similarity on a combination record, using Broadcast variable + Args: + record (tuple): (google record, amazon record) + idfsBroadcast (Broadcast): broadcasted IDF dictionary + Returns: + tuple: (google URL, amazon ID, cosine similarity value) + """ + googleRec = record[0] + amazonRec = record[1] + googleURL = googleRec[0] + amazonID = amazonRec[0] + googleValue = googleRec[1] + amazonValue = amazonRec[1] + cs = cosineSimilarity(googleValue, amazonValue, idfsBroadcast.value) + return (googleURL, amazonID, cs) + +# 主函数,设置Spark上下文和广播变量 +if __name__ == "__main__": + # 创建SparkSession + spark = SparkSession.builder \ + .appName("TextSimilarity") \ + .getOrCreate() + + sc = spark.sparkContext + + # HDFS路径 + amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv" + google_path = "hdfs://master:9000/user/root/Google_small.csv" + + # 假设的IDF权重字典 + idfsDictionary = { + "hello": 1.2, + "world": 1.3, + "goodbye": 1.1, + "photoshop": 2.5, + "illustrator": 2.7 + } + + # 创建广播变量 + idfsBroadcast = sc.broadcast(idfsDictionary) + + # 加载CSV数据 + amazon_data = spark.read.csv(amazon_path, header=True, inferSchema=True) + google_data = spark.read.csv(google_path, header=True, inferSchema=True) + + # 假设的列名,根据实际数据进行调整 + amazon_small = amazon_data.select("asin", "description").rdd.map(lambda x: (x[0], x[1])) + google_small = google_data.select("url", "description").rdd.map(lambda x: (x[0], x[1])) + + # 计算相似度 + cross_small = amazon_small.cartesian(google_small) + similarities = cross_small.map(lambda x: computeSimilarityBroadcast(x, idfsBroadcast)) + + # 打印结果 + similarities.collect() + + # 关闭Spark上下文 + sc.stop() + spark.stop()