From 6756de70cc2a07d95f26f496adc2459e4a9baf4c Mon Sep 17 00:00:00 2001 From: fly6516 Date: Mon, 24 Feb 2025 00:45:15 +0800 Subject: [PATCH] =?UTF-8?q?feat(Test):=20=E6=B7=BB=E5=8A=A0=E8=83=BD?= =?UTF-8?q?=E6=BA=90=E9=A2=84=E6=B5=8B=E8=84=9A=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 energy-prediction.py 脚本,实现全球发电厂数据库的下载、处理和分析 - 使用随机森林回归模型预测发电量,并评估模型性能- 可视化预测结果与实际值的对比 - 设置中文字体支持,确保图表显示正常 --- Test/energy-prediction.py | 94 ++++++++++++++++++++++++ Test/forecasting.py | 151 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 245 insertions(+) create mode 100644 Test/energy-prediction.py create mode 100644 Test/forecasting.py diff --git a/Test/energy-prediction.py b/Test/energy-prediction.py new file mode 100644 index 0000000..1e0d935 --- /dev/null +++ b/Test/energy-prediction.py @@ -0,0 +1,94 @@ +import requests +import zipfile +import os +import pandas as pd +from sklearn.model_selection import train_test_split +from sklearn.ensemble import RandomForestRegressor +from sklearn.metrics import mean_squared_error +from sklearn.preprocessing import LabelEncoder +import matplotlib.pyplot as plt +import matplotlib + +def main(): + # 设置支持中文的字体(例如:SimHei 字体) + matplotlib.rcParams['font.sans-serif'] = ['SimHei'] # 或使用 'Microsoft YaHei' + matplotlib.rcParams['axes.unicode_minus'] = False # 防止负号显示为方块 + + # 获取数据资源 URL(选择最新版本的 ZIP 文件) + url = "https://datasets.wri.org/private-admin/dataset/53623dfd-3df6-4f15-a091-67457cdb571f/resource/66bcdacc-3d0e-46ad-9271-a5a76b1853d2/download/globalpowerplantdatabasev130.zip" + + # 下载 ZIP 文件 + response = requests.get(url) + zip_path = "global_power_plant_data.zip" + with open(zip_path, 'wb') as file: + file.write(response.content) + + # 解压 ZIP 文件 + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + zip_ref.extractall("data") + + # 确保解压后的文件名和路径 + csv_file = "data/global_power_plant_database.csv" + if not os.path.exists(csv_file): + print(f"错误: 文件 {csv_file} 不存在!") + return + + # 加载解压后的 CSV 数据 + data = pd.read_csv(csv_file, dtype={'other_fuel3': str}) + # data = pd.read_csv(csv_file) + + # 数据预处理 + data.fillna(0, inplace=True) # 用 0 填充缺失值 + + # 确保 'owner' 列的类型统一为字符串类型 + data['owner'] = data['owner'].astype(str) + + # 对分类列进行编码(例如 primary_fuel, owner) + label_encoder = LabelEncoder() + data['primary_fuel'] = label_encoder.fit_transform(data['primary_fuel']) + data['owner'] = label_encoder.fit_transform(data['owner']) + + # 确保发电量列存在并计算总发电量 + generation_columns = ['generation_gwh_2013', 'generation_gwh_2014', 'generation_gwh_2015', + 'generation_gwh_2016', 'generation_gwh_2017'] + + # 确保所有发电量列都存在 + missing_cols = [col for col in generation_columns if col not in data.columns] + if missing_cols: + print(f"警告: 缺少以下列: {', '.join(missing_cols)}") + return + + # 聚合不同年份的发电数据 + data['total_generation'] = data[generation_columns].sum(axis=1) + + # 选择特征(X)和目标变量(y) + X = data[['capacity_mw', 'latitude', 'longitude', 'primary_fuel', 'total_generation']] # 示例特征 + y = data['generation_gwh_2017'] # 预测目标:2017年的发电量 + + # 将数据分割为训练集和测试集 + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + # 初始化并训练随机森林回归模型 + model = RandomForestRegressor(n_estimators=100, random_state=42) + model.fit(X_train, y_train) + + # 在测试集上进行预测 + y_pred = model.predict(X_test) + + # 评估模型 + mse = mean_squared_error(y_test, y_pred) + rmse = mse ** 0.5 + print(f"均方根误差(RMSE):{rmse}") + + # 可视化预测值与实际值,使用不同颜色标记 + plt.scatter(y_test, y_pred, color='blue', label='预测发电量', alpha=0.6) + plt.scatter(y_test, y_test, color='red', label='实际发电量', alpha=0.6) + plt.xlabel('实际发电量 (GWh)') + plt.ylabel('预测发电量 (GWh)') + plt.title('实际 vs 预测发电量') + plt.legend() + plt.show() + + +if __name__ == "__main__": + main() diff --git a/Test/forecasting.py b/Test/forecasting.py new file mode 100644 index 0000000..15941ce --- /dev/null +++ b/Test/forecasting.py @@ -0,0 +1,151 @@ +import requests +import zipfile +import io +import pandas as pd +import os +from sklearn.model_selection import train_test_split +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import accuracy_score +from sklearn.svm import SVR +from sklearn.metrics import mean_squared_error +from statsmodels.tsa.arima.model import ARIMA + +# 下载和解压ZIP文件 +def download_and_extract_zip(url, extract_to='.'): + print(f"Downloading ZIP file from {url}") + response = requests.get(url) + + if response.status_code == 200: + with zipfile.ZipFile(io.BytesIO(response.content)) as z: + z.extractall(extract_to) + print(f"Extracted ZIP file to {extract_to}") + else: + print(f"Failed to download ZIP file. Status code: {response.status_code}") + +# 加载CSV数据 +def load_data_from_csv(directory): + for filename in os.listdir(directory): + if filename.endswith(".csv"): + file_path = os.path.join(directory, filename) + print(f"Loading data from {file_path}") + return pd.read_csv(file_path) + print("No CSV files found.") + return None + +# 发电量预测模型(时间序列分析) +def generate_forecast(df): + # 确保使用正确的时间列(这里使用'year_of_capacity_data'列) + df['year_of_capacity_data'] = pd.to_datetime(df['year_of_capacity_data'], format='%Y', errors='coerce') + df = df.dropna(subset=['year_of_capacity_data']) # 删除无效的日期数据 + + # 设置日期为索引,按年分组 + df.set_index('year_of_capacity_data', inplace=True) + df = df.resample('A').sum() # 按年重采样并求和 + + # 检查是否有生成数据列 + if 'generation_gwh_2013' in df.columns: + model = ARIMA(df['generation_gwh_2013'], order=(5, 1, 0)) + model_fit = model.fit() + forecast = model_fit.forecast(steps=30) + print(f"Forecasted Generation for the next 30 years: {forecast}") + else: + print("Generation data not found for ARIMA model.") + +# 设备故障预测模型(分类模型) +def device_fault_prediction(df): + # 假设数据包含故障标签'fault'和其他特征 + if 'fault' in df.columns: + X = df.drop(columns='fault') + y = df['fault'] + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + model = RandomForestClassifier() + model.fit(X_train, y_train) + + y_pred = model.predict(X_test) + print("Device Fault Prediction Accuracy:", accuracy_score(y_test, y_pred)) + else: + print("Fault column not found in the dataset.") + +# 能源效率优化模型(回归模型) +def energy_efficiency_optimization(df): + # 假设数据包含'energy_consumption'和'energy_output'列 + if 'capacity_mw' in df.columns and 'generation_gwh_2013' in df.columns: + # 用capacity_mw作为能源消耗的代理,generation_gwh_2013作为能源输出 + X = df[['capacity_mw']] + y = df['generation_gwh_2013'] + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + model = SVR(kernel='rbf') + model.fit(X_train, y_train) + + y_pred = model.predict(X_test) + print("Mean Squared Error of Energy Efficiency Prediction:", mean_squared_error(y_test, y_pred)) + else: + print("Required columns for energy efficiency not found.") + +# 电力市场预测模型(回归模型) +def power_market_prediction(df): + if 'generation_gwh_2013' in df.columns and 'capacity_mw' in df.columns: + X = df[['capacity_mw']] + y = df['generation_gwh_2013'] + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + model = SVR(kernel='rbf') + model.fit(X_train, y_train) + + y_pred = model.predict(X_test) + print("Mean Squared Error of Power Market Price Prediction:", mean_squared_error(y_test, y_pred)) + else: + print("Required columns for power market prediction not found.") + +# 主函数 +def main(): + url = "https://datasets.wri.org/api/3/action/package_show?id=global-power-plant-database" + + # 获取API响应 + response = requests.get(url) + if response.status_code == 200: + data = response.json() + + # 提取资源下载链接 + resources = data["result"]["resources"] + latest_version_url = None + + for resource in resources: + if resource["name"].startswith("Version 1.3.0"): + latest_version_url = resource["url"] + break + + if latest_version_url: + # 下载并解压最新版本的ZIP文件 + download_and_extract_zip(latest_version_url, extract_to='./data') + + # 加载CSV文件并查看数据 + df = load_data_from_csv('./data') + if df is not None: + print("Loaded data:") + print(df.head()) + print("Columns in the dataset:", df.columns) + + # 1. 发电量预测 + generate_forecast(df) + + # 2. 设备故障预测 + device_fault_prediction(df) + + # 3. 能源效率优化 + energy_efficiency_optimization(df) + + # 4. 电力市场预测 + power_market_prediction(df) + else: + print("No download link found for Version 1.3.0.") + else: + print(f"Failed to retrieve data. Status code: {response.status_code}") + +if __name__ == "__main__": + main()