feat: 添加决策树算法实现及数据集训练
- 新增 decision_tree_algorithms.py 文件,实现 ID3、C4.5 和 CART 决策树算法 - 新增 wine_dataset_training.py 文件,使用 Wine 数据集进行训练和可视化- 使用 Iris 和 Wine 数据集进行十折交叉验证,比较算法性能 - 实现决策树的可视化功能
This commit is contained in:
commit
1eb8f57c77
77
decision_tree_algorithms.py
Normal file
77
decision_tree_algorithms.py
Normal file
@ -0,0 +1,77 @@
|
||||
# 导入必要的库
|
||||
from ucimlrepo import fetch_ucirepo # 用于从UCI机器学习库中获取数据集
|
||||
from sklearn.model_selection import cross_val_score # 用于交叉验证
|
||||
from sklearn.tree import DecisionTreeClassifier # 用于构建决策树分类器
|
||||
from sklearn import tree # 用于可视化决策树
|
||||
import matplotlib.pyplot as plt # 用于绘图
|
||||
|
||||
# 获取Iris数据集
|
||||
iris = fetch_ucirepo(id=53) # 获取Iris数据集,id=53表示Iris数据集
|
||||
|
||||
# 数据 (作为pandas数据框)
|
||||
X = iris.data.features # 特征数据,包含Iris数据集的四个特征
|
||||
y = iris.data.targets # 目标标签,包含Iris数据集的类别标签
|
||||
|
||||
# 元数据
|
||||
print(iris.metadata) # 打印数据集的元数据信息
|
||||
|
||||
# 变量信息
|
||||
print(iris.variables) # 打印数据集的变量信息
|
||||
|
||||
# 定义ID3算法
|
||||
def id3_algorithm():
|
||||
# 使用信息增益作为划分标准
|
||||
clf = DecisionTreeClassifier(criterion="entropy") # 创建决策树分类器,使用信息增益作为划分标准
|
||||
return clf
|
||||
|
||||
# 定义C4.5算法
|
||||
def c45_algorithm():
|
||||
# 使用信息增益比作为划分标准
|
||||
clf = DecisionTreeClassifier(criterion="entropy", splitter="best") # 创建决策树分类器,使用信息增益比作为划分标准
|
||||
return clf
|
||||
|
||||
# 定义CART算法
|
||||
def cart_algorithm():
|
||||
# 使用基尼指数作为划分标准
|
||||
clf = DecisionTreeClassifier(criterion="gini") # 创建决策树分类器,使用基尼指数作为划分标准
|
||||
return clf
|
||||
|
||||
# 十折交叉验证
|
||||
def cross_validation(clf, X, y):
|
||||
scores = cross_val_score(clf, X, y, cv=10) # 进行十折交叉验证,返回每折的准确率
|
||||
return scores.mean(), scores.std() # 返回平均准确率和标准差
|
||||
|
||||
# 比较三个算法的精度和速度
|
||||
def compare_algorithms():
|
||||
algorithms = {
|
||||
"ID3": id3_algorithm(), # ID3算法
|
||||
"C4.5": c45_algorithm(), # C4.5算法
|
||||
"CART": cart_algorithm() # CART算法
|
||||
}
|
||||
|
||||
results = {}
|
||||
for name, clf in algorithms.items():
|
||||
mean_score, std_score = cross_validation(clf, X, y) # 对每个算法进行交叉验证
|
||||
results[name] = (mean_score, std_score) # 存储结果
|
||||
print(f"{name} - Mean Accuracy: {mean_score}, Std: {std_score}") # 打印每个算法的平均准确率和标准差
|
||||
|
||||
return results
|
||||
|
||||
# 可视化决策树
|
||||
def visualize_tree(clf, feature_names, class_names):
|
||||
plt.figure(figsize=(12,8)) # 设置图像大小
|
||||
# 手动指定类别名称列表
|
||||
class_names_list = ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica'] # 手动指定类别名称列表
|
||||
tree.plot_tree(clf, filled=True, feature_names=feature_names, class_names=class_names_list) # 绘制决策树
|
||||
plt.show() # 显示图像
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
# 比较算法
|
||||
results = compare_algorithms() # 比较三个算法的性能
|
||||
|
||||
# 可视化ID3算法的决策树
|
||||
clf = id3_algorithm() # 创建ID3算法的决策树分类器
|
||||
clf.fit(X, y) # 训练模型
|
||||
# 使用手动指定的类别名称列表
|
||||
visualize_tree(clf, iris.data.features.columns, ['Iris-setosa', 'Iris-versicolor', 'Iris-virginica']) # 可视化决策树
|
77
wine_dataset_training.py
Normal file
77
wine_dataset_training.py
Normal file
@ -0,0 +1,77 @@
|
||||
# 导入必要的库
|
||||
from ucimlrepo import fetch_ucirepo # 用于从UCI机器学习库中获取数据集
|
||||
from sklearn.model_selection import cross_val_score # 用于交叉验证
|
||||
from sklearn.tree import DecisionTreeClassifier # 用于构建决策树分类器
|
||||
from sklearn import tree # 用于可视化决策树
|
||||
import matplotlib.pyplot as plt # 用于绘图
|
||||
|
||||
# 获取Wine数据集
|
||||
wine = fetch_ucirepo(id=109) # 获取Wine数据集,id=109表示Wine数据集
|
||||
|
||||
# 数据 (作为pandas数据框)
|
||||
X = wine.data.features # 特征数据,包含Wine数据集的13个特征
|
||||
y = wine.data.targets # 目标标签,包含Wine数据集的类别标签
|
||||
|
||||
# 元数据
|
||||
print(wine.metadata) # 打印数据集的元数据信息
|
||||
|
||||
# 变量信息
|
||||
print(wine.variables) # 打印数据集的变量信息
|
||||
|
||||
# 定义ID3算法
|
||||
def id3_algorithm():
|
||||
# 使用信息增益作为划分标准
|
||||
clf = DecisionTreeClassifier(criterion="entropy") # 创建决策树分类器,使用信息增益作为划分标准
|
||||
return clf
|
||||
|
||||
# 定义C4.5算法
|
||||
def c45_algorithm():
|
||||
# 使用信息增益比作为划分标准
|
||||
clf = DecisionTreeClassifier(criterion="entropy", splitter="best") # 创建决策树分类器,使用信息增益比作为划分标准
|
||||
return clf
|
||||
|
||||
# 定义CART算法
|
||||
def cart_algorithm():
|
||||
# 使用基尼指数作为划分标准
|
||||
clf = DecisionTreeClassifier(criterion="gini") # 创建决策树分类器,使用基尼指数作为划分标准
|
||||
return clf
|
||||
|
||||
# 十折交叉验证
|
||||
def cross_validation(clf, X, y):
|
||||
scores = cross_val_score(clf, X, y, cv=10) # 进行十折交叉验证,返回每折的准确率
|
||||
return scores.mean(), scores.std() # 返回平均准确率和标准差
|
||||
|
||||
# 比较三个算法的精度和速度
|
||||
def compare_algorithms():
|
||||
algorithms = {
|
||||
"ID3": id3_algorithm(), # ID3算法
|
||||
"C4.5": c45_algorithm(), # C4.5算法
|
||||
"CART": cart_algorithm() # CART算法
|
||||
}
|
||||
|
||||
results = {}
|
||||
for name, clf in algorithms.items():
|
||||
mean_score, std_score = cross_validation(clf, X, y) # 对每个算法进行交叉验证
|
||||
results[name] = (mean_score, std_score) # 存储结果
|
||||
print(f"{name} - Mean Accuracy: {mean_score}, Std: {std_score}") # 打印每个算法的平均准确率和标准差
|
||||
|
||||
return results
|
||||
|
||||
# 可视化决策树
|
||||
def visualize_tree(clf, feature_names, class_names):
|
||||
plt.figure(figsize=(12,8)) # 设置图像大小
|
||||
# 手动指定类别名称列表
|
||||
class_names_list = ['Class 1', 'Class 2', 'Class 3'] # 手动指定类别名称列表
|
||||
tree.plot_tree(clf, filled=True, feature_names=feature_names, class_names=class_names_list) # 绘制决策树
|
||||
plt.show() # 显示图像
|
||||
|
||||
# 主函数
|
||||
if __name__ == "__main__":
|
||||
# 比较算法
|
||||
results = compare_algorithms() # 比较三个算法的性能
|
||||
|
||||
# 可视化ID3算法的决策树
|
||||
clf = id3_algorithm() # 创建ID3算法的决策树分类器
|
||||
clf.fit(X, y) # 训练模型
|
||||
# 使用手动指定的类别名称列表
|
||||
visualize_tree(clf, wine.data.features.columns, ['Class 1', 'Class 2', 'Class 3']) # 可视化决策树
|
Loading…
Reference in New Issue
Block a user