- 实现了计算两个文本之间相似度的完整流程 - 包括 TF-IDF 计算、余弦相似度计算等功能 - 使用 Spark 广播变量优化计算效率 -支持从 HDFS 读取数据进行计算
187 lines
5.6 KiB
Python
187 lines
5.6 KiB
Python
import math
|
||
import re
|
||
from pyspark import SparkContext, Broadcast
|
||
from pyspark.sql import SparkSession
|
||
|
||
# 1. 计算点积
|
||
def dotprod(a, b):
|
||
""" Compute dot product
|
||
Args:
|
||
a (dict): first dictionary of record to value
|
||
b (dict): second dictionary of record to value
|
||
Returns:
|
||
float: result of the dot product with the two input dictionaries
|
||
"""
|
||
sum = 0.0
|
||
for k1, v1 in a.items():
|
||
for k2, v2 in b.items():
|
||
if k1 == k2:
|
||
sum += v1 * v2
|
||
return sum
|
||
|
||
# 2. 计算范数
|
||
def norm(a):
|
||
""" Compute square root of the dot product
|
||
Args:
|
||
a (dict): a dictionary of record to value
|
||
Returns:
|
||
float: norm of the dictionary
|
||
"""
|
||
return math.sqrt(dotprod(a, a))
|
||
|
||
# 3. 计算余弦相似度
|
||
def cossim(a, b):
|
||
""" Compute cosine similarity
|
||
Args:
|
||
a (dict): first dictionary of record to value
|
||
b (dict): second dictionary of record to value
|
||
Returns:
|
||
float: cosine similarity value
|
||
"""
|
||
return dotprod(a, b) / (norm(a) * norm(b))
|
||
|
||
# 4. 计算TF-IDF
|
||
def tfidf(tokens, idfsDictionary):
|
||
""" Calculate TF-IDF values for token list
|
||
Args:
|
||
tokens (list): list of tokens
|
||
idfsDictionary (dict): IDF values
|
||
Returns:
|
||
dict: dictionary of token -> TF-IDF value
|
||
"""
|
||
tf = {}
|
||
for token in tokens:
|
||
tf[token] = tf.get(token, 0) + 1
|
||
total_tokens = len(tokens)
|
||
for token in tf:
|
||
tf[token] = tf[token] / total_tokens
|
||
tfidf = {token: tf[token] * idfsDictionary.get(token, 0) for token in tf}
|
||
return tfidf
|
||
|
||
# 5. 余弦相似度计算函数
|
||
def cosineSimilarity(string1, string2, idfsDictionary):
|
||
""" Compute cosine similarity between two strings using TF-IDF weights
|
||
Args:
|
||
string1 (str): first string
|
||
string2 (str): second string
|
||
idfsDictionary (dict): IDF dictionary
|
||
Returns:
|
||
float: cosine similarity
|
||
"""
|
||
tokens1 = tokenize(string1)
|
||
tokens2 = tokenize(string2)
|
||
tfidf1 = tfidf(tokens1, idfsDictionary)
|
||
tfidf2 = tfidf(tokens2, idfsDictionary)
|
||
return cossim(tfidf1, tfidf2)
|
||
|
||
# 6. Tokenize function (split by spaces or punctuation)
|
||
def tokenize(text):
|
||
""" Tokenizes a string into a list of words
|
||
Args:
|
||
text (str): input string
|
||
Returns:
|
||
list: list of tokens (words)
|
||
"""
|
||
text = re.sub(r'[^\w\s]', '', text.lower())
|
||
return text.split()
|
||
|
||
# 7. 计算相似度的RDD处理函数
|
||
def computeSimilarity(record, idfsDictionary):
|
||
""" Compute similarity on a combination record
|
||
Args:
|
||
record (tuple): (google record, amazon record)
|
||
idfsDictionary (dict): IDF dictionary
|
||
Returns:
|
||
tuple: (google URL, amazon ID, cosine similarity value)
|
||
"""
|
||
googleRec = record[0]
|
||
amazonRec = record[1]
|
||
googleURL = googleRec[0]
|
||
amazonID = amazonRec[0]
|
||
googleValue = googleRec[1]
|
||
amazonValue = amazonRec[1]
|
||
cs = cosineSimilarity(googleValue, amazonValue, idfsDictionary)
|
||
return (googleURL, amazonID, cs)
|
||
|
||
# 8. 解析黄金标准数据
|
||
def parse_goldfile_line(goldfile_line):
|
||
""" Parse a line from the 'golden standard' data file
|
||
Args:
|
||
goldfile_line (str): a line of data
|
||
Returns:
|
||
tuple: ((key, 'gold', 1 if successful or else 0))
|
||
"""
|
||
GOLDFILE_PATTERN = '^(.+),(.+)'
|
||
match = re.search(GOLDFILE_PATTERN, goldfile_line)
|
||
if match is None:
|
||
print(f'Invalid goldfile line: {goldfile_line}')
|
||
return (goldfile_line, -1)
|
||
elif match.group(1) == '"idAmazon"':
|
||
print(f'Header datafile line: {goldfile_line}')
|
||
return (goldfile_line, 0)
|
||
else:
|
||
key = f'{match.group(1)} {match.group(2)}'
|
||
return ((key, 'gold'), 1)
|
||
|
||
# 9. 使用广播变量提高效率
|
||
def computeSimilarityBroadcast(record, idfsBroadcast):
|
||
""" Compute similarity on a combination record, using Broadcast variable
|
||
Args:
|
||
record (tuple): (google record, amazon record)
|
||
idfsBroadcast (Broadcast): broadcasted IDF dictionary
|
||
Returns:
|
||
tuple: (google URL, amazon ID, cosine similarity value)
|
||
"""
|
||
googleRec = record[0]
|
||
amazonRec = record[1]
|
||
googleURL = googleRec[0]
|
||
amazonID = amazonRec[0]
|
||
googleValue = googleRec[1]
|
||
amazonValue = amazonRec[1]
|
||
cs = cosineSimilarity(googleValue, amazonValue, idfsBroadcast.value)
|
||
return (googleURL, amazonID, cs)
|
||
|
||
# 主函数,设置Spark上下文和广播变量
|
||
if __name__ == "__main__":
|
||
# 创建SparkSession
|
||
spark = SparkSession.builder \
|
||
.appName("TextSimilarity") \
|
||
.getOrCreate()
|
||
|
||
sc = spark.sparkContext
|
||
|
||
# HDFS路径
|
||
amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv"
|
||
google_path = "hdfs://master:9000/user/root/Google_small.csv"
|
||
|
||
# 假设的IDF权重字典
|
||
idfsDictionary = {
|
||
"hello": 1.2,
|
||
"world": 1.3,
|
||
"goodbye": 1.1,
|
||
"photoshop": 2.5,
|
||
"illustrator": 2.7
|
||
}
|
||
|
||
# 创建广播变量
|
||
idfsBroadcast = sc.broadcast(idfsDictionary)
|
||
|
||
# 加载CSV数据
|
||
amazon_data = spark.read.csv(amazon_path, header=True, inferSchema=True)
|
||
google_data = spark.read.csv(google_path, header=True, inferSchema=True)
|
||
|
||
# 假设的列名,根据实际数据进行调整
|
||
amazon_small = amazon_data.select("asin", "description").rdd.map(lambda x: (x[0], x[1]))
|
||
google_small = google_data.select("url", "description").rdd.map(lambda x: (x[0], x[1]))
|
||
|
||
# 计算相似度
|
||
cross_small = amazon_small.cartesian(google_small)
|
||
similarities = cross_small.map(lambda x: computeSimilarityBroadcast(x, idfsBroadcast))
|
||
|
||
# 打印结果
|
||
similarities.collect()
|
||
|
||
# 关闭Spark上下文
|
||
sc.stop()
|
||
spark.stop()
|