69 lines
2.4 KiB
Python
69 lines
2.4 KiB
Python
|
"""
|
|||
|
YOLOv12 模型继续训练脚本
|
|||
|
功能:从指定pt模型恢复训练状态并继续训练
|
|||
|
使用方法:python continue_train.py --weights=runs/detect/train/weights --epochs=50 --batch=48
|
|||
|
"""
|
|||
|
|
|||
|
import torch
|
|||
|
from ultralytics import YOLO
|
|||
|
import argparse
|
|||
|
import os
|
|||
|
import glob
|
|||
|
|
|||
|
def parse_opt():
|
|||
|
"""解析命令行参数"""
|
|||
|
parser = argparse.ArgumentParser()
|
|||
|
parser.add_argument('--weights', type=str, required=True, help='预训练模型路径')
|
|||
|
parser.add_argument('--data', type=str, default='data.yaml', help='数据配置文件')
|
|||
|
parser.add_argument('--epochs', type=int, default=100, help='总训练轮数')
|
|||
|
parser.add_argument('--batch', type=int, default=16, help='批次大小')
|
|||
|
parser.add_argument('--imgsz', type=int, default=640, help='输入图像尺寸')
|
|||
|
parser.add_argument('--lr0', type=float, default=0.0001, help='初始学习率')
|
|||
|
parser.add_argument('--lrf', type=float, default=0.01, help='最终学习率')
|
|||
|
parser.add_argument('--optimizer', type=str, default='AdamW', choices=['SGD', 'Adam', 'AdamW'], help='优化器类型')
|
|||
|
parser.add_argument('--resume_epoch', type=int, default=0, help='从第几个epoch开始继续训练')
|
|||
|
|
|||
|
return parser.parse_args()
|
|||
|
|
|||
|
def main(opt):
|
|||
|
"""主训练流程"""
|
|||
|
# 验证权重路径
|
|||
|
if os.path.isdir(opt.weights):
|
|||
|
pt_files = glob.glob(os.path.join(opt.weights, "*.pt"))
|
|||
|
if pt_files:
|
|||
|
opt.weights = pt_files[0] # 选择第一个找到的pt文件
|
|||
|
else:
|
|||
|
raise ValueError(f"未在指定目录找到.pt模型文件: {opt.weights}")
|
|||
|
|
|||
|
# 加载模型
|
|||
|
model = YOLO(opt.weights)
|
|||
|
|
|||
|
# 获取最佳权重路径
|
|||
|
weights_dir = os.path.dirname(opt.weights)
|
|||
|
best_weights_path = os.path.join(weights_dir, 'best.pt')
|
|||
|
|
|||
|
# 如果存在最佳权重则加载
|
|||
|
if os.path.exists(best_weights_path):
|
|||
|
model = YOLO(best_weights_path)
|
|||
|
print(f'已加载最佳权重: {best_weights_path}')
|
|||
|
|
|||
|
# 开始训练(强制指定使用CPU避免CUDA设备错误)
|
|||
|
results = model.train(
|
|||
|
data=opt.data,
|
|||
|
epochs=opt.epochs,
|
|||
|
batch=opt.batch,
|
|||
|
imgsz=opt.imgsz,
|
|||
|
lr0=opt.lr0,
|
|||
|
lrf=opt.lrf,
|
|||
|
optimizer=opt.optimizer,
|
|||
|
resume=True,
|
|||
|
device='cpu' # 强制指定使用CPU
|
|||
|
)
|
|||
|
|
|||
|
# 输出训练结果
|
|||
|
print('训练完成,最终指标:')
|
|||
|
print(results)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
opt = parse_opt()
|
|||
|
main(opt)
|