isolators-transfer/train.py

129 lines
3.8 KiB
Python
Raw Normal View History

"""
YOLOv12 绝缘子检测优化训练脚本最终稳定版
主要改进
1. 适配YOLOv12 参数规范
2. 移除已弃用参数
3. 优化训练配置
"""
import torch
from ultralytics import YOLO
import shutil
import os
import glob
from PIL import Image, ImageDraw
import numpy as np
# +++ 有效参数配置区块 +++
TRAIN_CONFIG = {
"stage1": {
"freeze": [0,1,2,3,4], # 冻结前5层
"epochs": 50,
"batch": 16,
"lr0": 0.005, # 初始学习率
"lrf": 0.05, # 最终学习率
"hsv_h": 0.01, # 色相增强幅度
"degrees": 5.0, # 旋转角度范围
"perspective": 0.001, # 透视变换
"flipud": 0.1, # 上下翻转概率
"optimizer": "SGD", # 优化器类型
"weight_decay": 0.001, # 权重衰减
"label_smoothing": 0.1 # 标签平滑
},
"stage2": {
"epochs": 100,
"batch": 8,
"lr0": 0.0005,
"lrf": 0.005,
"mixup": 0.15, # MixUp增强系数
"close_mosaic": 10, # 最后10epoch关闭Mosaic
"optimizer": "AdamW", # 优化器切换
"warmup_epochs": 5 # 学习率预热
}
}
def enhanced_data_check():
"""增强型数据验证(稳定版)"""
print("\n=== 执行增强数据验证 ===")
label_files = glob.glob("datasets/insulator/labels/*.txt")
class_dist = {}
# 类别分布分析
for lbl in label_files:
with open(lbl) as f:
for line in f:
class_id = int(line.strip().split()[0])
class_dist[class_id] = class_dist.get(class_id, 0) + 1
print(f"类别分布:{class_dist}")
# 样本可视化(示例保留结构)
sample_count = 3
for i in range(sample_count):
img_path = f"datasets/insulator/images/{i:04d}.jpg"
lbl_path = f"datasets/insulator/labels/{i:04d}.txt"
# 添加实际可视化逻辑
def main():
# 初始化模型(正确方式)
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
# === 阶段一:冻结训练 ===
print("\n=== 阶段一:特征提取层训练 ===")
stage1_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage1"],
device=0,
workers=4,
box=5.0, # box损失权重
cls=1.0, # 分类损失权重
save_period=5, # 每5epoch保存检查点
patience=20, # 早停等待周期
deterministic=False # 禁用确定性模式
)
# === 阶段二:全网络微调 ===
print("\n=== 阶段二:全网络微调 ===")
stage2_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage2"],
resume=True, # 从上一阶段继续
name="final_model", # 实验名称
shear=1.5, # 剪切变换幅度
copy_paste=0.05, # 复制粘贴增强
erasing=0.1, # 随机擦除概率
overlap_mask=False # 禁用掩码重叠
)
# === 模型验证 ===
print("\n=== 最终模型验证 ===")
metrics = model.val(
data="data.yaml",
conf=0.35, # 置信度阈值
iou=0.65, # IoU阈值
plots=True, # 生成评估图表
half=False # 禁用半精度验证
)
# === 模型导出 ===
print("\n=== 模型导出 ===")
model.export(
opset_version=15,
format="onnx", # 优先导出ONNX
# format="torchscript",
dynamic=True, # 动态维度
simplify=True # 简化模型
)
if __name__ == "__main__":
# 硬件优化配置
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
# 数据检查
enhanced_data_check()
# 启动训练
main()