refactor(4-1):重构数据加载和解析逻辑

- 移除了不必要的导入和未使用的代码
- 新增了 parseData 和 loadData 函数,用于解析和加载数据文件
- 优化了数据解析的正则表达式和逻辑
- 简化了代码结构,提高了可读性和可维护性
This commit is contained in:
fly6516 2025-04-20 02:32:18 +08:00
parent e84c0ff633
commit 15fcc21975

229
4-1.py
View File

@ -1,186 +1,71 @@
import math
import re
from pyspark import SparkContext, Broadcast
from pyspark.sql import SparkSession
import os
from pyspark import SparkContext
# 1. 计算点积
def dotprod(a, b):
""" Compute dot product
Args:
a (dict): first dictionary of record to value
b (dict): second dictionary of record to value
Returns:
float: result of the dot product with the two input dictionaries
"""
sum = 0.0
for k1, v1 in a.items():
for k2, v2 in b.items():
if k1 == k2:
sum += v1 * v2
return sum
# 初始化 SparkContext
sc = SparkContext(appName="TextAnalysis")
# 2. 计算范数
def norm(a):
""" Compute square root of the dot product
Args:
a (dict): a dictionary of record to value
Returns:
float: norm of the dictionary
"""
return math.sqrt(dotprod(a, a))
# 定义数据文件路径
GOOGLE_PATH = 'Google.csv'
GOOGLE_SMALL_PATH = 'Google_small.csv'
AMAZON_PATH = 'Amazon.csv'
AMAZON_SMALL_PATH = 'Amazon_small.csv'
STOPWORDS_PATH = 'stopwords.txt'
# 3. 计算余弦相似度
def cossim(a, b):
""" Compute cosine similarity
Args:
a (dict): first dictionary of record to value
b (dict): second dictionary of record to value
Returns:
float: cosine similarity value
"""
return dotprod(a, b) / (norm(a) * norm(b))
# 定义正则表达式模式,用于解析数据行
DATAFILE_PATTERN = '^(.+),"(.+)",(.*),(.*),(.*)'
# 4. 计算TF-IDF
def tfidf(tokens, idfsDictionary):
""" Calculate TF-IDF values for token list
Args:
tokens (list): list of tokens
idfsDictionary (dict): IDF values
Returns:
dict: dictionary of token -> TF-IDF value
"""
tf = {}
for token in tokens:
tf[token] = tf.get(token, 0) + 1
total_tokens = len(tokens)
for token in tf:
tf[token] = tf[token] / total_tokens
tfidf = {token: tf[token] * idfsDictionary.get(token, 0) for token in tf}
return tfidf
def removeQuotes(s):
""" 去掉输入字符串中的引号 """
return ''.join(i for i in s if i!='"')
# 5. 余弦相似度计算函数
def cosineSimilarity(string1, string2, idfsDictionary):
""" Compute cosine similarity between two strings using TF-IDF weights
Args:
string1 (str): first string
string2 (str): second string
idfsDictionary (dict): IDF dictionary
Returns:
float: cosine similarity
"""
tokens1 = tokenize(string1)
tokens2 = tokenize(string2)
tfidf1 = tfidf(tokens1, idfsDictionary)
tfidf2 = tfidf(tokens2, idfsDictionary)
return cossim(tfidf1, tfidf2)
# 6. Tokenize function (split by spaces or punctuation)
def tokenize(text):
""" Tokenizes a string into a list of words
Args:
text (str): input string
Returns:
list: list of tokens (words)
"""
text = re.sub(r'[^\w\s]', '', text.lower())
return text.split()
# 7. 计算相似度的RDD处理函数
def computeSimilarity(record, idfsDictionary):
""" Compute similarity on a combination record
Args:
record (tuple): (google record, amazon record)
idfsDictionary (dict): IDF dictionary
Returns:
tuple: (google URL, amazon ID, cosine similarity value)
"""
googleRec = record[0]
amazonRec = record[1]
googleURL = googleRec[0]
amazonID = amazonRec[0]
googleValue = googleRec[1]
amazonValue = amazonRec[1]
cs = cosineSimilarity(googleValue, amazonValue, idfsDictionary)
return (googleURL, amazonID, cs)
# 8. 解析黄金标准数据
def parse_goldfile_line(goldfile_line):
""" Parse a line from the 'golden standard' data file
Args:
goldfile_line (str): a line of data
Returns:
tuple: ((key, 'gold', 1 if successful or else 0))
"""
GOLDFILE_PATTERN = '^(.+),(.+)'
match = re.search(GOLDFILE_PATTERN, goldfile_line)
def parseDatafileLine(datafileLine):
""" 解析数据文件中的每一行 """
match = re.search(DATAFILE_PATTERN, str(datafileLine))
if match is None:
print('Invalid goldfile line: {0}'.format(goldfile_line))
return (goldfile_line, -1)
elif match.group(1) == '"idAmazon"':
print('Header datafile line: {0}'.format(goldfile_line))
return (goldfile_line, 0)
print('Invalid datafile line: %s' % datafileLine)
return (datafileLine, -1)
elif match.group(1) == '"id"':
print('Header datafile line: %s' % datafileLine)
return (datafileLine, 0)
else:
key = '{0} {1}'.format(match.group(1), match.group(2))
return ((key, 'gold'), 1)
product = '%s %s %s' % (match.group(2), match.group(3), match.group(4))
return ((removeQuotes(match.group(1)), product), 1)
# 9. 使用广播变量提高效率
def computeSimilarityBroadcast(record, idfsBroadcast):
""" Compute similarity on a combination record, using Broadcast variable
Args:
record (tuple): (google record, amazon record)
idfsBroadcast (Broadcast): broadcasted IDF dictionary
Returns:
tuple: (google URL, amazon ID, cosine similarity value)
"""
googleRec = record[0]
amazonRec = record[1]
googleURL = googleRec[0]
amazonID = amazonRec[0]
googleValue = googleRec[1]
amazonValue = amazonRec[1]
cs = cosineSimilarity(googleValue, amazonValue, idfsBroadcast.value)
return (googleURL, amazonID, cs)
def parseData(filename):
""" 解析数据文件 """
return (sc
.textFile(filename, 4, 0)
.map(parseDatafileLine)
.cache())
# 主函数设置Spark上下文和广播变量
if __name__ == "__main__":
# 创建SparkSession
spark = SparkSession.builder \
.appName("TextSimilarity") \
.getOrCreate()
def loadData(path):
""" 加载数据文件 """
filename = path
raw = parseData(filename).cache()
failed = (raw
.filter(lambda s: s[1] == -1)
.map(lambda s: s[0]))
for line in failed.take(1):
print ('{0} - Invalid datafile line: {1}'.format(path, line))
valid = (raw
.filter(lambda s: s[1] == 1)
.map(lambda s: s[0])
.cache())
print ('{0} - Read {1} lines, successfully parsed {2} lines, failed to parse {3} lines'.format(path,raw.count(),valid.count(),failed.count()))
return valid
sc = spark.sparkContext
# 加载数据
googleSmall = loadData(GOOGLE_SMALL_PATH)
google = loadData(GOOGLE_PATH)
amazonSmall = loadData(AMAZON_SMALL_PATH)
amazon = loadData(AMAZON_PATH)
# HDFS路径
amazon_path = "hdfs://master:9000/user/root/Amazon_small.csv"
google_path = "hdfs://master:9000/user/root/Google_small.csv"
# 打印部分数据以检查
for line in googleSmall.take(3):
print ('google: %s: %s\n' % (line[0], line[1]))
# 假设的IDF权重字典
idfsDictionary = {
"hello": 1.2,
"world": 1.3,
"goodbye": 1.1,
"photoshop": 2.5,
"illustrator": 2.7
}
for line in amazonSmall.take(3):
print ('amazon: %s: %s\n' % (line[0], line[1]))
# 创建广播变量
idfsBroadcast = sc.broadcast(idfsDictionary)
# 加载CSV数据
amazon_data = spark.read.csv(amazon_path, header=True, inferSchema=True)
google_data = spark.read.csv(google_path, header=True, inferSchema=True)
# 假设的列名,根据实际数据进行调整
amazon_small = amazon_data.select("asin", "description").rdd.map(lambda x: (x[0], x[1]))
google_small = google_data.select("url", "description").rdd.map(lambda x: (x[0], x[1]))
# 计算相似度
cross_small = amazon_small.cartesian(google_small)
similarities = cross_small.map(lambda x: computeSimilarityBroadcast(x, idfsBroadcast))
# 打印结果
similarities.collect()
# 关闭Spark上下文
sc.stop()
spark.stop()
# 假设数据现在已经正确加载,你可以继续后续的分析