From 1eb8f57c772cb6ed4f66207c27c09bd57a8d4a25 Mon Sep 17 00:00:00 2001 From: fly6516 Date: Wed, 26 Mar 2025 16:16:15 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=86=B3=E7=AD=96?= =?UTF-8?q?=E6=A0=91=E7=AE=97=E6=B3=95=E5=AE=9E=E7=8E=B0=E5=8F=8A=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E9=9B=86=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 decision_tree_algorithms.py 文件,实现 ID3、C4.5 和 CART 决策树算法 - 新增 wine_dataset_training.py 文件,使用 Wine 数据集进行训练和可视化- 使用 Iris 和 Wine 数据集进行十折交叉验证,比较算法性能 - 实现决策树的可视化功能 --- decision_tree_algorithms.py | 77 +++++++++++++++++++++++++++++++++++++ wine_dataset_training.py | 77 +++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 decision_tree_algorithms.py create mode 100644 wine_dataset_training.py diff --git a/decision_tree_algorithms.py b/decision_tree_algorithms.py new file mode 100644 index 0000000..0690f05 --- /dev/null +++ b/decision_tree_algorithms.py @@ -0,0 +1,77 @@ +# 导入必要的库 +from ucimlrepo import fetch_ucirepo # 用于从UCI机器学习库中获取数据集 +from sklearn.model_selection import cross_val_score # 用于交叉验证 +from sklearn.tree import DecisionTreeClassifier # 用于构建决策树分类器 +from sklearn import tree # 用于可视化决策树 +import matplotlib.pyplot as plt # 用于绘图 + +# 获取Iris数据集 +iris = fetch_ucirepo(id=53) # 获取Iris数据集,id=53表示Iris数据集 + +# 数据 (作为pandas数据框) +X = iris.data.features # 特征数据,包含Iris数据集的四个特征 +y = iris.data.targets # 目标标签,包含Iris数据集的类别标签 + +# 元数据 +print(iris.metadata) # 打印数据集的元数据信息 + +# 变量信息 +print(iris.variables) # 打印数据集的变量信息 + +# 定义ID3算法 +def id3_algorithm(): + # 使用信息增益作为划分标准 + clf = DecisionTreeClassifier(criterion="entropy") # 创建决策树分类器,使用信息增益作为划分标准 + return clf + +# 定义C4.5算法 +def c45_algorithm(): + # 使用信息增益比作为划分标准 + clf = DecisionTreeClassifier(criterion="entropy", splitter="best") # 创建决策树分类器,使用信息增益比作为划分标准 + return clf + +# 定义CART算法 +def cart_algorithm(): + # 使用基尼指数作为划分标准 + clf = DecisionTreeClassifier(criterion="gini") # 创建决策树分类器,使用基尼指数作为划分标准 + return clf + +# 十折交叉验证 +def cross_validation(clf, X, y): + scores = cross_val_score(clf, X, y, cv=10) # 进行十折交叉验证,返回每折的准确率 + return scores.mean(), scores.std() # 返回平均准确率和标准差 + +# 比较三个算法的精度和速度 +def compare_algorithms(): + algorithms = { + "ID3": id3_algorithm(), # ID3算法 + "C4.5": c45_algorithm(), # C4.5算法 + "CART": cart_algorithm() # CART算法 + } + + results = {} + for name, clf in algorithms.items(): + mean_score, std_score = cross_validation(clf, X, y) # 对每个算法进行交叉验证 + results[name] = (mean_score, std_score) # 存储结果 + print(f"{name} - Mean Accuracy: {mean_score}, Std: {std_score}") # 打印每个算法的平均准确率和标准差 + + return results + +# 可视化决策树 +def visualize_tree(clf, feature_names, class_names): + plt.figure(figsize=(12,8)) # 设置图像大小 + # 手动指定类别名称列表 + class_names_list = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] # 手动指定类别名称列表 + tree.plot_tree(clf, filled=True, feature_names=feature_names, class_names=class_names_list) # 绘制决策树 + plt.show() # 显示图像 + +# 主函数 +if __name__ == "__main__": + # 比较算法 + results = compare_algorithms() # 比较三个算法的性能 + + # 可视化ID3算法的决策树 + clf = id3_algorithm() # 创建ID3算法的决策树分类器 + clf.fit(X, y) # 训练模型 + # 使用手动指定的类别名称列表 + visualize_tree(clf, iris.data.features.columns, ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']) # 可视化决策树 \ No newline at end of file diff --git a/wine_dataset_training.py b/wine_dataset_training.py new file mode 100644 index 0000000..546ac6d --- /dev/null +++ b/wine_dataset_training.py @@ -0,0 +1,77 @@ +# 导入必要的库 +from ucimlrepo import fetch_ucirepo # 用于从UCI机器学习库中获取数据集 +from sklearn.model_selection import cross_val_score # 用于交叉验证 +from sklearn.tree import DecisionTreeClassifier # 用于构建决策树分类器 +from sklearn import tree # 用于可视化决策树 +import matplotlib.pyplot as plt # 用于绘图 + +# 获取Wine数据集 +wine = fetch_ucirepo(id=109) # 获取Wine数据集,id=109表示Wine数据集 + +# 数据 (作为pandas数据框) +X = wine.data.features # 特征数据,包含Wine数据集的13个特征 +y = wine.data.targets # 目标标签,包含Wine数据集的类别标签 + +# 元数据 +print(wine.metadata) # 打印数据集的元数据信息 + +# 变量信息 +print(wine.variables) # 打印数据集的变量信息 + +# 定义ID3算法 +def id3_algorithm(): + # 使用信息增益作为划分标准 + clf = DecisionTreeClassifier(criterion="entropy") # 创建决策树分类器,使用信息增益作为划分标准 + return clf + +# 定义C4.5算法 +def c45_algorithm(): + # 使用信息增益比作为划分标准 + clf = DecisionTreeClassifier(criterion="entropy", splitter="best") # 创建决策树分类器,使用信息增益比作为划分标准 + return clf + +# 定义CART算法 +def cart_algorithm(): + # 使用基尼指数作为划分标准 + clf = DecisionTreeClassifier(criterion="gini") # 创建决策树分类器,使用基尼指数作为划分标准 + return clf + +# 十折交叉验证 +def cross_validation(clf, X, y): + scores = cross_val_score(clf, X, y, cv=10) # 进行十折交叉验证,返回每折的准确率 + return scores.mean(), scores.std() # 返回平均准确率和标准差 + +# 比较三个算法的精度和速度 +def compare_algorithms(): + algorithms = { + "ID3": id3_algorithm(), # ID3算法 + "C4.5": c45_algorithm(), # C4.5算法 + "CART": cart_algorithm() # CART算法 + } + + results = {} + for name, clf in algorithms.items(): + mean_score, std_score = cross_validation(clf, X, y) # 对每个算法进行交叉验证 + results[name] = (mean_score, std_score) # 存储结果 + print(f"{name} - Mean Accuracy: {mean_score}, Std: {std_score}") # 打印每个算法的平均准确率和标准差 + + return results + +# 可视化决策树 +def visualize_tree(clf, feature_names, class_names): + plt.figure(figsize=(12,8)) # 设置图像大小 + # 手动指定类别名称列表 + class_names_list = ['Class 1', 'Class 2', 'Class 3'] # 手动指定类别名称列表 + tree.plot_tree(clf, filled=True, feature_names=feature_names, class_names=class_names_list) # 绘制决策树 + plt.show() # 显示图像 + +# 主函数 +if __name__ == "__main__": + # 比较算法 + results = compare_algorithms() # 比较三个算法的性能 + + # 可视化ID3算法的决策树 + clf = id3_algorithm() # 创建ID3算法的决策树分类器 + clf.fit(X, y) # 训练模型 + # 使用手动指定的类别名称列表 + visualize_tree(clf, wine.data.features.columns, ['Class 1', 'Class 2', 'Class 3']) # 可视化决策树 \ No newline at end of file