182 lines
8.4 KiB
Python
182 lines
8.4 KiB
Python
|
import os # 文件和路径管理
|
||
|
import time # 计时
|
||
|
import logging # 日志记录
|
||
|
import json # 处理COCO注释JSON
|
||
|
|
||
|
import numpy as np # 数值运算
|
||
|
from PIL import Image # 图像处理
|
||
|
|
||
|
import torch # PyTorch核心
|
||
|
import torch.nn as nn # 神经网络模块
|
||
|
import torch.optim as optim # 优化器
|
||
|
from torch.utils.data import Dataset, DataLoader # 数据集和加载器
|
||
|
import torchvision.utils as vutils # 可视化工具
|
||
|
|
||
|
from pycocotools import mask as maskUtils # COCO掩码处理工具
|
||
|
import albumentations as A # 数据增强
|
||
|
from albumentations.pytorch import ToTensorV2 # Albumentations到Tensor的转换
|
||
|
|
||
|
torch.manual_seed(42) # 固定随机种子,确保可复现
|
||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # 配置日志格式
|
||
|
|
||
|
# ------------------ 定义 U-Net 模型(自定义,不依赖外部分割包) ------------------
|
||
|
class UNet(nn.Module):
|
||
|
def __init__(self, in_channels=3, base_channels=64, out_channels=1):
|
||
|
super(UNet, self).__init__()
|
||
|
# 双卷积块
|
||
|
def double_conv(in_c, out_c):
|
||
|
return nn.Sequential(
|
||
|
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
||
|
nn.BatchNorm2d(out_c),
|
||
|
nn.ReLU(inplace=True),
|
||
|
nn.Conv2d(out_c, out_c, kernel_size=3, padding=1),
|
||
|
nn.BatchNorm2d(out_c),
|
||
|
nn.ReLU(inplace=True)
|
||
|
)
|
||
|
# 下采样路径
|
||
|
self.enc1 = double_conv(in_channels, base_channels)
|
||
|
self.enc2 = double_conv(base_channels, base_channels*2)
|
||
|
self.enc3 = double_conv(base_channels*2, base_channels*4)
|
||
|
self.enc4 = double_conv(base_channels*4, base_channels*8)
|
||
|
self.pool = nn.MaxPool2d(2)
|
||
|
# 中心
|
||
|
self.center = double_conv(base_channels*8, base_channels*16)
|
||
|
# 上采样路径
|
||
|
self.up4 = nn.ConvTranspose2d(base_channels*16, base_channels*8, kernel_size=2, stride=2)
|
||
|
self.dec4 = double_conv(base_channels*16, base_channels*8)
|
||
|
self.up3 = nn.ConvTranspose2d(base_channels*8, base_channels*4, kernel_size=2, stride=2)
|
||
|
self.dec3 = double_conv(base_channels*8, base_channels*4)
|
||
|
self.up2 = nn.ConvTranspose2d(base_channels*4, base_channels*2, kernel_size=2, stride=2)
|
||
|
self.dec2 = double_conv(base_channels*4, base_channels*2)
|
||
|
self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=2, stride=2)
|
||
|
self.dec1 = double_conv(base_channels*2, base_channels)
|
||
|
# 最终1x1卷积
|
||
|
self.final = nn.Conv2d(base_channels, out_channels, kernel_size=1)
|
||
|
|
||
|
def forward(self, x):
|
||
|
# 编码
|
||
|
e1 = self.enc1(x)
|
||
|
e2 = self.enc2(self.pool(e1))
|
||
|
e3 = self.enc3(self.pool(e2))
|
||
|
e4 = self.enc4(self.pool(e3))
|
||
|
# 中心
|
||
|
c = self.center(self.pool(e4))
|
||
|
# 解码
|
||
|
d4 = self.up4(c)
|
||
|
d4 = torch.cat([d4, e4], dim=1)
|
||
|
d4 = self.dec4(d4)
|
||
|
d3 = self.up3(d4)
|
||
|
d3 = torch.cat([d3, e3], dim=1)
|
||
|
d3 = self.dec3(d3)
|
||
|
d2 = self.up2(d3)
|
||
|
d2 = torch.cat([d2, e2], dim=1)
|
||
|
d2 = self.dec2(d2)
|
||
|
d1 = self.up1(d2)
|
||
|
d1 = torch.cat([d1, e1], dim=1)
|
||
|
d1 = self.dec1(d1)
|
||
|
return self.final(d1)
|
||
|
|
||
|
# ------------------ Dice Loss 与 IoU 指标 ------------------
|
||
|
class DiceLoss(nn.Module):
|
||
|
def __init__(self, eps=1e-6):
|
||
|
super(DiceLoss, self).__init__()
|
||
|
self.eps = eps
|
||
|
def forward(self, logits, targets):
|
||
|
probs = torch.sigmoid(logits)
|
||
|
num = 2 * (probs * targets).sum(dim=(2,3)) + self.eps
|
||
|
den = probs.sum(dim=(2,3)) + targets.sum(dim=(2,3)) + self.eps
|
||
|
return 1 - (num/den).mean()
|
||
|
|
||
|
def iou_score(preds, targets, eps=1e-6):
|
||
|
preds = (preds > 0.5).float()
|
||
|
inter = (preds * targets).sum(dim=(2,3))
|
||
|
union = preds.sum(dim=(2,3)) + targets.sum(dim=(2,3)) - inter
|
||
|
return ((inter+eps)/(union+eps)).mean().item()
|
||
|
|
||
|
# ------------------ COCO 分割数据集 ------------------
|
||
|
class CocoSegDataset(Dataset):
|
||
|
def __init__(self, root_dir, annotation_file, transforms=None, mask_transforms=None):
|
||
|
self.root_dir = root_dir
|
||
|
self.transforms = transforms
|
||
|
self.mask_transforms = mask_transforms
|
||
|
with open(annotation_file, 'r') as f:
|
||
|
self.coco = json.load(f)
|
||
|
self.annotations = {}
|
||
|
for ann in self.coco['annotations']:
|
||
|
self.annotations.setdefault(ann['image_id'], []).append(ann)
|
||
|
self.image_ids = list(self.annotations.keys())
|
||
|
def __len__(self): return len(self.image_ids)
|
||
|
def __getitem__(self, idx):
|
||
|
img_id = self.image_ids[idx]
|
||
|
info = next(x for x in self.coco['images'] if x['id']==img_id)
|
||
|
img = Image.open(os.path.join(self.root_dir, info['file_name'])).convert('RGB')
|
||
|
h, w = info['height'], info['width']
|
||
|
mask = np.zeros((h,w), dtype=np.uint8)
|
||
|
for ann in self.annotations[img_id]:
|
||
|
seg = ann['segmentation']
|
||
|
if isinstance(seg, list): rle = maskUtils.merge(maskUtils.frPyObjects(seg,h,w))
|
||
|
else: rle = seg
|
||
|
mask += maskUtils.decode(rle)
|
||
|
mask = (mask>0).astype(np.float32)
|
||
|
mask = Image.fromarray(mask)
|
||
|
if self.transforms and self.mask_transforms:
|
||
|
aug = self.transforms(image=np.array(img), mask=np.array(mask))
|
||
|
img_t = aug['image']; m_t = aug['mask'].unsqueeze(0)
|
||
|
else:
|
||
|
img_t = ToTensorV2()(image=np.array(img))['image']
|
||
|
m_t = ToTensorV2()(image=np.array(mask))['image']
|
||
|
return img_t, m_t
|
||
|
|
||
|
# ------------------ 主训练流程 ------------------
|
||
|
if __name__ == '__main__':
|
||
|
# 路径配置
|
||
|
train_dir, val_dir = '../data/train', '../data/valid'
|
||
|
train_ann, val_ann = os.path.join(train_dir,'_annotations.coco.json'), os.path.join(val_dir,'_annotations.coco.json')
|
||
|
# 增强配置
|
||
|
train_tf = A.Compose([A.Resize(256,256),A.HorizontalFlip(0.5),A.RandomBrightnessContrast(0.2),A.ShiftScaleRotate(0.5),A.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),ToTensorV2()])
|
||
|
val_tf = A.Compose([A.Resize(256,256),A.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),ToTensorV2()])
|
||
|
# 数据加载
|
||
|
train_ds = CocoSegDataset(train_dir,train_ann,train_tf,train_tf)
|
||
|
val_ds = CocoSegDataset(val_dir, val_ann, val_tf, val_tf)
|
||
|
train_ld = DataLoader(train_ds,batch_size=8,shuffle=True,num_workers=4)
|
||
|
val_ld = DataLoader(val_ds, batch_size=8,shuffle=False,num_workers=4)
|
||
|
logging.info(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")
|
||
|
# 模型与训练配置
|
||
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
model = UNet().to(device)
|
||
|
opt = optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1e-4)
|
||
|
sched = optim.lr_scheduler.ReduceLROnPlateau(opt,'min',patience=3,factor=0.5)
|
||
|
bce = nn.BCEWithLogitsLoss(); dice = DiceLoss()
|
||
|
epochs=20; vis_dir='output/val_visuals'; os.makedirs(vis_dir,exist_ok=True)
|
||
|
# 训练循环
|
||
|
for ep in range(1,epochs+1):
|
||
|
model.train(); run_loss=0
|
||
|
for imgs,msks in train_ld:
|
||
|
imgs,msks=imgs.to(device),msks.to(device)
|
||
|
opt.zero_grad(); out=model(imgs)
|
||
|
l=(bce(out,msks)+dice(out,msks)); l.backward(); opt.step()
|
||
|
run_loss+=l.item()*imgs.size(0)
|
||
|
tr_loss=run_loss/len(train_ds)
|
||
|
# 验证
|
||
|
model.eval(); v_loss=0; v_iou=0; v_dice=0
|
||
|
with torch.no_grad():
|
||
|
for imgs,msks in val_ld:
|
||
|
imgs,msks=imgs.to(device),msks.to(device)
|
||
|
out=model(imgs)
|
||
|
v_loss+=(bce(out,msks)+dice(out,msks)).item()*imgs.size(0)
|
||
|
pr=torch.sigmoid(out)
|
||
|
v_iou+=iou_score(pr,msks)*imgs.size(0)
|
||
|
v_dice+=(1 - dice(out,msks)).item()*imgs.size(0)
|
||
|
v_loss/=len(val_ds); v_iou/=len(val_ds); v_dice/=len(val_ds)
|
||
|
logging.info(f"Epoch {ep}/{epochs} - Tr:{tr_loss:.4f} Val:{v_loss:.4f} IoU:{v_iou:.4f} Dice:{v_dice:.4f}")
|
||
|
# 可视化
|
||
|
si,sm=next(iter(val_ld)); si=si.to(device)
|
||
|
with torch.no_grad(): sp=torch.sigmoid(model(si))
|
||
|
grid=vutils.make_grid(torch.cat([si.cpu(),sm.repeat(1,3,1,1).cpu(),sp.repeat(1,3,1,1).cpu()],0),nrow=si.size(0))
|
||
|
vpth=os.path.join(vis_dir,f'ep{ep}.png'); vutils.save_image(grid,vpth)
|
||
|
logging.info(f"Saved visual: {vpth}")
|
||
|
sched.step(v_loss)
|
||
|
# 保存模型
|
||
|
torch.save(model.state_dict(),'unet_coco_segmentation.pth')
|
||
|
logging.info('训练完成,模型已保存')
|