Yolo-12-everything/continue_train.py

69 lines
2.4 KiB
Python
Raw Permalink Normal View History

"""
YOLOv12 模型继续训练脚本
功能从指定pt模型恢复训练状态并继续训练
使用方法python continue_train.py --weights=runs/detect/train/weights --epochs=50 --batch=48
"""
import torch
from ultralytics import YOLO
import argparse
import os
import glob
def parse_opt():
"""解析命令行参数"""
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, required=True, help='预训练模型路径')
parser.add_argument('--data', type=str, default='data.yaml', help='数据配置文件')
parser.add_argument('--epochs', type=int, default=100, help='总训练轮数')
parser.add_argument('--batch', type=int, default=16, help='批次大小')
parser.add_argument('--imgsz', type=int, default=640, help='输入图像尺寸')
parser.add_argument('--lr0', type=float, default=0.0001, help='初始学习率')
parser.add_argument('--lrf', type=float, default=0.01, help='最终学习率')
parser.add_argument('--optimizer', type=str, default='AdamW', choices=['SGD', 'Adam', 'AdamW'], help='优化器类型')
parser.add_argument('--resume_epoch', type=int, default=0, help='从第几个epoch开始继续训练')
return parser.parse_args()
def main(opt):
"""主训练流程"""
# 验证权重路径
if os.path.isdir(opt.weights):
pt_files = glob.glob(os.path.join(opt.weights, "*.pt"))
if pt_files:
opt.weights = pt_files[0] # 选择第一个找到的pt文件
else:
raise ValueError(f"未在指定目录找到.pt模型文件: {opt.weights}")
# 加载模型
model = YOLO(opt.weights)
# 获取最佳权重路径
weights_dir = os.path.dirname(opt.weights)
best_weights_path = os.path.join(weights_dir, 'best.pt')
# 如果存在最佳权重则加载
if os.path.exists(best_weights_path):
model = YOLO(best_weights_path)
print(f'已加载最佳权重: {best_weights_path}')
# 开始训练强制指定使用CPU避免CUDA设备错误
results = model.train(
data=opt.data,
epochs=opt.epochs,
batch=opt.batch,
imgsz=opt.imgsz,
lr0=opt.lr0,
lrf=opt.lrf,
optimizer=opt.optimizer,
resume=True,
device='cpu' # 强制指定使用CPU
)
# 输出训练结果
print('训练完成,最终指标:')
print(results)
if __name__ == '__main__':
opt = parse_opt()
main(opt)