refactor(env): 优化 DOA_SAC_sim2real 脚本
- 修改了 URDF 文件路径 - 删除了冗余代码和注释 - 调整了代码格式和缩进,提高了可读性 - 更新了模型保存路径
This commit is contained in:
parent
3d914c8c53
commit
e3b17f5eb2
@ -17,13 +17,16 @@ import matplotlib as mpl
|
|||||||
from mpl_toolkits.mplot3d import Axes3D
|
from mpl_toolkits.mplot3d import Axes3D
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
|
|
||||||
|
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
||||||
|
|
||||||
MAX_EPISODE_LEN = 20 * 100
|
MAX_EPISODE_LEN = 20 * 100
|
||||||
x = []
|
x = []
|
||||||
y = []
|
y = []
|
||||||
z = []
|
z = []
|
||||||
#运动模式
|
|
||||||
|
|
||||||
|
# 运动模式
|
||||||
# PI=3.1415926
|
# PI=3.1415926
|
||||||
# ABS = 0 # 绝对运动
|
# ABS = 0 # 绝对运动
|
||||||
# INCR = 1 # 增量运动
|
# INCR = 1 # 增量运动
|
||||||
@ -122,12 +125,9 @@ class jakaEnv(gym.Env):
|
|||||||
p.resetBaseVelocity(self.blockId, linearVelocity=[0, random.choice(self.speed), 0])
|
p.resetBaseVelocity(self.blockId, linearVelocity=[0, random.choice(self.speed), 0])
|
||||||
p.stepSimulation()
|
p.stepSimulation()
|
||||||
# robot.servo_move_enable(True)
|
# robot.servo_move_enable(True)
|
||||||
|
|
||||||
# robot.servo_p(cartesian_pose=[dx*100, dy*100, dz*50, 0, 0, 0], move_mode=1)
|
# robot.servo_p(cartesian_pose=[dx*100, dy*100, dz*50, 0, 0, 0], move_mode=1)
|
||||||
# time.sleep(0.008)
|
# time.sleep(0.008)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
state_object, state_object_orienation = p.getBasePositionAndOrientation(self.objectId)
|
state_object, state_object_orienation = p.getBasePositionAndOrientation(self.objectId)
|
||||||
twist_object, twist_object_orienation = p.getBaseVelocity(self.objectId)
|
twist_object, twist_object_orienation = p.getBaseVelocity(self.objectId)
|
||||||
@ -282,7 +282,7 @@ class jakaEnv(gym.Env):
|
|||||||
reset_realposes = [0.8, 2.4, 1.3, 1, -1.57, 0]
|
reset_realposes = [0.8, 2.4, 1.3, 1, -1.57, 0]
|
||||||
# 改变路径为你机械臂的URDF文件路径
|
# 改变路径为你机械臂的URDF文件路径
|
||||||
self.jaka_id = p.loadURDF(
|
self.jaka_id = p.loadURDF(
|
||||||
r"D:\Python\robot_DRL\env\lib64\urdf\jaka_description\urdf\jaka_description.urdf",
|
"urdf/jaka_description/urdf/jaka_description.urdf",
|
||||||
basePosition=[0, 0.5, 0.65],
|
basePosition=[0, 0.5, 0.65],
|
||||||
baseOrientation=p.getQuaternionFromEuler([0, 0, 3.14]),
|
baseOrientation=p.getQuaternionFromEuler([0, 0, 3.14]),
|
||||||
useFixedBase=True,
|
useFixedBase=True,
|
||||||
@ -308,7 +308,7 @@ class jakaEnv(gym.Env):
|
|||||||
block_vel_range = np.array(self.speed)
|
block_vel_range = np.array(self.speed)
|
||||||
block_loc = np.array([0, 0, 0])
|
block_loc = np.array([0, 0, 0])
|
||||||
block_rel_pos = np.array([0, 0, 0])
|
block_rel_pos = np.array([0, 0, 0])
|
||||||
block2_pos = np.array([0.6, 0.66, 0.43])
|
block2_pos = np.array([0.6, 0.66, 0.43])
|
||||||
|
|
||||||
obs = np.concatenate(
|
obs = np.concatenate(
|
||||||
[
|
[
|
||||||
@ -425,6 +425,7 @@ if __name__ == "__main__":
|
|||||||
It must contains the file created by the ``Monitor`` wrapper.
|
It must contains the file created by the ``Monitor`` wrapper.
|
||||||
:param verbose: Verbosity level.
|
:param verbose: Verbosity level.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
|
def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
|
||||||
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
|
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
|
||||||
self.check_freq = check_freq
|
self.check_freq = check_freq
|
||||||
@ -440,32 +441,34 @@ if __name__ == "__main__":
|
|||||||
def _on_step(self) -> bool:
|
def _on_step(self) -> bool:
|
||||||
if self.n_calls % self.check_freq == 0:
|
if self.n_calls % self.check_freq == 0:
|
||||||
|
|
||||||
# Retrieve training reward
|
# Retrieve training reward
|
||||||
x, y = ts2xy(load_results(self.log_dir), 'timesteps')
|
x, y = ts2xy(load_results(self.log_dir), 'timesteps')
|
||||||
if len(x) > 0:
|
if len(x) > 0:
|
||||||
# Mean training reward over the last 100 episodes
|
# Mean training reward over the last 100 episodes
|
||||||
mean_reward = np.mean(y[-100:])
|
mean_reward = np.mean(y[-100:])
|
||||||
if self.verbose > 0:
|
if self.verbose > 0:
|
||||||
print(f"Num timesteps: {self.num_timesteps}")
|
print(f"Num timesteps: {self.num_timesteps}")
|
||||||
print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
|
print(
|
||||||
|
f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
|
||||||
|
|
||||||
# New best model, you could save the agent here
|
# New best model, you could save the agent here
|
||||||
if mean_reward > self.best_mean_reward:
|
if mean_reward > self.best_mean_reward:
|
||||||
self.best_mean_reward = mean_reward
|
self.best_mean_reward = mean_reward
|
||||||
# Example for saving best model
|
# Example for saving best model
|
||||||
if self.verbose > 0:
|
if self.verbose > 0:
|
||||||
print(f"Saving new best model to {self.save_path}")
|
print(f"Saving new best model to {self.save_path}")
|
||||||
self.model.save(self.save_path)
|
self.model.save(self.save_path)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
tempt = 1
|
tempt = 1
|
||||||
log_dir = './tensorboard/DOA_SAC_callback/'
|
log_dir = './tensorboard/DOA_SAC_callback/'
|
||||||
os.makedirs(log_dir, exist_ok=True)
|
os.makedirs(log_dir, exist_ok=True)
|
||||||
env = jakaEnv()
|
env = jakaEnv()
|
||||||
env = Monitor(env, log_dir)
|
env = Monitor(env, log_dir)
|
||||||
if tempt:
|
if tempt:
|
||||||
model = SAC('MlpPolicy', env=env, verbose=1, tensorboard_log=log_dir,
|
model = SAC('MlpPolicy', env=env, verbose=1, tensorboard_log=log_dir,
|
||||||
device="cuda"
|
device="cuda"
|
||||||
)
|
)
|
||||||
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
|
callback = SaveOnBestTrainingRewardCallback(check_freq=1000, log_dir=log_dir)
|
||||||
@ -476,8 +479,7 @@ if __name__ == "__main__":
|
|||||||
else:
|
else:
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
# 改变路径为你保存模型的路径
|
# 改变路径为你保存模型的路径
|
||||||
model = SAC.load(r'C:\Users\fly\PycharmProjects\RL-PowerTracking-new\best_model.zip', env=env)
|
model = SAC.load(r'best_model.zip', env=env)
|
||||||
|
|
||||||
|
|
||||||
for j in range(50):
|
for j in range(50):
|
||||||
for i in range(2000):
|
for i in range(2000):
|
||||||
@ -498,7 +500,6 @@ if __name__ == "__main__":
|
|||||||
break
|
break
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
# 三维
|
# 三维
|
||||||
# fig1 = plt.figure("机械臂运行轨迹")
|
# fig1 = plt.figure("机械臂运行轨迹")
|
||||||
# ax = fig1.add_subplot(projection="3d") # 三维图形
|
# ax = fig1.add_subplot(projection="3d") # 三维图形
|
||||||
|
Loading…
Reference in New Issue
Block a user