import os import torch from torch.utils.data import DataLoader from torchvision import transforms from PIL import Image import cv2 import numpy as np # 导入训练好的分割数据集类和模型定义 from unet_coco_segmentation import CocoSegDataset import segmentation_models_pytorch as smp # 如使用smp模型 # 配置 TEST_DIR = '../data/test' # 测试集图像目录 TEST_ANN = '../data/test/_annotations.coco.json' # 测试集COCO注释文件 MODEL_PATH = 'unet_coco_segmentation.pth' # 预训练模型权重 OUTPUT_DIR = 'output/unet_results' # 建立输出目录 os.makedirs(OUTPUT_DIR, exist_ok=True) # 设备配置 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 数据预处理(与训练时保持一致) transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=(0.485,0.456,0.406), std=(0.229,0.224,0.225)), ]) # 仅需同样的变换,不需要mask变换 # 加载测试集 from unet_coco_segmentation import CocoSegDataset test_dataset = CocoSegDataset( root_dir=TEST_DIR, annotation_file=TEST_ANN, transforms=None, # 在Dataset里单独处理 mask_transforms=None ) test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False) # 初始化模型(同训练时) model = smp.Unet( encoder_name='resnet34', encoder_weights=None, in_channels=3, classes=1 ) model.load_state_dict(torch.load(MODEL_PATH, map_location=device)) # 加载权重 model.to(device) model.eval() print("成功加载模型权重并切换到评估模式") # 遍历测试集 for img, mask_true in test_loader: # img 保存为Tensor batch=1 img = img.to(device) img_id = test_dataset.image_ids[test_loader.dataset.image_ids.index(test_dataset.image_ids[0])] # 这里获取ID # 预测 with torch.no_grad(): output = model(img) output_prob = torch.sigmoid(output).squeeze().cpu().numpy() # 恢复到原始尺寸 # 获取原图尺寸 # 重新加载原始图像获取尺寸 img_info = next(item for item in test_dataset.coco['images'] if item['id']==test_dataset.image_ids[0]) orig_w, orig_h = img_info['width'], img_info['height'] output_prob = cv2.resize(output_prob, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR) # 二值化 threshold = 0.5 output_mask = (output_prob > threshold).astype(np.uint8) * 255 # 保存结果 output_path = os.path.join(OUTPUT_DIR, f"{img_id}_mask.png") cv2.imwrite(output_path, output_mask) print(f"Saved mask for image {img_id} to {output_path}")