2025-02-23 16:45:15 +00:00
|
|
|
|
import requests
|
|
|
|
|
import zipfile
|
|
|
|
|
import os
|
|
|
|
|
import pandas as pd
|
2025-02-23 18:08:05 +00:00
|
|
|
|
from sklearn.model_selection import train_test_split, cross_val_score
|
2025-02-23 16:45:15 +00:00
|
|
|
|
from sklearn.ensemble import RandomForestRegressor
|
|
|
|
|
from sklearn.metrics import mean_squared_error
|
2025-02-23 18:08:05 +00:00
|
|
|
|
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
2025-02-23 16:45:15 +00:00
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import matplotlib
|
2025-02-23 18:08:05 +00:00
|
|
|
|
import numpy as np
|
|
|
|
|
|
2025-02-23 16:45:15 +00:00
|
|
|
|
|
|
|
|
|
def main():
|
|
|
|
|
# 设置支持中文的字体(例如:SimHei 字体)
|
|
|
|
|
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 或使用 'Microsoft YaHei'
|
|
|
|
|
matplotlib.rcParams['axes.unicode_minus'] = False # 防止负号显示为方块
|
|
|
|
|
|
|
|
|
|
# 获取数据资源 URL(选择最新版本的 ZIP 文件)
|
|
|
|
|
url = "https://datasets.wri.org/private-admin/dataset/53623dfd-3df6-4f15-a091-67457cdb571f/resource/66bcdacc-3d0e-46ad-9271-a5a76b1853d2/download/globalpowerplantdatabasev130.zip"
|
|
|
|
|
|
|
|
|
|
# 下载 ZIP 文件
|
|
|
|
|
response = requests.get(url)
|
|
|
|
|
zip_path = "global_power_plant_data.zip"
|
|
|
|
|
with open(zip_path, 'wb') as file:
|
|
|
|
|
file.write(response.content)
|
|
|
|
|
|
|
|
|
|
# 解压 ZIP 文件
|
|
|
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
|
|
|
|
zip_ref.extractall("data")
|
|
|
|
|
|
|
|
|
|
# 确保解压后的文件名和路径
|
|
|
|
|
csv_file = "data/global_power_plant_database.csv"
|
|
|
|
|
if not os.path.exists(csv_file):
|
|
|
|
|
print(f"错误: 文件 {csv_file} 不存在!")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 加载解压后的 CSV 数据
|
|
|
|
|
data = pd.read_csv(csv_file, dtype={'other_fuel3': str})
|
|
|
|
|
|
|
|
|
|
# 数据预处理
|
|
|
|
|
data.fillna(0, inplace=True) # 用 0 填充缺失值
|
|
|
|
|
|
|
|
|
|
# 确保 'owner' 列的类型统一为字符串类型
|
|
|
|
|
data['owner'] = data['owner'].astype(str)
|
|
|
|
|
|
|
|
|
|
# 对分类列进行编码(例如 primary_fuel, owner)
|
|
|
|
|
label_encoder = LabelEncoder()
|
|
|
|
|
data['primary_fuel'] = label_encoder.fit_transform(data['primary_fuel'])
|
|
|
|
|
data['owner'] = label_encoder.fit_transform(data['owner'])
|
|
|
|
|
|
|
|
|
|
# 确保发电量列存在并计算总发电量
|
|
|
|
|
generation_columns = ['generation_gwh_2013', 'generation_gwh_2014', 'generation_gwh_2015',
|
|
|
|
|
'generation_gwh_2016', 'generation_gwh_2017']
|
|
|
|
|
|
|
|
|
|
# 确保所有发电量列都存在
|
|
|
|
|
missing_cols = [col for col in generation_columns if col not in data.columns]
|
|
|
|
|
if missing_cols:
|
|
|
|
|
print(f"警告: 缺少以下列: {', '.join(missing_cols)}")
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# 聚合不同年份的发电数据
|
|
|
|
|
data['total_generation'] = data[generation_columns].sum(axis=1)
|
|
|
|
|
|
2025-02-23 18:08:05 +00:00
|
|
|
|
# 清理数据,移除异常值
|
|
|
|
|
data = data[(data['capacity_mw'] > 0) & (data['capacity_mw'] < 5000)] # 假设容量过大或过小的数据无效
|
|
|
|
|
data = data[data['total_generation'] >= 0] # 确保发电量是正数
|
|
|
|
|
|
|
|
|
|
# 再次检查数据的统计信息
|
|
|
|
|
print(data.describe())
|
|
|
|
|
|
2025-02-23 16:45:15 +00:00
|
|
|
|
# 选择特征(X)和目标变量(y)
|
2025-02-23 18:08:05 +00:00
|
|
|
|
X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']]
|
2025-02-23 16:45:15 +00:00
|
|
|
|
y = data['generation_gwh_2017'] # 预测目标:2017年的发电量
|
|
|
|
|
|
|
|
|
|
# 将数据分割为训练集和测试集
|
|
|
|
|
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
|
|
|
|
|
2025-02-23 18:08:05 +00:00
|
|
|
|
# 特征标准化
|
|
|
|
|
scaler = StandardScaler()
|
|
|
|
|
X_train_scaled = scaler.fit_transform(X_train)
|
|
|
|
|
X_test_scaled = scaler.transform(X_test)
|
|
|
|
|
|
2025-02-23 16:45:15 +00:00
|
|
|
|
# 初始化并训练随机森林回归模型
|
|
|
|
|
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
2025-02-23 18:08:05 +00:00
|
|
|
|
model.fit(X_train_scaled, y_train)
|
2025-02-23 16:45:15 +00:00
|
|
|
|
|
|
|
|
|
# 在测试集上进行预测
|
2025-02-23 18:08:05 +00:00
|
|
|
|
y_pred = model.predict(X_test_scaled)
|
2025-02-23 16:45:15 +00:00
|
|
|
|
|
|
|
|
|
# 评估模型
|
|
|
|
|
mse = mean_squared_error(y_test, y_pred)
|
|
|
|
|
rmse = mse ** 0.5
|
|
|
|
|
print(f"均方根误差(RMSE):{rmse}")
|
|
|
|
|
|
|
|
|
|
# 可视化预测值与实际值,使用不同颜色标记
|
|
|
|
|
plt.scatter(y_test, y_pred, color='blue', label='预测发电量', alpha=0.6)
|
|
|
|
|
plt.scatter(y_test, y_test, color='red', label='实际发电量', alpha=0.6)
|
|
|
|
|
plt.xlabel('实际发电量 (GWh)')
|
|
|
|
|
plt.ylabel('预测发电量 (GWh)')
|
|
|
|
|
plt.title('实际 vs 预测发电量')
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.show()
|
|
|
|
|
|
2025-02-23 18:08:05 +00:00
|
|
|
|
# 交叉验证(例如使用10折交叉验证)
|
|
|
|
|
cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=10, scoring='neg_mean_squared_error')
|
|
|
|
|
|
|
|
|
|
# 检查是否存在 NaN 值
|
|
|
|
|
if np.any(np.isnan(cv_scores)):
|
|
|
|
|
print("警告: 交叉验证中发现 NaN 值,可能由数据问题导致。")
|
|
|
|
|
|
|
|
|
|
# 输出交叉验证的每折得分(RMSE)
|
|
|
|
|
print("交叉验证的每折结果(负均方误差):")
|
|
|
|
|
for i, score in enumerate(cv_scores, 1):
|
|
|
|
|
if np.isnan(score):
|
|
|
|
|
print(f"折 {i}: 无效得分")
|
|
|
|
|
else:
|
|
|
|
|
print(f"score: {score}")
|
|
|
|
|
print(f"折 {i}: {(-score) ** 0.5:.4f} RMSE")
|
|
|
|
|
|
|
|
|
|
# 输出交叉验证的平均RMSE
|
|
|
|
|
mean_rmse = (-cv_scores.mean()) ** 0.5
|
|
|
|
|
print(f"交叉验证的平均RMSE:{mean_rmse:.4f}")
|
|
|
|
|
|
|
|
|
|
# 可视化交叉验证结果
|
|
|
|
|
plt.figure(figsize=(8, 6))
|
|
|
|
|
plt.plot(range(1, 11), -cv_scores, marker='o', label='每折负均方误差')
|
|
|
|
|
plt.xlabel('折数')
|
|
|
|
|
plt.ylabel('负均方误差')
|
|
|
|
|
plt.title('交叉验证结果(每折负均方误差)')
|
|
|
|
|
plt.xticks(range(1, 11))
|
|
|
|
|
plt.legend()
|
|
|
|
|
plt.grid(True)
|
|
|
|
|
plt.show()
|
2025-02-23 16:45:15 +00:00
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
main()
|