feat: 添加 YOLO 模型从 PyTorch 到 ONNX 的转换及验证功能
- 实现了将 YOLO模型从 PyTorch 格式转换为 ONNX 格式的功能 - 加载并验证了导出的 ONNX 模型 - 编写了使用验证集进行模型测试的代码,包括图像预处理和推理 - 添加了调试信息输出,便于分析模型性能和验证结果
This commit is contained in:
parent
5f7e66e97f
commit
7e38b3acfd
67
pt2onnx.py
Normal file
67
pt2onnx.py
Normal file
@ -0,0 +1,67 @@
|
||||
from ultralytics import YOLO
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
import yaml
|
||||
import cv2
|
||||
import numpy as np
|
||||
import glob
|
||||
import os
|
||||
|
||||
# 1. 加载并导出模型
|
||||
model = YOLO('runs/detect/train/weights/best.pt')
|
||||
model.to('cuda')
|
||||
|
||||
model.export(
|
||||
format="onnx", # ONNX格式
|
||||
dynamic=True,
|
||||
simplify=True
|
||||
)
|
||||
|
||||
# 2. 加载ONNX模型
|
||||
onnx_model = onnx.load('runs/detect/train/weights/best.onnx')
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession('runs/detect/train/weights/best.onnx')
|
||||
|
||||
# 3. 加载验证集路径
|
||||
with open('data.yaml', 'r') as f:
|
||||
data_config = yaml.safe_load(f)
|
||||
val_dir = os.path.dirname(data_config['val'])
|
||||
val_dir = val_dir.replace('../', '') # 直接删除"../"前缀
|
||||
print("val_dir:", val_dir) # 调试信息
|
||||
val_images = os.path.join(val_dir, 'images') if 'valid' in val_dir else os.path.join(val_dir, 'val', 'images')
|
||||
val_images = val_images.replace('\\', '/') # 统一使用正斜杠路径分隔符
|
||||
print(f"val_images: {val_images}") # 调试信息
|
||||
|
||||
# 4. 验证集测试
|
||||
correct = 0
|
||||
total = 0
|
||||
print("开始验证...")
|
||||
|
||||
matched_files = glob.glob(val_images + '/*.jpg')
|
||||
print(f"找到 {len(matched_files)} 张测试图片") # 文件匹配调试
|
||||
|
||||
for img_path in matched_files[:20]: # 测试前20张
|
||||
# 预处理
|
||||
img = cv2.imread(img_path)
|
||||
img = cv2.resize(img, (640, 640))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 确保RGB顺序
|
||||
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
|
||||
img = np.expand_dims(img, axis=0).astype(np.float32)
|
||||
|
||||
# 推理
|
||||
outputs = ort_session.run(None, {"images": img})
|
||||
|
||||
# 添加输出调试信息
|
||||
print(f"输出形状: {outputs[0].shape}")
|
||||
print(f"输出示例数据: {outputs[0][0][:5]}...")
|
||||
|
||||
# 使用置信度阈值判断检测
|
||||
confidences = outputs[0][0, 4, :]
|
||||
print(f"最大置信度: {np.max(confidences):.3f}")
|
||||
|
||||
# 使用阈值判断是否检测到目标
|
||||
if np.any(confidences > 0.5):
|
||||
correct += 1
|
||||
total += 1
|
||||
|
||||
print(f"验证结果: {correct}/{total} 张图片检测到目标")
|
BIN
runs/detect/train/weights/best.onnx
Normal file
BIN
runs/detect/train/weights/best.onnx
Normal file
Binary file not shown.
Loading…
Reference in New Issue
Block a user