feat(Test): 添加能源预测脚本

- 新增 energy-prediction.py 脚本,实现全球发电厂数据库的下载、处理和分析
- 使用随机森林回归模型预测发电量,并评估模型性能- 可视化预测结果与实际值的对比
- 设置中文字体支持,确保图表显示正常
This commit is contained in:
fly6516 2025-02-24 00:45:15 +08:00
commit 6756de70cc
2 changed files with 245 additions and 0 deletions

94
Test/energy-prediction.py Normal file
View File

@ -0,0 +1,94 @@
import requests
import zipfile
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
import matplotlib
def main():
# 设置支持中文的字体例如SimHei 字体)
matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 或使用 'Microsoft YaHei'
matplotlib.rcParams['axes.unicode_minus'] = False # 防止负号显示为方块
# 获取数据资源 URL选择最新版本的 ZIP 文件)
url = "https://datasets.wri.org/private-admin/dataset/53623dfd-3df6-4f15-a091-67457cdb571f/resource/66bcdacc-3d0e-46ad-9271-a5a76b1853d2/download/globalpowerplantdatabasev130.zip"
# 下载 ZIP 文件
response = requests.get(url)
zip_path = "global_power_plant_data.zip"
with open(zip_path, 'wb') as file:
file.write(response.content)
# 解压 ZIP 文件
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall("data")
# 确保解压后的文件名和路径
csv_file = "data/global_power_plant_database.csv"
if not os.path.exists(csv_file):
print(f"错误: 文件 {csv_file} 不存在!")
return
# 加载解压后的 CSV 数据
data = pd.read_csv(csv_file, dtype={'other_fuel3': str})
# data = pd.read_csv(csv_file)
# 数据预处理
data.fillna(0, inplace=True) # 用 0 填充缺失值
# 确保 'owner' 列的类型统一为字符串类型
data['owner'] = data['owner'].astype(str)
# 对分类列进行编码(例如 primary_fuel, owner
label_encoder = LabelEncoder()
data['primary_fuel'] = label_encoder.fit_transform(data['primary_fuel'])
data['owner'] = label_encoder.fit_transform(data['owner'])
# 确保发电量列存在并计算总发电量
generation_columns = ['generation_gwh_2013', 'generation_gwh_2014', 'generation_gwh_2015',
'generation_gwh_2016', 'generation_gwh_2017']
# 确保所有发电量列都存在
missing_cols = [col for col in generation_columns if col not in data.columns]
if missing_cols:
print(f"警告: 缺少以下列: {', '.join(missing_cols)}")
return
# 聚合不同年份的发电数据
data['total_generation'] = data[generation_columns].sum(axis=1)
# 选择特征X和目标变量y
X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']] # 示例特征
y = data['generation_gwh_2017'] # 预测目标2017年的发电量
# 将数据分割为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 初始化并训练随机森林回归模型
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# 在测试集上进行预测
y_pred = model.predict(X_test)
# 评估模型
mse = mean_squared_error(y_test, y_pred)
rmse = mse ** 0.5
print(f"均方根误差RMSE{rmse}")
# 可视化预测值与实际值,使用不同颜色标记
plt.scatter(y_test, y_pred, color='blue', label='预测发电量', alpha=0.6)
plt.scatter(y_test, y_test, color='red', label='实际发电量', alpha=0.6)
plt.xlabel('实际发电量 (GWh)')
plt.ylabel('预测发电量 (GWh)')
plt.title('实际 vs 预测发电量')
plt.legend()
plt.show()
if __name__ == "__main__":
main()

151
Test/forecasting.py Normal file
View File

