48 lines
1.5 KiB
Python
48 lines
1.5 KiB
Python
|
from ultralytics import YOLO
|
|||
|
import torch
|
|||
|
|
|||
|
def apply_data_augmentation():
|
|||
|
"""应用数据增强技术提升模型泛化能力"""
|
|||
|
# 检查是否支持CUDA
|
|||
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|||
|
print(f'Using device: {device}')
|
|||
|
|
|||
|
# 加载预训练模型
|
|||
|
model = YOLO('yolo12s.yaml')
|
|||
|
|
|||
|
# 开始训练
|
|||
|
model.train(
|
|||
|
data='datasets/data.yaml',
|
|||
|
epochs=600,
|
|||
|
imgsz=640,
|
|||
|
batch=32,
|
|||
|
device=0,
|
|||
|
optimizer='SGD',
|
|||
|
lr0=0.001,
|
|||
|
project='run',
|
|||
|
name='train',
|
|||
|
workers=4,
|
|||
|
# 设置模型保存周期(每50个epoch保存一次)
|
|||
|
save_period=50,
|
|||
|
# 数据增强参数
|
|||
|
hsv_h=0.015, # 色调增强
|
|||
|
hsv_s=0.7, # 饱和度增强
|
|||
|
hsv_v=0.4, # 明度增强
|
|||
|
degrees=30, # 随机旋转角度
|
|||
|
translate=0.1, # 平移比例
|
|||
|
scale=0.5, # 缩放比例
|
|||
|
shear=15, # 剪切变换角度
|
|||
|
perspective=0.0001, # 透视变换概率
|
|||
|
flipud=0.5, # 上下翻转概率
|
|||
|
fliplr=0.5, # 左右翻转概率
|
|||
|
mosaic=1.0, # 马赛克增强概率
|
|||
|
mixup=0.2, # MixUp增强概率
|
|||
|
copy_paste=0.3, # 复制粘贴增强概率
|
|||
|
auto_augment='randaugment', # 自动增强策略
|
|||
|
# erasing=0.4, # 随机擦除概率
|
|||
|
# fraction=0.9, # 训练验证集划分比例
|
|||
|
cache= 'disk' # 启用缓存加速训练
|
|||
|
)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
apply_data_augmentation()
|