""" YOLOv12 绝缘子检测优化训练脚本(最终稳定版) 主要改进: 1. 适配YOLOv12 参数规范 2. 移除已弃用参数 3. 优化训练配置 """ import torch from ultralytics import YOLO import shutil import os import glob from PIL import Image, ImageDraw import numpy as np # +++ 有效参数配置区块 +++ TRAIN_CONFIG = { "stage1": { "freeze": [0, 1, 2, 3, 4], # 冻结前5层 "epochs": 150, # 优化训练轮数 "batch": 48, # 增加批量大小 "lr0": 0.0005, # 优化初始学习率 "lrf": 0.1, # 提高最终学习率 "hsv_h": 0.01, # 色相增强幅度 "degrees": 10.0, # 增加旋转角度范围 "perspective": 0.005, # 增加透视变换 "flipud": 0.2, # 增加上下翻转概率 "optimizer": "SGD", # 优化器类型 "weight_decay": 0.001, # 权重衰减 "label_smoothing": 0.1, # 标签平滑 # "mosaic": True, # 启用Mosaic增强 "amp": True # 启用混合精度训练 }, "stage2": { "epochs": 150, # 优化训练轮数 "batch": 48, # 增加批量大小 "lr0": 0.00005, # 优化初始学习率 "lrf": 0.01, # 提高最终学习率 "mixup": 0.2, # 增加MixUp增强系数 "close_mosaic": 10, # 最后10epoch关闭Mosaic "optimizer": "AdamW", # 切换优化器为AdamW "warmup_epochs": 10, # 增大学习率预热轮数 "amp": True # 启用混合精度训练 } } def validate_bbox_coordinates(label_path): """验证标注文件边界框坐标是否在标准化范围(0-1)""" try: with open(label_path) as f: for line in f: parts = list(map(float, line.strip().split())) if len(parts) > 1: x_center, y_center, width, height = parts[1:5] assert 0 <= x_center <= 1 and 0 <= y_center <= 1, f"坐标超出范围:{label_path}" assert 0 < width <= 1 and 0 < height <= 1, f"尺寸异常:{label_path}" except Exception as e: print(f"验证失败 {label_path}: {str(e)}") def enhanced_data_check(): """增强型数据验证(稳定版)""" print("\n=== 执行增强数据验证 ===") label_files = glob.glob("datasets/train/labels/*.txt") class_dist = {} # 类别分布分析 for lbl in label_files: with open(lbl) as f: for line in f: class_id = int(line.strip().split()[0]) class_dist[class_id] = class_dist.get(class_id, 0) + 1 print(f"类别分布:{class_dist}") # 坐标验证 for lbl in label_files: validate_bbox_coordinates(lbl) # 样本可视化 sample_count = 3 for i in range(sample_count): img_path = f"datasets/train/images/{i:04d}.jpg" lbl_path = f"datasets/train/labels/{i:04d}.txt" try: img = Image.open(img_path) draw = ImageDraw.Draw(img) with open(lbl_path) as f: for line in f: parts = list(map(float, line.strip().split())) if len(parts) > 1: x_center, y_center, width, height = parts[1:5] # 转换为绝对坐标 w, h = img.size x1 = (x_center - width/2) * w y1 = (y_center - height/2) * h x2 = (x_center + width/2) * w y2 = (y_center + height/2) * h draw.rectangle([x1, y1, x2, y2], outline="red", width=2) img.save(f"sample_{i}.jpg") print(f"样本可视化已保存:sample_{i}.jpg") except Exception as e: print(f"可视化失败 {i}: {str(e)}") def main(): # 初始化模型(正确方式) model = YOLO("yolo12s.yaml") # 确保yaml文件存在 # === 阶段一:冻结训练 === print("\n=== 阶段一:特征提取层训练 ===") stage1_results = model.train( data="data.yaml", **TRAIN_CONFIG["stage1"], device=0, workers=16, box=5.0, # box损失权重 cls=1.0, # 分类损失权重 save_period=5, # 每5epoch保存检查点 patience=20, # 早停等待周期 deterministic=False # 禁用确定性模式 ) # === 阶段二:全网络微调 === print("\n=== 阶段二:全网络微调 ===") stage2_results = model.train( data="data.yaml", **TRAIN_CONFIG["stage2"], resume=True, # 从上一阶段继续 name="final_model", # 实验名称 shear=1.5, # 剪切变换幅度 copy_paste=0.05, # 复制粘贴增强 erasing=0.1, # 随机擦除概率 overlap_mask=False # 禁用掩码重叠 ) # === 模型验证 === print("\n=== 最终模型验证 ===") metrics = model.val( data="data.yaml", conf=0.35, # 置信度阈值 iou=0.65, # IoU阈值 plots=True, # 生成评估图表 half=False # 禁用半精度验证 ) # === 模型导出 === print("\n=== 模型导出 ===") # 导出为PT格式(TorchScript) model.export( format="torchscript", # PT格式 # dynamic=True, simplify=True ) # 导出为ONNX格式 model.export( format="onnx", # ONNX格式 dynamic=True, simplify=True ) # 导出为TFLite格式 model.export( format="tflite", # TFLite格式 # dynamic=True, simplify=True ) if __name__ == "__main__": # 硬件优化配置 torch.set_float32_matmul_precision('high') torch.backends.cudnn.benchmark = True os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" # 数据检查 enhanced_data_check() # 启动训练 main()