PowerPlant-analysis/Test/forecasting.py
fly6516 6756de70cc feat(Test): 添加能源预测脚本
- 新增 energy-prediction.py 脚本,实现全球发电厂数据库的下载、处理和分析
- 使用随机森林回归模型预测发电量,并评估模型性能- 可视化预测结果与实际值的对比
- 设置中文字体支持,确保图表显示正常
2025-02-24 00:45:15 +08:00

152 lines
5.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import requests
import zipfile
import io
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error
from statsmodels.tsa.arima.model import ARIMA
# 下载和解压ZIP文件
def download_and_extract_zip(url, extract_to='.'):
print(f"Downloading ZIP file from {url}")
response = requests.get(url)
if response.status_code == 200:
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
z.extractall(extract_to)
print(f"Extracted ZIP file to {extract_to}")
else:
print(f"Failed to download ZIP file. Status code: {response.status_code}")
# 加载CSV数据
def load_data_from_csv(directory):
for filename in os.listdir(directory):
if filename.endswith(".csv"):
file_path = os.path.join(directory, filename)
print(f"Loading data from {file_path}")
return pd.read_csv(file_path)
print("No CSV files found.")
return None
# 发电量预测模型(时间序列分析)
def generate_forecast(df):
# 确保使用正确的时间列(这里使用'year_of_capacity_data'列)
df['year_of_capacity_data'] = pd.to_datetime(df['year_of_capacity_data'], format='%Y', errors='coerce')
df = df.dropna(subset=['year_of_capacity_data']) # 删除无效的日期数据
# 设置日期为索引,按年分组
df.set_index('year_of_capacity_data', inplace=True)
df = df.resample('A').sum() # 按年重采样并求和
# 检查是否有生成数据列
if 'generation_gwh_2013' in df.columns:
model = ARIMA(df['generation_gwh_2013'], order=(5, 1, 0))
model_fit = model.fit()
forecast = model_fit.forecast(steps=30)
print(f"Forecasted Generation for the next 30 years: {forecast}")
else:
print("Generation data not found for ARIMA model.")
# 设备故障预测模型(分类模型)
def device_fault_prediction(df):
# 假设数据包含故障标签'fault'和其他特征
if 'fault' in df.columns:
X = df.drop(columns='fault')
y = df['fault']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = RandomForestClassifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("Device Fault Prediction Accuracy:", accuracy_score(y_test, y_pred))
else:
print("Fault column not found in the dataset.")
# 能源效率优化模型(回归模型)
def energy_efficiency_optimization(df):
# 假设数据包含'energy_consumption'和'energy_output'列
if 'capacity_mw' in df.columns and 'generation_gwh_2013' in df.columns:
# 用capacity_mw作为能源消耗的代理generation_gwh_2013作为能源输出
X = df[['capacity_mw']]
y = df['generation_gwh_2013']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = SVR(kernel='rbf')
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("Mean Squared Error of Energy Efficiency Prediction:", mean_squared_error(y_test, y_pred))
else:
print("Required columns for energy efficiency not found.")
# 电力市场预测模型(回归模型)
def power_market_prediction(df):
if 'generation_gwh_2013' in df.columns and 'capacity_mw' in df.columns:
X = df[['capacity_mw']]
y = df['generation_gwh_2013']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = SVR(kernel='rbf')
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("Mean Squared Error of Power Market Price Prediction:", mean_squared_error(y_test, y_pred))
else:
print("Required columns for power market prediction not found.")
# 主函数
def main():
url = "https://datasets.wri.org/api/3/action/package_show?id=global-power-plant-database"
# 获取API响应
response = requests.get(url)
if response.status_code == 200:
data = response.json()
# 提取资源下载链接
resources = data["result"]["resources"]
latest_version_url = None
for resource in resources:
if resource["name"].startswith("Version 1.3.0"):
latest_version_url = resource["url"]
break
if latest_version_url:
# 下载并解压最新版本的ZIP文件
download_and_extract_zip(latest_version_url, extract_to='./data')
# 加载CSV文件并查看数据
df = load_data_from_csv('./data')
if df is not None:
print("Loaded data:")
print(df.head())
print("Columns in the dataset:", df.columns)
# 1. 发电量预测
generate_forecast(df)
# 2. 设备故障预测
device_fault_prediction(df)
# 3. 能源效率优化
energy_efficiency_optimization(df)
# 4. 电力市场预测
power_market_prediction(df)
else:
print("No download link found for Version 1.3.0.")
else:
print(f"Failed to retrieve data. Status code: {response.status_code}")
if __name__ == "__main__":
main()