refactor(DOA_SAC_sim2real): 优化代码设置并准备模型加载路径
- 将 tempt 变量初始化为 0- 更新模型加载路径为具体文件夹位置 - 显式指定共享模型在 GPU 上运行 - 修正本地模型初始化,使用 SAC 替代 SAC_Model
This commit is contained in:
parent
c76dab54b0
commit
3a111b67e2
@ -465,7 +465,7 @@ if __name__ == "__main__":
|
||||
return True
|
||||
|
||||
|
||||
tempt = 1
|
||||
tempt = 0
|
||||
log_dir = './tensorboard/DOA_SAC_callback/'
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
env = jakaEnv()
|
||||
@ -482,7 +482,7 @@ if __name__ == "__main__":
|
||||
else:
|
||||
obs = env.reset()
|
||||
# 改变路径为你保存模型的路径
|
||||
model = SAC.load(r'best_model.zip', env=env)
|
||||
model = SAC.load(r'D:\Python-Project\RL-PowerTracking-new\model\best_model.zip', env=env)
|
||||
|
||||
for j in range(50):
|
||||
for i in range(2000):
|
||||
@ -540,7 +540,7 @@ def train_parallel(num_processes):
|
||||
# 创建共享模型(使用Stable Baselines3的SAC)
|
||||
env = jakaEnv() # 创建环境实例
|
||||
model = SAC('MlpPolicy', env=env, verbose=1, device="cuda") # 使用CUDA加速
|
||||
shared_model = model.policy.to(torch.device('cuda')) # 确保模型在GPU上
|
||||
shared_model = model.policy.to(torch.device('cuda')) # 显式指定模型在GPU上
|
||||
shared_model.share_memory() # 共享模型参数
|
||||
|
||||
# 创建进程列表
|
||||
@ -564,7 +564,7 @@ def train_process(rank, shared_model):
|
||||
env = create_arm_environment() # 替换为实际的环境创建函数
|
||||
|
||||
# 创建本地模型副本
|
||||
local_model = SAC_Model().to(device)
|
||||
local_model = SAC().to(device)
|
||||
local_model.load_state_dict(shared_model.state_dict())
|
||||
|
||||
# 在此处替换原有训练循环为并行版本
|
||||
|
Loading…
Reference in New Issue
Block a user