184 lines
6.0 KiB
Python
184 lines
6.0 KiB
Python
|
||
"""
|
||
YOLOv12 绝缘子检测优化训练脚本(最终稳定版)
|
||
主要改进:
|
||
1. 适配YOLOv12 参数规范
|
||
2. 移除已弃用参数
|
||
3. 优化训练配置
|
||
"""
|
||
|
||
import torch
|
||
from ultralytics import YOLO
|
||
import shutil
|
||
import os
|
||
import glob
|
||
from PIL import Image, ImageDraw
|
||
import numpy as np
|
||
|
||
# +++ 有效参数配置区块 +++
|
||
TRAIN_CONFIG = {
|
||
"stage1": {
|
||
"freeze": [0, 1, 2, 3, 4], # 冻结前5层
|
||
"epochs": 150, # 优化训练轮数
|
||
"batch": 48, # 增加批量大小
|
||
"lr0": 0.0005, # 优化初始学习率
|
||
"lrf": 0.1, # 提高最终学习率
|
||
"hsv_h": 0.01, # 色相增强幅度
|
||
"degrees": 10.0, # 增加旋转角度范围
|
||
"perspective": 0.005, # 增加透视变换
|
||
"flipud": 0.2, # 增加上下翻转概率
|
||
"optimizer": "SGD", # 优化器类型
|
||
"weight_decay": 0.001, # 权重衰减
|
||
"label_smoothing": 0.1, # 标签平滑
|
||
# "mosaic": True, # 启用Mosaic增强
|
||
"amp": True # 启用混合精度训练
|
||
},
|
||
"stage2": {
|
||
"epochs": 150, # 优化训练轮数
|
||
"batch": 48, # 增加批量大小
|
||
"lr0": 0.00005, # 优化初始学习率
|
||
"lrf": 0.01, # 提高最终学习率
|
||
"mixup": 0.2, # 增加MixUp增强系数
|
||
"close_mosaic": 10, # 最后10epoch关闭Mosaic
|
||
"optimizer": "AdamW", # 切换优化器为AdamW
|
||
"warmup_epochs": 10, # 增大学习率预热轮数
|
||
"amp": True # 启用混合精度训练
|
||
}
|
||
}
|
||
|
||
|
||
def validate_bbox_coordinates(label_path):
|
||
"""验证标注文件边界框坐标是否在标准化范围(0-1)"""
|
||
try:
|
||
with open(label_path) as f:
|
||
for line in f:
|
||
parts = list(map(float, line.strip().split()))
|
||
if len(parts) > 1:
|
||
x_center, y_center, width, height = parts[1:5]
|
||
assert 0 <= x_center <= 1 and 0 <= y_center <= 1, f"坐标超出范围:{label_path}"
|
||
assert 0 < width <= 1 and 0 < height <= 1, f"尺寸异常:{label_path}"
|
||
except Exception as e:
|
||
print(f"验证失败 {label_path}: {str(e)}")
|
||
|
||
|
||
def enhanced_data_check():
|
||
"""增强型数据验证(稳定版)"""
|
||
print("\n=== 执行增强数据验证 ===")
|
||
label_files = glob.glob("datasets/train/labels/*.txt")
|
||
class_dist = {}
|
||
|
||
# 类别分布分析
|
||
for lbl in label_files:
|
||
with open(lbl) as f:
|
||
for line in f:
|
||
class_id = int(line.strip().split()[0])
|
||
class_dist[class_id] = class_dist.get(class_id, 0) + 1
|
||
|
||
print(f"类别分布:{class_dist}")
|
||
|
||
# 坐标验证
|
||
for lbl in label_files:
|
||
validate_bbox_coordinates(lbl)
|
||
|
||
# 样本可视化
|
||
sample_count = 3
|
||
for i in range(sample_count):
|
||
img_path = f"datasets/train/images/{i:04d}.jpg"
|
||
lbl_path = f"datasets/train/labels/{i:04d}.txt"
|
||
try:
|
||
img = Image.open(img_path)
|
||
draw = ImageDraw.Draw(img)
|
||
with open(lbl_path) as f:
|
||
for line in f:
|
||
parts = list(map(float, line.strip().split()))
|
||
if len(parts) > 1:
|
||
x_center, y_center, width, height = parts[1:5]
|
||
# 转换为绝对坐标
|
||
w, h = img.size
|
||
x1 = (x_center - width/2) * w
|
||
y1 = (y_center - height/2) * h
|
||
x2 = (x_center + width/2) * w
|
||
y2 = (y_center + height/2) * h
|
||
draw.rectangle([x1, y1, x2, y2], outline="red", width=2)
|
||
img.save(f"sample_{i}.jpg")
|
||
print(f"样本可视化已保存:sample_{i}.jpg")
|
||
except Exception as e:
|
||
print(f"可视化失败 {i}: {str(e)}")
|
||
|
||
def main():
|
||
# 初始化模型(正确方式)
|
||
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
|
||
|
||
# === 阶段一:冻结训练 ===
|
||
print("\n=== 阶段一:特征提取层训练 ===")
|
||
stage1_results = model.train(
|
||
data="data.yaml",
|
||
**TRAIN_CONFIG["stage1"],
|
||
device=0,
|
||
workers=16,
|
||
box=5.0, # box损失权重
|
||
cls=1.0, # 分类损失权重
|
||
save_period=5, # 每5epoch保存检查点
|
||
patience=20, # 早停等待周期
|
||
deterministic=False # 禁用确定性模式
|
||
)
|
||
|
||
# === 阶段二:全网络微调 ===
|
||
print("\n=== 阶段二:全网络微调 ===")
|
||
stage2_results = model.train(
|
||
data="data.yaml",
|
||
**TRAIN_CONFIG["stage2"],
|
||
resume=True, # 从上一阶段继续
|
||
name="final_model", # 实验名称
|
||
shear=1.5, # 剪切变换幅度
|
||
copy_paste=0.05, # 复制粘贴增强
|
||
erasing=0.1, # 随机擦除概率
|
||
overlap_mask=False # 禁用掩码重叠
|
||
)
|
||
|
||
# === 模型验证 ===
|
||
print("\n=== 最终模型验证 ===")
|
||
metrics = model.val(
|
||
data="data.yaml",
|
||
conf=0.35, # 置信度阈值
|
||
iou=0.65, # IoU阈值
|
||
plots=True, # 生成评估图表
|
||
half=False # 禁用半精度验证
|
||
)
|
||
|
||
# === 模型导出 ===
|
||
print("\n=== 模型导出 ===")
|
||
|
||
# 导出为PT格式(TorchScript)
|
||
model.export(
|
||
format="torchscript", # PT格式
|
||
# dynamic=True,
|
||
simplify=True
|
||
)
|
||
|
||
# 导出为ONNX格式
|
||
model.export(
|
||
format="onnx", # ONNX格式
|
||
dynamic=True,
|
||
simplify=True
|
||
)
|
||
|
||
# 导出为TFLite格式
|
||
model.export(
|
||
format="tflite", # TFLite格式
|
||
# dynamic=True,
|
||
simplify=True
|
||
)
|
||
|
||
if __name__ == "__main__":
|
||
# 硬件优化配置
|
||
torch.set_float32_matmul_precision('high')
|
||
torch.backends.cudnn.benchmark = True
|
||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
|
||
|
||
# 数据检查
|
||
enhanced_data_check()
|
||
|
||
# 启动训练
|
||
main()
|