- 创建项目结构和核心模块 - 实现机械手臂控制和目标跟踪功能 - 开发强化学习环境和训练脚本 - 添加文档和使用示例 - 设置日志记录和TensorBoard可视化
110 lines
2.8 KiB
Python
110 lines
2.8 KiB
Python
"""
|
|
强化学习训练示例
|
|
"""
|
|
import os
|
|
import torch
|
|
from contextlib import nullcontext
|
|
from torch.cuda.amp import autocast as autocast_ctx_decorator
|
|
|
|
# 添加项目根目录到Python路径
|
|
import sys
|
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from src.rl_env.cartesian_env import CartesianSpaceEnv
|
|
from src.utils.test_data import generate_test_scenarios, save_scenarios_to_json
|
|
|
|
|
|
def create_mixed_precision_wrapper(train_func, autocast_ctx):
|
|
"""创建支持混合精度训练的包装器函数"""
|
|
def mixed_precision_train(*args, **kwargs):
|
|
with autocast_ctx:
|
|
return train_func(*args, **kwargs)
|
|
return mixed_precision_train
|
|
|
|
|
|
# 强制在文件顶部禁用GPU
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
|
|
|
|
|
# 添加缺失的monitor导入
|
|
from stable_baselines3.common.monitor import Monitor
|
|
from stable_baselines3.common.vec_env import DummyVecEnv
|
|
# 使用Gymnasium的导入
|
|
import gymnasium as gym
|
|
from gymnasium.wrappers import RecordEpisodeStatistics
|
|
# 添加缺失的环境注册导入
|
|
from gymnasium.envs.registration import register
|
|
# 显式添加缺失的导入
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.callbacks import EvalCallback
|
|
|
|
|
|
def train_rl_model():
|
|
"""训练强化学习模型"""
|
|
# 强制禁用GPU
|
|
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
|
|
|
# 显式设置设备为CPU
|
|
device = 'cpu'
|
|
|
|
# 注册自定义环境
|
|
register(
|
|
id='CartesianSpace-v0',
|
|
entry_point='src.rl_env.cartesian_env:CartesianSpaceEnv',
|
|
max_episode_steps=1000,
|
|
)
|
|
|
|
# 创建Gymnasium环境并应用标准包装器
|
|
env = DummyVecEnv([
|
|
lambda: Monitor(
|
|
RecordEpisodeStatistics(gym.make('CartesianSpace-v0'))
|
|
)
|
|
])
|
|
|
|
# 创建PPO模型时确保使用CPU
|
|
model = PPO(
|
|
"MlpPolicy",
|
|
env,
|
|
verbose=1,
|
|
device=device,
|
|
tensorboard_log="tensorboard_logs/",
|
|
policy_kwargs=dict(
|
|
net_arch=dict(pi=[256, 256], vf=[256, 256]),
|
|
log_std_init=-2,
|
|
),
|
|
batch_size=1024,
|
|
use_sde=True
|
|
)
|
|
|
|
# 添加缺失的评估回调
|
|
eval_callback = EvalCallback(
|
|
env,
|
|
eval_freq=1000,
|
|
best_model_save_path="models/best_model/",
|
|
deterministic=True,
|
|
render=False
|
|
)
|
|
|
|
# 训练模型
|
|
model.learn(
|
|
total_timesteps=50000,
|
|
callback=eval_callback,
|
|
tb_log_name="PPO_Cartesian"
|
|
)
|
|
|
|
# 保存最终模型
|
|
model.save("models/ppo_cartesian_final")
|
|
|
|
# 测试训练好的模型
|
|
obs = env.reset()
|
|
for _ in range(100):
|
|
action, _states = model.predict(obs, deterministic=True)
|
|
obs, reward, done, truncated, info = env.step(action)
|
|
env.render()
|
|
if done or truncated:
|
|
obs = env.reset()
|
|
|
|
env.close()
|
|
|
|
if __name__ == "__main__":
|
|
train_rl_model() |