Isolators-Detection/train.py
fly6516 c33f35d0de feat(Isolators-Detection): 添加训练和验证集标签
- 在 train/labels 和 valid/labels目录下添加了多个标签文件
- 标签文件包含物体检测的坐标信息
- 此次添加的标签主要用于绝缘子检测任务
2025-03-25 10:26:19 +08:00

32 lines
1.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
from ultralytics import YOLO
# Create a new YOLO model from scratch
model = YOLO("yolo12n.yaml")
# 控制 GPU 显存使用(例如限制为总显存的 70%
#torch.cuda.set_per_process_memory_fraction(0.7, device=0)
# 增加训练参数配置:图像大小、批次大小、学习率、数据增强等
if __name__ == "__main__":
results = model.train(
data="data.yaml",
epochs=50, # 增加训练轮次到50线缆识别通常需要更多迭代
imgsz=640, # 设置输入图像尺寸
batch=4, # 根据GPU内存调整批次大小
lr0=0.01, # 初始学习率
augment=True, # 启用数据增强
name="cable_detection", # 训练结果保存目录
device=0, # 使用GPU 0
workers=0
)
# 新增:训练完成后自动评估
results = model.val(data="data.yaml")
# 新增导出最佳模型到PyTorch TorchScript格式移除ONNX依赖
success = model.export(format="torchscript") # 使用支持的PyTorch导出格式
# 新增保存原始PyTorch模型权重到model.pt
model.save('model.pt') # 直接保存.pt格式的原始模型文件