docs(code): 优化代码结构和注释
- 调整代码格式,提高可读性- 增加详细注释,解释代码功能 - 优化变量命名,提高代码可理解性 - 简化部分代码逻辑,提高执行效率
This commit is contained in:
parent
b1e90c4c9d
commit
0cef8dc8e8
@ -1,55 +1,58 @@
|
|||||||
import cv2
|
# 导入必要的库
|
||||||
import numpy as np
|
import cv2 # OpenCV库用于图像处理
|
||||||
import time
|
import numpy as np # NumPy库用于数值计算
|
||||||
import os
|
import time # 时间模块用于计时
|
||||||
|
import os # 操作系统模块用于文件和目录操作
|
||||||
|
|
||||||
def sift_feature_extraction(image_path):
|
def sift_feature_extraction(image_path):
|
||||||
# 读取图像
|
# 读取输入图像
|
||||||
img = cv2.imread(image_path)
|
img = cv2.imread(image_path)
|
||||||
|
# 检查图像是否加载成功
|
||||||
if img is None:
|
if img is None:
|
||||||
print("无法加载图像,请检查路径")
|
print("无法加载图像,请检查路径")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 转换为灰度图
|
# 将图像转换为灰度图
|
||||||
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||||
|
|
||||||
# 创建SIFT对象
|
# 创建SIFT特征检测器对象
|
||||||
sift = cv2.SIFT_create()
|
sift = cv2.SIFT_create()
|
||||||
|
|
||||||
# 记录开始时间
|
# 记录特征提取开始时间
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 检测关键点和描述符
|
# 检测关键点并计算描述符
|
||||||
keypoints, descriptors = sift.detectAndCompute(gray, None)
|
keypoints, descriptors = sift.detectAndCompute(gray, None)
|
||||||
|
|
||||||
# 计算耗时
|
# 计算特征提取耗时
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
processing_time = end_time - start_time
|
processing_time = end_time - start_time
|
||||||
|
|
||||||
# 绘制关键点
|
# 绘制检测到的关键点
|
||||||
img_with_keypoints = cv2.drawKeypoints(
|
img_with_keypoints = cv2.drawKeypoints(
|
||||||
gray, keypoints, img,
|
gray, keypoints, img,
|
||||||
flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS
|
flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建输出目录
|
# 创建输出目录(如果不存在)
|
||||||
output_dir = "output/sift_results"
|
output_dir = "output/sift_results"
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
# 保存结果
|
# 构建输出文件路径
|
||||||
output_path = os.path.join(output_dir, "sift_result.jpg")
|
output_path = os.path.join(output_dir, "sift_result.jpg")
|
||||||
|
# 保存结果图像
|
||||||
cv2.imwrite(output_path, img_with_keypoints)
|
cv2.imwrite(output_path, img_with_keypoints)
|
||||||
|
|
||||||
# 输出结果信息
|
# 输出特征提取结果信息
|
||||||
print(f"检测到 {len(keypoints)} 个关键点")
|
print(f"检测到 {len(keypoints)} 个关键点")
|
||||||
print(f"描述符形状: {descriptors.shape if descriptors is not None else 'None'}")
|
print(f"描述符形状: {descriptors.shape if descriptors is not None else 'None'}")
|
||||||
print(f"特征提取耗时: {processing_time:.4f} 秒")
|
print(f"特征提取耗时: {processing_time:.4f} 秒")
|
||||||
print(f"结果已保存至: {output_path}")
|
print(f"结果已保存至: {output_path}")
|
||||||
|
|
||||||
# 显示结果(可选)
|
# 显示结果窗口(可选)
|
||||||
cv2.imshow('SIFT Features', img_with_keypoints)
|
cv2.imshow('SIFT Features', img_with_keypoints)
|
||||||
cv2.waitKey(0)
|
cv2.waitKey(0) # 等待按键
|
||||||
cv2.destroyAllWindows()
|
cv2.destroyAllWindows() # 关闭所有窗口
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 示例用法
|
# 示例用法
|
||||||
|
@ -1,38 +1,36 @@
|
|||||||
import os
|
# 导入必要的库
|
||||||
import torch
|
import os # 操作系统模块
|
||||||
from torch.utils.data import DataLoader
|
import torch # PyTorch深度学习框架
|
||||||
from torchvision import transforms
|
from torch.utils.data import DataLoader # 数据加载器
|
||||||
from PIL import Image
|
from torchvision import transforms # 图像变换工具
|
||||||
import cv2
|
from PIL import Image # 图像处理库
|
||||||
import numpy as np
|
import cv2 # OpenCV库
|
||||||
|
import numpy as np # NumPy库
|
||||||
|
|
||||||
# 导入训练好的分割数据集类和模型定义
|
# 导入自定义模块
|
||||||
from unet_coco_segmentation import CocoSegDataset
|
from unet_coco_segmentation import CocoSegDataset # COCO数据集类
|
||||||
import segmentation_models_pytorch as smp # 如使用smp模型
|
import segmentation_models_pytorch as smp # 分割模型库
|
||||||
|
|
||||||
# 配置
|
# 配置参数
|
||||||
TEST_DIR = '../data/test' # 测试集图像目录
|
TEST_DIR = '../data/test' # 测试集目录
|
||||||
TEST_ANN = '../data/test/_annotations.coco.json' # 测试集COCO注释文件
|
TEST_ANN = '../data/test/_annotations.coco.json' # 测试集标注文件
|
||||||
MODEL_PATH = 'unet_coco_segmentation.pth' # 预训练模型权重
|
MODEL_PATH = 'unet_coco_segmentation.pth' # 模型权重路径
|
||||||
OUTPUT_DIR = 'output/unet_results'
|
OUTPUT_DIR = 'output/unet_results' # 输出目录
|
||||||
|
|
||||||
# 建立输出目录
|
# 创建输出目录(如果不存在)
|
||||||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||||||
|
|
||||||
# 设备配置
|
# 设备配置(GPU或CPU)
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
# 数据预处理(与训练时保持一致)
|
# 数据预处理配置(与训练时保持一致)
|
||||||
transform = transforms.Compose([
|
transform = transforms.Compose([
|
||||||
transforms.Resize((256, 256)),
|
transforms.Resize((256, 256)), # 调整尺寸
|
||||||
transforms.ToTensor(),
|
transforms.ToTensor(), # 转换为张量
|
||||||
transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
|
transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)), # 归一化
|
||||||
])
|
])
|
||||||
# 仅需同样的变换,不需要mask变换
|
|
||||||
|
|
||||||
# 加载测试集
|
# 加载测试集
|
||||||
from unet_coco_segmentation import CocoSegDataset
|
|
||||||
|
|
||||||
test_dataset = CocoSegDataset(
|
test_dataset = CocoSegDataset(
|
||||||
root_dir=TEST_DIR,
|
root_dir=TEST_DIR,
|
||||||
annotation_file=TEST_ANN,
|
annotation_file=TEST_ANN,
|
||||||
@ -41,35 +39,36 @@ test_dataset = CocoSegDataset(
|
|||||||
)
|
)
|
||||||
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
||||||
|
|
||||||
# 初始化模型(同训练时)
|
# 初始化模型(与训练时相同)
|
||||||
model = smp.Unet(
|
model = smp.Unet(
|
||||||
encoder_name='resnet34',
|
encoder_name='resnet34', # 编码器名称
|
||||||
encoder_weights=None,
|
encoder_weights=None, # 不使用预训练权重
|
||||||
in_channels=3,
|
in_channels=3, # 输入通道数
|
||||||
classes=1
|
classes=1 # 输出类别数
|
||||||
)
|
)
|
||||||
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) # 加载权重
|
# 加载预训练模型权重
|
||||||
model.to(device)
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
||||||
model.eval()
|
model.to(device) # 将模型移动到指定设备
|
||||||
|
model.eval() # 设置为评估模式
|
||||||
print("成功加载模型权重并切换到评估模式")
|
print("成功加载模型权重并切换到评估模式")
|
||||||
|
|
||||||
# 遍历测试集
|
# 遍历测试集
|
||||||
for img, mask_true in test_loader:
|
for img, mask_true in test_loader:
|
||||||
# img 保存为Tensor batch=1
|
# 图像保存为Tensor,batch=1
|
||||||
img = img.to(device)
|
img = img.to(device)
|
||||||
img_id = test_dataset.image_ids[test_loader.dataset.image_ids.index(test_dataset.image_ids[0])] # 这里获取ID
|
# 获取图像ID
|
||||||
# 预测
|
img_id = test_dataset.image_ids[test_loader.dataset.image_ids.index(test_dataset.image_ids[0])]
|
||||||
|
# 进行预测
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
output = model(img)
|
output = model(img)
|
||||||
output_prob = torch.sigmoid(output).squeeze().cpu().numpy()
|
output_prob = torch.sigmoid(output).squeeze().cpu().numpy()
|
||||||
# 恢复到原始尺寸
|
# 获取原始图像尺寸
|
||||||
# 获取原图尺寸
|
|
||||||
# 重新加载原始图像获取尺寸
|
|
||||||
img_info = next(item for item in test_dataset.coco['images'] if item['id']==test_dataset.image_ids[0])
|
img_info = next(item for item in test_dataset.coco['images'] if item['id']==test_dataset.image_ids[0])
|
||||||
orig_w, orig_h = img_info['width'], img_info['height']
|
orig_w, orig_h = img_info['width'], img_info['height']
|
||||||
|
# 调整输出尺寸到原始图像大小
|
||||||
output_prob = cv2.resize(output_prob, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
output_prob = cv2.resize(output_prob, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
||||||
|
|
||||||
# 二值化
|
# 二值化处理
|
||||||
threshold = 0.5
|
threshold = 0.5
|
||||||
output_mask = (output_prob > threshold).astype(np.uint8) * 255
|
output_mask = (output_prob > threshold).astype(np.uint8) * 255
|
||||||
|
|
||||||
|
@ -1,29 +1,34 @@
|
|||||||
|
# 导入必要的库
|
||||||
import os # 文件和路径管理
|
import os # 文件和路径管理
|
||||||
import time # 计时
|
import time # 计时功能
|
||||||
import logging # 日志记录
|
import logging # 日志记录
|
||||||
import json # 处理COCO注释JSON
|
import json # 处理JSON数据
|
||||||
|
|
||||||
import numpy as np # 数值运算
|
import numpy as np # 数值运算
|
||||||
from PIL import Image # 图像处理
|
from PIL import Image # 图像处理
|
||||||
|
|
||||||
import torch # PyTorch核心
|
import torch # PyTorch深度学习框架
|
||||||
import torch.nn as nn # 神经网络模块
|
import torch.nn as nn # 神经网络模块
|
||||||
import torch.optim as optim # 优化器
|
import torch.optim as optim # 优化器
|
||||||
from torch.utils.data import Dataset, DataLoader # 数据集和加载器
|
from torch.utils.data import Dataset, DataLoader # 数据集和加载器
|
||||||
import torchvision.utils as vutils # 可视化工具
|
import torchvision.utils as vutils # 可视化工具
|
||||||
|
|
||||||
|
# COCO数据集相关库
|
||||||
from pycocotools import mask as maskUtils # COCO掩码处理工具
|
from pycocotools import mask as maskUtils # COCO掩码处理工具
|
||||||
import albumentations as A # 数据增强
|
import albumentations as A # 数据增强库
|
||||||
from albumentations.pytorch import ToTensorV2 # Albumentations到Tensor的转换
|
from albumentations.pytorch import ToTensorV2 # Albumentations到Tensor的转换
|
||||||
|
|
||||||
torch.manual_seed(42) # 固定随机种子,确保可复现
|
torch.backends.cudnn.benchmark = True
|
||||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # 配置日志格式
|
# 固定随机种子,确保实验可复现
|
||||||
|
torch.manual_seed(42)
|
||||||
|
# 配置日志格式
|
||||||
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
|
||||||
# ------------------ 定义 U-Net 模型(自定义,不依赖外部分割包) ------------------
|
# ------------------ 定义 U-Net 模型(自定义) ------------------
|
||||||
class UNet(nn.Module):
|
class UNet(nn.Module):
|
||||||
def __init__(self, in_channels=3, base_channels=64, out_channels=1):
|
def __init__(self, in_channels=3, base_channels=64, out_channels=1):
|
||||||
super(UNet, self).__init__()
|
super(UNet, self).__init__()
|
||||||
# 双卷积块
|
# 双卷积块定义
|
||||||
def double_conv(in_c, out_c):
|
def double_conv(in_c, out_c):
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
|
||||||
@ -38,8 +43,8 @@ class UNet(nn.Module):
|
|||||||
self.enc2 = double_conv(base_channels, base_channels*2)
|
self.enc2 = double_conv(base_channels, base_channels*2)
|
||||||
self.enc3 = double_conv(base_channels*2, base_channels*4)
|
self.enc3 = double_conv(base_channels*2, base_channels*4)
|
||||||
self.enc4 = double_conv(base_channels*4, base_channels*8)
|
self.enc4 = double_conv(base_channels*4, base_channels*8)
|
||||||
self.pool = nn.MaxPool2d(2)
|
self.pool = nn.MaxPool2d(2) # 最大池化层
|
||||||
# 中心
|
# 中心层
|
||||||
self.center = double_conv(base_channels*8, base_channels*16)
|
self.center = double_conv(base_channels*8, base_channels*16)
|
||||||
# 上采样路径
|
# 上采样路径
|
||||||
self.up4 = nn.ConvTranspose2d(base_channels*16, base_channels*8, kernel_size=2, stride=2)
|
self.up4 = nn.ConvTranspose2d(base_channels*16, base_channels*8, kernel_size=2, stride=2)
|
||||||
@ -50,20 +55,20 @@ class UNet(nn.Module):
|
|||||||
self.dec2 = double_conv(base_channels*4, base_channels*2)
|
self.dec2 = double_conv(base_channels*4, base_channels*2)
|
||||||
self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=2, stride=2)
|
self.up1 = nn.ConvTranspose2d(base_channels*2, base_channels, kernel_size=2, stride=2)
|
||||||
self.dec1 = double_conv(base_channels*2, base_channels)
|
self.dec1 = double_conv(base_channels*2, base_channels)
|
||||||
# 最终1x1卷积
|
# 最终1x1卷积层
|
||||||
self.final = nn.Conv2d(base_channels, out_channels, kernel_size=1)
|
self.final = nn.Conv2d(base_channels, out_channels, kernel_size=1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# 编码
|
# 编码路径
|
||||||
e1 = self.enc1(x)
|
e1 = self.enc1(x)
|
||||||
e2 = self.enc2(self.pool(e1))
|
e2 = self.enc2(self.pool(e1))
|
||||||
e3 = self.enc3(self.pool(e2))
|
e3 = self.enc3(self.pool(e2))
|
||||||
e4 = self.enc4(self.pool(e3))
|
e4 = self.enc4(self.pool(e3))
|
||||||
# 中心
|
# 中心层
|
||||||
c = self.center(self.pool(e4))
|
c = self.center(self.pool(e4))
|
||||||
# 解码
|
# 解码路径
|
||||||
d4 = self.up4(c)
|
d4 = self.up4(c)
|
||||||
d4 = torch.cat([d4, e4], dim=1)
|
d4 = torch.cat([d4, e4], dim=1) # 特征拼接
|
||||||
d4 = self.dec4(d4)
|
d4 = self.dec4(d4)
|
||||||
d3 = self.up3(d4)
|
d3 = self.up3(d4)
|
||||||
d3 = torch.cat([d3, e3], dim=1)
|
d3 = torch.cat([d3, e3], dim=1)
|
||||||
@ -74,108 +79,121 @@ class UNet(nn.Module):
|
|||||||
d1 = self.up1(d2)
|
d1 = self.up1(d2)
|
||||||
d1 = torch.cat([d1, e1], dim=1)
|
d1 = torch.cat([d1, e1], dim=1)
|
||||||
d1 = self.dec1(d1)
|
d1 = self.dec1(d1)
|
||||||
return self.final(d1)
|
return self.final(d1) # 最终输出
|
||||||
|
|
||||||
# ------------------ Dice Loss 与 IoU 指标 ------------------
|
# ------------------ Dice Loss 与 IoU 指标 ------------------
|
||||||
class DiceLoss(nn.Module):
|
class DiceLoss(nn.Module):
|
||||||
def __init__(self, eps=1e-6):
|
def __init__(self, eps=1e-6):
|
||||||
super(DiceLoss, self).__init__()
|
super(DiceLoss, self).__init__()
|
||||||
self.eps = eps
|
self.eps = eps # 平滑项
|
||||||
def forward(self, logits, targets):
|
def forward(self, logits, targets):
|
||||||
probs = torch.sigmoid(logits)
|
probs = torch.sigmoid(logits) # 转换为概率
|
||||||
num = 2 * (probs * targets).sum(dim=(2,3)) + self.eps
|
num = 2 * (probs * targets).sum(dim=(2,3)) + self.eps # 分子
|
||||||
den = probs.sum(dim=(2,3)) + targets.sum(dim=(2,3)) + self.eps
|
den = probs.sum(dim=(2,3)) + targets.sum(dim=(2,3)) + self.eps # 分母
|
||||||
return 1 - (num/den).mean()
|
return 1 - (num/den).mean() # 返回Dice损失
|
||||||
|
|
||||||
|
# 计算IoU评分
|
||||||
def iou_score(preds, targets, eps=1e-6):
|
def iou_score(preds, targets, eps=1e-6):
|
||||||
preds = (preds > 0.5).float()
|
preds = (preds > 0.5).float() # 二值化预测结果
|
||||||
inter = (preds * targets).sum(dim=(2,3))
|
inter = (preds * targets).sum(dim=(2,3)) # 交集
|
||||||
union = preds.sum(dim=(2,3)) + targets.sum(dim=(2,3)) - inter
|
union = preds.sum(dim=(2,3)) + targets.sum(dim=(2,3)) - inter # 并集
|
||||||
return ((inter+eps)/(union+eps)).mean().item()
|
return ((inter+eps)/(union+eps)).mean().item() # 返回IoU评分
|
||||||
|
|
||||||
# ------------------ COCO 分割数据集 ------------------
|
# ------------------ COCO 分割数据集 ------------------
|
||||||
class CocoSegDataset(Dataset):
|
class CocoSegDataset(Dataset):
|
||||||
def __init__(self, root_dir, annotation_file, transforms=None, mask_transforms=None):
|
def __init__(self, root_dir, annotation_file, transforms=None, mask_transforms=None):
|
||||||
self.root_dir = root_dir
|
self.root_dir = root_dir # 数据集根目录
|
||||||
self.transforms = transforms
|
self.transforms = transforms # 数据变换
|
||||||
self.mask_transforms = mask_transforms
|
self.mask_transforms = mask_transforms # 掩码变换
|
||||||
with open(annotation_file, 'r') as f:
|
with open(annotation_file, 'r') as f:
|
||||||
self.coco = json.load(f)
|
self.coco = json.load(f) # 加载COCO标注文件
|
||||||
self.annotations = {}
|
self.annotations = {}
|
||||||
for ann in self.coco['annotations']:
|
for ann in self.coco['annotations']:
|
||||||
self.annotations.setdefault(ann['image_id'], []).append(ann)
|
self.annotations.setdefault(ann['image_id'], []).append(ann) # 建立图像ID到标注的映射
|
||||||
self.image_ids = list(self.annotations.keys())
|
self.image_ids = list(self.annotations.keys()) # 所有图像ID列表
|
||||||
def __len__(self): return len(self.image_ids)
|
def __len__(self): return len(self.image_ids) # 返回数据集大小
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
img_id = self.image_ids[idx]
|
img_id = self.image_ids[idx] # 获取图像ID
|
||||||
info = next(x for x in self.coco['images'] if x['id']==img_id)
|
info = next(x for x in self.coco['images'] if x['id']==img_id) # 获取图像信息
|
||||||
img = Image.open(os.path.join(self.root_dir, info['file_name'])).convert('RGB')
|
img = Image.open(os.path.join(self.root_dir, info['file_name'])).convert('RGB') # 读取图像
|
||||||
h, w = info['height'], info['width']
|
h, w = info['height'], info['width'] # 获取图像尺寸
|
||||||
mask = np.zeros((h,w), dtype=np.uint8)
|
mask = np.zeros((h,w), dtype=np.uint8) # 初始化掩码
|
||||||
for ann in self.annotations[img_id]:
|
for ann in self.annotations[img_id]: # 生成掩码
|
||||||
seg = ann['segmentation']
|
seg = ann['segmentation']
|
||||||
if isinstance(seg, list): rle = maskUtils.merge(maskUtils.frPyObjects(seg,h,w))
|
if isinstance(seg, list): rle = maskUtils.merge(maskUtils.frPyObjects(seg,h,w))
|
||||||
else: rle = seg
|
else: rle = seg
|
||||||
mask += maskUtils.decode(rle)
|
mask += maskUtils.decode(rle)
|
||||||
mask = (mask>0).astype(np.float32)
|
mask = (mask>0).astype(np.float32) # 二值化掩码
|
||||||
mask = Image.fromarray(mask)
|
mask = Image.fromarray(mask) # 转换为PIL图像
|
||||||
|
# 应用数据增强
|
||||||
if self.transforms and self.mask_transforms:
|
if self.transforms and self.mask_transforms:
|
||||||
aug = self.transforms(image=np.array(img), mask=np.array(mask))
|
aug = self.transforms(image=np.array(img), mask=np.array(mask))
|
||||||
img_t = aug['image']; m_t = aug['mask'].unsqueeze(0)
|
img_t = aug['image']; m_t = aug['mask'].unsqueeze(0)
|
||||||
else:
|
else:
|
||||||
img_t = ToTensorV2()(image=np.array(img))['image']
|
img_t = ToTensorV2()(image=np.array(img))['image']
|
||||||
m_t = ToTensorV2()(image=np.array(mask))['image']
|
m_t = ToTensorV2()(image=np.array(mask))['image']
|
||||||
return img_t, m_t
|
return img_t, m_t # 返回处理后的图像和掩码
|
||||||
|
|
||||||
# ------------------ 主训练流程 ------------------
|
# ------------------ 主训练流程 ------------------
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# 路径配置
|
# 路径配置
|
||||||
train_dir, val_dir = '../data/train', '../data/valid'
|
train_dir, val_dir = '../data/train', '../data/valid' # 训练集和验证集目录
|
||||||
train_ann, val_ann = os.path.join(train_dir,'_annotations.coco.json'), os.path.join(val_dir,'_annotations.coco.json')
|
train_ann, val_ann = os.path.join(train_dir,'_annotations.coco.json'), os.path.join(val_dir,'_annotations.coco.json') # 标注文件路径
|
||||||
# 增强配置
|
# 数据增强配置
|
||||||
train_tf = A.Compose([A.Resize(256,256),A.HorizontalFlip(0.5),A.RandomBrightnessContrast(0.2),A.ShiftScaleRotate(0.5),A.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),ToTensorV2()])
|
train_tf = A.Compose([
|
||||||
val_tf = A.Compose([A.Resize(256,256),A.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)),ToTensorV2()])
|
A.Resize(256,256), # 调整尺寸
|
||||||
|
A.HorizontalFlip(0.5), # 水平翻转
|
||||||
|
A.RandomBrightnessContrast(0.2), # 随机亮度对比度
|
||||||
|
A.ShiftScaleRotate(0.5), # 随机位移缩放旋转
|
||||||
|
A.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)), # 归一化
|
||||||
|
ToTensorV2()]) # 转换为Tensor
|
||||||
|
val_tf = A.Compose([
|
||||||
|
A.Resize(256,256), # 调整尺寸
|
||||||
|
A.Normalize((0.485,0.456,0.406),(0.229,0.224,0.225)), # 归一化
|
||||||
|
ToTensorV2()]) # 转换为Tensor
|
||||||
# 数据加载
|
# 数据加载
|
||||||
train_ds = CocoSegDataset(train_dir,train_ann,train_tf,train_tf)
|
train_ds = CocoSegDataset(train_dir,train_ann,train_tf,train_tf) # 训练数据集
|
||||||
val_ds = CocoSegDataset(val_dir, val_ann, val_tf, val_tf)
|
val_ds = CocoSegDataset(val_dir, val_ann, val_tf, val_tf) # 验证数据集
|
||||||
train_ld = DataLoader(train_ds,batch_size=8,shuffle=True,num_workers=4)
|
train_ld = DataLoader(train_ds,batch_size=8,shuffle=True,num_workers=4) # 训练数据加载器
|
||||||
val_ld = DataLoader(val_ds, batch_size=8,shuffle=False,num_workers=4)
|
val_ld = DataLoader(val_ds, batch_size=8,shuffle=False,num_workers=4) # 验证数据加载器
|
||||||
logging.info(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}")
|
logging.info(f"Train samples: {len(train_ds)}, Val samples: {len(val_ds)}") # 输出数据集信息
|
||||||
# 模型与训练配置
|
# 模型与训练配置
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备选择
|
||||||
model = UNet().to(device)
|
print(f"当前设备:{device}") # 输出当前设备信息
|
||||||
opt = optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1e-4)
|
model = UNet().to(device) # 初始化模型并移动到指定设备
|
||||||
sched = optim.lr_scheduler.ReduceLROnPlateau(opt,'min',patience=3,factor=0.5)
|
opt = optim.AdamW(model.parameters(),lr=1e-3,weight_decay=1e-4) # 优化器
|
||||||
bce = nn.BCEWithLogitsLoss(); dice = DiceLoss()
|
sched = optim.lr_scheduler.ReduceLROnPlateau(opt,'min',patience=3,factor=0.5) # 学习率调度器
|
||||||
epochs=20; vis_dir='output/val_visuals'; os.makedirs(vis_dir,exist_ok=True)
|
bce = nn.BCEWithLogitsLoss(); dice = DiceLoss() # 损失函数
|
||||||
|
epochs=20; vis_dir='output/val_visuals'; os.makedirs(vis_dir,exist_ok=True) # 训练参数和可视化目录
|
||||||
# 训练循环
|
# 训练循环
|
||||||
for ep in range(1,epochs+1):
|
for ep in range(1,epochs+1):
|
||||||
model.train(); run_loss=0
|
model.train(); run_loss=0 # 设置为训练模式
|
||||||
for imgs,msks in train_ld:
|
for imgs,msks in train_ld:
|
||||||
imgs,msks=imgs.to(device),msks.to(device)
|
imgs,msks=imgs.to(device),msks.to(device) # 将数据移动到指定设备
|
||||||
opt.zero_grad(); out=model(imgs)
|
opt.zero_grad(); out=model(imgs) # 前向传播
|
||||||
l=(bce(out,msks)+dice(out,msks)); l.backward(); opt.step()
|
l=(bce(out,msks)+dice(out,msks)); l.backward(); opt.step() # 计算损失、反向传播
|
||||||
run_loss+=l.item()*imgs.size(0)
|
run_loss+=l.item()*imgs.size(0) # 累计损失
|
||||||
tr_loss=run_loss/len(train_ds)
|
tr_loss=run_loss/len(train_ds) # 计算平均训练损失
|
||||||
# 验证
|
# 验证阶段
|
||||||
model.eval(); v_loss=0; v_iou=0; v_dice=0
|
model.eval(); v_loss=0; v_iou=0; v_dice=0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for imgs,msks in val_ld:
|
for imgs,msks in val_ld:
|
||||||
imgs,msks=imgs.to(device),msks.to(device)
|
imgs,msks=imgs.to(device),msks.to(device) # 将数据移动到指定设备
|
||||||
out=model(imgs)
|
out=model(imgs) # 前向传播
|
||||||
v_loss+=(bce(out,msks)+dice(out,msks)).item()*imgs.size(0)
|
v_loss+=(bce(out,msks)+dice(out,msks)).item()*imgs.size(0) # 累计验证损失
|
||||||
pr=torch.sigmoid(out)
|
pr=torch.sigmoid(out) # 转换为概率
|
||||||
v_iou+=iou_score(pr,msks)*imgs.size(0)
|
v_iou+=iou_score(pr,msks)*imgs.size(0) # 累计IoU评分
|
||||||
v_dice+=(1 - dice(out,msks)).item()*imgs.size(0)
|
v_dice+=(1 - dice(out,msks)).item()*imgs.size(0) # 累计Dice评分
|
||||||
v_loss/=len(val_ds); v_iou/=len(val_ds); v_dice/=len(val_ds)
|
v_loss/=len(val_ds); v_iou/=len(val_ds); v_dice/=len(val_ds) # 计算平均验证指标
|
||||||
|
# 输出训练信息
|
||||||
logging.info(f"Epoch {ep}/{epochs} - Tr:{tr_loss:.4f} Val:{v_loss:.4f} IoU:{v_iou:.4f} Dice:{v_dice:.4f}")
|
logging.info(f"Epoch {ep}/{epochs} - Tr:{tr_loss:.4f} Val:{v_loss:.4f} IoU:{v_iou:.4f} Dice:{v_dice:.4f}")
|
||||||
# 可视化
|
# 可视化
|
||||||
si,sm=next(iter(val_ld)); si=si.to(device)
|
si,sm=next(iter(val_ld)); si=si.to(device) # 获取示例图像
|
||||||
with torch.no_grad(): sp=torch.sigmoid(model(si))
|
with torch.no_grad(): sp=torch.sigmoid(model(si)) # 进行预测
|
||||||
grid=vutils.make_grid(torch.cat([si.cpu(),sm.repeat(1,3,1,1).cpu(),sp.repeat(1,3,1,1).cpu()],0),nrow=si.size(0))
|
grid=vutils.make_grid(torch.cat([si.cpu(),sm.repeat(1,3,1,1).cpu(),sp.repeat(1,3,1,1).cpu()],0),nrow=si.size(0)) # 创建网格
|
||||||
vpth=os.path.join(vis_dir,f'ep{ep}.png'); vutils.save_image(grid,vpth)
|
vpth=os.path.join(vis_dir,f'ep{ep}.png'); vutils.save_image(grid,vpth) # 保存可视化结果
|
||||||
logging.info(f"Saved visual: {vpth}")
|
logging.info(f"Saved visual: {vpth}") # 输出保存信息
|
||||||
sched.step(v_loss)
|
sched.step(v_loss) # 更新学习率
|
||||||
# 保存模型
|
# 保存模型权重
|
||||||
torch.save(model.state_dict(),'unet_coco_segmentation.pth')
|
torch.save(model.state_dict(),'unet_coco_segmentation.pth')
|
||||||
logging.info('训练完成,模型已保存')
|
logging.info('训练完成,模型已保存')
|
@ -1,6 +1,8 @@
|
|||||||
torch>=1.13.1
|
torch
|
||||||
torchvision>=0.14.1
|
torchvision
|
||||||
opencv-python>=4.5.5.64
|
opencv-python>=4.5.5.64
|
||||||
Pillow>=9.2.0
|
Pillow>=9.2.0
|
||||||
kagglehub>=0.2.2
|
kagglehub>=0.2.2
|
||||||
numpy~=2.2.5
|
numpy~=2.2.5
|
||||||
|
pycocotools
|
||||||
|
albumentations
|
Loading…
Reference in New Issue
Block a user