""" YOLOv12 绝缘子检测优化训练脚本(最终稳定版) 主要改进: 1. 适配YOLOv12 参数规范 2. 移除已弃用参数 3. 优化训练配置 """ import torch from ultralytics import YOLO import shutil import os import glob from PIL import Image, ImageDraw import numpy as np # +++ 有效参数配置区块 +++ TRAIN_CONFIG = { "stage1": { "freeze": [0, 1, 2, 3, 4], # 冻结前5层 "epochs": 75, # 增加训练轮数 "batch": 32, # 增加批量大小 "lr0": 0.001, # 降低初始学习率 "lrf": 0.1, # 提高最终学习率 "hsv_h": 0.01, # 色相增强幅度 "degrees": 10.0, # 增加旋转角度范围 "perspective": 0.005, # 增加透视变换 "flipud": 0.2, # 增加上下翻转概率 "optimizer": "SGD", # 优化器类型 "weight_decay": 0.001, # 权重衰减 "label_smoothing": 0.1, # 标签平滑 "mosaic": True, # 启用Mosaic增强 # "class_weights": [1.0, 2.0] # 设置类别权重,insulator权重更高 }, "stage2": { "epochs": 150, # 增加训练轮数 "batch": 16, # 增加批量大小 "lr0": 0.0001, # 降低初始学习率 "lrf": 0.01, # 提高最终学习率 "mixup": 0.2, # 增加MixUp增强系数 "close_mosaic": 10, # 最后10epoch关闭Mosaic "optimizer": "AdamW", # 切换优化器为AdamW "warmup_epochs": 10 # 增大学习率预热轮数 } } def enhanced_data_check(): """增强型数据验证(稳定版)""" print("\n=== 执行增强数据验证 ===") label_files = glob.glob("datasets/train/labels/*.txt") class_dist = {} # 类别分布分析 for lbl in label_files: with open(lbl) as f: for line in f: class_id = int(line.strip().split()[0]) class_dist[class_id] = class_dist.get(class_id, 0) + 1 print(f"类别分布:{class_dist}") # 样本可视化(示例保留结构) sample_count = 3 for i in range(sample_count): img_path = f"datasets/train/images/{i:04d}.jpg" lbl_path = f"datasets/train/labels/{i:04d}.txt" # 添加实际可视化逻辑 def main(): # 初始化模型(正确方式) model = YOLO("yolo12s.yaml") # 确保yaml文件存在 # === 阶段一:冻结训练 === print("\n=== 阶段一:特征提取层训练 ===") stage1_results = model.train( data="data.yaml", **TRAIN_CONFIG["stage1"], device=0, workers=4, box=5.0, # box损失权重 cls=1.0, # 分类损失权重 save_period=5, # 每5epoch保存检查点 patience=20, # 早停等待周期 deterministic=False # 禁用确定性模式 ) # === 阶段二:全网络微调 === print("\n=== 阶段二:全网络微调 ===") stage2_results = model.train( data="data.yaml", **TRAIN_CONFIG["stage2"], resume=True, # 从上一阶段继续 name="final_model", # 实验名称 shear=1.5, # 剪切变换幅度 copy_paste=0.05, # 复制粘贴增强 erasing=0.1, # 随机擦除概率 overlap_mask=False # 禁用掩码重叠 ) # === 模型验证 === print("\n=== 最终模型验证 ===") metrics = model.val( data="data.yaml", conf=0.35, # 置信度阈值 iou=0.65, # IoU阈值 plots=True, # 生成评估图表 half=False # 禁用半精度验证 ) # === 模型导出 === print("\n=== 模型导出 ===") model.export( opset_version=15, format="onnx", # 优先导出ONNX # format="torchscript", dynamic=True, # 动态维度 simplify=True # 简化模型 ) if __name__ == "__main__": # 硬件优化配置 torch.set_float32_matmul_precision('high') torch.backends.cudnn.benchmark = True os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" # 数据检查 enhanced_data_check() # 启动训练 main()