ML-exp-1/experiment.py

60 lines
2.8 KiB
Python
Raw Permalink Normal View History

# 导入必要的库
from ucimlrepo import fetch_ucirepo # 用于加载UCI数据集
from sklearn.model_selection import cross_val_score, StratifiedKFold # 用于交叉验证
from sklearn.linear_model import LogisticRegression # 用于Logistic回归算法
from sklearn.preprocessing import StandardScaler # 用于数据标准化
from c45_algorithm import C45DecisionTree # 自定义的C4.5算法实现
import time # 用于计时
import pandas as pd # 用于数据处理
# 加载数据集
def load_dataset(dataset_id):
dataset = fetch_ucirepo(id=dataset_id) # 根据ID加载数据集
X = dataset.data.features.values # 提取特征并转换为NumPy数组
y = dataset.data.targets.values.ravel() # 提取标签并转换为一维数组
y, _ = pd.factorize(y) # 将字符串标签映射为整数标签
return X, y # 返回特征和标签
# 比较算法性能
def compare_algorithms(X, y, algorithm_name):
print(f"正在比较算法性能:{algorithm_name}") # 打印当前数据集名称
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 初始化算法
c45 = C45DecisionTree() # 初始化C4.5决策树
lr = LogisticRegression(max_iter=5000) # 初始化Logistic回归模型设置最大迭代次数为5000
# 十折交叉验证
if algorithm_name == "Wine Quality":
skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42) # 定义分层五折交叉验证
else:
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42) # 定义分层十折交叉验证
# C4.5算法
start_time = time.time() # 记录开始时间
c45_scores = cross_val_score(c45, X_scaled, y, cv=skf, scoring='accuracy') # 计算C4.5算法的交叉验证精度
c45_time = time.time() - start_time # 计算运行时间
# Logistic回归算法
start_time = time.time() # 记录开始时间
lr_scores = cross_val_score(lr, X_scaled, y, cv=skf, scoring='accuracy') # 计算Logistic回归的交叉验证精度
lr_time = time.time() - start_time # 计算运行时间
# 输出结果
print(f"C4.5算法 - 平均精度: {c45_scores.mean():.4f}, 平均时间: {c45_time:.4f}") # 打印C4.5算法的结果
print(f"Logistic回归 - 平均精度: {lr_scores.mean():.4f}, 平均时间: {lr_time:.4f}") # 打印Logistic回归的结果
# 主函数
if __name__ == "__main__":
# Iris数据集实验
print("Iris数据集实验:") # 打印实验标题
X_iris, y_iris = load_dataset(53) # 加载Iris数据集
compare_algorithms(X_iris, y_iris, "Iris") # 比较算法性能
# Wine Quality数据集实验
print("\nWine Quality数据集实验:") # 打印实验标题
X_wine, y_wine = load_dataset(186) # 加载Wine Quality数据集
compare_algorithms(X_wine, y_wine, "Wine Quality") # 比较算法性能