isolators-transfer/train.py
fly6516 c208f07e28 train:调整训练参数以优化绝缘子检测性能
- 增加训练轮数和批量大小,提高模型收敛速度- 调整学习率范围,加快训练进程
-增强数据增强技术,提高模型泛化能力
- 引入类别权重,解决类别不平衡问题- 优化优化器参数,提高模型训练效果
2025-05-29 20:32:57 +08:00

132 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YOLOv12 绝缘子检测优化训练脚本(最终稳定版)
主要改进:
1. 适配YOLOv12 参数规范
2. 移除已弃用参数
3. 优化训练配置
"""
import torch
from ultralytics import YOLO
import shutil
import os
import glob
from PIL import Image, ImageDraw
import numpy as np
# +++ 有效参数配置区块 +++
TRAIN_CONFIG = {
"stage1": {
"freeze": [0, 1, 2, 3, 4], # 冻结前5层
"epochs": 75, # 增加训练轮数
"batch": 32, # 增加批量大小
"lr0": 0.001, # 降低初始学习率
"lrf": 0.1, # 提高最终学习率
"hsv_h": 0.01, # 色相增强幅度
"degrees": 10.0, # 增加旋转角度范围
"perspective": 0.005, # 增加透视变换
"flipud": 0.2, # 增加上下翻转概率
"optimizer": "SGD", # 优化器类型
"weight_decay": 0.001, # 权重衰减
"label_smoothing": 0.1, # 标签平滑
"mosaic": True, # 启用Mosaic增强
"class_weights": [1.0, 2.0] # 设置类别权重insulator权重更高
},
"stage2": {
"epochs": 150, # 增加训练轮数
"batch": 16, # 增加批量大小
"lr0": 0.0001, # 降低初始学习率
"lrf": 0.01, # 提高最终学习率
"mixup": 0.2, # 增加MixUp增强系数
"close_mosaic": 10, # 最后10epoch关闭Mosaic
"optimizer": "AdamW", # 切换优化器为AdamW
"warmup_epochs": 10 # 增大学习率预热轮数
}
}
def enhanced_data_check():
"""增强型数据验证(稳定版)"""
print("\n=== 执行增强数据验证 ===")
label_files = glob.glob("datasets/insulator/labels/*.txt")
class_dist = {}
# 类别分布分析
for lbl in label_files:
with open(lbl) as f:
for line in f:
class_id = int(line.strip().split()[0])
class_dist[class_id] = class_dist.get(class_id, 0) + 1
print(f"类别分布:{class_dist}")
# 样本可视化(示例保留结构)
sample_count = 3
for i in range(sample_count):
img_path = f"datasets/insulator/images/{i:04d}.jpg"
lbl_path = f"datasets/insulator/labels/{i:04d}.txt"
# 添加实际可视化逻辑
def main():
# 初始化模型(正确方式)
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
# === 阶段一:冻结训练 ===
print("\n=== 阶段一:特征提取层训练 ===")
stage1_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage1"],
device=0,
workers=4,
box=5.0, # box损失权重
cls=1.0, # 分类损失权重
save_period=5, # 每5epoch保存检查点
patience=20, # 早停等待周期
deterministic=False # 禁用确定性模式
)
# === 阶段二:全网络微调 ===
print("\n=== 阶段二:全网络微调 ===")
stage2_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage2"],
resume=True, # 从上一阶段继续
name="final_model", # 实验名称
shear=1.5, # 剪切变换幅度
copy_paste=0.05, # 复制粘贴增强
erasing=0.1, # 随机擦除概率
overlap_mask=False # 禁用掩码重叠
)
# === 模型验证 ===
print("\n=== 最终模型验证 ===")
metrics = model.val(
data="data.yaml",
conf=0.35, # 置信度阈值
iou=0.65, # IoU阈值
plots=True, # 生成评估图表
half=False # 禁用半精度验证
)
# === 模型导出 ===
print("\n=== 模型导出 ===")
model.export(
opset_version=15,
format="onnx", # 优先导出ONNX
# format="torchscript",
dynamic=True, # 动态维度
simplify=True # 简化模型
)
if __name__ == "__main__":
# 硬件优化配置
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
# 数据检查
enhanced_data_check()
# 启动训练
main()