- 在 train/labels 和 valid/labels目录下添加了多个标签文件 - 标签文件包含物体检测的坐标信息 - 此次添加的标签主要用于绝缘子检测任务
32 lines
1.2 KiB
Python
32 lines
1.2 KiB
Python
import torch
|
||
from ultralytics import YOLO
|
||
|
||
# Create a new YOLO model from scratch
|
||
model = YOLO("yolo12n.yaml")
|
||
|
||
# 控制 GPU 显存使用(例如限制为总显存的 70%)
|
||
#torch.cuda.set_per_process_memory_fraction(0.7, device=0)
|
||
|
||
# 增加训练参数配置:图像大小、批次大小、学习率、数据增强等
|
||
if __name__ == "__main__":
|
||
results = model.train(
|
||
data="data.yaml",
|
||
epochs=50, # 增加训练轮次到50(线缆识别通常需要更多迭代)
|
||
imgsz=640, # 设置输入图像尺寸
|
||
batch=4, # 根据GPU内存调整批次大小
|
||
lr0=0.01, # 初始学习率
|
||
augment=True, # 启用数据增强
|
||
name="cable_detection", # 训练结果保存目录
|
||
device=0, # 使用GPU 0
|
||
workers=0
|
||
)
|
||
|
||
# 新增:训练完成后自动评估
|
||
results = model.val(data="data.yaml")
|
||
|
||
# 新增:导出最佳模型到PyTorch TorchScript格式(移除ONNX依赖)
|
||
success = model.export(format="torchscript") # 使用支持的PyTorch导出格式
|
||
|
||
# 新增:保存原始PyTorch模型权重到model.pt
|
||
model.save('model.pt') # 直接保存.pt格式的原始模型文件
|