ML-exp-1/experiment.py
fly6516 1d2138e0d2 feat: 实现 C4.5 决策树算法并进行实验
- 实现了 C4.5 决策树算法,包括信息熵、信息增益、树的构建和预测等功能
- 添加了实验脚本,使用 Iris 和 Wine Quality 数据集进行性能比较- 生成了实验报告,总结了 C4.5 算法与 Logistic 回归的性能差异
2025-03-07 11:25:27 +08:00

60 lines
2.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 导入必要的库
from ucimlrepo import fetch_ucirepo # 用于加载UCI数据集
from sklearn.model_selection import cross_val_score, StratifiedKFold # 用于交叉验证
from sklearn.linear_model import LogisticRegression # 用于Logistic回归算法
from sklearn.preprocessing import StandardScaler # 用于数据标准化
from c45_algorithm import C45DecisionTree # 自定义的C4.5算法实现
import time # 用于计时
import pandas as pd # 用于数据处理
# 加载数据集
def load_dataset(dataset_id):
dataset = fetch_ucirepo(id=dataset_id) # 根据ID加载数据集
X = dataset.data.features.values # 提取特征并转换为NumPy数组
y = dataset.data.targets.values.ravel() # 提取标签并转换为一维数组
y, _ = pd.factorize(y) # 将字符串标签映射为整数标签
return X, y # 返回特征和标签
# 比较算法性能
def compare_algorithms(X, y, algorithm_name):
print(f"正在比较算法性能:{algorithm_name}") # 打印当前数据集名称
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 初始化算法
c45 = C45DecisionTree() # 初始化C4.5决策树
lr = LogisticRegression(max_iter=5000) # 初始化Logistic回归模型设置最大迭代次数为5000
# 十折交叉验证
if algorithm_name == "Wine Quality":
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) # 定义分层五折交叉验证
else:
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) # 定义分层十折交叉验证
# C4.5算法
start_time = time.time() # 记录开始时间
c45_scores = cross_val_score(c45, X_scaled, y, cv=skf, scoring='accuracy') # 计算C4.5算法的交叉验证精度
c45_time = time.time() - start_time # 计算运行时间
# Logistic回归算法
start_time = time.time() # 记录开始时间
lr_scores = cross_val_score(lr, X_scaled, y, cv=skf, scoring='accuracy') # 计算Logistic回归的交叉验证精度
lr_time = time.time() - start_time # 计算运行时间
# 输出结果
print(f"C4.5算法 - 平均精度: {c45_scores.mean():.4f}, 平均时间: {c45_time:.4f}") # 打印C4.5算法的结果
print(f"Logistic回归 - 平均精度: {lr_scores.mean():.4f}, 平均时间: {lr_time:.4f}") # 打印Logistic回归的结果
# 主函数
if __name__ == "__main__":
# Iris数据集实验
print("Iris数据集实验:") # 打印实验标题
X_iris, y_iris = load_dataset(53) # 加载Iris数据集
compare_algorithms(X_iris, y_iris, "Iris") # 比较算法性能
# Wine Quality数据集实验
print("\nWine Quality数据集实验:") # 打印实验标题
X_wine, y_wine = load_dataset(186) # 加载Wine Quality数据集
compare_algorithms(X_wine, y_wine, "Wine Quality") # 比较算法性能