feat(similarity): 添加文本相似度计算功能
- 实现了计算两个文本之间相似度的完整流程 - 包括 TF-IDF 计算、余弦相似度计算等功能 - 使用 Spark 广播变量优化计算效率 -支持从 HDFS 读取数据进行计算
This commit is contained in:
parent
1043551309
commit
250e1b99e0
186
4-1.py
Normal file
186
4-1.py
Normal file
@ -0,0 +1,186 @@
|
||||
import math
|
||||
import re
|
||||
from pyspark import SparkContext, Broadcast
|
||||
from pyspark.sql import SparkSession
|
||||
|
||||
# 1. 计算点积
|
||||
def dotprod(a, b):
|
||||
""" Compute dot product
|
||||
Args:
|
||||
a (dict): first dictionary of record to value
|
||||
b (dict): second dictionary of record to value
|
||||
Returns:
|
||||
float: result of the dot product with the two input dictionaries
|
||||
"""
|
||||
sum = 0.0
|
||||
for k1, v1 in a.items():
|
||||
for k2, v2 in b.items():
|
||||
if k1 == k2:
|
||||
sum += v1 * v2
|
||||
return sum
|
||||
|
||||
# 2. 计算范数
|
||||
def norm(a):
|
||||
""" Compute square root of the dot product
|
||||
Args:
|
||||
a (dict): a dictionary of record to value
|
||||
Returns:
|
||||
float: norm of the dictionary
|
||||
"""
|
||||
return math.sqrt(dotprod(a, a))
|
||||
|
||||
# 3. 计算余弦相似度
|
||||
def cossim(a, b):
|
||||
""" Compute cosine similarity
|
||||
Args:
|
||||
a (dict): first dictionary of record to value
|
||||
b (dict): second dictionary of record to value
|
||||
Returns:
|
||||
float: cosine similarity value
|
||||
"""
|
||||
return dotprod(a, b) / (norm(a) * norm(b))
|
||||
|
||||
# 4. 计算TF-IDF
|
||||
def tfidf(tokens, idfsDictionary):
|
||||
""" Calculate TF-IDF values for token list
|
||||
Args:
|
||||
tokens (list): list of tokens
|
||||
idfsDictionary (dict): IDF values
|
||||
Returns:
|
||||
dict: dictionary of token -> TF-IDF value
|
||||
"""
|
||||
tf = {}
|
||||
for token in tokens:
|
||||
tf[token] = tf.get(token, 0) + 1
|
||||
total_tokens = len(tokens)
|
||||
for token in tf:
|
||||
tf[token] = tf[token] / total_tokens
|
||||
tfidf = {token: tf[token] * idfsDictionary.get(token, 0) for token in tf}
|
||||
return tfidf
|
||||
|
||||
# 5. 余弦相似度计算函数
|
||||
def cosineSimilarity(string1, string2, idfsDictionary):
|
||||
""" Compute cosine similarity between two strings using TF-IDF weights
|
||||
Args:
|
||||
string1 (str): first string
|
||||
string2 (str): second string
|
||||
idfsDictionary (dict): IDF dictionary
|
||||
Returns:
|
||||
float: cosine similarity
|
||||
"""
|
||||
tokens1 = tokenize(string1)
|
||||
tokens2 = tokenize(string2)
|
||||
tfidf1 = tfidf(tokens1, idfsDictionary)
|
||||
tfidf2 = tfidf(tokens2, idfsDictionary)
|
||||
return cossim(tfidf1, tfidf2)
|
||||
|
||||
# 6. Tokenize function (split by spaces or punctuation)
|
||||
def tokenize(text):
|
||||
""" Tokenizes a string into a list of words
|
||||
Args:
|
||||
text (str): input string
|
||||
Returns:
|
||||
list: list of tokens (words)
|
||||
"""
|
||||
text = re.sub(r'[^\w\s]', '', text.lower())
|
||||
return text.split()
|
||||
|
||||
# 7. 计算相似度的RDD处理函数
|
||||
def computeSimilarity(record, idfsDictionary):
|
||||
""" Compute similarity on a combination record
|
||||
Args:
|
||||
record (tuple): (google record, amazon record)
|
||||
idfsDictionary (dict): IDF dictionary
|
||||
Returns:
|
||||
tuple: (google URL, amazon ID, cosine similarity value)
|
||||
"""
|
||||
googleRec = record[0]
|
||||
amazonRec = record[1]
|
||||
googleURL = googleRec[0]
|
||||
amazonID = amazonRec[0]
|
||||
googleValue = googleRec[1]
|
||||
amazonValue = amazonRec[1]
|
||||
cs = cosineSimilarity(googleValue, amazonValue, idfsDictionary)
|
||||
return (googleURL, amazonID, cs)
|
||||
|
||||
# 8. 解析黄金标准数据
|
||||
def parse_goldfile_line(goldfile_line):
|
||||
""" Parse a line from the 'golden standard' data file
|
||||
Args:
|
||||
goldfile_line (str): a line of data
|
||||
Returns:
|
||||
tuple: ((key, 'gold', 1 if successful or else 0))
|
||||
"""
|
||||
GOLDFILE_PATTERN = '^(.+),(.+)'
|
||||
match = re.search(GOLDFILE_PATTERN, goldfile_line)
|
||||
if match is None:
|
||||
print(f'Invalid goldfile line: {goldfile_line}')
|
||||
return (goldfile_line, -1)
|
||||
elif match.group(1) == '"idAmazon"':
|
||||
print(f'Header datafile line: {goldfile_line}')
|
||||
return (goldfile_line, 0)
|
||||
else:
|
||||
key = f'{match.group(1)} {match.group(2)}'
|
||||
return ((key, 'gold'), 1)
|
||||
|
||||
# 9. 使用广播变量提高效率
|
||||
def computeSimilarityBroadcast(record, idfsBroadcast):
|
||||
""" Compute similarity on a combination record, using Broadcast variable
|
||||
Args:
|
||||
record (tuple): (google record, amazon record)
|
||||
idfsBroadcast (Broadcast): broadcasted IDF dictionary
|
||||
Returns:
|
||||
tuple: (google URL, amazon ID, cosine similarity value)
|
||||
"""
|
||||
googleRec = record[0]
|
||||
amazonRec = record[1]
|
||||
googleURL = googleRec[0]
|
||||
amazonID = amazonRec[0]
|
||||
googleValue = googleRec[1]
|
||||
amazonValue = amazonRec[1]
|
||||
cs = cosineSimilarity(googleValue, amazonValue, idfsBroadcast.value)
|
||||
return (googleURL, amazonID, cs)
|
||||
|
||||
# 主函数,设置Spark上下文和广播变量
|
||||
if __name__ == "__main__":
|
||||
# 创建SparkSession
|
||||
spark = SparkSession.builder \
|
||||
.appName("TextSimilarity") \
|
||||
.getOrCreate()
|
||||
|
||||
sc = spark.sparkContext
|
||||
|
||||
# HDFS路径
|
||||
amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv"
|
||||
google_path = "hdfs://master:9000/user/root/Google_small.csv"
|
||||
|
||||
# 假设的IDF权重字典
|
||||
idfsDictionary = {
|
||||
"hello": 1.2,
|
||||
"world": 1.3,
|
||||
"goodbye": 1.1,
|
||||
"photoshop": 2.5,
|
||||
"illustrator": 2.7
|
||||
}
|
||||
|
||||
# 创建广播变量
|
||||
idfsBroadcast = sc.broadcast(idfsDictionary)
|
||||
|
||||
# 加载CSV数据
|
||||
amazon_data = spark.read.csv(amazon_path, header=True, inferSchema=True)
|
||||
google_data = spark.read.csv(google_path, header=True, inferSchema=True)
|
||||
|
||||
# 假设的列名,根据实际数据进行调整
|
||||
amazon_small = amazon_data.select("asin", "description").rdd.map(lambda x: (x[0], x[1]))
|
||||
google_small = google_data.select("url", "description").rdd.map(lambda x: (x[0], x[1]))
|
||||
|
||||
# 计算相似度
|
||||
cross_small = amazon_small.cartesian(google_small)
|
||||
similarities = cross_small.map(lambda x: computeSimilarityBroadcast(x, idfsBroadcast))
|
||||
|
||||
# 打印结果
|
||||
similarities.collect()
|
||||
|
||||
# 关闭Spark上下文
|
||||
sc.stop()
|
||||
spark.stop()
|
Loading…
Reference in New Issue
Block a user