feat: 初始化RL-PowerTracking项目
- 创建项目结构和核心模块 - 实现机械手臂控制和目标跟踪功能 - 开发强化学习环境和训练脚本 - 添加文档和使用示例 - 设置日志记录和TensorBoard可视化
This commit is contained in:
commit
21304165a8
29
README.md
Normal file
29
README.md
Normal file
@ -0,0 +1,29 @@
|
||||
# RL-PowerTracking
|
||||
|
||||
基于强化学习的电力目标跟踪系统
|
||||
|
||||
## 项目结构
|
||||
- `src/` - 核心代码
|
||||
- `docs/` - 文档
|
||||
- `tests/` - 测试用例
|
||||
- `examples/` - 使用示例
|
||||
- `data/` - 数据集
|
||||
- `models/` - 模型保存目录
|
||||
- `tensorboard_logs/` - TensorBoard日志
|
||||
|
||||
## 目录自动创建
|
||||
为确保项目运行时所需目录存在,新增了`create_directories.py`脚本。在运行训练前,请先执行:
|
||||
```bash
|
||||
python create_directories.py
|
||||
```
|
||||
|
||||
## 安装依赖
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 最新优化
|
||||
1. 环境注册问题已修复,自定义环境正确注册为`CartesianSpace-v0`
|
||||
2. 训练速度优化:使用MlpPolicy、较大的batch_size=1024、use_sde=True提升CPU利用率
|
||||
3. 解决了向量化环境观测值解包问题
|
||||
4. 环境包装器嵌套顺序优化(Monitor包裹RecordEpisodeStatistics)
|
19
create_directories.py
Normal file
19
create_directories.py
Normal file
@ -0,0 +1,19 @@
|
||||
import os
|
||||
|
||||
directories = [
|
||||
'src',
|
||||
'docs',
|
||||
'tests',
|
||||
'examples',
|
||||
'data',
|
||||
'src/rl_env',
|
||||
'src/robot_control',
|
||||
'src/vision',
|
||||
'src/utils',
|
||||
'models',
|
||||
'tensorboard_logs'
|
||||
]
|
||||
|
||||
for dir in directories:
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
print(f'Created directory: {dir}')
|
32
docs/index.md
Normal file
32
docs/index.md
Normal file
@ -0,0 +1,32 @@
|
||||
# RL-PowerTracking 文档
|
||||
|
||||
## 项目概述
|
||||
本项目实现了一个基于强化学习的电力目标跟踪系统,包含机械手臂控制、目标识别与跟踪等功能。
|
||||
|
||||
## 安装指南
|
||||
|
||||
### 环境要求
|
||||
- Python 3.8+
|
||||
- Windows/Linux/MacOS
|
||||
|
||||
### 安装步骤
|
||||
```bash
|
||||
# 创建虚拟环境
|
||||
python -m venv .venv
|
||||
source .venv/bin/activate # Windows: .venv\Scripts\activate
|
||||
|
||||
# 安装依赖
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 训练强化学习模型
|
||||
```bash
|
||||
python examples/train_rl_model.py
|
||||
```
|
||||
|
||||
## 模块文档
|
||||
- [强化学习环境](modules/rl_env.md)
|
||||
- [机械手臂控制](modules/robot_control.md)
|
||||
- [视觉识别模块](modules/vision.md)
|
25
docs/modules/rl_env.md
Normal file
25
docs/modules/rl_env.md
Normal file
@ -0,0 +1,25 @@
|
||||
# 强化学习环境模块
|
||||
|
||||
## CartesianSpaceEnv 类
|
||||
|
||||
### 描述
|
||||
实现了一个机械手臂在笛卡尔空间中运动规划的强化学习环境,基于OpenAI Gym接口。
|
||||
|
||||
### 功能
|
||||
- 提供3自由度的位置控制
|
||||
- 计算当前位置与目标位置之间的误差
|
||||
- 提供基于距离的奖励函数
|
||||
|
||||
### 使用示例
|
||||
```python
|
||||
from src.rl_env.cartesian_env import CartesianSpaceEnv
|
||||
|
||||
# 创建环境
|
||||
env = CartesianSpaceEnv()
|
||||
|
||||
# 重置环境
|
||||
obs = env.reset()
|
||||
|
||||
# 执行动作
|
||||
action = env.action_space.sample()
|
||||
obs, reward, done, truncated, info = env.step(action)
|
132
examples/data/training_scenarios.json
Normal file
132
examples/data/training_scenarios.json
Normal file
@ -0,0 +1,132 @@
|
||||
[
|
||||
{
|
||||
"scenario_id": 0,
|
||||
"init_pos": [
|
||||
-0.3851543281195917,
|
||||
0.23622170885579574,
|
||||
-0.3504204256464779
|
||||
],
|
||||
"target_pos": [
|
||||
-0.0014902990848678632,
|
||||
-0.34198739704281766,
|
||||
-0.2163872669679774
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 1,
|
||||
"init_pos": [
|
||||
0.3586915049870687,
|
||||
0.4663490964064959,
|
||||
-0.07029539508983895
|
||||
],
|
||||
"target_pos": [
|
||||
0.3219930783202991,
|
||||
-0.4321836991042829,
|
||||
-0.25417837068744065
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 2,
|
||||
"init_pos": [
|
||||
-0.38247076322451923,
|
||||
0.48122889084068443,
|
||||
-0.06763670046735482
|
||||
],
|
||||
"target_pos": [
|
||||
-0.35682694658868086,
|
||||
0.062285073029546334,
|
||||
-0.4166566267909795
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 3,
|
||||
"init_pos": [
|
||||
0.19529006823279427,
|
||||
0.1187002364968689,
|
||||
0.2855794636803419
|
||||
],
|
||||
"target_pos": [
|
||||
-0.3399272040323661,
|
||||
-0.47575458600588605,
|
||||
-0.2059009418702218
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 4,
|
||||
"init_pos": [
|
||||
-0.24366443768753376,
|
||||
0.22250245546839742,
|
||||
-0.43859687158676164
|
||||
],
|
||||
"target_pos": [
|
||||
0.29408296030315706,
|
||||
0.36045438818494446,
|
||||
0.41069252046135574
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 5,
|
||||
"init_pos": [
|
||||
-0.16311247368497173,
|
||||
-0.08433052178776723,
|
||||
-0.011765488986697492
|
||||
],
|
||||
"target_pos": [
|
||||
0.3857500541808496,
|
||||
0.16125642952352803,
|
||||
0.38244009584453975
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 6,
|
||||
"init_pos": [
|
||||
0.24265699548741348,
|
||||
-0.308374294493032,
|
||||
-0.253119206920678
|
||||
],
|
||||
"target_pos": [
|
||||
0.19486565729452343,
|
||||
0.1322450012939419,
|
||||
0.33993681365266004
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 7,
|
||||
"init_pos": [
|
||||
0.47939449711779103,
|
||||
0.06129058017887501,
|
||||
0.197264862656631
|
||||
],
|
||||
"target_pos": [
|
||||
-0.43113693403897635,
|
||||
-0.0876026102519023,
|
||||
0.258142080216726
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 8,
|
||||
"init_pos": [
|
||||
-0.4088245272018959,
|
||||
0.49738735842249504,
|
||||
-0.4947383188654939
|
||||
],
|
||||
"target_pos": [
|
||||
0.26305788406141506,
|
||||
0.15670787216531112,
|
||||
-0.46228543975760295
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 9,
|
||||
"init_pos": [
|
||||
-0.06846851174488988,
|
||||
-0.04387989406670256,
|
||||
0.30249047933660234
|
||||
],
|
||||
"target_pos": [
|
||||
-0.13672602393638544,
|
||||
0.3888877361972328,
|
||||
0.30250508674023746
|
||||
]
|
||||
}
|
||||
]
|
67
examples/evaluate_model.py
Normal file
67
examples/evaluate_model.py
Normal file
@ -0,0 +1,67 @@
|
||||
"""
|
||||
模型评估示例
|
||||
"""
|
||||
import os
|
||||
import numpy as np
|
||||
import gym
|
||||
from stable_baselines3 import PPO
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.rl_env.cartesian_env import CartesianSpaceEnv
|
||||
|
||||
def evaluate_model(model_path, num_episodes=5):
|
||||
"""
|
||||
评估训练好的模型
|
||||
|
||||
参数:
|
||||
model_path: 模型文件路径
|
||||
num_episodes: 要运行的测试回合数
|
||||
"""
|
||||
# 创建环境
|
||||
env = CartesianSpaceEnv()
|
||||
|
||||
# 加载模型
|
||||
model = PPO.load(model_path)
|
||||
|
||||
print(f"开始评估,共 {num_episodes} 个episode")
|
||||
|
||||
for episode in range(num_episodes):
|
||||
obs, _ = env.reset()
|
||||
done = False
|
||||
total_reward = 0.0
|
||||
steps = 0
|
||||
|
||||
print(f"\nEpisode {episode + 1}/{num_episodes}")
|
||||
print(f"目标位置: {env._target_pos}")
|
||||
|
||||
while not done:
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, truncated, info = env.step(action)
|
||||
total_reward += reward
|
||||
steps += 1
|
||||
|
||||
# 显示中间步骤信息
|
||||
if steps % 10 == 0:
|
||||
distance = np.linalg.norm(env._target_pos - env._current_pos)
|
||||
print(f"Step {steps}: 距离={distance:.4f}, 累计奖励={total_reward:.4f}")
|
||||
|
||||
# 渲染最后一步
|
||||
if done or truncated:
|
||||
env.render()
|
||||
distance = np.linalg.norm(env._target_pos - env._current_pos)
|
||||
print(f"\n最终结果:")
|
||||
print(f"总步数: {steps}")
|
||||
print(f"最终距离: {distance:.4f}")
|
||||
print(f"总奖励: {total_reward:.4f}")
|
||||
|
||||
env.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 使用最佳模型进行评估
|
||||
MODEL_PATH = "models/best_model"
|
||||
|
||||
# 运行评估
|
||||
evaluate_model(MODEL_PATH)
|
BIN
examples/models/best_model/best_model.zip
Normal file
BIN
examples/models/best_model/best_model.zip
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
110
examples/train_rl_model.py
Normal file
110
examples/train_rl_model.py
Normal file
@ -0,0 +1,110 @@
|
||||
"""
|
||||
强化学习训练示例
|
||||
"""
|
||||
import os
|
||||
import torch
|
||||
from contextlib import nullcontext
|
||||
from torch.cuda.amp import autocast as autocast_ctx_decorator
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
import sys
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.rl_env.cartesian_env import CartesianSpaceEnv
|
||||
from src.utils.test_data import generate_test_scenarios, save_scenarios_to_json
|
||||
|
||||
|
||||
def create_mixed_precision_wrapper(train_func, autocast_ctx):
|
||||
"""创建支持混合精度训练的包装器函数"""
|
||||
def mixed_precision_train(*args, **kwargs):
|
||||
with autocast_ctx:
|
||||
return train_func(*args, **kwargs)
|
||||
return mixed_precision_train
|
||||
|
||||
|
||||
# 强制在文件顶部禁用GPU
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
|
||||
|
||||
# 添加缺失的monitor导入
|
||||
from stable_baselines3.common.monitor import Monitor
|
||||
from stable_baselines3.common.vec_env import DummyVecEnv
|
||||
# 使用Gymnasium的导入
|
||||
import gymnasium as gym
|
||||
from gymnasium.wrappers import RecordEpisodeStatistics
|
||||
# 添加缺失的环境注册导入
|
||||
from gymnasium.envs.registration import register
|
||||
# 显式添加缺失的导入
|
||||
from stable_baselines3 import PPO
|
||||
from stable_baselines3.common.callbacks import EvalCallback
|
||||
|
||||
|
||||
def train_rl_model():
|
||||
"""训练强化学习模型"""
|
||||
# 强制禁用GPU
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
|
||||
|
||||
# 显式设置设备为CPU
|
||||
device = 'cpu'
|
||||
|
||||
# 注册自定义环境
|
||||
register(
|
||||
id='CartesianSpace-v0',
|
||||
entry_point='src.rl_env.cartesian_env:CartesianSpaceEnv',
|
||||
max_episode_steps=1000,
|
||||
)
|
||||
|
||||
# 创建Gymnasium环境并应用标准包装器
|
||||
env = DummyVecEnv([
|
||||
lambda: Monitor(
|
||||
RecordEpisodeStatistics(gym.make('CartesianSpace-v0'))
|
||||
)
|
||||
])
|
||||
|
||||
# 创建PPO模型时确保使用CPU
|
||||
model = PPO(
|
||||
"MlpPolicy",
|
||||
env,
|
||||
verbose=1,
|
||||
device=device,
|
||||
tensorboard_log="tensorboard_logs/",
|
||||
policy_kwargs=dict(
|
||||
net_arch=dict(pi=[256, 256], vf=[256, 256]),
|
||||
log_std_init=-2,
|
||||
),
|
||||
batch_size=1024,
|
||||
use_sde=True
|
||||
)
|
||||
|
||||
# 添加缺失的评估回调
|
||||
eval_callback = EvalCallback(
|
||||
env,
|
||||
eval_freq=1000,
|
||||
best_model_save_path="models/best_model/",
|
||||
deterministic=True,
|
||||
render=False
|
||||
)
|
||||
|
||||
# 训练模型
|
||||
model.learn(
|
||||
total_timesteps=50000,
|
||||
callback=eval_callback,
|
||||
tb_log_name="PPO_Cartesian"
|
||||
)
|
||||
|
||||
# 保存最终模型
|
||||
model.save("models/ppo_cartesian_final")
|
||||
|
||||
# 测试训练好的模型
|
||||
obs = env.reset()
|
||||
for _ in range(100):
|
||||
action, _states = model.predict(obs, deterministic=True)
|
||||
obs, reward, done, truncated, info = env.step(action)
|
||||
env.render()
|
||||
if done or truncated:
|
||||
obs = env.reset()
|
||||
|
||||
env.close()
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_rl_model()
|
BIN
models/best_model/best_model.zip
Normal file
BIN
models/best_model/best_model.zip
Normal file
Binary file not shown.
BIN
models/ppo_cartesian_final.zip
Normal file
BIN
models/ppo_cartesian_final.zip
Normal file
Binary file not shown.
15
requirements.txt
Normal file
15
requirements.txt
Normal file
@ -0,0 +1,15 @@
|
||||
# 核心依赖
|
||||
gym>=0.26.2
|
||||
stable-baselines3>=2.0.0
|
||||
pybullet>=3.2.5
|
||||
numpy>=1.24.4
|
||||
opencv-python>=4.8.1.78
|
||||
matplotlib>=3.7.3
|
||||
pandas>=2.0.3
|
||||
scikit-learn>=1.3.0
|
||||
shimmy>=2.0 # OpenAI Gym与Gymnasium兼容层
|
||||
tensorboard>=2.19.0 # 强化学习训练可视化
|
||||
|
||||
# 文档生成
|
||||
docutils>=0.20.1
|
||||
sphinx>=7.2.1
|
3
src/rl_env/__init__.py
Normal file
3
src/rl_env/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
强化学习环境模块
|
||||
"""
|
128
src/rl_env/cartesian_env.py
Normal file
128
src/rl_env/cartesian_env.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""
|
||||
笛卡尔空间强化学习环境
|
||||
"""
|
||||
import gymnasium as gym
|
||||
from gymnasium import spaces
|
||||
import numpy as np
|
||||
import logging
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CartesianSpaceEnv(gym.Env):
|
||||
"""机械手臂在笛卡尔空间的运动规划环境"""
|
||||
metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': 50}
|
||||
|
||||
def __init__(self, render_mode=None):
|
||||
# 定义动作空间:x, y, z轴的速度控制
|
||||
self.action_space = spaces.Box(
|
||||
low=np.array([-1.0, -1.0, -1.0]),
|
||||
high=np.array([1.0, 1.0, 1.0]),
|
||||
shape=(3,),
|
||||
dtype=np.float32
|
||||
)
|
||||
|
||||
# 定义观测空间:当前位置、目标位置、误差
|
||||
self.observation_space = spaces.Box(
|
||||
low=np.array([-np.inf]*9),
|
||||
high=np.array([np.inf]*9),
|
||||
dtype=np.float32
|
||||
)
|
||||
|
||||
# 初始化环境参数
|
||||
self._target_pos = np.array([0.5, 0.5, 0.5])
|
||||
self._current_pos = np.array([0.0, 0.0, 0.0])
|
||||
self._dt = 0.01 # 时间步长
|
||||
self.render_mode = render_mode
|
||||
self.episode_step = 0
|
||||
self.episode_reward = 0.0
|
||||
|
||||
# 测试场景支持
|
||||
self.test_scenarios = None
|
||||
self.current_scenario_idx = 0
|
||||
|
||||
def _get_obs(self):
|
||||
# 获取当前观测值
|
||||
return np.concatenate([
|
||||
self._current_pos,
|
||||
self._target_pos,
|
||||
self._target_pos - self._current_pos
|
||||
])
|
||||
|
||||
# 如果使用旧版reset方法,需要更新为新的API格式
|
||||
def reset(self, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
|
||||
# 重置环境参数
|
||||
self.episode_step = 0
|
||||
self.episode_reward = 0.0
|
||||
|
||||
# 随机初始化当前位置(添加随机性)
|
||||
self._current_pos = self.np_random.uniform(low=-0.5, high=0.5, size=(3,)).astype(np.float32)
|
||||
|
||||
# 如果有测试场景,使用场景中的目标位置
|
||||
if self.test_scenarios and self.current_scenario_idx < len(self.test_scenarios):
|
||||
self._target_pos = np.array(self.test_scenarios[self.current_scenario_idx]['target_position'], dtype=np.float32)
|
||||
self.current_scenario_idx += 1
|
||||
else:
|
||||
# 否则随机生成目标位置
|
||||
self._target_pos = self.np_random.uniform(low=0.0, high=1.0, size=(3,)).astype(np.float32)
|
||||
|
||||
# 获取初始观测
|
||||
observation = self._get_obs()
|
||||
|
||||
# 创建info字典
|
||||
info = {"reset": "environment reset to initial state"}
|
||||
|
||||
return observation, info
|
||||
|
||||
# 更新step方法以返回新的五元组格式
|
||||
def step(self, action):
|
||||
# 执行一步环境交互
|
||||
self._current_pos += action * self._dt
|
||||
observation = self._get_obs().astype(np.float32)
|
||||
|
||||
# 计算奖励
|
||||
distance = np.linalg.norm(self._target_pos - self._current_pos)
|
||||
reward = -distance
|
||||
|
||||
# 累计episode统计
|
||||
self.episode_step += 1
|
||||
self.episode_reward += reward
|
||||
|
||||
# 判断是否终止
|
||||
terminated = distance < 0.01
|
||||
truncated = False
|
||||
info = {}
|
||||
|
||||
# 记录每一步的信息
|
||||
logger.debug(f"Step {self.episode_step}: Position={self._current_pos}, Reward={reward:.4f}, Distance={distance:.4f}")
|
||||
|
||||
# 如果episode结束,记录总结
|
||||
if terminated or truncated:
|
||||
logger.info(f"Episode end: Steps={self.episode_step}, Total reward={self.episode_reward:.4f}, Final distance={distance:.4f}")
|
||||
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
def render(self):
|
||||
# 渲染环境
|
||||
if self.render_mode == "human":
|
||||
print(f"Current position: {self._current_pos}, Target position: {self._target_pos}")
|
||||
|
||||
def close(self):
|
||||
# 关闭环境
|
||||
pass
|
||||
|
||||
def load_test_scenarios(self, scenarios):
|
||||
"""
|
||||
加载测试场景
|
||||
|
||||
参数:
|
||||
scenarios: 测试场景列表
|
||||
"""
|
||||
self.test_scenarios = scenarios
|
||||
self.current_scenario_idx = 0
|
3
src/robot_control/__init__.py
Normal file
3
src/robot_control/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
机械手臂控制模块
|
||||
"""
|
52
src/robot_control/arm_controller.py
Normal file
52
src/robot_control/arm_controller.py
Normal file
@ -0,0 +1,52 @@
|
||||
"""
|
||||
机械手臂控制器
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
class RoboticArmController:
|
||||
"""机械手臂控制接口类"""
|
||||
def __init__(self, port='COM3', baud_rate=115200):
|
||||
"""
|
||||
初始化机械手臂控制器
|
||||
|
||||
参数:
|
||||
port: 串口端口号
|
||||
baud_rate: 波特率
|
||||
"""
|
||||
self.port = port
|
||||
self.baud_rate = baud_rate
|
||||
# 模拟连接状态
|
||||
self.connected = False
|
||||
|
||||
def connect(self):
|
||||
"""连接机械手臂"""
|
||||
# 实际实现中应建立串口连接
|
||||
self.connected = True
|
||||
print(f"已连接到机械手臂,端口: {self.port}")
|
||||
|
||||
def disconnect(self):
|
||||
"""断开连接"""
|
||||
self.connected = False
|
||||
print("已断开机械手臂连接")
|
||||
|
||||
def move_to_position(self, position):
|
||||
"""
|
||||
移动到指定位置
|
||||
|
||||
参数:
|
||||
position: 包含x,y,z坐标的列表或numpy数组
|
||||
"""
|
||||
if not self.connected:
|
||||
raise Exception("未连接到机械手臂")
|
||||
|
||||
x, y, z = position
|
||||
print(f"移动到位置: X={x:.3f}, Y={y:.3f}, Z={z:.3f}")
|
||||
# 这里应添加实际的运动控制代码
|
||||
|
||||
def get_current_position(self):
|
||||
"""获取当前机械手臂位置"""
|
||||
if not self.connected:
|
||||
raise Exception("未连接到机械手臂")
|
||||
|
||||
# 实际实现中应从设备读取当前位置
|
||||
return np.array([0.0, 0.0, 0.0])
|
67
src/utils/data/test_scenarios.json
Normal file
67
src/utils/data/test_scenarios.json
Normal file
@ -0,0 +1,67 @@
|
||||
[
|
||||
{
|
||||
"scenario_id": 0,
|
||||
"init_pos": [
|
||||
0.12676142133140977,
|
||||
0.29058529336172323,
|
||||
-0.47831770529728
|
||||
],
|
||||
"target_pos": [
|
||||
-0.4635523833493924,
|
||||
0.27438275835742165,
|
||||
0.2353816483034984
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 1,
|
||||
"init_pos": [
|
||||
0.11685331999464632,
|
||||
-0.17857159753499696,
|
||||
-0.1587136584596952
|
||||
],
|
||||
"target_pos": [
|
||||
0.2245201353490277,
|
||||
0.10729710663916903,
|
||||
0.2840006653075371
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 2,
|
||||
"init_pos": [
|
||||
-0.23270897488779074,
|
||||
0.08247045283575827,
|
||||
-0.48044965574607523
|
||||
],
|
||||
"target_pos": [
|
||||
0.0640611291490617,
|
||||
-0.23399219689938255,
|
||||
-0.3227835865078418
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 3,
|
||||
"init_pos": [
|
||||
0.34884565936750833,
|
||||
0.43771500409129593,
|
||||
-0.3460318142792451
|
||||
],
|
||||
"target_pos": [
|
||||
0.36108425245942233,
|
||||
-0.41171310141719586,
|
||||
-0.3719897927157988
|
||||
]
|
||||
},
|
||||
{
|
||||
"scenario_id": 4,
|
||||
"init_pos": [
|
||||
-0.46642579644165305,
|
||||
0.15788805534463513,
|
||||
-0.3057637936771299
|
||||
],
|
||||
"target_pos": [
|
||||
-0.10038566176960095,
|
||||
-0.20307994408280605,
|
||||
0.348180686937688
|
||||
]
|
||||
}
|
||||
]
|
73
src/utils/test_data.py
Normal file
73
src/utils/test_data.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""
|
||||
测试数据生成工具
|
||||
"""
|
||||
import numpy as np
|
||||
import os
|
||||
import json
|
||||
|
||||
def generate_test_scenarios(num_scenarios=10, workspace_range=(-0.5, 0.5)):
|
||||
"""
|
||||
生成测试场景数据
|
||||
|
||||
参数:
|
||||
num_scenarios: 要生成的测试场景数量
|
||||
workspace_range: 工作空间范围
|
||||
返回:
|
||||
scenarios: 包含初始位置和目标位置的测试场景列表
|
||||
"""
|
||||
scenarios = []
|
||||
for i in range(num_scenarios):
|
||||
# 生成随机初始位置和目标位置
|
||||
init_pos = np.random.uniform(workspace_range[0], workspace_range[1], 3)
|
||||
target_pos = np.random.uniform(workspace_range[0], workspace_range[1], 3)
|
||||
|
||||
# 确保目标位置与初始位置有一定距离
|
||||
while np.linalg.norm(target_pos - init_pos) < 0.3:
|
||||
target_pos = np.random.uniform(workspace_range[0], workspace_range[1], 3)
|
||||
|
||||
scenarios.append({
|
||||
'scenario_id': i,
|
||||
'init_pos': init_pos.tolist(),
|
||||
'target_pos': target_pos.tolist()
|
||||
})
|
||||
|
||||
return scenarios
|
||||
|
||||
|
||||
def save_scenarios_to_json(scenarios, file_path='test_scenarios.json'):
|
||||
"""
|
||||
保存测试场景到JSON文件
|
||||
|
||||
参数:
|
||||
scenarios: 测试场景列表
|
||||
file_path: 输出文件路径
|
||||
"""
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
|
||||
with open(file_path, 'w') as f:
|
||||
json.dump(scenarios, f, indent=2)
|
||||
print(f"成功保存 {len(scenarios)} 个测试场景到 {file_path}")
|
||||
|
||||
def load_scenarios_from_json(file_path='test_scenarios.json'):
|
||||
"""
|
||||
从JSON文件加载测试场景
|
||||
|
||||
参数:
|
||||
file_path: 输入文件路径
|
||||
返回:
|
||||
scenarios: 加载的测试场景列表
|
||||
"""
|
||||
try:
|
||||
with open(file_path, 'r') as f:
|
||||
scenarios = json.load(f)
|
||||
print(f"成功加载 {len(scenarios)} 个测试场景")
|
||||
return scenarios
|
||||
except FileNotFoundError:
|
||||
print(f"错误:文件 {file_path} 未找到")
|
||||
return None
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例用法
|
||||
test_scenarios = generate_test_scenarios(5)
|
||||
save_scenarios_to_json(test_scenarios, 'data/test_scenarios.json')
|
3
src/vision/__init__.py
Normal file
3
src/vision/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
"""
|
||||
视觉识别模块
|
||||
"""
|
78
src/vision/target_tracker.py
Normal file
78
src/vision/target_tracker.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""
|
||||
目标跟踪模块
|
||||
"""
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
class TargetTracker:
|
||||
"""基于计算机视觉的目标识别与跟踪类"""
|
||||
def __init__(self, camera_index=0):
|
||||
"""
|
||||
初始化摄像头和跟踪参数
|
||||
|
||||
参数:
|
||||
camera_index: 摄像头设备索引号
|
||||
"""
|
||||
self.cap = cv2.VideoCapture(camera_index)
|
||||
if not self.cap.isOpened():
|
||||
raise Exception("无法打开摄像头")
|
||||
|
||||
# 初始化跟踪参数
|
||||
self.tracker = cv2.TrackerCSRT_create()
|
||||
self.bbox = None
|
||||
self.tracking = False
|
||||
|
||||
def get_frame(self):
|
||||
"""获取当前帧"""
|
||||
ret, frame = self.cap.read()
|
||||
if not ret:
|
||||
raise Exception("无法读取帧")
|
||||
|
||||
return frame
|
||||
|
||||
def detect_power_equipment(self, frame):
|
||||
"""
|
||||
检测电力设备(模拟实现)
|
||||
|
||||
参数:
|
||||
frame: 输入图像帧
|
||||
返回:
|
||||
检测到的设备位置列表
|
||||
"""
|
||||
# 这里应添加实际的电力设备检测算法
|
||||
# 模拟返回一个检测到的目标
|
||||
height, width = frame.shape[:2]
|
||||
center_x, center_y = width//2, height//2
|
||||
return [(center_x-50, center_y-50, 100, 100)] # 返回一个示例边界框
|
||||
|
||||
def start_tracking(self, frame, bbox):
|
||||
"""
|
||||
开始跟踪指定区域
|
||||
|
||||
参数:
|
||||
frame: 当前图像帧
|
||||
bbox: 要跟踪的边界框 (x, y, w, h)
|
||||
"""
|
||||
self.tracker.init(frame, bbox)
|
||||
self.bbox = bbox
|
||||
self.tracking = True
|
||||
|
||||
def update_tracking(self):
|
||||
"""更新跟踪结果"""
|
||||
if not self.tracking:
|
||||
return False, None
|
||||
|
||||
# 读取下一帧
|
||||
frame = self.get_frame()
|
||||
success, bbox = self.tracker.update(frame)
|
||||
|
||||
if success:
|
||||
self.bbox = bbox
|
||||
else:
|
||||
self.tracking = False
|
||||
|
||||
return success, bbox
|
||||
|
||||
def release(self):
|
||||
"""释放摄像头资源"""
|
||||
self.cap.release()
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
17
tensorboard_start.bat
Normal file
17
tensorboard_start.bat
Normal file
@ -0,0 +1,17 @@
|
||||
@echo off
|
||||
SETLOCAL
|
||||
|
||||
:: 设置日志目录(相对于项目根目录)
|
||||
set LOG_DIR=tensorboard_logs
|
||||
|
||||
:: 检查日志目录是否存在
|
||||
dir /b %LOG_DIR% >nul 2>&1
|
||||
if errorlevel 1 (
|
||||
echo 创建日志目录: %LOG_DIR%
|
||||
mkdir %LOG_DIR%
|
||||
)
|
||||
|
||||
:: 启动TensorBoard
|
||||
python -m tensorboard.main --logdir=%cd%\%LOG_DIR% --port=6006
|
||||
|
||||
ENDLOCAL
|
Loading…
Reference in New Issue
Block a user