From 3a111b67e2278d9ec9f1e2c4d0552e6efd42bbe6 Mon Sep 17 00:00:00 2001 From: Asuka <15019597+asuka-civil@user.noreply.gitee.com> Date: Wed, 28 May 2025 20:07:40 +0800 Subject: [PATCH] =?UTF-8?q?refactor(DOA=5FSAC=5Fsim2real):=20=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E4=BB=A3=E7=A0=81=E8=AE=BE=E7=BD=AE=E5=B9=B6=E5=87=86?= =?UTF-8?q?=E5=A4=87=E6=A8=A1=E5=9E=8B=E5=8A=A0=E8=BD=BD=E8=B7=AF=E5=BE=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 将 tempt 变量初始化为 0- 更新模型加载路径为具体文件夹位置 - 显式指定共享模型在 GPU 上运行 - 修正本地模型初始化,使用 SAC 替代 SAC_Model --- DOA_SAC_sim2real.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/DOA_SAC_sim2real.py b/DOA_SAC_sim2real.py index b0bfa9d..4697738 100644 --- a/DOA_SAC_sim2real.py +++ b/DOA_SAC_sim2real.py @@ -465,7 +465,7 @@ if __name__ == "__main__": return True - tempt = 1 + tempt = 0 log_dir = './tensorboard/DOA_SAC_callback/' os.makedirs(log_dir, exist_ok=True) env = jakaEnv() @@ -482,7 +482,7 @@ if __name__ == "__main__": else: obs = env.reset() # 改变路径为你保存模型的路径 - model = SAC.load(r'best_model.zip', env=env) + model = SAC.load(r'D:\Python-Project\RL-PowerTracking-new\model\best_model.zip', env=env) for j in range(50): for i in range(2000): @@ -540,7 +540,7 @@ def train_parallel(num_processes): # 创建共享模型(使用Stable Baselines3的SAC) env = jakaEnv() # 创建环境实例 model = SAC('MlpPolicy', env=env, verbose=1, device="cuda") # 使用CUDA加速 - shared_model = model.policy.to(torch.device('cuda')) # 确保模型在GPU上 + shared_model = model.policy.to(torch.device('cuda')) # 显式指定模型在GPU上 shared_model.share_memory() # 共享模型参数 # 创建进程列表 @@ -564,7 +564,7 @@ def train_process(rank, shared_model): env = create_arm_environment() # 替换为实际的环境创建函数 # 创建本地模型副本 - local_model = SAC_Model().to(device) + local_model = SAC().to(device) local_model.load_state_dict(shared_model.state_dict()) # 在此处替换原有训练循环为并行版本