feat(Test): 优化能源预测模型并添加交叉验证

- 添加数据清理步骤,移除异常值- 对特征进行标准化处理
- 引入交叉验证评估模型性能
- 增加可视化交叉验证结果的图表
- 优化代码结构,提高可读性和可维护性
This commit is contained in:
fly6516 2025-02-24 02:08:05 +08:00
parent 6756de70cc
commit 0bf372d423

View File

@ -2,12 +2,14 @@ import requests
import zipfile
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import LabelEncoder
from sklearn.preprocessing import StandardScaler, LabelEncoder
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
def main():
# 设置支持中文的字体例如SimHei 字体)
@ -35,7 +37,6 @@ def main():
# 加载解压后的 CSV 数据
data = pd.read_csv(csv_file, dtype={'other_fuel3': str})
# data = pd.read_csv(csv_file)
# 数据预处理
data.fillna(0, inplace=True) # 用 0 填充缺失值
@ -61,19 +62,31 @@ def main():
# 聚合不同年份的发电数据
data['total_generation'] = data[generation_columns].sum(axis=1)
# 清理数据,移除异常值
data = data[(data['capacity_mw'] > 0) & (data['capacity_mw'] < 5000)] # 假设容量过大或过小的数据无效
data = data[data['total_generation'] >= 0] # 确保发电量是正数
# 再次检查数据的统计信息
print(data.describe())
# 选择特征X和目标变量y
X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']] # 示例特征
X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']]
y = data['generation_gwh_2017'] # 预测目标2017年的发电量
# 将数据分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 特征标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 初始化并训练随机森林回归模型
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
model.fit(X_train_scaled, y_train)
# 在测试集上进行预测
y_pred = model.predict(X_test)
y_pred = model.predict(X_test_scaled)
# 评估模型
mse = mean_squared_error(y_test, y_pred)
@ -89,6 +102,36 @@ def main():
plt.legend()
plt.show()
# 交叉验证例如使用10折交叉验证
cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=10, scoring='neg_mean_squared_error')
# 检查是否存在 NaN 值
if np.any(np.isnan(cv_scores)):
print("警告: 交叉验证中发现 NaN 值,可能由数据问题导致。")
# 输出交叉验证的每折得分RMSE
print("交叉验证的每折结果(负均方误差):")
for i, score in enumerate(cv_scores, 1):
if np.isnan(score):
print(f"{i}: 无效得分")
else:
print(f"score: {score}")
print(f"{i}: {(-score) ** 0.5:.4f} RMSE")
# 输出交叉验证的平均RMSE
mean_rmse = (-cv_scores.mean()) ** 0.5
print(f"交叉验证的平均RMSE{mean_rmse:.4f}")
# 可视化交叉验证结果
plt.figure(figsize=(8, 6))
plt.plot(range(1, 11), -cv_scores, marker='o', label='每折负均方误差')
plt.xlabel('折数')
plt.ylabel('负均方误差')
plt.title('交叉验证结果(每折负均方误差)')
plt.xticks(range(1, 11))
plt.legend()
plt.grid(True)
plt.show()
if __name__ == "__main__":
main()