DM-exp-3/decision_tree_algorithms.py
fly6516 1eb8f57c77 feat: 添加决策树算法实现及数据集训练
- 新增 decision_tree_algorithms.py 文件,实现 ID3、C4.5 和 CART 决策树算法
- 新增 wine_dataset_training.py 文件,使用 Wine 数据集进行训练和可视化- 使用 Iris 和 Wine 数据集进行十折交叉验证,比较算法性能
- 实现决策树的可视化功能
2025-03-26 16:16:15 +08:00

77 lines
3.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 导入必要的库
from ucimlrepo import fetch_ucirepo # 用于从UCI机器学习库中获取数据集
from sklearn.model_selection import cross_val_score # 用于交叉验证
from sklearn.tree import DecisionTreeClassifier # 用于构建决策树分类器
from sklearn import tree # 用于可视化决策树
import matplotlib.pyplot as plt # 用于绘图
# 获取Iris数据集
iris = fetch_ucirepo(id=53) # 获取Iris数据集id=53表示Iris数据集
# 数据 (作为pandas数据框)
X = iris.data.features # 特征数据包含Iris数据集的四个特征
y = iris.data.targets # 目标标签包含Iris数据集的类别标签
# 元数据
print(iris.metadata) # 打印数据集的元数据信息
# 变量信息
print(iris.variables) # 打印数据集的变量信息
# 定义ID3算法
def id3_algorithm():
# 使用信息增益作为划分标准
clf = DecisionTreeClassifier(criterion="entropy") # 创建决策树分类器,使用信息增益作为划分标准
return clf
# 定义C4.5算法
def c45_algorithm():
# 使用信息增益比作为划分标准
clf = DecisionTreeClassifier(criterion="entropy", splitter="best") # 创建决策树分类器,使用信息增益比作为划分标准
return clf
# 定义CART算法
def cart_algorithm():
# 使用基尼指数作为划分标准
clf = DecisionTreeClassifier(criterion="gini") # 创建决策树分类器,使用基尼指数作为划分标准
return clf
# 十折交叉验证
def cross_validation(clf, X, y):
scores = cross_val_score(clf, X, y, cv=10) # 进行十折交叉验证,返回每折的准确率
return scores.mean(), scores.std() # 返回平均准确率和标准差
# 比较三个算法的精度和速度
def compare_algorithms():
algorithms = {
"ID3": id3_algorithm(), # ID3算法
"C4.5": c45_algorithm(), # C4.5算法
"CART": cart_algorithm() # CART算法
}
results = {}
for name, clf in algorithms.items():
mean_score, std_score = cross_validation(clf, X, y) # 对每个算法进行交叉验证
results[name] = (mean_score, std_score) # 存储结果
print(f"{name} - Mean Accuracy: {mean_score}, Std: {std_score}") # 打印每个算法的平均准确率和标准差
return results
# 可视化决策树
def visualize_tree(clf, feature_names, class_names):
plt.figure(figsize=(12,8)) # 设置图像大小
# 手动指定类别名称列表
class_names_list = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] # 手动指定类别名称列表
tree.plot_tree(clf, filled=True, feature_names=feature_names, class_names=class_names_list) # 绘制决策树
plt.show() # 显示图像
# 主函数
if __name__ == "__main__":
# 比较算法
results = compare_algorithms() # 比较三个算法的性能
# 可视化ID3算法的决策树
clf = id3_algorithm() # 创建ID3算法的决策树分类器
clf.fit(X, y) # 训练模型
# 使用手动指定的类别名称列表
visualize_tree(clf, iris.data.features.columns, ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']) # 可视化决策树