- 添加 U-Net 模型实现图像分割功能- 实现 SIFT 特征提取算法 - 创建实验报告模板和环境配置指南 - 添加数据集下载脚本和目录结构设置脚本 - 实现模型训练和测试流程
80 lines
2.5 KiB
Python
80 lines
2.5 KiB
Python
import os
|
||
import torch
|
||
from torch.utils.data import DataLoader
|
||
from torchvision import transforms
|
||
from PIL import Image
|
||
import cv2
|
||
import numpy as np
|
||
|
||
# 导入训练好的分割数据集类和模型定义
|
||
from unet_coco_segmentation import CocoSegDataset
|
||
import segmentation_models_pytorch as smp # 如使用smp模型
|
||
|
||
# 配置
|
||
TEST_DIR = '../data/test' # 测试集图像目录
|
||
TEST_ANN = '../data/test/_annotations.coco.json' # 测试集COCO注释文件
|
||
MODEL_PATH = 'unet_coco_segmentation.pth' # 预训练模型权重
|
||
OUTPUT_DIR = 'output/unet_results'
|
||
|
||
# 建立输出目录
|
||
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
||
|
||
# 设备配置
|
||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||
|
||
# 数据预处理(与训练时保持一致)
|
||
transform = transforms.Compose([
|
||
transforms.Resize((256, 256)),
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
|
||
])
|
||
# 仅需同样的变换,不需要mask变换
|
||
|
||
# 加载测试集
|
||
from unet_coco_segmentation import CocoSegDataset
|
||
|
||
test_dataset = CocoSegDataset(
|
||
root_dir=TEST_DIR,
|
||
annotation_file=TEST_ANN,
|
||
transforms=None, # 在Dataset里单独处理
|
||
mask_transforms=None
|
||
)
|
||
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
||
|
||
# 初始化模型(同训练时)
|
||
model = smp.Unet(
|
||
encoder_name='resnet34',
|
||
encoder_weights=None,
|
||
in_channels=3,
|
||
classes=1
|
||
)
|
||
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) # 加载权重
|
||
model.to(device)
|
||
model.eval()
|
||
print("成功加载模型权重并切换到评估模式")
|
||
|
||
# 遍历测试集
|
||
for img, mask_true in test_loader:
|
||
# img 保存为Tensor batch=1
|
||
img = img.to(device)
|
||
img_id = test_dataset.image_ids[test_loader.dataset.image_ids.index(test_dataset.image_ids[0])] # 这里获取ID
|
||
# 预测
|
||
with torch.no_grad():
|
||
output = model(img)
|
||
output_prob = torch.sigmoid(output).squeeze().cpu().numpy()
|
||
# 恢复到原始尺寸
|
||
# 获取原图尺寸
|
||
# 重新加载原始图像获取尺寸
|
||
img_info = next(item for item in test_dataset.coco['images'] if item['id']==test_dataset.image_ids[0])
|
||
orig_w, orig_h = img_info['width'], img_info['height']
|
||
output_prob = cv2.resize(output_prob, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
||
|
||
# 二值化
|
||
threshold = 0.5
|
||
output_mask = (output_prob > threshold).astype(np.uint8) * 255
|
||
|
||
# 保存结果
|
||
output_path = os.path.join(OUTPUT_DIR, f"{img_id}_mask.png")
|
||
cv2.imwrite(output_path, output_mask)
|
||
print(f"Saved mask for image {img_id} to {output_path}")
|