feat: 添加 YOLOv12 模型继续训练脚本
- 新增 continue_train.py 文件,实现从指定模型恢复训练状态并继续训练的功能 - 支持通过命令行参数指定预训练模型路径、数据配置文件、训练轮数等参数 -增加了对最佳权重的加载逻辑,确保从最优模型开始继续训练 - 强制使用 CPU 进行训练,避免 CUDA 设备错误
This commit is contained in:
parent
f13e50aa6b
commit
6a772b23ed
69
continue_train.py
Normal file
69
continue_train.py
Normal file
@ -0,0 +1,69 @@
|
|||||||
|
"""
|
||||||
|
YOLOv12 模型继续训练脚本
|
||||||
|
功能:从指定pt模型恢复训练状态并继续训练
|
||||||
|
使用方法:python continue_train.py --weights=runs/detect/train/weights --epochs=50 --batch=48
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from ultralytics import YOLO
|
||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
|
||||||
|
def parse_opt():
|
||||||
|
"""解析命令行参数"""
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--weights', type=str, required=True, help='预训练模型路径')
|
||||||
|
parser.add_argument('--data', type=str, default='data.yaml', help='数据配置文件')
|
||||||
|
parser.add_argument('--epochs', type=int, default=100, help='总训练轮数')
|
||||||
|
parser.add_argument('--batch', type=int, default=16, help='批次大小')
|
||||||
|
parser.add_argument('--imgsz', type=int, default=640, help='输入图像尺寸')
|
||||||
|
parser.add_argument('--lr0', type=float, default=0.0001, help='初始学习率')
|
||||||
|
parser.add_argument('--lrf', type=float, default=0.01, help='最终学习率')
|
||||||
|
parser.add_argument('--optimizer', type=str, default='AdamW', choices=['SGD', 'Adam', 'AdamW'], help='优化器类型')
|
||||||
|
parser.add_argument('--resume_epoch', type=int, default=0, help='从第几个epoch开始继续训练')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
def main(opt):
|
||||||
|
"""主训练流程"""
|
||||||
|
# 验证权重路径
|
||||||
|
if os.path.isdir(opt.weights):
|
||||||
|
pt_files = glob.glob(os.path.join(opt.weights, "*.pt"))
|
||||||
|
if pt_files:
|
||||||
|
opt.weights = pt_files[0] # 选择第一个找到的pt文件
|
||||||
|
else:
|
||||||
|
raise ValueError(f"未在指定目录找到.pt模型文件: {opt.weights}")
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model = YOLO(opt.weights)
|
||||||
|
|
||||||
|
# 获取最佳权重路径
|
||||||
|
weights_dir = os.path.dirname(opt.weights)
|
||||||
|
best_weights_path = os.path.join(weights_dir, 'best.pt')
|
||||||
|
|
||||||
|
# 如果存在最佳权重则加载
|
||||||
|
if os.path.exists(best_weights_path):
|
||||||
|
model = YOLO(best_weights_path)
|
||||||
|
print(f'已加载最佳权重: {best_weights_path}')
|
||||||
|
|
||||||
|
# 开始训练(强制指定使用CPU避免CUDA设备错误)
|
||||||
|
results = model.train(
|
||||||
|
data=opt.data,
|
||||||
|
epochs=opt.epochs,
|
||||||
|
batch=opt.batch,
|
||||||
|
imgsz=opt.imgsz,
|
||||||
|
lr0=opt.lr0,
|
||||||
|
lrf=opt.lrf,
|
||||||
|
optimizer=opt.optimizer,
|
||||||
|
resume=True,
|
||||||
|
device='cpu' # 强制指定使用CPU
|
||||||
|
)
|
||||||
|
|
||||||
|
# 输出训练结果
|
||||||
|
print('训练完成,最终指标:')
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
opt = parse_opt()
|
||||||
|
main(opt)
|
Loading…
Reference in New Issue
Block a user