refactor(env): 优化 DOA_SAC_sim2real 脚本

- 修改了 URDF 文件路径
- 删除了冗余代码和注释
- 调整了代码格式和缩进,提高了可读性
- 更新了模型保存路径
This commit is contained in:
fly6516 2025-05-27 21:55:49 +08:00
parent 3d914c8c53
commit e3b17f5eb2

View File

@ -17,13 +17,16 @@ import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import pandas as pd
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
MAX_EPISODE_LEN = 20 * 100
x = []
y = []
z = []
#运动模式
# 运动模式
# PI=3.1415926
# ABS = 0 # 绝对运动
# INCR = 1 # 增量运动
@ -126,9 +129,6 @@ class jakaEnv(gym.Env):
# robot.servo_p(cartesian_pose=[dx*100, dy*100, dz*50, 0, 0, 0], move_mode=1)
# time.sleep(0.008)
state_object, state_object_orienation = p.getBasePositionAndOrientation(self.objectId)
twist_object, twist_object_orienation = p.getBaseVelocity(self.objectId)
state_robot = p.getLinkState(self.jaka_id, 6)[0]
@ -282,7 +282,7 @@ class jakaEnv(gym.Env):
reset_realposes = [0.8, 2.4, 1.3, 1, -1.57, 0]
# 改变路径为你机械臂的URDF文件路径
self.jaka_id = p.loadURDF(
r"D:\Python\robot_DRL\env\lib64\urdf\jaka_description\urdf\jaka_description.urdf",
"urdf/jaka_description/urdf/jaka_description.urdf",
basePosition=[0, 0.5, 0.65],
baseOrientation=p.getQuaternionFromEuler([0, 0, 3.14]),
useFixedBase=True,
@ -425,6 +425,7 @@ if __name__ == "__main__":
It must contains the file created by the ``Monitor`` wrapper.
:param verbose: Verbosity level.
"""
def __init__(self, check_freq: int, log_dir: str, verbose: int = 1):
super(SaveOnBestTrainingRewardCallback, self).__init__(verbose)
self.check_freq = check_freq
@ -447,7 +448,8 @@ if __name__ == "__main__":
mean_reward = np.mean(y[-100:])
if self.verbose > 0:
print(f"Num timesteps: {self.num_timesteps}")
print(f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
print(
f"Best mean reward: {self.best_mean_reward:.2f} - Last mean reward per episode: {mean_reward:.2f}")
# New best model, you could save the agent here
if mean_reward > self.best_mean_reward:
@ -459,6 +461,7 @@ if __name__ == "__main__":
return True
tempt = 1
log_dir = './tensorboard/DOA_SAC_callback/'
os.makedirs(log_dir, exist_ok=True)
@ -476,8 +479,7 @@ if __name__ == "__main__":
else:
obs = env.reset()
# 改变路径为你保存模型的路径
model = SAC.load(r'C:\Users\fly\PycharmProjects\RL-PowerTracking-new\best_model.zip', env=env)
model = SAC.load(r'best_model.zip', env=env)
for j in range(50):
for i in range(2000):
@ -498,7 +500,6 @@ if __name__ == "__main__":
break
break
# 三维
# fig1 = plt.figure("机械臂运行轨迹")
# ax = fig1.add_subplot(projection="3d") # 三维图形