feat(similarity): 添加文本相似度计算功能
- 实现了计算两个文本之间相似度的完整流程 - 包括 TF-IDF 计算、余弦相似度计算等功能 - 使用 Spark 广播变量优化计算效率 -支持从 HDFS 读取数据进行计算
This commit is contained in:
parent
1043551309
commit
250e1b99e0
186
4-1.py
Normal file
186
4-1.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import math
|
||||||
|
import re
|
||||||
|
from pyspark import SparkContext, Broadcast
|
||||||
|
from pyspark.sql import SparkSession
|
||||||
|
|
||||||
|
# 1. 计算点积
|
||||||
|
def dotprod(a, b):
|
||||||
|
""" Compute dot product
|
||||||
|
Args:
|
||||||
|
a (dict): first dictionary of record to value
|
||||||
|
b (dict): second dictionary of record to value
|
||||||
|
Returns:
|
||||||
|
float: result of the dot product with the two input dictionaries
|
||||||
|
"""
|
||||||
|
sum = 0.0
|
||||||
|
for k1, v1 in a.items():
|
||||||
|
for k2, v2 in b.items():
|
||||||
|
if k1 == k2:
|
||||||
|
sum += v1 * v2
|
||||||
|
return sum
|
||||||
|
|
||||||
|
# 2. 计算范数
|
||||||
|
def norm(a):
|
||||||
|
""" Compute square root of the dot product
|
||||||
|
Args:
|
||||||
|
a (dict): a dictionary of record to value
|
||||||
|
Returns:
|
||||||
|
float: norm of the dictionary
|
||||||
|
"""
|
||||||
|
return math.sqrt(dotprod(a, a))
|
||||||
|
|
||||||
|
# 3. 计算余弦相似度
|
||||||
|
def cossim(a, b):
|
||||||
|
""" Compute cosine similarity
|
||||||
|
Args:
|
||||||
|
a (dict): first dictionary of record to value
|
||||||
|
b (dict): second dictionary of record to value
|
||||||
|
Returns:
|
||||||
|
float: cosine similarity value
|
||||||
|
"""
|
||||||
|
return dotprod(a, b) / (norm(a) * norm(b))
|
||||||
|
|
||||||
|
# 4. 计算TF-IDF
|
||||||
|
def tfidf(tokens, idfsDictionary):
|
||||||
|
""" Calculate TF-IDF values for token list
|
||||||
|
Args:
|
||||||
|
tokens (list): list of tokens
|
||||||
|
idfsDictionary (dict): IDF values
|
||||||
|
Returns:
|
||||||
|
dict: dictionary of token -> TF-IDF value
|
||||||
|
"""
|
||||||
|
tf = {}
|
||||||
|
for token in tokens:
|
||||||
|
tf[token] = tf.get(token, 0) + 1
|
||||||
|
total_tokens = len(tokens)
|
||||||
|
for token in tf:
|
||||||
|
tf[token] = tf[token] / total_tokens
|
||||||
|
tfidf = {token: tf[token] * idfsDictionary.get(token, 0) for token in tf}
|
||||||
|
return tfidf
|
||||||
|
|
||||||
|
# 5. 余弦相似度计算函数
|
||||||
|
def cosineSimilarity(string1, string2, idfsDictionary):
|
||||||
|
""" Compute cosine similarity between two strings using TF-IDF weights
|
||||||
|
Args:
|
||||||
|
string1 (str): first string
|
||||||
|
string2 (str): second string
|
||||||
|
idfsDictionary (dict): IDF dictionary
|
||||||
|
Returns:
|
||||||
|
float: cosine similarity
|
||||||
|
"""
|
||||||
|
tokens1 = tokenize(string1)
|
||||||
|
tokens2 = tokenize(string2)
|
||||||
|
tfidf1 = tfidf(tokens1, idfsDictionary)
|
||||||
|
tfidf2 = tfidf(tokens2, idfsDictionary)
|
||||||
|
return cossim(tfidf1, tfidf2)
|
||||||
|
|
||||||
|
# 6. Tokenize function (split by spaces or punctuation)
|
||||||
|
def tokenize(text):
|
||||||
|
""" Tokenizes a string into a list of words
|
||||||
|
Args:
|
||||||
|
text (str): input string
|
||||||
|
Returns:
|
||||||
|
list: list of tokens (words)
|
||||||
|
"""
|
||||||
|
text = re.sub(r'[^\w\s]', '', text.lower())
|
||||||
|
return text.split()
|
||||||
|
|
||||||
|
# 7. 计算相似度的RDD处理函数
|
||||||
|
def computeSimilarity(record, idfsDictionary):
|
||||||
|
""" Compute similarity on a combination record
|
||||||
|
Args:
|
||||||
|
record (tuple): (google record, amazon record)
|
||||||
|
idfsDictionary (dict): IDF dictionary
|
||||||
|
Returns:
|
||||||
|
tuple: (google URL, amazon ID, cosine similarity value)
|
||||||
|
"""
|
||||||
|
googleRec = record[0]
|
||||||
|
amazonRec = record[1]
|
||||||
|
googleURL = googleRec[0]
|
||||||
|
amazonID = amazonRec[0]
|
||||||
|
googleValue = googleRec[1]
|
||||||
|
amazonValue = amazonRec[1]
|
||||||
|
cs = cosineSimilarity(googleValue, amazonValue, idfsDictionary)
|
||||||
|
return (googleURL, amazonID, cs)
|
||||||
|
|
||||||
|
# 8. 解析黄金标准数据
|
||||||
|
def parse_goldfile_line(goldfile_line):
|
||||||
|
""" Parse a line from the 'golden standard' data file
|
||||||
|
Args:
|
||||||
|
goldfile_line (str): a line of data
|
||||||
|
Returns:
|
||||||
|
tuple: ((key, 'gold', 1 if successful or else 0))
|
||||||
|
"""
|
||||||
|
GOLDFILE_PATTERN = '^(.+),(.+)'
|
||||||
|
match = re.search(GOLDFILE_PATTERN, goldfile_line)
|
||||||
|
if match is None:
|
||||||
|
print(f'Invalid goldfile line: {goldfile_line}')
|
||||||
|
return (goldfile_line, -1)
|
||||||
|
elif match.group(1) == '"idAmazon"':
|
||||||
|
print(f'Header datafile line: {goldfile_line}')
|
||||||
|
return (goldfile_line, 0)
|
||||||
|
else:
|
||||||
|
key = f'{match.group(1)} {match.group(2)}'
|
||||||
|
return ((key, 'gold'), 1)
|
||||||
|
|
||||||
|
# 9. 使用广播变量提高效率
|
||||||
|
def computeSimilarityBroadcast(record, idfsBroadcast):
|
||||||
|
""" Compute similarity on a combination record, using Broadcast variable
|
||||||
|
Args:
|
||||||
|
record (tuple): (google record, amazon record)
|
||||||
|
idfsBroadcast (Broadcast): broadcasted IDF dictionary
|
||||||
|
Returns:
|
||||||
|
tuple: (google URL, amazon ID, cosine similarity value)
|
||||||
|
"""
|
||||||
|
googleRec = record[0]
|
||||||
|
amazonRec = record[1]
|
||||||
|
googleURL = googleRec[0]
|
||||||
|
amazonID = amazonRec[0]
|
||||||
|
googleValue = googleRec[1]
|
||||||
|
amazonValue = amazonRec[1]
|
||||||
|
cs = cosineSimilarity(googleValue, amazonValue, idfsBroadcast.value)
|
||||||
|
return (googleURL, amazonID, cs)
|
||||||
|
|
||||||
|
# 主函数,设置Spark上下文和广播变量
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 创建SparkSession
|
||||||
|
spark = SparkSession.builder \
|
||||||
|
.appName("TextSimilarity") \
|
||||||
|
.getOrCreate()
|
||||||
|
|
||||||
|
sc = spark.sparkContext
|
||||||
|
|
||||||
|
# HDFS路径
|
||||||
|
amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv"
|
||||||
|
google_path = "hdfs://master:9000/user/root/Google_small.csv"
|
||||||
|
|
||||||
|
# 假设的IDF权重字典
|
||||||
|
idfsDictionary = {
|
||||||
|
"hello": 1.2,
|
||||||
|
"world": 1.3,
|
||||||
|
"goodbye": 1.1,
|
||||||
|
"photoshop": 2.5,
|
||||||
|
"illustrator": 2.7
|
||||||
|
}
|
||||||
|
|
||||||
|
# 创建广播变量
|
||||||
|
idfsBroadcast = sc.broadcast(idfsDictionary)
|
||||||
|
|
||||||
|
# 加载CSV数据
|
||||||
|
amazon_data = spark.read.csv(amazon_path, header=True, inferSchema=True)
|
||||||
|
google_data = spark.read.csv(google_path, header=True, inferSchema=True)
|
||||||
|
|
||||||
|
# 假设的列名,根据实际数据进行调整
|
||||||
|
amazon_small = amazon_data.select("asin", "description").rdd.map(lambda x: (x[0], x[1]))
|
||||||
|
google_small = google_data.select("url", "description").rdd.map(lambda x: (x[0], x[1]))
|
||||||
|
|
||||||
|
# 计算相似度
|
||||||
|
cross_small = amazon_small.cartesian(google_small)
|
||||||
|
similarities = cross_small.map(lambda x: computeSimilarityBroadcast(x, idfsBroadcast))
|
||||||
|
|
||||||
|
# 打印结果
|
||||||
|
similarities.collect()
|
||||||
|
|
||||||
|
# 关闭Spark上下文
|
||||||
|
sc.stop()
|
||||||
|
spark.stop()
|
Loading…
Reference in New Issue
Block a user