PowerPlant-analysis/Test/energy-prediction.py
fly6516 0bf372d423 feat(Test): 优化能源预测模型并添加交叉验证
- 添加数据清理步骤,移除异常值- 对特征进行标准化处理
- 引入交叉验证评估模型性能
- 增加可视化交叉验证结果的图表
- 优化代码结构,提高可读性和可维护性
2025-02-24 02:08:05 +08:00

138 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import zipfile
import os
import pandas as pd
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler, LabelEncoder
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
def main():
# 设置支持中文的字体例如SimHei 字体)
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 或使用 'Microsoft YaHei'
matplotlib.rcParams['axes.unicode_minus'] = False # 防止负号显示为方块
# 获取数据资源 URL选择最新版本的 ZIP 文件)
url = "https://datasets.wri.org/private-admin/dataset/53623dfd-3df6-4f15-a091-67457cdb571f/resource/66bcdacc-3d0e-46ad-9271-a5a76b1853d2/download/globalpowerplantdatabasev130.zip"
# 下载 ZIP 文件
response = requests.get(url)
zip_path = "global_power_plant_data.zip"
with open(zip_path, 'wb') as file:
file.write(response.content)
# 解压 ZIP 文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall("data")
# 确保解压后的文件名和路径
csv_file = "data/global_power_plant_database.csv"
if not os.path.exists(csv_file):
print(f"错误: 文件 {csv_file} 不存在!")
return
# 加载解压后的 CSV 数据
data = pd.read_csv(csv_file, dtype={'other_fuel3': str})
# 数据预处理
data.fillna(0, inplace=True) # 用 0 填充缺失值
# 确保 'owner' 列的类型统一为字符串类型
data['owner'] = data['owner'].astype(str)
# 对分类列进行编码(例如 primary_fuel, owner
label_encoder = LabelEncoder()
data['primary_fuel'] = label_encoder.fit_transform(data['primary_fuel'])
data['owner'] = label_encoder.fit_transform(data['owner'])
# 确保发电量列存在并计算总发电量
generation_columns = ['generation_gwh_2013', 'generation_gwh_2014', 'generation_gwh_2015',
'generation_gwh_2016', 'generation_gwh_2017']
# 确保所有发电量列都存在
missing_cols = [col for col in generation_columns if col not in data.columns]
if missing_cols:
print(f"警告: 缺少以下列: {', '.join(missing_cols)}")
return
# 聚合不同年份的发电数据
data['total_generation'] = data[generation_columns].sum(axis=1)
# 清理数据,移除异常值
data = data[(data['capacity_mw'] > 0) & (data['capacity_mw'] < 5000)] # 假设容量过大或过小的数据无效
data = data[data['total_generation'] >= 0] # 确保发电量是正数
# 再次检查数据的统计信息
print(data.describe())
# 选择特征X和目标变量y
X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']]
y = data['generation_gwh_2017'] # 预测目标2017年的发电量
# 将数据分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 特征标准化
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 初始化并训练随机森林回归模型
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train_scaled, y_train)
# 在测试集上进行预测
y_pred = model.predict(X_test_scaled)
# 评估模型
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(f"均方根误差RMSE{rmse}")
# 可视化预测值与实际值,使用不同颜色标记
plt.scatter(y_test, y_pred, color='blue', label='预测发电量', alpha=0.6)
plt.scatter(y_test, y_test, color='red', label='实际发电量', alpha=0.6)
plt.xlabel('实际发电量 (GWh)')
plt.ylabel('预测发电量 (GWh)')
plt.title('实际 vs 预测发电量')
plt.legend()
plt.show()
# 交叉验证例如使用10折交叉验证
cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=10, scoring='neg_mean_squared_error')
# 检查是否存在 NaN 值
if np.any(np.isnan(cv_scores)):
print("警告: 交叉验证中发现 NaN 值,可能由数据问题导致。")
# 输出交叉验证的每折得分RMSE
print("交叉验证的每折结果(负均方误差):")
for i, score in enumerate(cv_scores, 1):
if np.isnan(score):
print(f"{i}: 无效得分")
else:
print(f"score: {score}")
print(f"{i}: {(-score) ** 0.5:.4f} RMSE")
# 输出交叉验证的平均RMSE
mean_rmse = (-cv_scores.mean()) ** 0.5
print(f"交叉验证的平均RMSE{mean_rmse:.4f}")
# 可视化交叉验证结果
plt.figure(figsize=(8, 6))
plt.plot(range(1, 11), -cv_scores, marker='o', label='每折负均方误差')
plt.xlabel('折数')
plt.ylabel('负均方误差')
plt.title('交叉验证结果(每折负均方误差)')
plt.xticks(range(1, 11))
plt.legend()
plt.grid(True)
plt.show()
if __name__ == "__main__":
main()