feat: 添加 YOLO 模型从 PyTorch 到 ONNX 的转换及验证功能
- 实现了将 YOLO模型从 PyTorch 格式转换为 ONNX 格式的功能 - 加载并验证了导出的 ONNX 模型 - 编写了使用验证集进行模型测试的代码,包括图像预处理和推理 - 添加了调试信息输出,便于分析模型性能和验证结果
This commit is contained in:
parent
5f7e66e97f
commit
7e38b3acfd
67
pt2onnx.py
Normal file
67
pt2onnx.py
Normal file
@ -0,0 +1,67 @@
|
|||||||
|
from ultralytics import YOLO
|
||||||
|
import onnx
|
||||||
|
import onnxruntime as ort
|
||||||
|
import yaml
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 1. 加载并导出模型
|
||||||
|
model = YOLO('runs/detect/train/weights/best.pt')
|
||||||
|
model.to('cuda')
|
||||||
|
|
||||||
|
model.export(
|
||||||
|
format="onnx", # ONNX格式
|
||||||
|
dynamic=True,
|
||||||
|
simplify=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 加载ONNX模型
|
||||||
|
onnx_model = onnx.load('runs/detect/train/weights/best.onnx')
|
||||||
|
onnx.checker.check_model(onnx_model)
|
||||||
|
ort_session = ort.InferenceSession('runs/detect/train/weights/best.onnx')
|
||||||
|
|
||||||
|
# 3. 加载验证集路径
|
||||||
|
with open('data.yaml', 'r') as f:
|
||||||
|
data_config = yaml.safe_load(f)
|
||||||
|
val_dir = os.path.dirname(data_config['val'])
|
||||||
|
val_dir = val_dir.replace('../', '') # 直接删除"../"前缀
|
||||||
|
print("val_dir:", val_dir) # 调试信息
|
||||||
|
val_images = os.path.join(val_dir, 'images') if 'valid' in val_dir else os.path.join(val_dir, 'val', 'images')
|
||||||
|
val_images = val_images.replace('\\', '/') # 统一使用正斜杠路径分隔符
|
||||||
|
print(f"val_images: {val_images}") # 调试信息
|
||||||
|
|
||||||
|
# 4. 验证集测试
|
||||||
|
correct = 0
|
||||||
|
total = 0
|
||||||
|
print("开始验证...")
|
||||||
|
|
||||||
|
matched_files = glob.glob(val_images + '/*.jpg')
|
||||||
|
print(f"找到 {len(matched_files)} 张测试图片") # 文件匹配调试
|
||||||
|
|
||||||
|
for img_path in matched_files[:20]: # 测试前20张
|
||||||
|
# 预处理
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
img = cv2.resize(img, (640, 640))
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 确保RGB顺序
|
||||||
|
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
|
||||||
|
img = np.expand_dims(img, axis=0).astype(np.float32)
|
||||||
|
|
||||||
|
# 推理
|
||||||
|
outputs = ort_session.run(None, {"images": img})
|
||||||
|
|
||||||
|
# 添加输出调试信息
|
||||||
|
print(f"输出形状: {outputs[0].shape}")
|
||||||
|
print(f"输出示例数据: {outputs[0][0][:5]}...")
|
||||||
|
|
||||||
|
# 使用置信度阈值判断检测
|
||||||
|
confidences = outputs[0][0, 4, :]
|
||||||
|
print(f"最大置信度: {np.max(confidences):.3f}")
|
||||||
|
|
||||||
|
# 使用阈值判断是否检测到目标
|
||||||
|
if np.any(confidences > 0.5):
|
||||||
|
correct += 1
|
||||||
|
total += 1
|
||||||
|
|
||||||
|
print(f"验证结果: {correct}/{total} 张图片检测到目标")
|
BIN
runs/detect/train/weights/best.onnx
Normal file
BIN
runs/detect/train/weights/best.onnx
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user