132 lines
4.2 KiB
Python
132 lines
4.2 KiB
Python
|
||
"""
|
||
YOLOv12 绝缘子检测优化训练脚本(最终稳定版)
|
||
主要改进:
|
||
1. 适配YOLOv12 参数规范
|
||
2. 移除已弃用参数
|
||
3. 优化训练配置
|
||
"""
|
||
|
||
import torch
|
||
from ultralytics import YOLO
|
||
import shutil
|
||
import os
|
||
import glob
|
||
from PIL import Image, ImageDraw
|
||
import numpy as np
|
||
|
||
# +++ 有效参数配置区块 +++
|
||
TRAIN_CONFIG = {
|
||
"stage1": {
|
||
"freeze": [0, 1, 2, 3, 4], # 冻结前5层
|
||
"epochs": 75, # 增加训练轮数
|
||
"batch": 32, # 增加批量大小
|
||
"lr0": 0.001, # 降低初始学习率
|
||
"lrf": 0.1, # 提高最终学习率
|
||
"hsv_h": 0.01, # 色相增强幅度
|
||
"degrees": 10.0, # 增加旋转角度范围
|
||
"perspective": 0.005, # 增加透视变换
|
||
"flipud": 0.2, # 增加上下翻转概率
|
||
"optimizer": "SGD", # 优化器类型
|
||
"weight_decay": 0.001, # 权重衰减
|
||
"label_smoothing": 0.1, # 标签平滑
|
||
"mosaic": True, # 启用Mosaic增强
|
||
# "class_weights": [1.0, 2.0] # 设置类别权重,insulator权重更高
|
||
},
|
||
"stage2": {
|
||
"epochs": 150, # 增加训练轮数
|
||
"batch": 16, # 增加批量大小
|
||
"lr0": 0.0001, # 降低初始学习率
|
||
"lrf": 0.01, # 提高最终学习率
|
||
"mixup": 0.2, # 增加MixUp增强系数
|
||
"close_mosaic": 10, # 最后10epoch关闭Mosaic
|
||
"optimizer": "AdamW", # 切换优化器为AdamW
|
||
"warmup_epochs": 10 # 增大学习率预热轮数
|
||
}
|
||
}
|
||
|
||
def enhanced_data_check():
|
||
"""增强型数据验证(稳定版)"""
|
||
print("\n=== 执行增强数据验证 ===")
|
||
label_files = glob.glob("datasets/train/labels/*.txt")
|
||
class_dist = {}
|
||
|
||
# 类别分布分析
|
||
for lbl in label_files:
|
||
with open(lbl) as f:
|
||
for line in f:
|
||
class_id = int(line.strip().split()[0])
|
||
class_dist[class_id] = class_dist.get(class_id, 0) + 1
|
||
|
||
print(f"类别分布:{class_dist}")
|
||
|
||
# 样本可视化(示例保留结构)
|
||
sample_count = 3
|
||
for i in range(sample_count):
|
||
img_path = f"datasets/train/images/{i:04d}.jpg"
|
||
lbl_path = f"datasets/train/labels/{i:04d}.txt"
|
||
# 添加实际可视化逻辑
|
||
|
||
def main():
|
||
# 初始化模型(正确方式)
|
||
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
|
||
|
||
# === 阶段一:冻结训练 ===
|
||
print("\n=== 阶段一:特征提取层训练 ===")
|
||
stage1_results = model.train(
|
||
data="data.yaml",
|
||
**TRAIN_CONFIG["stage1"],
|
||
device=0,
|
||
workers=4,
|
||
box=5.0, # box损失权重
|
||
cls=1.0, # 分类损失权重
|
||
save_period=5, # 每5epoch保存检查点
|
||
patience=20, # 早停等待周期
|
||
deterministic=False # 禁用确定性模式
|
||
)
|
||
|
||
# === 阶段二:全网络微调 ===
|
||
print("\n=== 阶段二:全网络微调 ===")
|
||
stage2_results = model.train(
|
||
data="data.yaml",
|
||
**TRAIN_CONFIG["stage2"],
|
||
resume=True, # 从上一阶段继续
|
||
name="final_model", # 实验名称
|
||
shear=1.5, # 剪切变换幅度
|
||
copy_paste=0.05, # 复制粘贴增强
|
||
erasing=0.1, # 随机擦除概率
|
||
overlap_mask=False # 禁用掩码重叠
|
||
)
|
||
|
||
# === 模型验证 ===
|
||
print("\n=== 最终模型验证 ===")
|
||
metrics = model.val(
|
||
data="data.yaml",
|
||
conf=0.35, # 置信度阈值
|
||
iou=0.65, # IoU阈值
|
||
plots=True, # 生成评估图表
|
||
half=False # 禁用半精度验证
|
||
)
|
||
|
||
# === 模型导出 ===
|
||
print("\n=== 模型导出 ===")
|
||
model.export(
|
||
opset_version=15,
|
||
format="onnx", # 优先导出ONNX
|
||
# format="torchscript",
|
||
dynamic=True, # 动态维度
|
||
simplify=True # 简化模型
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
# 硬件优化配置
|
||
torch.set_float32_matmul_precision('high')
|
||
torch.backends.cudnn.benchmark = True
|
||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||
|
||
# 数据检查
|
||
enhanced_data_check()
|
||
|
||
# 启动训练
|
||
main()
|