Yolo-12-everything/train.py

184 lines
6.0 KiB
Python
Raw Normal View History

"""
YOLOv12 绝缘子检测优化训练脚本最终稳定版
主要改进
1. 适配YOLOv12 参数规范
2. 移除已弃用参数
3. 优化训练配置
"""
import torch
from ultralytics import YOLO
import shutil
import os
import glob
from PIL import Image, ImageDraw
import numpy as np
# +++ 有效参数配置区块 +++
TRAIN_CONFIG = {
"stage1": {
"freeze": [0, 1, 2, 3, 4], # 冻结前5层
"epochs": 150, # 优化训练轮数
"batch": 48, # 增加批量大小
"lr0": 0.0005, # 优化初始学习率
"lrf": 0.1, # 提高最终学习率
"hsv_h": 0.01, # 色相增强幅度
"degrees": 10.0, # 增加旋转角度范围
"perspective": 0.005, # 增加透视变换
"flipud": 0.2, # 增加上下翻转概率
"optimizer": "SGD", # 优化器类型
"weight_decay": 0.001, # 权重衰减
"label_smoothing": 0.1, # 标签平滑
# "mosaic": True, # 启用Mosaic增强
"amp": True # 启用混合精度训练
},
"stage2": {
"epochs": 150, # 优化训练轮数
"batch": 48, # 增加批量大小
"lr0": 0.00005, # 优化初始学习率
"lrf": 0.01, # 提高最终学习率
"mixup": 0.2, # 增加MixUp增强系数
"close_mosaic": 10, # 最后10epoch关闭Mosaic
"optimizer": "AdamW", # 切换优化器为AdamW
"warmup_epochs": 10, # 增大学习率预热轮数
"amp": True # 启用混合精度训练
}
}
def validate_bbox_coordinates(label_path):
"""验证标注文件边界框坐标是否在标准化范围0-1"""
try:
with open(label_path) as f:
for line in f:
parts = list(map(float, line.strip().split()))
if len(parts) > 1:
x_center, y_center, width, height = parts[1:5]
assert 0 <= x_center <= 1 and 0 <= y_center <= 1, f"坐标超出范围:{label_path}"
assert 0 < width <= 1 and 0 < height <= 1, f"尺寸异常:{label_path}"
except Exception as e:
print(f"验证失败 {label_path}: {str(e)}")
def enhanced_data_check():
"""增强型数据验证(稳定版)"""
print("\n=== 执行增强数据验证 ===")
label_files = glob.glob("datasets/train/labels/*.txt")
class_dist = {}
# 类别分布分析
for lbl in label_files:
with open(lbl) as f:
for line in f:
class_id = int(line.strip().split()[0])
class_dist[class_id] = class_dist.get(class_id, 0) + 1
print(f"类别分布:{class_dist}")
# 坐标验证
for lbl in label_files:
validate_bbox_coordinates(lbl)
# 样本可视化
sample_count = 3
for i in range(sample_count):
img_path = f"datasets/train/images/{i:04d}.jpg"
lbl_path = f"datasets/train/labels/{i:04d}.txt"
try:
img = Image.open(img_path)
draw = ImageDraw.Draw(img)
with open(lbl_path) as f:
for line in f:
parts = list(map(float, line.strip().split()))
if len(parts) > 1:
x_center, y_center, width, height = parts[1:5]
# 转换为绝对坐标
w, h = img.size
x1 = (x_center - width/2) * w
y1 = (y_center - height/2) * h
x2 = (x_center + width/2) * w
y2 = (y_center + height/2) * h
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
img.save(f"sample_{i}.jpg")
print(f"样本可视化已保存sample_{i}.jpg")
except Exception as e:
print(f"可视化失败 {i}: {str(e)}")
def main():
# 初始化模型(正确方式)
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
# === 阶段一:冻结训练 ===
print("\n=== 阶段一:特征提取层训练 ===")
stage1_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage1"],
device=0,
workers=16,
box=5.0, # box损失权重
cls=1.0, # 分类损失权重
save_period=5, # 每5epoch保存检查点
patience=20, # 早停等待周期
deterministic=False # 禁用确定性模式
)
# === 阶段二:全网络微调 ===
print("\n=== 阶段二:全网络微调 ===")
stage2_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage2"],
resume=True, # 从上一阶段继续
name="final_model", # 实验名称
shear=1.5, # 剪切变换幅度
copy_paste=0.05, # 复制粘贴增强
erasing=0.1, # 随机擦除概率
overlap_mask=False # 禁用掩码重叠
)
# === 模型验证 ===
print("\n=== 最终模型验证 ===")
metrics = model.val(
data="data.yaml",
conf=0.35, # 置信度阈值
iou=0.65, # IoU阈值
plots=True, # 生成评估图表
half=False # 禁用半精度验证
)
# === 模型导出 ===
print("\n=== 模型导出 ===")
# 导出为PT格式TorchScript
model.export(
format="torchscript", # PT格式
# dynamic=True,
simplify=True
)
# 导出为ONNX格式
model.export(
format="onnx", # ONNX格式
dynamic=True,
simplify=True
)
# 导出为TFLite格式
model.export(
format="tflite", # TFLite格式
# dynamic=True,
simplify=True
)
if __name__ == "__main__":
# 硬件优化配置
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
# 数据检查
enhanced_data_check()
# 启动训练
main()