refactor(data_prepare): 重构数据准备脚本

- 优化了 ratings 和 movies 数据的解析逻辑- 添加了排序函数和测试用例,用于验证排序功能
- 更新了 Spark 集群连接和 Python 版本设置
- 改进了代码格式和变量命名,提高了可读性
This commit is contained in:
fly6516 2025-04-22 13:16:54 +08:00
parent 3cb7ec6dba
commit 554928b81f

View File

@ -1,40 +1,71 @@
from pyspark import SparkContext, SparkConf
import os
os.environ['JAVA_HOME'] = "/opt/module/jdk1.8.0_171"
# 设置 Java 环境变量
os.environ['JAVA_HOME'] = '/opt/module/jdk1.8.0_171'
# 解析 ratings 行为 (userID, movieID, rating)
def get_ratings_tuple(entry):
user, movie, rating, _ = entry.split('::')
return int(user), int(movie), float(rating)
items = entry.split('::')
return int(items[0]), int(items[1]), float(items[2])
# 解析 movies 行为 (movieID, title)
def get_movie_tuple(entry):
mid, title, _ = entry.split('::')
return int(mid), title
items = entry.split('::')
return int(items[0]), items[1]
# 用于排序时生成确定性键
def sort_key(rec):
score, name = rec
return f"{score:06.3f} {name}"
def sortFunction(tuple):
key = str('%06.3f ' % tuple[0])
value = tuple[1]
return (key + ' ' + value)
# 初始化并返回 ratingsRDD, moviesRDD
def init_rdds(sc, hdfs_base='hdfs://master:9000/user/root/als_movie'):
ratings_path = f"{hdfs_base}/ratings.txt"
movies_path = f"{hdfs_base}/movies.dat"
raw_r = sc.textFile(ratings_path).repartition(2)
raw_m = sc.textFile(movies_path)
ratings_rdd = raw_r.map(get_ratings_tuple).cache()
movies_rdd = raw_m.map(get_movie_tuple).cache()
return ratings_rdd, movies_rdd
def init_rdds(sc):
ratingsFilename = "hdfs://master:9000/user/root/als_movie/ratings.txt"
moviesFilename = "hdfs://master:9000/user/root/als_movie/movies.dat"
if __name__ == '__main__':
conf = SparkConf().setMaster('spark://master:7077').setAppName('als_movie')
numPartitions = 2
rawRatings = sc.textFile(ratingsFilename).repartition(numPartitions)
rawMovies = sc.textFile(moviesFilename)
ratingsRDD = rawRatings.map(get_ratings_tuple).cache()
moviesRDD = rawMovies.map(get_movie_tuple).cache()
return ratingsRDD, moviesRDD
if __name__ == "__main__":
import sys, os
os.environ["PYSPARK_PYTHON"] = "/usr/bin/python3"
os.environ["PYSPARK_DRIVER_PYTHON"] = "/usr/bin/python3"
conf = SparkConf().setMaster("spark://master:7077").setAppName("als_movie")
#连接到Spark独立集群的主节点Master地址为master:7077;任务名als_movie
sc = SparkContext.getOrCreate(conf)
sc.setLogLevel('ERROR')
sc.setLogLevel("ERROR")
rdd_ratings, rdd_movies = init_rdds(sc)
print(f"Ratings count: {rdd_ratings.count()}")
print(f"Movies count: {rdd_movies.count()}")
sc.stop()
ratingsRDD, moviesRDD = init_rdds(sc)
ratingsCount = ratingsRDD.count()
moviesCount = moviesRDD.count()
print('There are %s ratings and %s movies in the datasets' % (ratingsCount, moviesCount))
print('Ratings: %s' % ratingsRDD.take(3))
print('Movies: %s' % moviesRDD.take(3))
tmp1 = [(1, u'alpha'), (2, u'alpha'), (2, u'beta'), (3, u'alpha'), (1, u'epsilon'), (1, u'delta')]
tmp2 = [(1, u'delta'), (2, u'alpha'), (2, u'beta'), (3, u'alpha'), (1, u'epsilon'), (1, u'alpha')]
oneRDD = sc.parallelize(tmp1)
twoRDD = sc.parallelize(tmp2)
oneSorted = oneRDD.sortByKey(True).collect()
twoSorted = twoRDD.sortByKey(True).collect()
print(oneSorted)
print(twoSorted)
assert set(oneSorted) == set(twoSorted)
assert twoSorted[0][0] < twoSorted.pop()[0]
assert oneSorted[0:2] != twoSorted[0:2]
print(oneRDD.sortBy(sortFunction, True).collect())
print(twoRDD.sortBy(sortFunction, True).collect())
oneSorted1 = oneRDD.takeOrdered(oneRDD.count(), key=sortFunction)
twoSorted1 = twoRDD.takeOrdered(twoRDD.count(), key=sortFunction)
print('one is %s' % oneSorted1)
print('two is %s' % twoSorted1)
assert oneSorted1 == twoSorted1
sc.stop()