Yolo-12-everything/continue_train.py
fly6516 6a772b23ed feat: 添加 YOLOv12 模型继续训练脚本
- 新增 continue_train.py 文件,实现从指定模型恢复训练状态并继续训练的功能
- 支持通过命令行参数指定预训练模型路径、数据配置文件、训练轮数等参数
-增加了对最佳权重的加载逻辑,确保从最优模型开始继续训练
- 强制使用 CPU 进行训练,避免 CUDA 设备错误
2025-06-04 23:24:53 +08:00

69 lines
2.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
YOLOv12 模型继续训练脚本
功能从指定pt模型恢复训练状态并继续训练
使用方法python continue_train.py --weights=runs/detect/train/weights --epochs=50 --batch=48
"""
import torch
from ultralytics import YOLO
import argparse
import os
import glob
def parse_opt():
"""解析命令行参数"""
parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, required=True, help='预训练模型路径')
parser.add_argument('--data', type=str, default='data.yaml', help='数据配置文件')
parser.add_argument('--epochs', type=int, default=100, help='总训练轮数')
parser.add_argument('--batch', type=int, default=16, help='批次大小')
parser.add_argument('--imgsz', type=int, default=640, help='输入图像尺寸')
parser.add_argument('--lr0', type=float, default=0.0001, help='初始学习率')
parser.add_argument('--lrf', type=float, default=0.01, help='最终学习率')
parser.add_argument('--optimizer', type=str, default='AdamW', choices=['SGD', 'Adam', 'AdamW'], help='优化器类型')
parser.add_argument('--resume_epoch', type=int, default=0, help='从第几个epoch开始继续训练')
return parser.parse_args()
def main(opt):
"""主训练流程"""
# 验证权重路径
if os.path.isdir(opt.weights):
pt_files = glob.glob(os.path.join(opt.weights, "*.pt"))
if pt_files:
opt.weights = pt_files[0] # 选择第一个找到的pt文件
else:
raise ValueError(f"未在指定目录找到.pt模型文件: {opt.weights}")
# 加载模型
model = YOLO(opt.weights)
# 获取最佳权重路径
weights_dir = os.path.dirname(opt.weights)
best_weights_path = os.path.join(weights_dir, 'best.pt')
# 如果存在最佳权重则加载
if os.path.exists(best_weights_path):
model = YOLO(best_weights_path)
print(f'已加载最佳权重: {best_weights_path}')
# 开始训练强制指定使用CPU避免CUDA设备错误
results = model.train(
data=opt.data,
epochs=opt.epochs,
batch=opt.batch,
imgsz=opt.imgsz,
lr0=opt.lr0,
lrf=opt.lrf,
optimizer=opt.optimizer,
resume=True,
device='cpu' # 强制指定使用CPU
)
# 输出训练结果
print('训练完成,最终指标:')
print(results)
if __name__ == '__main__':
opt = parse_opt()
main(opt)