48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
from ultralytics import YOLO
|
||
import torch
|
||
|
||
def apply_data_augmentation():
|
||
"""应用数据增强技术提升模型泛化能力"""
|
||
# 检查是否支持CUDA
|
||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||
print(f'Using device: {device}')
|
||
|
||
# 加载预训练模型
|
||
model = YOLO('yolo12s.yaml')
|
||
|
||
# 开始训练
|
||
model.train(
|
||
data='datasets/data.yaml',
|
||
epochs=600,
|
||
imgsz=640,
|
||
batch=32,
|
||
device=0,
|
||
optimizer='SGD',
|
||
lr0=0.001,
|
||
project='run',
|
||
name='train',
|
||
workers=4,
|
||
# 设置模型保存周期(每50个epoch保存一次)
|
||
save_period=50,
|
||
# 数据增强参数
|
||
hsv_h=0.015, # 色调增强
|
||
hsv_s=0.7, # 饱和度增强
|
||
hsv_v=0.4, # 明度增强
|
||
degrees=30, # 随机旋转角度
|
||
translate=0.1, # 平移比例
|
||
scale=0.5, # 缩放比例
|
||
shear=15, # 剪切变换角度
|
||
perspective=0.0001, # 透视变换概率
|
||
flipud=0.5, # 上下翻转概率
|
||
fliplr=0.5, # 左右翻转概率
|
||
mosaic=1.0, # 马赛克增强概率
|
||
mixup=0.2, # MixUp增强概率
|
||
copy_paste=0.3, # 复制粘贴增强概率
|
||
auto_augment='randaugment', # 自动增强策略
|
||
# erasing=0.4, # 随机擦除概率
|
||
# fraction=0.9, # 训练验证集划分比例
|
||
cache= 'disk' # 启用缓存加速训练
|
||
)
|
||
|
||
if __name__ == '__main__':
|
||
apply_data_augmentation() |