diff --git a/continue_train.py b/continue_train.py new file mode 100644 index 00000000..df38a2e3 --- /dev/null +++ b/continue_train.py @@ -0,0 +1,69 @@ +""" +YOLOv12 模型继续训练脚本 +功能:从指定pt模型恢复训练状态并继续训练 +使用方法:python continue_train.py --weights=runs/detect/train/weights --epochs=50 --batch=48 +""" + +import torch +from ultralytics import YOLO +import argparse +import os +import glob + +def parse_opt(): + """解析命令行参数""" + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, required=True, help='预训练模型路径') + parser.add_argument('--data', type=str, default='data.yaml', help='数据配置文件') + parser.add_argument('--epochs', type=int, default=100, help='总训练轮数') + parser.add_argument('--batch', type=int, default=16, help='批次大小') + parser.add_argument('--imgsz', type=int, default=640, help='输入图像尺寸') + parser.add_argument('--lr0', type=float, default=0.0001, help='初始学习率') + parser.add_argument('--lrf', type=float, default=0.01, help='最终学习率') + parser.add_argument('--optimizer', type=str, default='AdamW', choices=['SGD', 'Adam', 'AdamW'], help='优化器类型') + parser.add_argument('--resume_epoch', type=int, default=0, help='从第几个epoch开始继续训练') + + return parser.parse_args() + +def main(opt): + """主训练流程""" + # 验证权重路径 + if os.path.isdir(opt.weights): + pt_files = glob.glob(os.path.join(opt.weights, "*.pt")) + if pt_files: + opt.weights = pt_files[0] # 选择第一个找到的pt文件 + else: + raise ValueError(f"未在指定目录找到.pt模型文件: {opt.weights}") + + # 加载模型 + model = YOLO(opt.weights) + + # 获取最佳权重路径 + weights_dir = os.path.dirname(opt.weights) + best_weights_path = os.path.join(weights_dir, 'best.pt') + + # 如果存在最佳权重则加载 + if os.path.exists(best_weights_path): + model = YOLO(best_weights_path) + print(f'已加载最佳权重: {best_weights_path}') + + # 开始训练(强制指定使用CPU避免CUDA设备错误) + results = model.train( + data=opt.data, + epochs=opt.epochs, + batch=opt.batch, + imgsz=opt.imgsz, + lr0=opt.lr0, + lrf=opt.lrf, + optimizer=opt.optimizer, + resume=True, + device='cpu' # 强制指定使用CPU + ) + + # 输出训练结果 + print('训练完成,最终指标:') + print(results) + +if __name__ == '__main__': + opt = parse_opt() + main(opt) \ No newline at end of file