diff --git a/6-1.py b/6-1.py new file mode 100644 index 0000000..1e92078 --- /dev/null +++ b/6-1.py @@ -0,0 +1,113 @@ +from pyspark import SparkContext +from pyspark.accumulators import AccumulatorParam +import matplotlib.pyplot as plt + +# 创建 SparkContext +sc = SparkContext(appName="TextAnalysis") + +# 假设 similaritiesFullRDD 和 goldStandard 已经存在 +# similaritiesFullRDD: RDD of ((Amazon ID, Google URL), Similarity) +# goldStandard: RDD of ((Amazon ID, Google URL), 1) for true duplicates + +# 创建 simsFullRDD 和 simsFullValuesRDD +simsFullRDD = similaritiesFullRDD.map(lambda x: ("%s %s" % (x[0][0], x[0][1]), x[1])) +simsFullValuesRDD = simsFullRDD.map(lambda x: x[1]).cache() + +# 计算真阳性的相似度 +def gs_value(record): + if record[1][1] is None: + return 0 + else: + return record[1][1] + +trueDupSimsRDD = (goldStandard + .leftOuterJoin(simsFullRDD) + .map(gs_value) + .cache()) +print('There are %s true duplicates.' % trueDupSimsRDD.count()) + +# 定义累加器 +class VectorAccumulatorParam(AccumulatorParam): + def zero(self, value): + return [0] * len(value) + def addInPlace(self, val1, val2): + for i in range(len(val1)): + val1[i] += val2[i] + return val1 + +def set_bit(x, value, length): + bits = [] + for y in range(length): + if x == y: + bits.append(value) + else: + bits.append(0) + return bits + +BINS = 101 +nthresholds = 100 +def bin(similarity): + return int(similarity * nthresholds) + +zeros = [0] * BINS +fpCounts = sc.accumulator(zeros, VectorAccumulatorParam()) + +def add_element(score): + global fpCounts + b = bin(score) + fpCounts += set_bit(b, 1, BINS) + +simsFullValuesRDD.foreach(add_element) + +def sub_element(score): + global fpCounts + b = bin(score) + fpCounts += set_bit(b, -1, BINS) + +trueDupSimsRDD.foreach(sub_element) + +def falsepos(threshold): + fpList = fpCounts.value + return sum([fpList[b] for b in range(0, BINS) if float(b) / nthresholds >= threshold]) + +def falseneg(threshold): + return trueDupSimsRDD.filter(lambda x: x < threshold).count() + +def truepos(threshold): + return trueDupSimsRDD.count() - falseneg(threshold) + +# 计算准确率、召回率和F度量 +def precision(threshold): + tp = truepos(threshold) + return float(tp) / (tp + falsepos(threshold)) + +def recall(threshold): + tp = truepos(threshold) + return float(tp) / (tp + falseneg(threshold)) + +def fmeasure(threshold): + r = recall(threshold) + p = precision(threshold) + if r == 0 and p == 0: + return 1 + else: + return 2 * r * p / (r + p) + +# 生成阈值列表 +thresholds = [float(n) / nthresholds for n in range(0, nthresholds)] + +# 计算准确率、召回率和F度量的值 +precisions = [precision(t) for t in thresholds] +recalls = [recall(t) for t in thresholds] +fmeasures = [fmeasure(t) for t in thresholds] + +# 绘制折线图 +fig = plt.figure() +plt.plot(thresholds, precisions) +plt.plot(thresholds, recalls) +plt.plot(thresholds, fmeasures) +plt.legend(['Precision', 'Recall', 'F-measure']) +plt.show() + +# 停止 SparkContext +sc.stop() \ No newline at end of file