AI-exp-2/code/test_unet.py
fly6516 b1e90c4c9d feat(AI-exp-2): 实现图像分割与 SIFT 特征提取
- 添加 U-Net 模型实现图像分割功能- 实现 SIFT 特征提取算法
- 创建实验报告模板和环境配置指南
- 添加数据集下载脚本和目录结构设置脚本
- 实现模型训练和测试流程
2025-05-16 21:17:46 +08:00

80 lines
2.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
# 导入训练好的分割数据集类和模型定义
from unet_coco_segmentation import CocoSegDataset
import segmentation_models_pytorch as smp # 如使用smp模型
# 配置
TEST_DIR = '../data/test' # 测试集图像目录
TEST_ANN = '../data/test/_annotations.coco.json' # 测试集COCO注释文件
MODEL_PATH = 'unet_coco_segmentation.pth' # 预训练模型权重
OUTPUT_DIR = 'output/unet_results'
# 建立输出目录
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 设备配置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据预处理(与训练时保持一致)
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)),
])
# 仅需同样的变换不需要mask变换
# 加载测试集
from unet_coco_segmentation import CocoSegDataset
test_dataset = CocoSegDataset(
root_dir=TEST_DIR,
annotation_file=TEST_ANN,
transforms=None, # 在Dataset里单独处理
mask_transforms=None
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# 初始化模型(同训练时)
model = smp.Unet(
encoder_name='resnet34',
encoder_weights=None,
in_channels=3,
classes=1
)
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) # 加载权重
model.to(device)
model.eval()
print("成功加载模型权重并切换到评估模式")
# 遍历测试集
for img, mask_true in test_loader:
# img 保存为Tensor batch=1
img = img.to(device)
img_id = test_dataset.image_ids[test_loader.dataset.image_ids.index(test_dataset.image_ids[0])] # 这里获取ID
# 预测
with torch.no_grad():
output = model(img)
output_prob = torch.sigmoid(output).squeeze().cpu().numpy()
# 恢复到原始尺寸
# 获取原图尺寸
# 重新加载原始图像获取尺寸
img_info = next(item for item in test_dataset.coco['images'] if item['id']==test_dataset.image_ids[0])
orig_w, orig_h = img_info['width'], img_info['height']
output_prob = cv2.resize(output_prob, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
# 二值化
threshold = 0.5
output_mask = (output_prob > threshold).astype(np.uint8) * 255
# 保存结果
output_path = os.path.join(OUTPUT_DIR, f"{img_id}_mask.png")
cv2.imwrite(output_path, output_mask)
print(f"Saved mask for image {img_id} to {output_path}")