2025-05-25 08:45:26 +00:00
|
|
|
|
# 导入必要的库
|
|
|
|
|
import os # 操作系统模块
|
|
|
|
|
import torch # PyTorch深度学习框架
|
|
|
|
|
from torch.utils.data import DataLoader # 数据加载器
|
|
|
|
|
from torchvision import transforms # 图像变换工具
|
|
|
|
|
from PIL import Image # 图像处理库
|
|
|
|
|
import cv2 # OpenCV库
|
|
|
|
|
import numpy as np # NumPy库
|
2025-05-16 13:17:46 +00:00
|
|
|
|
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 导入自定义模块
|
|
|
|
|
from unet_coco_segmentation import CocoSegDataset # COCO数据集类
|
|
|
|
|
import segmentation_models_pytorch as smp # 分割模型库
|
2025-05-16 13:17:46 +00:00
|
|
|
|
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 配置参数
|
|
|
|
|
TEST_DIR = '../data/test' # 测试集目录
|
|
|
|
|
TEST_ANN = '../data/test/_annotations.coco.json' # 测试集标注文件
|
|
|
|
|
MODEL_PATH = 'unet_coco_segmentation.pth' # 模型权重路径
|
|
|
|
|
OUTPUT_DIR = 'output/unet_results' # 输出目录
|
2025-05-16 13:17:46 +00:00
|
|
|
|
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 创建输出目录(如果不存在)
|
2025-05-16 13:17:46 +00:00
|
|
|
|
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 设备配置(GPU或CPU)
|
2025-05-16 13:17:46 +00:00
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
|
|
|
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 数据预处理配置(与训练时保持一致)
|
2025-05-16 13:17:46 +00:00
|
|
|
|
transform = transforms.Compose([
|
2025-05-25 08:45:26 +00:00
|
|
|
|
transforms.Resize((256, 256)), # 调整尺寸
|
|
|
|
|
transforms.ToTensor(), # 转换为张量
|
|
|
|
|
transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)), # 归一化
|
2025-05-16 13:17:46 +00:00
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# 加载测试集
|
|
|
|
|
test_dataset = CocoSegDataset(
|
|
|
|
|
root_dir=TEST_DIR,
|
|
|
|
|
annotation_file=TEST_ANN,
|
|
|
|
|
transforms=None, # 在Dataset里单独处理
|
|
|
|
|
mask_transforms=None
|
|
|
|
|
)
|
|
|
|
|
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
|
|
|
|
|
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 初始化模型(与训练时相同)
|
2025-05-16 13:17:46 +00:00
|
|
|
|
model = smp.Unet(
|
2025-05-25 08:45:26 +00:00
|
|
|
|
encoder_name='resnet34', # 编码器名称
|
|
|
|
|
encoder_weights=None, # 不使用预训练权重
|
|
|
|
|
in_channels=3, # 输入通道数
|
|
|
|
|
classes=1 # 输出类别数
|
2025-05-16 13:17:46 +00:00
|
|
|
|
)
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 加载预训练模型权重
|
|
|
|
|
model.load_state_dict(torch.load(MODEL_PATH, map_location=device))
|
|
|
|
|
model.to(device) # 将模型移动到指定设备
|
|
|
|
|
model.eval() # 设置为评估模式
|
2025-05-16 13:17:46 +00:00
|
|
|
|
print("成功加载模型权重并切换到评估模式")
|
|
|
|
|
|
|
|
|
|
# 遍历测试集
|
|
|
|
|
for img, mask_true in test_loader:
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 图像保存为Tensor,batch=1
|
2025-05-16 13:17:46 +00:00
|
|
|
|
img = img.to(device)
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 获取图像ID
|
|
|
|
|
img_id = test_dataset.image_ids[test_loader.dataset.image_ids.index(test_dataset.image_ids[0])]
|
|
|
|
|
# 进行预测
|
2025-05-16 13:17:46 +00:00
|
|
|
|
with torch.no_grad():
|
|
|
|
|
output = model(img)
|
|
|
|
|
output_prob = torch.sigmoid(output).squeeze().cpu().numpy()
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 获取原始图像尺寸
|
2025-05-16 13:17:46 +00:00
|
|
|
|
img_info = next(item for item in test_dataset.coco['images'] if item['id']==test_dataset.image_ids[0])
|
|
|
|
|
orig_w, orig_h = img_info['width'], img_info['height']
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 调整输出尺寸到原始图像大小
|
2025-05-16 13:17:46 +00:00
|
|
|
|
output_prob = cv2.resize(output_prob, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
|
|
|
|
|
2025-05-25 08:45:26 +00:00
|
|
|
|
# 二值化处理
|
2025-05-16 13:17:46 +00:00
|
|
|
|
threshold = 0.5
|
|
|
|
|
output_mask = (output_prob > threshold).astype(np.uint8) * 255
|
|
|
|
|
|
|
|
|
|
# 保存结果
|
|
|
|
|
output_path = os.path.join(OUTPUT_DIR, f"{img_id}_mask.png")
|
|
|
|
|
cv2.imwrite(output_path, output_mask)
|
2025-05-25 08:45:26 +00:00
|
|
|
|
print(f"Saved mask for image {img_id} to {output_path}")
|