PowerPlant-analysis/Test/energy-prediction.py
fly6516 6756de70cc feat(Test): 添加能源预测脚本
- 新增 energy-prediction.py 脚本,实现全球发电厂数据库的下载、处理和分析
- 使用随机森林回归模型预测发电量,并评估模型性能- 可视化预测结果与实际值的对比
- 设置中文字体支持,确保图表显示正常
2025-02-24 00:45:15 +08:00

95 lines
3.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import zipfile
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import matplotlib
def main():
# 设置支持中文的字体例如SimHei 字体)
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 或使用 'Microsoft YaHei'
matplotlib.rcParams['axes.unicode_minus'] = False # 防止负号显示为方块
# 获取数据资源 URL选择最新版本的 ZIP 文件)
url = "https://datasets.wri.org/private-admin/dataset/53623dfd-3df6-4f15-a091-67457cdb571f/resource/66bcdacc-3d0e-46ad-9271-a5a76b1853d2/download/globalpowerplantdatabasev130.zip"
# 下载 ZIP 文件
response = requests.get(url)
zip_path = "global_power_plant_data.zip"
with open(zip_path, 'wb') as file:
file.write(response.content)
# 解压 ZIP 文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall("data")
# 确保解压后的文件名和路径
csv_file = "data/global_power_plant_database.csv"
if not os.path.exists(csv_file):
print(f"错误: 文件 {csv_file} 不存在!")
return
# 加载解压后的 CSV 数据
data = pd.read_csv(csv_file, dtype={'other_fuel3': str})
# data = pd.read_csv(csv_file)
# 数据预处理
data.fillna(0, inplace=True) # 用 0 填充缺失值
# 确保 'owner' 列的类型统一为字符串类型
data['owner'] = data['owner'].astype(str)
# 对分类列进行编码(例如 primary_fuel, owner
label_encoder = LabelEncoder()
data['primary_fuel'] = label_encoder.fit_transform(data['primary_fuel'])
data['owner'] = label_encoder.fit_transform(data['owner'])
# 确保发电量列存在并计算总发电量
generation_columns = ['generation_gwh_2013', 'generation_gwh_2014', 'generation_gwh_2015',
'generation_gwh_2016', 'generation_gwh_2017']
# 确保所有发电量列都存在
missing_cols = [col for col in generation_columns if col not in data.columns]
if missing_cols:
print(f"警告: 缺少以下列: {', '.join(missing_cols)}")
return
# 聚合不同年份的发电数据
data['total_generation'] = data[generation_columns].sum(axis=1)
# 选择特征X和目标变量y
X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']] # 示例特征
y = data['generation_gwh_2017'] # 预测目标2017年的发电量
# 将数据分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化并训练随机森林回归模型
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# 在测试集上进行预测
y_pred = model.predict(X_test)
# 评估模型
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(f"均方根误差RMSE{rmse}")
# 可视化预测值与实际值,使用不同颜色标记
plt.scatter(y_test, y_pred, color='blue', label='预测发电量', alpha=0.6)
plt.scatter(y_test, y_test, color='red', label='实际发电量', alpha=0.6)
plt.xlabel('实际发电量 (GWh)')
plt.ylabel('预测发电量 (GWh)')
plt.title('实际 vs 预测发电量')
plt.legend()
plt.show()
if __name__ == "__main__":
main()