""" 强化学习训练示例 """ import os import torch from contextlib import nullcontext from torch.cuda.amp import autocast as autocast_ctx_decorator # 添加项目根目录到Python路径 import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from src.rl_env.cartesian_env import CartesianSpaceEnv from src.utils.test_data import generate_test_scenarios, save_scenarios_to_json def create_mixed_precision_wrapper(train_func, autocast_ctx): """创建支持混合精度训练的包装器函数""" def mixed_precision_train(*args, **kwargs): with autocast_ctx: return train_func(*args, **kwargs) return mixed_precision_train # 强制在文件顶部禁用GPU os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # 添加缺失的monitor导入 from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv # 使用Gymnasium的导入 import gymnasium as gym from gymnasium.wrappers import RecordEpisodeStatistics # 添加缺失的环境注册导入 from gymnasium.envs.registration import register # 显式添加缺失的导入 from stable_baselines3 import PPO from stable_baselines3.common.callbacks import EvalCallback def train_rl_model(): """训练强化学习模型""" # 强制禁用GPU os.environ['CUDA_VISIBLE_DEVICES'] = '-1' # 显式设置设备为CPU device = 'cpu' # 注册自定义环境 register( id='CartesianSpace-v0', entry_point='src.rl_env.cartesian_env:CartesianSpaceEnv', max_episode_steps=1000, ) # 创建Gymnasium环境并应用标准包装器 env = DummyVecEnv([ lambda: Monitor( RecordEpisodeStatistics(gym.make('CartesianSpace-v0')) ) ]) # 创建PPO模型时确保使用CPU model = PPO( "MlpPolicy", env, verbose=1, device=device, tensorboard_log="tensorboard_logs/", policy_kwargs=dict( net_arch=dict(pi=[256, 256], vf=[256, 256]), log_std_init=-2, ), batch_size=1024, use_sde=True ) # 添加缺失的评估回调 eval_callback = EvalCallback( env, eval_freq=1000, best_model_save_path="models/best_model/", deterministic=True, render=False ) # 训练模型 model.learn( total_timesteps=50000, callback=eval_callback, tb_log_name="PPO_Cartesian" ) # 保存最终模型 model.save("models/ppo_cartesian_final") # 测试训练好的模型 obs = env.reset() for _ in range(100): action, _states = model.predict(obs, deterministic=True) obs, reward, done, truncated, info = env.step(action) env.render() if done or truncated: obs = env.reset() env.close() if __name__ == "__main__": train_rl_model()