diff --git a/DOA_SAC_sim2real.py b/DOA_SAC_sim2real.py index 4697738..b643887 100644 --- a/DOA_SAC_sim2real.py +++ b/DOA_SAC_sim2real.py @@ -13,13 +13,13 @@ import numpy as np import random from scipy import signal import sys -# sys.path.append('D:\\vs2019ws\PythonCtt\PythonCtt') -import jkrc - -import matplotlib as mpl -from mpl_toolkits.mplot3d import Axes3D -import matplotlib.pyplot as plt -import pandas as pd +# 动态路径配置:为硬件模块(如jkrc)实现动态路径配置,确保在不同环境中均可正确加载。 +try: + import jkrc + JAKA_AVAILABLE = True +except ImportError: + jkrc = None + JAKA_AVAILABLE = False os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" @@ -89,7 +89,11 @@ class jakaEnv(gym.Env): self.z_high = 1.3 self.speed = [0.3, 0.5, 0.8, 1.0] self.action_space = spaces.Box(np.array([-1] * 3), np.array([1] * 3)) - self.observation_space = spaces.Box(np.array([-1] * 38, np.float32), np.array([1] * 38, np.float32)) + # 修改为38维观测空间 + self.observation_space = spaces.Box( + np.array([-1] * 38, np.float32), + np.array([1] * 38, np.float32) + ) def compute_reward(self, achieved_goal, goal): d = goal_distance(achieved_goal, goal) @@ -415,7 +419,7 @@ if __name__ == "__main__": from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.results_plotter import load_results, ts2xy, plot_results from stable_baselines3.common.noise import NormalActionNoise - from stable_baselines3.common.callbacks import BaseCallback + from stable_baselines3.common.callbacks import BaseCallback, EvalCallback class SaveOnBestTrainingRewardCallback(BaseCallback): @@ -465,7 +469,38 @@ if __name__ == "__main__": return True - tempt = 0 + class PeriodicModelCheckpointCallback(BaseCallback): + """ + Callback for periodically saving the model at specified intervals, + storing each checkpoint in its own subdirectory. + """ + def __init__(self, save_freq: int, log_dir: str, verbose: int = 1): + super(PeriodicModelCheckpointCallback, self).__init__(verbose) + self.save_freq = save_freq + self.log_dir = log_dir + self.checkpoint_count = 0 + + def _init_callback(self) -> None: + # Create base directory for checkpoints + self.checkpoint_dir = os.path.join(self.log_dir, 'checkpoints') + os.makedirs(self.checkpoint_dir, exist_ok=True) + + def _on_step(self) -> bool: + if self.n_calls % self.save_freq == 0: + # Create new checkpoint directory + checkpoint_path = os.path.join(self.checkpoint_dir, f'checkpoint_{self.n_calls}') + os.makedirs(checkpoint_path, exist_ok=True) + + # Save model + if self.verbose > 0: + print(f'Saving model checkpoint to {checkpoint_path}') + self.model.save(os.path.join(checkpoint_path, 'model')) + + self.checkpoint_count += 1 + + return True + + tempt = 1 log_dir = './tensorboard/DOA_SAC_callback/' os.makedirs(log_dir, exist_ok=True) env = jakaEnv() @@ -474,8 +509,17 @@ if __name__ == "__main__": model = SAC('MlpPolicy', env=env, verbose=1, tensorboard_log=log_dir, device="cuda" ) - callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir) - model.learn(total_timesteps=4000000, callback=callback) + # 创建回调函数列表,包含原有的最佳模型保存回调和新的周期性检查点回调 + callback_list = [ + SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir), + PeriodicModelCheckpointCallback(save_freq=5000, log_dir=log_dir) + ] + + model.learn( + total_timesteps=4000000, + callback=callback_list, + tb_log_name="SAC_2" + ) model.save('model/DOA_SAC_ENV_callback') del model