From 0bf372d4238f85e8cc89656d6a3abd3e1f435669 Mon Sep 17 00:00:00 2001 From: fly6516 Date: Mon, 24 Feb 2025 02:08:05 +0800 Subject: [PATCH] =?UTF-8?q?feat(Test):=20=E4=BC=98=E5=8C=96=E8=83=BD?= =?UTF-8?q?=E6=BA=90=E9=A2=84=E6=B5=8B=E6=A8=A1=E5=9E=8B=E5=B9=B6=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E4=BA=A4=E5=8F=89=E9=AA=8C=E8=AF=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加数据清理步骤,移除异常值- 对特征进行标准化处理 - 引入交叉验证评估模型性能 - 增加可视化交叉验证结果的图表 - 优化代码结构,提高可读性和可维护性 --- Test/energy-prediction.py | 55 ++++++++++++++++++++++++++++++++++----- 1 file changed, 49 insertions(+), 6 deletions(-) diff --git a/Test/energy-prediction.py b/Test/energy-prediction.py index 1e0d935..52dcaa0 100644 --- a/Test/energy-prediction.py +++ b/Test/energy-prediction.py @@ -2,12 +2,14 @@ import requests import zipfile import os import pandas as pd -from sklearn.model_selection import train_test_split +from sklearn.model_selection import train_test_split, cross_val_score from sklearn.ensemble import RandomForestRegressor from sklearn.metrics import mean_squared_error -from sklearn.preprocessing import LabelEncoder +from sklearn.preprocessing import StandardScaler, LabelEncoder import matplotlib.pyplot as plt import matplotlib +import numpy as np + def main(): # 设置支持中文的字体(例如:SimHei 字体) @@ -35,7 +37,6 @@ def main(): # 加载解压后的 CSV 数据 data = pd.read_csv(csv_file, dtype={'other_fuel3': str}) - # data = pd.read_csv(csv_file) # 数据预处理 data.fillna(0, inplace=True) # 用 0 填充缺失值 @@ -61,19 +62,31 @@ def main(): # 聚合不同年份的发电数据 data['total_generation'] = data[generation_columns].sum(axis=1) + # 清理数据,移除异常值 + data = data[(data['capacity_mw'] > 0) & (data['capacity_mw'] < 5000)] # 假设容量过大或过小的数据无效 + data = data[data['total_generation'] >= 0] # 确保发电量是正数 + + # 再次检查数据的统计信息 + print(data.describe()) + # 选择特征(X)和目标变量(y) - X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']] # 示例特征 + X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']] y = data['generation_gwh_2017'] # 预测目标:2017年的发电量 # 将数据分割为训练集和测试集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + # 特征标准化 + scaler = StandardScaler() + X_train_scaled = scaler.fit_transform(X_train) + X_test_scaled = scaler.transform(X_test) + # 初始化并训练随机森林回归模型 model = RandomForestRegressor(n_estimators=100, random_state=42) - model.fit(X_train, y_train) + model.fit(X_train_scaled, y_train) # 在测试集上进行预测 - y_pred = model.predict(X_test) + y_pred = model.predict(X_test_scaled) # 评估模型 mse = mean_squared_error(y_test, y_pred) @@ -89,6 +102,36 @@ def main(): plt.legend() plt.show() + # 交叉验证(例如使用10折交叉验证) + cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=10, scoring='neg_mean_squared_error') + + # 检查是否存在 NaN 值 + if np.any(np.isnan(cv_scores)): + print("警告: 交叉验证中发现 NaN 值,可能由数据问题导致。") + + # 输出交叉验证的每折得分(RMSE) + print("交叉验证的每折结果(负均方误差):") + for i, score in enumerate(cv_scores, 1): + if np.isnan(score): + print(f"折 {i}: 无效得分") + else: + print(f"score: {score}") + print(f"折 {i}: {(-score) ** 0.5:.4f} RMSE") + + # 输出交叉验证的平均RMSE + mean_rmse = (-cv_scores.mean()) ** 0.5 + print(f"交叉验证的平均RMSE:{mean_rmse:.4f}") + + # 可视化交叉验证结果 + plt.figure(figsize=(8, 6)) + plt.plot(range(1, 11), -cv_scores, marker='o', label='每折负均方误差') + plt.xlabel('折数') + plt.ylabel('负均方误差') + plt.title('交叉验证结果(每折负均方误差)') + plt.xticks(range(1, 11)) + plt.legend() + plt.grid(True) + plt.show() if __name__ == "__main__": main()