ComView/train.py
fly6516 72a61d34cd build: 更新 ONNX 构建指南并添加训练配置文件
- 更新 README.md 中的 ONNX构建指南,增加新版本的安装步骤
- 在 runs/detect/final_model 和 runs/detect/train 目录下添加训练配置文件 args.yaml
- 配置文件包含模型训练的各种参数设置,为后续训练提供详细配置
2025-05-30 02:15:36 +08:00

132 lines
4.2 KiB
Python

"""
YOLOv12 绝缘子检测优化训练脚本(最终稳定版)
主要改进:
1. 适配YOLOv12 参数规范
2. 移除已弃用参数
3. 优化训练配置
"""
import torch
from ultralytics import YOLO
import shutil
import os
import glob
from PIL import Image, ImageDraw
import numpy as np
# +++ 有效参数配置区块 +++
TRAIN_CONFIG = {
"stage1": {
"freeze": [0, 1, 2, 3, 4], # 冻结前5层
"epochs": 75, # 增加训练轮数
"batch": 32, # 增加批量大小
"lr0": 0.001, # 降低初始学习率
"lrf": 0.1, # 提高最终学习率
"hsv_h": 0.01, # 色相增强幅度
"degrees": 10.0, # 增加旋转角度范围
"perspective": 0.005, # 增加透视变换
"flipud": 0.2, # 增加上下翻转概率
"optimizer": "SGD", # 优化器类型
"weight_decay": 0.001, # 权重衰减
"label_smoothing": 0.1, # 标签平滑
"mosaic": True, # 启用Mosaic增强
# "class_weights": [2.0, 2.0, 1.0] # 设置类别权重
},
"stage2": {
"epochs": 150, # 增加训练轮数
"batch": 16, # 增加批量大小
"lr0": 0.0001, # 降低初始学习率
"lrf": 0.01, # 提高最终学习率
"mixup": 0.2, # 增加MixUp增强系数
"close_mosaic": 10, # 最后10epoch关闭Mosaic
"optimizer": "AdamW", # 切换优化器为AdamW
"warmup_epochs": 10 # 增大学习率预热轮数
}
}
def enhanced_data_check():
"""增强型数据验证(稳定版)"""
print("\n=== 执行增强数据验证 ===")
label_files = glob.glob("datasets/train/labels/*.txt")
class_dist = {}
# 类别分布分析
for lbl in label_files:
with open(lbl) as f:
for line in f:
class_id = int(line.strip().split()[0])
class_dist[class_id] = class_dist.get(class_id, 0) + 1
print(f"类别分布:{class_dist}")
# 样本可视化(示例保留结构)
sample_count = 3
for i in range(sample_count):
img_path = f"datasets/train/images/{i:04d}.jpg"
lbl_path = f"datasets/train/labels/{i:04d}.txt"
# 添加实际可视化逻辑
def main():
# 初始化模型(正确方式)
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
# === 阶段一:冻结训练 ===
print("\n=== 阶段一:特征提取层训练 ===")
stage1_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage1"],
device=0,
workers=4,
box=5.0, # box损失权重
cls=1.0, # 分类损失权重
save_period=5, # 每5epoch保存检查点
patience=20, # 早停等待周期
deterministic=False # 禁用确定性模式
)
# === 阶段二:全网络微调 ===
print("\n=== 阶段二:全网络微调 ===")
stage2_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage2"],
resume=True, # 从上一阶段继续
name="final_model", # 实验名称
shear=1.5, # 剪切变换幅度
copy_paste=0.05, # 复制粘贴增强
erasing=0.1, # 随机擦除概率
overlap_mask=False # 禁用掩码重叠
)
# === 模型验证 ===
print("\n=== 最终模型验证 ===")
metrics = model.val(
data="data.yaml",
conf=0.35, # 置信度阈值
iou=0.65, # IoU阈值
plots=True, # 生成评估图表
half=False # 禁用半精度验证
)
# === 模型导出 ===
print("\n=== 模型导出 ===")
model.export(
# opset_version=15,
format="onnx", # 优先导出ONNX
# format="torchscript",
dynamic=True, # 动态维度
simplify=True # 简化模型
)
if __name__ == "__main__":
# 硬件优化配置
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
# 数据检查
enhanced_data_check()
# 启动训练
main()