PowerPlant-analysis/Test/energy-prediction.py
fly6516 11761ceb84 feat(Test): 使用线性回归预测未来30年发电量增长趋势
- 移除了标准缩放和交叉验证的相关代码
- 添加了线性回归模型,用于预测未来30年的发电量
- 绘制了已知发电量和未来预测发电量的趋势图- 优化了散点图的绘制,增加了对角线以更好地展示预测与实际值的对比
2025-02-24 02:23:10 +08:00

131 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import zipfile
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import matplotlib
from sklearn.linear_model import LinearRegression
import numpy as np
def main():
# 设置支持中文的字体例如SimHei 字体)
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 或使用 'Microsoft YaHei'
matplotlib.rcParams['axes.unicode_minus'] = False # 防止负号显示为方块
# 获取数据资源 URL选择最新版本的 ZIP 文件)
url = "https://datasets.wri.org/private-admin/dataset/53623dfd-3df6-4f15-a091-67457cdb571f/resource/66bcdacc-3d0e-46ad-9271-a5a76b1853d2/download/globalpowerplantdatabasev130.zip"
# 下载 ZIP 文件
response = requests.get(url)
zip_path = "global_power_plant_data.zip"
with open(zip_path, 'wb') as file:
file.write(response.content)
# 解压 ZIP 文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall("data")
# 确保解压后的文件名和路径
csv_file = "data/global_power_plant_database.csv"
if not os.path.exists(csv_file):
print(f"错误: 文件 {csv_file} 不存在!")
return
# 加载解压后的 CSV 数据
data = pd.read_csv(csv_file, dtype={'other_fuel3': str})
# 数据预处理
data.fillna(0, inplace=True) # 用 0 填充缺失值
# 确保 'owner' 列的类型统一为字符串类型
data['owner'] = data['owner'].astype(str)
# 对分类列进行编码(例如 primary_fuel, owner
label_encoder = LabelEncoder()
data['primary_fuel'] = label_encoder.fit_transform(data['primary_fuel'])
data['owner'] = label_encoder.fit_transform(data['owner'])
# 确保发电量列存在并计算总发电量
generation_columns = ['generation_gwh_2013', 'generation_gwh_2014', 'generation_gwh_2015',
'generation_gwh_2016', 'generation_gwh_2017']
# 聚合不同年份的发电数据
data['total_generation'] = data[generation_columns].sum(axis=1)
# 选择特征X和目标变量y
X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']] # 示例特征
y = data['generation_gwh_2017'] # 预测目标2017年的发电量
# 将数据分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化并训练随机森林回归模型
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# 在测试集上进行预测
y_pred = model.predict(X_test)
# 评估模型
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(f"均方根误差RMSE{rmse}")
# 绘制预测值与实际值的对比图(预测值为 y 轴,实际值为 x 轴)
plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, color='blue', alpha=0.6)
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red', linestyle='--')
plt.xlabel('实际发电量 (GWh)')
plt.ylabel('预测发电量 (GWh)')
plt.title('实际发电量与预测发电量对比')
plt.grid(True)
plt.show()
# 使用线性回归预测未来30年的增长趋势
years = [2013, 2014, 2015, 2016, 2017] # 使用已有的年份
generation_values = data[generation_columns].mean(axis=0) # 使用这些年份的平均发电量
# 创建一个线性回归模型来拟合这些数据
linear_regressor = LinearRegression()
linear_regressor.fit(np.array(years).reshape(-1, 1), generation_values)
# 预测未来30年的发电量
future_years = list(range(2018, 2018 + 30))
future_generation = linear_regressor.predict(np.array(future_years).reshape(-1, 1))
# 输出未来30年的预测发电量
print("未来30年发电量预测单位GWh")
for i, generation in enumerate(future_generation):
print(f"{i+1} 年 ({future_years[i]}): {generation:.2f} GWh")
# 获取2013到2017年已知的发电量数据
known_generation = data[generation_columns].mean(axis=0)
# 合并已知发电量数据和未来30年预测的发电量
all_years = years + future_years
all_generation = list(known_generation) + list(future_generation)
# 可视化年份和发电量的增长趋势(已知数据和预测数据)
plt.figure(figsize=(10, 6))
# 绘制已知数据的发电量趋势2013-2017年用绿色标色
plt.plot(years, known_generation, marker='o', color='green', label='已知发电量 (2013-2017)', linestyle='-', markersize=8)
# 绘制未来预测数据的发电量趋势2018-2047年用蓝色标色
plt.plot(future_years, future_generation, marker='o', color='blue', label='预测发电量 (2018-2047)', linestyle='-', markersize=8)
plt.xlabel('年份')
plt.ylabel('发电量 (GWh)')
plt.title('2013-2047年发电量增长趋势已知数据和预测数据')
plt.grid(True)
plt.legend()
plt.show()
if __name__ == "__main__":
main()