feat(DOA_SAC_sim2real): 增加周期性模型检查点回调并优化路径配置- 动态路径配置:为硬件模块(如jkrc)实现动态路径配置,确保在不同环境中均可正确加载。
- 观测空间修改:调整观测空间为38维。 - 回调增强:添加 PeriodicModelCheckpointCallback,用于定期保存训练过程中的模型检查点。 - 日志与保存:修改日志目录和模型保存逻辑,使其更加清晰和灵活。
This commit is contained in:
parent
3a111b67e2
commit
686164f670
@ -13,13 +13,13 @@ import numpy as np
|
||||
import random
|
||||
from scipy import signal
|
||||
import sys
|
||||
# sys.path.append('D:\\vs2019ws\PythonCtt\PythonCtt')
|
||||
import jkrc
|
||||
|
||||
import matplotlib as mpl
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
# 动态路径配置:为硬件模块(如jkrc)实现动态路径配置,确保在不同环境中均可正确加载。
|
||||
try:
|
||||
import jkrc
|
||||
JAKA_AVAILABLE = True
|
||||
except ImportError:
|
||||
jkrc = None
|
||||
JAKA_AVAILABLE = False
|
||||
|
||||
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||
|
||||
@ -89,7 +89,11 @@ class jakaEnv(gym.Env):
|
||||
self.z_high = 1.3
|
||||
self.speed = [0.3, 0.5, 0.8, 1.0]
|
||||
self.action_space = spaces.Box(np.array([-1] * 3), np.array([1] * 3))
|
||||
self.observation_space = spaces.Box(np.array([-1] * 38, np.float32), np.array([1] * 38, np.float32))
|
||||
# 修改为38维观测空间
|
||||
self.observation_space = spaces.Box(
|
||||
np.array([-1] * 38, np.float32),
|
||||
np.array([1] * 38, np.float32)
|
||||
)
|
||||
|
||||
def compute_reward(self, achieved_goal, goal):
|
||||
d = goal_distance(achieved_goal, goal)
|
||||
@ -415,7 +419,7 @@ if __name__ == "__main__":
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results
|
||||
from stable_baselines3.common.noise import NormalActionNoise
|
||||
from stable_baselines3.common.callbacks import BaseCallback
|
||||
from stable_baselines3.common.callbacks import BaseCallback, EvalCallback
|
||||
|
||||
|
||||
class SaveOnBestTrainingRewardCallback(BaseCallback):
|
||||
@ -465,7 +469,38 @@ if __name__ == "__main__":
|
||||
return True
|
||||
|
||||
|
||||
tempt = 0
|
||||
class PeriodicModelCheckpointCallback(BaseCallback):
|
||||
"""
|
||||
Callback for periodically saving the model at specified intervals,
|
||||
storing each checkpoint in its own subdirectory.
|
||||
"""
|
||||
def __init__(self, save_freq: int, log_dir: str, verbose: int = 1):
|
||||
super(PeriodicModelCheckpointCallback, self).__init__(verbose)
|
||||
self.save_freq = save_freq
|
||||
self.log_dir = log_dir
|
||||
self.checkpoint_count = 0
|
||||
|
||||
def _init_callback(self) -> None:
|
||||
# Create base directory for checkpoints
|
||||
self.checkpoint_dir = os.path.join(self.log_dir, 'checkpoints')
|
||||
os.makedirs(self.checkpoint_dir, exist_ok=True)
|
||||
|
||||
def _on_step(self) -> bool:
|
||||
if self.n_calls % self.save_freq == 0:
|
||||
# Create new checkpoint directory
|
||||
checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_{self.n_calls}')
|
||||
os.makedirs(checkpoint_path, exist_ok=True)
|
||||
|
||||
# Save model
|
||||
if self.verbose > 0:
|
||||
print(f'Saving model checkpoint to {checkpoint_path}')
|
||||
self.model.save(os.path.join(checkpoint_path, 'model'))
|
||||
|
||||
self.checkpoint_count += 1
|
||||
|
||||
return True
|
||||
|
||||
tempt = 1
|
||||
log_dir = './tensorboard/DOA_SAC_callback/'
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
env = jakaEnv()
|
||||
@ -474,8 +509,17 @@ if __name__ == "__main__":
|
||||
model = SAC('MlpPolicy', env=env, verbose=1, tensorboard_log=log_dir,
|
||||
device="cuda"
|
||||
)
|
||||
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
|
||||
model.learn(total_timesteps=4000000, callback=callback)
|
||||
# 创建回调函数列表,包含原有的最佳模型保存回调和新的周期性检查点回调
|
||||
callback_list = [
|
||||
SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir),
|
||||
PeriodicModelCheckpointCallback(save_freq=5000, log_dir=log_dir)
|
||||
]
|
||||
|
||||
model.learn(
|
||||
total_timesteps=4000000,
|
||||
callback=callback_list,
|
||||
tb_log_name="SAC_2"
|
||||
)
|
||||
model.save('model/DOA_SAC_ENV_callback')
|
||||
del model
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user