- 实现了将 YOLO模型从 PyTorch 格式转换为 ONNX 格式的功能 - 加载并验证了导出的 ONNX 模型 - 编写了使用验证集进行模型测试的代码,包括图像预处理和推理 - 添加了调试信息输出,便于分析模型性能和验证结果
67 lines
2.0 KiB
Python
67 lines
2.0 KiB
Python
from ultralytics import YOLO
|
|
import onnx
|
|
import onnxruntime as ort
|
|
import yaml
|
|
import cv2
|
|
import numpy as np
|
|
import glob
|
|
import os
|
|
|
|
# 1. 加载并导出模型
|
|
model = YOLO('runs/detect/train/weights/best.pt')
|
|
model.to('cuda')
|
|
|
|
model.export(
|
|
format="onnx", # ONNX格式
|
|
dynamic=True,
|
|
simplify=True
|
|
)
|
|
|
|
# 2. 加载ONNX模型
|
|
onnx_model = onnx.load('runs/detect/train/weights/best.onnx')
|
|
onnx.checker.check_model(onnx_model)
|
|
ort_session = ort.InferenceSession('runs/detect/train/weights/best.onnx')
|
|
|
|
# 3. 加载验证集路径
|
|
with open('data.yaml', 'r') as f:
|
|
data_config = yaml.safe_load(f)
|
|
val_dir = os.path.dirname(data_config['val'])
|
|
val_dir = val_dir.replace('../', '') # 直接删除"../"前缀
|
|
print("val_dir:", val_dir) # 调试信息
|
|
val_images = os.path.join(val_dir, 'images') if 'valid' in val_dir else os.path.join(val_dir, 'val', 'images')
|
|
val_images = val_images.replace('\\', '/') # 统一使用正斜杠路径分隔符
|
|
print(f"val_images: {val_images}") # 调试信息
|
|
|
|
# 4. 验证集测试
|
|
correct = 0
|
|
total = 0
|
|
print("开始验证...")
|
|
|
|
matched_files = glob.glob(val_images + '/*.jpg')
|
|
print(f"找到 {len(matched_files)} 张测试图片") # 文件匹配调试
|
|
|
|
for img_path in matched_files[:20]: # 测试前20张
|
|
# 预处理
|
|
img = cv2.imread(img_path)
|
|
img = cv2.resize(img, (640, 640))
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) # 确保RGB顺序
|
|
img = img.transpose(2, 0, 1).astype(np.float32) / 255.0
|
|
img = np.expand_dims(img, axis=0).astype(np.float32)
|
|
|
|
# 推理
|
|
outputs = ort_session.run(None, {"images": img})
|
|
|
|
# 添加输出调试信息
|
|
print(f"输出形状: {outputs[0].shape}")
|
|
print(f"输出示例数据: {outputs[0][0][:5]}...")
|
|
|
|
# 使用置信度阈值判断检测
|
|
confidences = outputs[0][0, 4, :]
|
|
print(f"最大置信度: {np.max(confidences):.3f}")
|
|
|
|
# 使用阈值判断是否检测到目标
|
|
if np.any(confidences > 0.5):
|
|
correct += 1
|
|
total += 1
|
|
|
|
print(f"验证结果: {correct}/{total} 张图片检测到目标") |