184 lines
6.0 KiB
Python
184 lines
6.0 KiB
Python
|
|
|||
|
"""
|
|||
|
YOLOv12 绝缘子检测优化训练脚本(最终稳定版)
|
|||
|
主要改进:
|
|||
|
1. 适配YOLOv12 参数规范
|
|||
|
2. 移除已弃用参数
|
|||
|
3. 优化训练配置
|
|||
|
"""
|
|||
|
|
|||
|
import torch
|
|||
|
from ultralytics import YOLO
|
|||
|
import shutil
|
|||
|
import os
|
|||
|
import glob
|
|||
|
from PIL import Image, ImageDraw
|
|||
|
import numpy as np
|
|||
|
|
|||
|
# +++ 有效参数配置区块 +++
|
|||
|
TRAIN_CONFIG = {
|
|||
|
"stage1": {
|
|||
|
"freeze": [0, 1, 2, 3, 4], # 冻结前5层
|
|||
|
"epochs": 150, # 优化训练轮数
|
|||
|
"batch": 48, # 增加批量大小
|
|||
|
"lr0": 0.0005, # 优化初始学习率
|
|||
|
"lrf": 0.1, # 提高最终学习率
|
|||
|
"hsv_h": 0.01, # 色相增强幅度
|
|||
|
"degrees": 10.0, # 增加旋转角度范围
|
|||
|
"perspective": 0.005, # 增加透视变换
|
|||
|
"flipud": 0.2, # 增加上下翻转概率
|
|||
|
"optimizer": "SGD", # 优化器类型
|
|||
|
"weight_decay": 0.001, # 权重衰减
|
|||
|
"label_smoothing": 0.1, # 标签平滑
|
|||
|
# "mosaic": True, # 启用Mosaic增强
|
|||
|
"amp": True # 启用混合精度训练
|
|||
|
},
|
|||
|
"stage2": {
|
|||
|
"epochs": 150, # 优化训练轮数
|
|||
|
"batch": 48, # 增加批量大小
|
|||
|
"lr0": 0.00005, # 优化初始学习率
|
|||
|
"lrf": 0.01, # 提高最终学习率
|
|||
|
"mixup": 0.2, # 增加MixUp增强系数
|
|||
|
"close_mosaic": 10, # 最后10epoch关闭Mosaic
|
|||
|
"optimizer": "AdamW", # 切换优化器为AdamW
|
|||
|
"warmup_epochs": 10, # 增大学习率预热轮数
|
|||
|
"amp": True # 启用混合精度训练
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
def validate_bbox_coordinates(label_path):
|
|||
|
"""验证标注文件边界框坐标是否在标准化范围(0-1)"""
|
|||
|
try:
|
|||
|
with open(label_path) as f:
|
|||
|
for line in f:
|
|||
|
parts = list(map(float, line.strip().split()))
|
|||
|
if len(parts) > 1:
|
|||
|
x_center, y_center, width, height = parts[1:5]
|
|||
|
assert 0 <= x_center <= 1 and 0 <= y_center <= 1, f"坐标超出范围:{label_path}"
|
|||
|
assert 0 < width <= 1 and 0 < height <= 1, f"尺寸异常:{label_path}"
|
|||
|
except Exception as e:
|
|||
|
print(f"验证失败 {label_path}: {str(e)}")
|
|||
|
|
|||
|
|
|||
|
def enhanced_data_check():
|
|||
|
"""增强型数据验证(稳定版)"""
|
|||
|
print("\n=== 执行增强数据验证 ===")
|
|||
|
label_files = glob.glob("datasets/train/labels/*.txt")
|
|||
|
class_dist = {}
|
|||
|
|
|||
|
# 类别分布分析
|
|||
|
for lbl in label_files:
|
|||
|
with open(lbl) as f:
|
|||
|
for line in f:
|
|||
|
class_id = int(line.strip().split()[0])
|
|||
|
class_dist[class_id] = class_dist.get(class_id, 0) + 1
|
|||
|
|
|||
|
print(f"类别分布:{class_dist}")
|
|||
|
|
|||
|
# 坐标验证
|
|||
|
for lbl in label_files:
|
|||
|
validate_bbox_coordinates(lbl)
|
|||
|
|
|||
|
# 样本可视化
|
|||
|
sample_count = 3
|
|||
|
for i in range(sample_count):
|
|||
|
img_path = f"datasets/train/images/{i:04d}.jpg"
|
|||
|
lbl_path = f"datasets/train/labels/{i:04d}.txt"
|
|||
|
try:
|
|||
|
img = Image.open(img_path)
|
|||
|
draw = ImageDraw.Draw(img)
|
|||
|
with open(lbl_path) as f:
|
|||
|
for line in f:
|
|||
|
parts = list(map(float, line.strip().split()))
|
|||
|
if len(parts) > 1:
|
|||
|
x_center, y_center, width, height = parts[1:5]
|
|||
|
# 转换为绝对坐标
|
|||
|
w, h = img.size
|
|||
|
x1 = (x_center - width/2) * w
|
|||
|
y1 = (y_center - height/2) * h
|
|||
|
x2 = (x_center + width/2) * w
|
|||
|
y2 = (y_center + height/2) * h
|
|||
|
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
|
|||
|
img.save(f"sample_{i}.jpg")
|
|||
|
print(f"样本可视化已保存:sample_{i}.jpg")
|
|||
|
except Exception as e:
|
|||
|
print(f"可视化失败 {i}: {str(e)}")
|
|||
|
|
|||
|
def main():
|
|||
|
# 初始化模型(正确方式)
|
|||
|
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
|
|||
|
|
|||
|
# === 阶段一:冻结训练 ===
|
|||
|
print("\n=== 阶段一:特征提取层训练 ===")
|
|||
|
stage1_results = model.train(
|
|||
|
data="data.yaml",
|
|||
|
**TRAIN_CONFIG["stage1"],
|
|||
|
device=0,
|
|||
|
workers=16,
|
|||
|
box=5.0, # box损失权重
|
|||
|
cls=1.0, # 分类损失权重
|
|||
|
save_period=5, # 每5epoch保存检查点
|
|||
|
patience=20, # 早停等待周期
|
|||
|
deterministic=False # 禁用确定性模式
|
|||
|
)
|
|||
|
|
|||
|
# === 阶段二:全网络微调 ===
|
|||
|
print("\n=== 阶段二:全网络微调 ===")
|
|||
|
stage2_results = model.train(
|
|||
|
data="data.yaml",
|
|||
|
**TRAIN_CONFIG["stage2"],
|
|||
|
resume=True, # 从上一阶段继续
|
|||
|
name="final_model", # 实验名称
|
|||
|
shear=1.5, # 剪切变换幅度
|
|||
|
copy_paste=0.05, # 复制粘贴增强
|
|||
|
erasing=0.1, # 随机擦除概率
|
|||
|
overlap_mask=False # 禁用掩码重叠
|
|||
|
)
|
|||
|
|
|||
|
# === 模型验证 ===
|
|||
|
print("\n=== 最终模型验证 ===")
|
|||
|
metrics = model.val(
|
|||
|
data="data.yaml",
|
|||
|
conf=0.35, # 置信度阈值
|
|||
|
iou=0.65, # IoU阈值
|
|||
|
plots=True, # 生成评估图表
|
|||
|
half=False # 禁用半精度验证
|
|||
|
)
|
|||
|
|
|||
|
# === 模型导出 ===
|
|||
|
print("\n=== 模型导出 ===")
|
|||
|
|
|||
|
# 导出为PT格式(TorchScript)
|
|||
|
model.export(
|
|||
|
format="torchscript", # PT格式
|
|||
|
# dynamic=True,
|
|||
|
simplify=True
|
|||
|
)
|
|||
|
|
|||
|
# 导出为ONNX格式
|
|||
|
model.export(
|
|||
|
format="onnx", # ONNX格式
|
|||
|
dynamic=True,
|
|||
|
simplify=True
|
|||
|
)
|
|||
|
|
|||
|
# 导出为TFLite格式
|
|||
|
model.export(
|
|||
|
format="tflite", # TFLite格式
|
|||
|
# dynamic=True,
|
|||
|
simplify=True
|
|||
|
)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 硬件优化配置
|
|||
|
torch.set_float32_matmul_precision('high')
|
|||
|
torch.backends.cudnn.benchmark = True
|
|||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
|||
|
|
|||
|
# 数据检查
|
|||
|
enhanced_data_check()
|
|||
|
|
|||
|
# 启动训练
|
|||
|
main()
|