AI-exp-2/code/test_unet.py
fly6516 0cef8dc8e8 docs(code): 优化代码结构和注释
- 调整代码格式,提高可读性- 增加详细注释,解释代码功能
- 优化变量命名,提高代码可理解性
- 简化部分代码逻辑,提高执行效率
2025-05-25 16:45:26 +08:00

78 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 导入必要的库
import os # 操作系统模块
import torch # PyTorch深度学习框架
from torch.utils.data import DataLoader # 数据加载器
from torchvision import transforms # 图像变换工具
from PIL import Image # 图像处理库
import cv2 # OpenCV库
import numpy as np # NumPy库
# 导入自定义模块
from unet_coco_segmentation import CocoSegDataset # COCO数据集类
import segmentation_models_pytorch as smp # 分割模型库
# 配置参数
TEST_DIR = '../data/test' # 测试集目录
TEST_ANN = '../data/test/_annotations.coco.json' # 测试集标注文件
MODEL_PATH = 'unet_coco_segmentation.pth' # 模型权重路径
OUTPUT_DIR = 'output/unet_results' # 输出目录
# 创建输出目录(如果不存在)
os.makedirs(OUTPUT_DIR, exist_ok=True)
# 设备配置GPU或CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 数据预处理配置(与训练时保持一致)
transform = transforms.Compose([
transforms.Resize((256, 256)), # 调整尺寸
transforms.ToTensor(), # 转换为张量
transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)), # 归一化
])
# 加载测试集
test_dataset = CocoSegDataset(
root_dir=TEST_DIR,
annotation_file=TEST_ANN,
transforms=None, # 在Dataset里单独处理
mask_transforms=None
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
# 初始化模型(与训练时相同)
model = smp.Unet(
encoder_name='resnet34', # 编码器名称
encoder_weights=None, # 不使用预训练权重
in_channels=3, # 输入通道数
classes=1 # 输出类别数
)
# 加载预训练模型权重
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
model.to(device) # 将模型移动到指定设备
model.eval() # 设置为评估模式
print("成功加载模型权重并切换到评估模式")
# 遍历测试集
for img, mask_true in test_loader:
# 图像保存为Tensorbatch=1
img = img.to(device)
# 获取图像ID
img_id = test_dataset.image_ids[test_loader.dataset.image_ids.index(test_dataset.image_ids[0])]
# 进行预测
with torch.no_grad():
output = model(img)
output_prob = torch.sigmoid(output).squeeze().cpu().numpy()
# 获取原始图像尺寸
img_info = next(item for item in test_dataset.coco['images'] if item['id']==test_dataset.image_ids[0])
orig_w, orig_h = img_info['width'], img_info['height']
# 调整输出尺寸到原始图像大小
output_prob = cv2.resize(output_prob, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
# 二值化处理
threshold = 0.5
output_mask = (output_prob > threshold).astype(np.uint8) * 255
# 保存结果
output_path = os.path.join(OUTPUT_DIR, f"{img_id}_mask.png")
cv2.imwrite(output_path, output_mask)
print(f"Saved mask for image {img_id} to {output_path}")