feat: 添加 YOLO 模型从 PyTorch 到 ONNX 的转换及验证功能

- 实现了将 YOLO模型从 PyTorch 格式转换为 ONNX 格式的功能
- 加载并验证了导出的 ONNX 模型
- 编写了使用验证集进行模型测试的代码,包括图像预处理和推理
- 添加了调试信息输出,便于分析模型性能和验证结果
This commit is contained in:
fly6516 2025-06-03 23:49:21 +08:00
parent 5f7e66e97f
commit 7e38b3acfd
2 changed files with 67 additions and 0 deletions

67
pt2onnx.py Normal file
View File

@ -0,0 +1,67 @@
from ultralytics import YOLO
import onnx
import onnxruntime as ort
import yaml
import cv2
import numpy as np
import glob
import os
# 1. 加载并导出模型
model = YOLO('runs/detect/train/weights/best.pt')
model.to('cuda')
model.export(
format="onnx", # ONNX格式
dynamic=True,
simplify=True
)
# 2. 加载ONNX模型
onnx_model = onnx.load('runs/detect/train/weights/best.onnx')
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession('runs/detect/train/weights/best.onnx')
# 3. 加载验证集路径
with open('data.yaml', 'r') as f:
data_config = yaml.safe_load(f)
val_dir = os.path.dirname(data_config['val'])
val_dir = val_dir.replace('../', '') # 直接删除"../"前缀
print("val_dir:", val_dir) # 调试信息
val_images = os.path.join(val_dir, 'images') if 'valid' in val_dir else os.path.join(val_dir, 'val', 'images')
val_images = val_images.replace('\\', '/') # 统一使用正斜杠路径分隔符
print(f"val_images: {val_images}") # 调试信息
# 4. 验证集测试
correct = 0
total = 0
print("开始验证...")
matched_files = glob.glob(val_images + '/*.jpg')
print(f"找到 {len(matched_files)} 张测试图片") # 文件匹配调试
for img_path in matched_files[:20]: # 测试前20张
# 预处理
img = cv2.imread(img_path)
img = cv2.resize(img, (640, 640))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 确保RGB顺序
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
img = np.expand_dims(img, axis=0).astype(np.float32)
# 推理
outputs = ort_session.run(None, {"images": img})
# 添加输出调试信息
print(f"输出形状: {outputs[0].shape}")
print(f"输出示例数据: {outputs[0][0][:5]}...")
# 使用置信度阈值判断检测
confidences = outputs[0][0, 4, :]
print(f"最大置信度: {np.max(confidences):.3f}")
# 使用阈值判断是否检测到目标
if np.any(confidences > 0.5):
correct += 1
total += 1
print(f"验证结果: {correct}/{total} 张图片检测到目标")

Binary file not shown.