import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import os import json import numpy as np import argparse import matplotlib.pyplot as plt # 新增可视化依赖 class PositionalEncoding(nn.Module): """位置编码层,将输入坐标扩展到更高维度""" def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10): super(PositionalEncoding, self).__init__() self.N_freqs = N_freqs self.funcs = [torch.sin, torch.cos] self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs) self.output_dim = input_dim * len(self.funcs) * N_freqs def forward(self, x): """ 对输入坐标进行位置编码 x: [N, input_dim] 输入坐标 returns: [N, output_dim] 编码后的特征 """ out = [] for freq in self.freq_bands: for func in self.funcs: out.append(func(freq * x)) return torch.cat(out, dim=-1) class NeRF(nn.Module): def __init__(self, input_ch=60, output_ch=4): super(NeRF, self).__init__() self.encoding = PositionalEncoding(input_dim=3, N_freqs=10) self.net = nn.Sequential( nn.Linear(self.encoding.output_dim, 256), # 使用编码后的维度 nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, output_ch) ) def forward(self, x): # 确保输入类型一致 if x.dtype != torch.float32: x = x.float() # 应用位置编码 x = self.encoding(x) return self.net(x) class NeRFDataset(Dataset): def __init__(self, data_dir, split='test'): self.data_dir = data_dir self.split = split # 使用根目录下的transforms_test.json和images目录 self.transforms_path = os.path.join(data_dir, 'transforms_test.json') self.images_dir = os.path.join(data_dir, 'images') # 检查必要文件是否存在 if not os.path.exists(self.transforms_path): raise FileNotFoundError( f"找不到必要文件: {self.transforms_path}\n" "请确保:\n" f"1. 测试数据已正确下载\n" f"2. 数据路径配置正确(当前路径:{data_dir})\n" "3. 包含images目录和transforms_test.json文件" ) with open(self.transforms_path, 'r') as f: self.transforms = json.load(f) # 验证文件格式 if 'frames' not in self.transforms: raise ValueError( f"transforms_test.json格式错误,缺少必要字段'frames'\n" "请确保:\n" "1. 数据集文件完整\n" "2. 使用正确的数据集版本\n" "3. transforms_test.json包含frames数组和intrinsic参数" ) self.image_paths = [os.path.join(self.images_dir, frame['file_path']) for frame in self.transforms['frames']] self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # 加载位姿 pose = self.poses[idx] # 生成随机输入坐标 num_samples = 64 x = torch.rand(num_samples, 3) # (x, y, z) return {'coords': x, 'pose': torch.tensor(pose, dtype=torch.float32)} if __name__ == '__main__': # 添加参数解析 parser = argparse.ArgumentParser(description='NeRF测试脚本') parser.add_argument('--data_path', type=str, default='../data/nerf/lego', help='数据集路径 (default: ../data/nerf/lego)') parser.add_argument('--batch_size', type=int, default=1, help='批量大小 (default: 1)') args = parser.parse_args() # 数据集路径配置 dataset = NeRFDataset(args.data_path) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False) # 加载模型 model = NeRF() checkpoint_path = os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth') if os.path.exists(checkpoint_path): model.load_state_dict(torch.load(checkpoint_path)) else: print("警告:未找到预训练模型,使用随机初始化进行测试") model.eval() # 测试循环 with torch.no_grad(): for i, data in enumerate(dataloader): coords = data['coords'] poses = data['pose'] # 前向传播 outputs = model(coords) # 保存预测结果到results/nerf目录 os.makedirs(os.path.join(args.data_path, 'results', 'nerf'), exist_ok=True) pred_path = os.path.join(args.data_path, 'results', 'nerf', f'pred_{i}.json') pred_result = { "rgb": outputs[:, :3].tolist(), "density": outputs[:, 3].tolist() } with open(pred_path, 'w') as f: json.dump(pred_result, f) # 可视化存储到results/nerf目录 os.makedirs(os.path.join(args.data_path, 'results', 'nerf', 'visualizations'), exist_ok=True) vis_path = os.path.join(args.data_path, 'results', 'nerf', 'visualizations', f'vis_{i}.png') plt.figure(figsize=(8, 4)) plt.subplot(121) plt.imshow(coords[0].reshape(8, 8, 3).cpu().numpy()) plt.title('Input Coordinates') plt.subplot(122) plt.imshow(outputs[:, :3].reshape(-1, 4, 3).cpu().numpy()) plt.title('Predicted RGB') plt.savefig(vis_path) plt.close() print(f'Test Sample {i+1}/{len(dataset)} completed and visualization saved to {vis_path}')