@ -0,0 +1,151 @@
import requests
import zipfile
import io
import pandas as pd
import os
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
from sklearn.svm import SVR
from sklearn.metrics import mean_squared_error
from statsmodels.tsa.arima.model import ARIMA
# 下载和解压ZIP文件
def download_and_extract_zip(url, extract_to='.'):
print(f"Downloading ZIP file from {url}")
response = requests.get(url)
if response.status_code == 200:
with zipfile.ZipFile(io.BytesIO(response.content)) as z:
z.extractall(extract_to)
print(f"Extracted ZIP file to {extract_to}")
else:
print(f"Failed to download ZIP file. Status code: {response.status_code}")
# 加载CSV数据
def load_data_from_csv(directory):
for filename in os.listdir(directory):
if filename.endswith(".csv"):
file_path = os.path.join(directory, filename)
print(f"Loading data from {file_path}")
return pd.read_csv(file_path)
print("No CSV files found.")
return None
# 发电量预测模型(时间序列分析)
def generate_forecast(df):
# 确保使用正确的时间列(这里使用'year_of_capacity_data'列)
df['year_of_capacity_data'] = pd.to_datetime(df['year_of_capacity_data'], format='%Y', errors='coerce')
df = df.dropna(subset=['year_of_capacity_data']) # 删除无效的日期数据
# 设置日期为索引,按年分组
df.set_index('year_of_capacity_data', inplace=True)
df = df.resample('A').sum() # 按年重采样并求和
# 检查是否有生成数据列
if 'generation_gwh_2013' in df.columns:
model = ARIMA(df['generation_gwh_2013'], order=(5, 1, 0))
model_fit = model.fit()
forecast = model_fit.forecast(steps=30)
print(f"Forecasted Generation for the next 30 years: {forecast}")
else:
print("Generation data not found for ARIMA model.")
# 设备故障预测模型(分类模型)
def device_fault_prediction(df):
# 假设数据包含故障标签'fault'和其他特征
if 'fault' in df.columns:
X = df.drop(columns='fault')
y = df['fault']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = RandomForestClassifier()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("Device Fault Prediction Accuracy:", accuracy_score(y_test, y_pred))
else:
print("Fault column not found in the dataset.")
# 能源效率优化模型(回归模型)
def energy_efficiency_optimization(df):
# 假设数据包含'energy_consumption'和'energy_output'列
if 'capacity_mw' in df.columns and 'generation_gwh_2013' in df.columns:
# 用capacity_mw作为能源消耗的代理generation_gwh_2013作为能源输出
X = df[['capacity_mw']]
y = df['generation_gwh_2013']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = SVR(kernel='rbf')
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("Mean Squared Error of Energy Efficiency Prediction:", mean_squared_error(y_test, y_pred))
else:
print("Required columns for energy efficiency not found.")
# 电力市场预测模型(回归模型)
def power_market_prediction(df):
if 'generation_gwh_2013' in df.columns and 'capacity_mw' in df.columns:
X = df[['capacity_mw']]
y = df['generation_gwh_2013']
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
model = SVR(kernel='rbf')
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
print("Mean Squared Error of Power Market Price Prediction:", mean_squared_error(y_test, y_pred))
else:
print("Required columns for power market prediction not found.")
# 主函数
def main():
url = "https://datasets.wri.org/api/3/action/package_show?id=global-power-plant-database"
# 获取API响应
response = requests.get(url)
if response.status_code == 200:
data = response.json()
# 提取资源下载链接
resources = data["result"]["resources"]
latest_version_url = None
for resource in resources:
if resource["name"].startswith("Version 1.3.0"):
latest_version_url = resource["url"]
break
if latest_version_url:
# 下载并解压最新版本的ZIP文件
download_and_extract_zip(latest_version_url, extract_to='./data')
# 加载CSV文件并查看数据
df = load_data_from_csv('./data')
if df is not None:
print("Loaded data:")
print(df.head())
print("Columns in the dataset:", df.columns)
# 1. 发电量预测
generate_forecast(df)
# 2. 设备故障预测
device_fault_prediction(df)
# 3. 能源效率优化
energy_efficiency_optimization(df)
# 4. 电力市场预测
power_market_prediction(df)
else:
print("No download link found for Version 1.3.0.")
else:
print(f"Failed to retrieve data. Status code: {response.status_code}")
if __name__ == "__main__":
main()