2025-05-31 09:52:10 +00:00
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
import numpy as np
|
|
|
|
|
import argparse
|
|
|
|
|
import matplotlib.pyplot as plt # 新增可视化依赖
|
|
|
|
|
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
|
|
|
"""位置编码层,将输入坐标扩展到更高维度"""
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10):
|
|
|
|
|
super(PositionalEncoding, self).__init__()
|
|
|
|
|
self.N_freqs = N_freqs
|
|
|
|
|
self.funcs = [torch.sin, torch.cos]
|
|
|
|
|
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs)
|
|
|
|
|
self.output_dim = input_dim * len(self.funcs) * N_freqs
|
|
|
|
|
|
|
|
|
|
def forward(self, x):
|
|
|
|
|
"""
|
|
|
|
|
对输入坐标进行位置编码
|
|
|
|
|
x: [N, input_dim] 输入坐标
|
|
|
|
|
returns: [N, output_dim] 编码后的特征
|
|
|
|
|
"""
|
|
|
|
|
out = []
|
|
|
|
|
for freq in self.freq_bands:
|
|
|
|
|
for func in self.funcs:
|
|
|
|
|
out.append(func(freq * x))
|
|
|
|
|
return torch.cat(out, dim=-1)
|
|
|
|
|
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
class NeRF(nn.Module):
|
|
|
|
|
def __init__(self, input_ch=60, output_ch=4):
|
|
|
|
|
super(NeRF, self).__init__()
|
|
|
|
|
self.encoding = PositionalEncoding(input_dim=3, N_freqs=10)
|
|
|
|
|
self.net = nn.Sequential(
|
|
|
|
|
nn.Linear(self.encoding.output_dim, 256), # 使用编码后的维度
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.Linear(256, 256),
|
|
|
|
|
nn.ReLU(),
|
|
|
|
|
nn.Linear(256, output_ch)
|
|
|
|
|
)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
def forward(self, x):
|
|
|
|
|
# 确保输入类型一致
|
|
|
|
|
if x.dtype != torch.float32:
|
|
|
|
|
x = x.float()
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 应用位置编码
|
|
|
|
|
x = self.encoding(x)
|
|
|
|
|
return self.net(x)
|
|
|
|
|
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
class NeRFDataset(Dataset):
|
|
|
|
|
def __init__(self, data_dir, split='test'):
|
2025-05-31 14:44:14 +00:00
|
|
|
|
# 解析相对路径
|
|
|
|
|
if args.data_path.startswith('../'):
|
|
|
|
|
# 如果使用相对路径,基于当前脚本位置解析
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
|
self.data_dir = os.path.normpath(os.path.join(script_dir, args.data_path))
|
|
|
|
|
else:
|
|
|
|
|
# 使用绝对路径或基于项目根目录的路径
|
|
|
|
|
self.data_dir = os.path.abspath(args.data_path)
|
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
self.split = split
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 使用根目录下的transforms_test.json和images目录
|
2025-05-31 14:44:14 +00:00
|
|
|
|
self.transforms_path = os.path.join(self.data_dir, 'transforms_test.json')
|
|
|
|
|
self.images_dir = os.path.join(self.data_dir, 'images')
|
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 检查必要文件是否存在
|
|
|
|
|
if not os.path.exists(self.transforms_path):
|
|
|
|
|
raise FileNotFoundError(
|
|
|
|
|
f"找不到必要文件: {self.transforms_path}\n"
|
|
|
|
|
"请确保:\n"
|
|
|
|
|
f"1. 测试数据已正确下载\n"
|
2025-05-31 14:44:14 +00:00
|
|
|
|
f"2. 数据路径配置正确(当前路径:{self.data_dir})\n"
|
2025-05-31 09:52:10 +00:00
|
|
|
|
"3. 包含images目录和transforms_test.json文件"
|
|
|
|
|
)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
with open(self.transforms_path, 'r') as f:
|
|
|
|
|
self.transforms = json.load(f)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 验证文件格式
|
|
|
|
|
if 'frames' not in self.transforms:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"transforms_test.json格式错误,缺少必要字段'frames'\n"
|
|
|
|
|
"请确保:\n"
|
|
|
|
|
"1. 数据集文件完整\n"
|
|
|
|
|
"2. 使用正确的数据集版本\n"
|
|
|
|
|
"3. transforms_test.json包含frames数组和intrinsic参数"
|
|
|
|
|
)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
|
|
|
|
self.image_paths = [os.path.join(self.images_dir, frame['file_path'])
|
|
|
|
|
for frame in self.transforms['frames']]
|
2025-05-31 09:52:10 +00:00
|
|
|
|
self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']])
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
def __len__(self):
|
|
|
|
|
return len(self.image_paths)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
def __getitem__(self, idx):
|
|
|
|
|
# 加载位姿
|
|
|
|
|
pose = self.poses[idx]
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 生成随机输入坐标
|
|
|
|
|
num_samples = 64
|
|
|
|
|
x = torch.rand(num_samples, 3) # (x, y, z)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
return {'coords': x, 'pose': torch.tensor(pose, dtype=torch.float32)}
|
|
|
|
|
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
# 添加参数解析
|
|
|
|
|
parser = argparse.ArgumentParser(description='NeRF测试脚本')
|
2025-05-31 14:44:14 +00:00
|
|
|
|
parser.add_argument('--data_path', type=str, default='../data/nerf/lego',
|
|
|
|
|
help='数据集路径 (default: ../data/nerf/lego)')
|
|
|
|
|
parser.add_argument('--batch_size', type=int, default=1,
|
|
|
|
|
help='批量大小 (default: 1)')
|
2025-05-31 09:52:10 +00:00
|
|
|
|
args = parser.parse_args()
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 数据集路径配置
|
|
|
|
|
dataset = NeRFDataset(args.data_path)
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 加载模型
|
|
|
|
|
model = NeRF()
|
|
|
|
|
checkpoint_path = os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth')
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
if os.path.exists(checkpoint_path):
|
|
|
|
|
model.load_state_dict(torch.load(checkpoint_path))
|
|
|
|
|
else:
|
|
|
|
|
print("警告:未找到预训练模型,使用随机初始化进行测试")
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
model.eval()
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 测试循环
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
for i, data in enumerate(dataloader):
|
|
|
|
|
coords = data['coords']
|
|
|
|
|
poses = data['pose']
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 前向传播
|
|
|
|
|
outputs = model(coords)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
# 可视化存储到results/nerf目录
|
|
|
|
|
os.makedirs(os.path.join(args.data_path, 'results', 'nerf', 'visualizations'), exist_ok=True)
|
|
|
|
|
vis_path = os.path.join(args.data_path, 'results', 'nerf', 'visualizations', f'vis_{i}.png')
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
|
|
|
|
# 应用sigmoid将输出限制在[0,1]范围
|
|
|
|
|
rgb_output = torch.sigmoid(outputs[:, :3])
|
|
|
|
|
|
|
|
|
|
# 添加Gamma校正(通常为2.2)
|
|
|
|
|
gamma = 2.2
|
|
|
|
|
rgb_output = torch.pow(rgb_output, 1/gamma)
|
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
plt.figure(figsize=(8, 4))
|
|
|
|
|
plt.subplot(121)
|
|
|
|
|
plt.imshow(coords[0].reshape(8, 8, 3).cpu().numpy())
|
|
|
|
|
plt.title('Input Coordinates')
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
plt.subplot(122)
|
2025-05-31 14:44:14 +00:00
|
|
|
|
plt.imshow(rgb_output.reshape(-1, 4, 3).cpu().numpy())
|
2025-05-31 09:52:10 +00:00
|
|
|
|
plt.title('Predicted RGB')
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
2025-05-31 09:52:10 +00:00
|
|
|
|
plt.savefig(vis_path)
|
|
|
|
|
plt.close()
|
2025-05-31 14:44:14 +00:00
|
|
|
|
|
|
|
|
|
print(f'Test Sample {i + 1}/{len(dataset)} completed and visualization saved to {vis_path}')
|