refactor(DOA_SAC_sim2real): 优化代码设置并准备模型加载路径

- 将 tempt 变量初始化为 0- 更新模型加载路径为具体文件夹位置
- 显式指定共享模型在 GPU 上运行
- 修正本地模型初始化,使用 SAC 替代 SAC_Model
This commit is contained in:
Asuka 2025-05-28 20:07:40 +08:00
parent c76dab54b0
commit 3a111b67e2

View File

@ -465,7 +465,7 @@ if __name__ == "__main__":
return True
tempt = 1
tempt = 0
log_dir = './tensorboard/DOA_SAC_callback/'
os.makedirs(log_dir, exist_ok=True)
env = jakaEnv()
@ -482,7 +482,7 @@ if __name__ == "__main__":
else:
obs = env.reset()
# 改变路径为你保存模型的路径
model = SAC.load(r'best_model.zip', env=env)
model = SAC.load(r'D:\Python-Project\RL-PowerTracking-new\model\best_model.zip', env=env)
for j in range(50):
for i in range(2000):
@ -540,7 +540,7 @@ def train_parallel(num_processes):
# 创建共享模型使用Stable Baselines3的SAC
env = jakaEnv() # 创建环境实例
model = SAC('MlpPolicy', env=env, verbose=1, device="cuda") # 使用CUDA加速
shared_model = model.policy.to(torch.device('cuda')) # 确保模型在GPU上
shared_model = model.policy.to(torch.device('cuda')) # 显式指定模型在GPU上
shared_model.share_memory() # 共享模型参数
# 创建进程列表
@ -564,7 +564,7 @@ def train_process(rank, shared_model):
env = create_arm_environment() # 替换为实际的环境创建函数
# 创建本地模型副本
local_model = SAC_Model().to(device)
local_model = SAC().to(device)
local_model.load_state_dict(shared_model.state_dict())
# 在此处替换原有训练循环为并行版本