- 更新 README.md 中的 ONNX构建指南,增加新版本的安装步骤 - 在 runs/detect/final_model 和 runs/detect/train 目录下添加训练配置文件 args.yaml - 配置文件包含模型训练的各种参数设置,为后续训练提供详细配置
132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
|
|
"""
|
|
YOLOv12 绝缘子检测优化训练脚本(最终稳定版)
|
|
主要改进:
|
|
1. 适配YOLOv12 参数规范
|
|
2. 移除已弃用参数
|
|
3. 优化训练配置
|
|
"""
|
|
|
|
import torch
|
|
from ultralytics import YOLO
|
|
import shutil
|
|
import os
|
|
import glob
|
|
from PIL import Image, ImageDraw
|
|
import numpy as np
|
|
|
|
# +++ 有效参数配置区块 +++
|
|
TRAIN_CONFIG = {
|
|
"stage1": {
|
|
"freeze": [0, 1, 2, 3, 4], # 冻结前5层
|
|
"epochs": 75, # 增加训练轮数
|
|
"batch": 32, # 增加批量大小
|
|
"lr0": 0.001, # 降低初始学习率
|
|
"lrf": 0.1, # 提高最终学习率
|
|
"hsv_h": 0.01, # 色相增强幅度
|
|
"degrees": 10.0, # 增加旋转角度范围
|
|
"perspective": 0.005, # 增加透视变换
|
|
"flipud": 0.2, # 增加上下翻转概率
|
|
"optimizer": "SGD", # 优化器类型
|
|
"weight_decay": 0.001, # 权重衰减
|
|
"label_smoothing": 0.1, # 标签平滑
|
|
"mosaic": True, # 启用Mosaic增强
|
|
# "class_weights": [2.0, 2.0, 1.0] # 设置类别权重
|
|
},
|
|
"stage2": {
|
|
"epochs": 150, # 增加训练轮数
|
|
"batch": 16, # 增加批量大小
|
|
"lr0": 0.0001, # 降低初始学习率
|
|
"lrf": 0.01, # 提高最终学习率
|
|
"mixup": 0.2, # 增加MixUp增强系数
|
|
"close_mosaic": 10, # 最后10epoch关闭Mosaic
|
|
"optimizer": "AdamW", # 切换优化器为AdamW
|
|
"warmup_epochs": 10 # 增大学习率预热轮数
|
|
}
|
|
}
|
|
|
|
def enhanced_data_check():
|
|
"""增强型数据验证(稳定版)"""
|
|
print("\n=== 执行增强数据验证 ===")
|
|
label_files = glob.glob("datasets/train/labels/*.txt")
|
|
class_dist = {}
|
|
|
|
# 类别分布分析
|
|
for lbl in label_files:
|
|
with open(lbl) as f:
|
|
for line in f:
|
|
class_id = int(line.strip().split()[0])
|
|
class_dist[class_id] = class_dist.get(class_id, 0) + 1
|
|
|
|
print(f"类别分布:{class_dist}")
|
|
|
|
# 样本可视化(示例保留结构)
|
|
sample_count = 3
|
|
for i in range(sample_count):
|
|
img_path = f"datasets/train/images/{i:04d}.jpg"
|
|
lbl_path = f"datasets/train/labels/{i:04d}.txt"
|
|
# 添加实际可视化逻辑
|
|
|
|
def main():
|
|
# 初始化模型(正确方式)
|
|
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
|
|
|
|
# === 阶段一:冻结训练 ===
|
|
print("\n=== 阶段一:特征提取层训练 ===")
|
|
stage1_results = model.train(
|
|
data="data.yaml",
|
|
**TRAIN_CONFIG["stage1"],
|
|
device=0,
|
|
workers=4,
|
|
box=5.0, # box损失权重
|
|
cls=1.0, # 分类损失权重
|
|
save_period=5, # 每5epoch保存检查点
|
|
patience=20, # 早停等待周期
|
|
deterministic=False # 禁用确定性模式
|
|
)
|
|
|
|
# === 阶段二:全网络微调 ===
|
|
print("\n=== 阶段二:全网络微调 ===")
|
|
stage2_results = model.train(
|
|
data="data.yaml",
|
|
**TRAIN_CONFIG["stage2"],
|
|
resume=True, # 从上一阶段继续
|
|
name="final_model", # 实验名称
|
|
shear=1.5, # 剪切变换幅度
|
|
copy_paste=0.05, # 复制粘贴增强
|
|
erasing=0.1, # 随机擦除概率
|
|
overlap_mask=False # 禁用掩码重叠
|
|
)
|
|
|
|
# === 模型验证 ===
|
|
print("\n=== 最终模型验证 ===")
|
|
metrics = model.val(
|
|
data="data.yaml",
|
|
conf=0.35, # 置信度阈值
|
|
iou=0.65, # IoU阈值
|
|
plots=True, # 生成评估图表
|
|
half=False # 禁用半精度验证
|
|
)
|
|
|
|
# === 模型导出 ===
|
|
print("\n=== 模型导出 ===")
|
|
model.export(
|
|
# opset_version=15,
|
|
format="onnx", # 优先导出ONNX
|
|
# format="torchscript",
|
|
dynamic=True, # 动态维度
|
|
simplify=True # 简化模型
|
|
)
|
|
|
|
if __name__ == "__main__":
|
|
# 硬件优化配置
|
|
torch.set_float32_matmul_precision('high')
|
|
torch.backends.cudnn.benchmark = True
|
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
|
|
|
# 数据检查
|
|
enhanced_data_check()
|
|
|
|
# 启动训练
|
|
main()
|