32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
|
import torch
|
|||
|
from ultralytics import YOLO
|
|||
|
|
|||
|
# Create a new YOLO model from scratch
|
|||
|
model = YOLO("yolo12n.yaml")
|
|||
|
|
|||
|
# 控制 GPU 显存使用(例如限制为总显存的 70%)
|
|||
|
#torch.cuda.set_per_process_memory_fraction(0.7, device=0)
|
|||
|
|
|||
|
# 增加训练参数配置:图像大小、批次大小、学习率、数据增强等
|
|||
|
if __name__ == "__main__":
|
|||
|
results = model.train(
|
|||
|
data="data.yaml",
|
|||
|
epochs=50, # 增加训练轮次到50(线缆识别通常需要更多迭代)
|
|||
|
imgsz=640, # 设置输入图像尺寸
|
|||
|
batch=4, # 根据GPU内存调整批次大小
|
|||
|
lr0=0.01, # 初始学习率
|
|||
|
augment=True, # 启用数据增强
|
|||
|
name="cable_detection", # 训练结果保存目录
|
|||
|
device=0, # 使用GPU 0
|
|||
|
workers=0
|
|||
|
)
|
|||
|
|
|||
|
# 新增:训练完成后自动评估
|
|||
|
results = model.val(data="data.yaml")
|
|||
|
|
|||
|
# 新增:导出最佳模型到PyTorch TorchScript格式(移除ONNX依赖)
|
|||
|
success = model.export(format="torchscript") # 使用支持的PyTorch导出格式
|
|||
|
|
|||
|
# 新增:保存原始PyTorch模型权重到model.pt
|
|||
|
model.save('model.pt') # 直接保存.pt格式的原始模型文件
|