ComView/train.py
Asuka 66d0877cfd refactor(data): 更新数据集路径并调整数据验证逻辑
- 修改 data.yaml 中的训练、验证和测试数据集路径
- 更新 train.py 中的数据验证逻辑,使用新的数据集路径
2025-05-29 20:16:02 +08:00

132 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YOLOv12 绝缘子检测优化训练脚本(最终稳定版)
主要改进:
1. 适配YOLOv12 参数规范
2. 移除已弃用参数
3. 优化训练配置
"""
import torch
from ultralytics import YOLO
import shutil
import os
import glob
from PIL import Image, ImageDraw
import numpy as np
# +++ 有效参数配置区块 +++
TRAIN_CONFIG = {
"stage1": {
"freeze": [0, 1, 2, 3, 4], # 冻结前5层
"epochs": 75, # 增加训练轮数
"batch": 32, # 增加批量大小
"lr0": 0.001, # 降低初始学习率
"lrf": 0.1, # 提高最终学习率
"hsv_h": 0.01, # 色相增强幅度
"degrees": 10.0, # 增加旋转角度范围
"perspective": 0.005, # 增加透视变换
"flipud": 0.2, # 增加上下翻转概率
"optimizer": "SGD", # 优化器类型
"weight_decay": 0.001, # 权重衰减
"label_smoothing": 0.1, # 标签平滑
"mosaic": True, # 启用Mosaic增强
# "class_weights": [1.0, 2.0] # 设置类别权重insulator权重更高
},
"stage2": {
"epochs": 150, # 增加训练轮数
"batch": 16, # 增加批量大小
"lr0": 0.0001, # 降低初始学习率
"lrf": 0.01, # 提高最终学习率
"mixup": 0.2, # 增加MixUp增强系数
"close_mosaic": 10, # 最后10epoch关闭Mosaic
"optimizer": "AdamW", # 切换优化器为AdamW
"warmup_epochs": 10 # 增大学习率预热轮数
}
}
def enhanced_data_check():
"""增强型数据验证(稳定版)"""
print("\n=== 执行增强数据验证 ===")
label_files = glob.glob("datasets/train/labels/*.txt")
class_dist = {}
# 类别分布分析
for lbl in label_files:
with open(lbl) as f:
for line in f:
class_id = int(line.strip().split()[0])
class_dist[class_id] = class_dist.get(class_id, 0) + 1
print(f"类别分布:{class_dist}")
# 样本可视化(示例保留结构)
sample_count = 3
for i in range(sample_count):
img_path = f"datasets/train/images/{i:04d}.jpg"
lbl_path = f"datasets/train/labels/{i:04d}.txt"
# 添加实际可视化逻辑
def main():
# 初始化模型(正确方式)
model = YOLO("yolo12s.yaml") # 确保yaml文件存在
# === 阶段一:冻结训练 ===
print("\n=== 阶段一:特征提取层训练 ===")
stage1_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage1"],
device=0,
workers=4,
box=5.0, # box损失权重
cls=1.0, # 分类损失权重
save_period=5, # 每5epoch保存检查点
patience=20, # 早停等待周期
deterministic=False # 禁用确定性模式
)
# === 阶段二:全网络微调 ===
print("\n=== 阶段二:全网络微调 ===")
stage2_results = model.train(
data="data.yaml",
**TRAIN_CONFIG["stage2"],
resume=True, # 从上一阶段继续
name="final_model", # 实验名称
shear=1.5, # 剪切变换幅度
copy_paste=0.05, # 复制粘贴增强
erasing=0.1, # 随机擦除概率
overlap_mask=False # 禁用掩码重叠
)
# === 模型验证 ===
print("\n=== 最终模型验证 ===")
metrics = model.val(
data="data.yaml",
conf=0.35, # 置信度阈值
iou=0.65, # IoU阈值
plots=True, # 生成评估图表
half=False # 禁用半精度验证
)
# === 模型导出 ===
print("\n=== 模型导出 ===")
model.export(
opset_version=15,
format="onnx", # 优先导出ONNX
# format="torchscript",
dynamic=True, # 动态维度
simplify=True # 简化模型
)
if __name__ == "__main__":
# 硬件优化配置
torch.set_float32_matmul_precision('high')
torch.backends.cudnn.benchmark = True
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8"
# 数据检查
enhanced_data_check()
# 启动训练
main()