AI-exp-5/train.py
fly6516 a009b57c89 feat(datasets): 添加新的训练和验证数据标签
- 在 train 和 valid 文件夹中添加了多个新的标签文件
- 标签文件包含了不同类型的目标检测框坐标和类别
- 新增数据有助于模型训练和性能提升
2025-06-06 18:07:28 +08:00

48 lines
1.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from ultralytics import YOLO
import torch
def apply_data_augmentation():
"""应用数据增强技术提升模型泛化能力"""
# 检查是否支持CUDA
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')
# 加载预训练模型
model = YOLO('yolo12s.yaml')
# 开始训练
model.train(
data='datasets/data.yaml',
epochs=600,
imgsz=640,
batch=32,
device=0,
optimizer='SGD',
lr0=0.001,
project='run',
name='train',
workers=4,
# 设置模型保存周期每50个epoch保存一次
save_period=50,
# 数据增强参数
hsv_h=0.015, # 色调增强
hsv_s=0.7, # 饱和度增强
hsv_v=0.4, # 明度增强
degrees=30, # 随机旋转角度
translate=0.1, # 平移比例
scale=0.5, # 缩放比例
shear=15, # 剪切变换角度
perspective=0.0001, # 透视变换概率
flipud=0.5, # 上下翻转概率
fliplr=0.5, # 左右翻转概率
mosaic=1.0, # 马赛克增强概率
mixup=0.2, # MixUp增强概率
copy_paste=0.3, # 复制粘贴增强概率
auto_augment='randaugment', # 自动增强策略
# erasing=0.4, # 随机擦除概率
# fraction=0.9, # 训练验证集划分比例
cache= 'disk' # 启用缓存加速训练
)
if __name__ == '__main__':
apply_data_augmentation()