AI-exp-4/nerf/test_nerf.py

162 lines
6.0 KiB
Python
Raw Normal View History

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import json
import numpy as np
import argparse
import matplotlib.pyplot as plt # 新增可视化依赖
class PositionalEncoding(nn.Module):
"""位置编码层,将输入坐标扩展到更高维度"""
def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10):
super(PositionalEncoding, self).__init__()
self.N_freqs = N_freqs
self.funcs = [torch.sin, torch.cos]
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs)
self.output_dim = input_dim * len(self.funcs) * N_freqs
def forward(self, x):
"""
对输入坐标进行位置编码
x: [N, input_dim] 输入坐标
returns: [N, output_dim] 编码后的特征
"""
out = []
for freq in self.freq_bands:
for func in self.funcs:
out.append(func(freq * x))
return torch.cat(out, dim=-1)
class NeRF(nn.Module):
def __init__(self, input_ch=60, output_ch=4):
super(NeRF, self).__init__()
self.encoding = PositionalEncoding(input_dim=3, N_freqs=10)
self.net = nn.Sequential(
nn.Linear(self.encoding.output_dim, 256), # 使用编码后的维度
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, output_ch)
)
def forward(self, x):
# 确保输入类型一致
if x.dtype != torch.float32:
x = x.float()
# 应用位置编码
x = self.encoding(x)
return self.net(x)
class NeRFDataset(Dataset):
def __init__(self, data_dir, split='test'):
self.data_dir = data_dir
self.split = split
# 使用根目录下的transforms_test.json和images目录
self.transforms_path = os.path.join(data_dir, 'transforms_test.json')
self.images_dir = os.path.join(data_dir, 'images')
# 检查必要文件是否存在
if not os.path.exists(self.transforms_path):
raise FileNotFoundError(
f"找不到必要文件: {self.transforms_path}\n"
"请确保:\n"
f"1. 测试数据已正确下载\n"
f"2. 数据路径配置正确(当前路径:{data_dir}\n"
"3. 包含images目录和transforms_test.json文件"
)
with open(self.transforms_path, 'r') as f:
self.transforms = json.load(f)
# 验证文件格式
if 'frames' not in self.transforms:
raise ValueError(
f"transforms_test.json格式错误缺少必要字段'frames'\n"
"请确保:\n"
"1. 数据集文件完整\n"
"2. 使用正确的数据集版本\n"
"3. transforms_test.json包含frames数组和intrinsic参数"
)
self.image_paths = [os.path.join(self.images_dir, frame['file_path'])
for frame in self.transforms['frames']]
self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载位姿
pose = self.poses[idx]
# 生成随机输入坐标
num_samples = 64
x = torch.rand(num_samples, 3) # (x, y, z)
return {'coords': x, 'pose': torch.tensor(pose, dtype=torch.float32)}
if __name__ == '__main__':
# 添加参数解析
parser = argparse.ArgumentParser(description='NeRF测试脚本')
parser.add_argument('--data_path', type=str, default='../data/nerf/lego',
help='数据集路径 (default: ../data/nerf/lego)')
parser.add_argument('--batch_size', type=int, default=1,
help='批量大小 (default: 1)')
args = parser.parse_args()
# 数据集路径配置
dataset = NeRFDataset(args.data_path)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
# 加载模型
model = NeRF()
checkpoint_path = os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth')
if os.path.exists(checkpoint_path):
model.load_state_dict(torch.load(checkpoint_path))
else:
print("警告:未找到预训练模型,使用随机初始化进行测试")
model.eval()
# 测试循环
with torch.no_grad():
for i, data in enumerate(dataloader):
coords = data['coords']
poses = data['pose']
# 前向传播
outputs = model(coords)
# 保存预测结果到results/nerf目录
os.makedirs(os.path.join(args.data_path, 'results', 'nerf'), exist_ok=True)
pred_path = os.path.join(args.data_path, 'results', 'nerf', f'pred_{i}.json')
pred_result = {
"rgb": outputs[:, :3].tolist(),
"density": outputs[:, 3].tolist()
}
with open(pred_path, 'w') as f:
json.dump(pred_result, f)
# 可视化存储到results/nerf目录
os.makedirs(os.path.join(args.data_path, 'results', 'nerf', 'visualizations'), exist_ok=True)
vis_path = os.path.join(args.data_path, 'results', 'nerf', 'visualizations', f'vis_{i}.png')
plt.figure(figsize=(8, 4))
plt.subplot(121)
plt.imshow(coords[0].reshape(8, 8, 3).cpu().numpy())
plt.title('Input Coordinates')
plt.subplot(122)
plt.imshow(outputs[:, :3].reshape(-1, 4, 3).cpu().numpy())
plt.title('Predicted RGB')
plt.savefig(vis_path)
plt.close()
print(f'Test Sample {i+1}/{len(dataset)} completed and visualization saved to {vis_path}')