From 6a772b23eddd63a5631eb6e24ad0505894e9aa15 Mon Sep 17 00:00:00 2001 From: fly6516 Date: Wed, 4 Jun 2025 23:24:53 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=20YOLOv12=20?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=BB=A7=E7=BB=AD=E8=AE=AD=E7=BB=83=E8=84=9A?= =?UTF-8?q?=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 continue_train.py 文件,实现从指定模型恢复训练状态并继续训练的功能 - 支持通过命令行参数指定预训练模型路径、数据配置文件、训练轮数等参数 -增加了对最佳权重的加载逻辑,确保从最优模型开始继续训练 - 强制使用 CPU 进行训练,避免 CUDA 设备错误 --- continue_train.py | 69 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 continue_train.py diff --git a/continue_train.py b/continue_train.py new file mode 100644 index 00000000..df38a2e3 --- /dev/null +++ b/continue_train.py @@ -0,0 +1,69 @@ +""" +YOLOv12 模型继续训练脚本 +功能:从指定pt模型恢复训练状态并继续训练 +使用方法:python continue_train.py --weights=runs/detect/train/weights --epochs=50 --batch=48 +""" + +import torch +from ultralytics import YOLO +import argparse +import os +import glob + +def parse_opt(): + """解析命令行参数""" + parser = argparse.ArgumentParser() + parser.add_argument('--weights', type=str, required=True, help='预训练模型路径') + parser.add_argument('--data', type=str, default='data.yaml', help='数据配置文件') + parser.add_argument('--epochs', type=int, default=100, help='总训练轮数') + parser.add_argument('--batch', type=int, default=16, help='批次大小') + parser.add_argument('--imgsz', type=int, default=640, help='输入图像尺寸') + parser.add_argument('--lr0', type=float, default=0.0001, help='初始学习率') + parser.add_argument('--lrf', type=float, default=0.01, help='最终学习率') + parser.add_argument('--optimizer', type=str, default='AdamW', choices=['SGD', 'Adam', 'AdamW'], help='优化器类型') + parser.add_argument('--resume_epoch', type=int, default=0, help='从第几个epoch开始继续训练') + + return parser.parse_args() + +def main(opt): + """主训练流程""" + # 验证权重路径 + if os.path.isdir(opt.weights): + pt_files = glob.glob(os.path.join(opt.weights, "*.pt")) + if pt_files: + opt.weights = pt_files[0] # 选择第一个找到的pt文件 + else: + raise ValueError(f"未在指定目录找到.pt模型文件: {opt.weights}") + + # 加载模型 + model = YOLO(opt.weights) + + # 获取最佳权重路径 + weights_dir = os.path.dirname(opt.weights) + best_weights_path = os.path.join(weights_dir, 'best.pt') + + # 如果存在最佳权重则加载 + if os.path.exists(best_weights_path): + model = YOLO(best_weights_path) + print(f'已加载最佳权重: {best_weights_path}') + + # 开始训练(强制指定使用CPU避免CUDA设备错误) + results = model.train( + data=opt.data, + epochs=opt.epochs, + batch=opt.batch, + imgsz=opt.imgsz, + lr0=opt.lr0, + lrf=opt.lrf, + optimizer=opt.optimizer, + resume=True, + device='cpu' # 强制指定使用CPU + ) + + # 输出训练结果 + print('训练完成,最终指标:') + print(results) + +if __name__ == '__main__': + opt = parse_opt() + main(opt) \ No newline at end of file