162 lines
6.0 KiB
Python
162 lines
6.0 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import Dataset, DataLoader
|
||
import os
|
||
import json
|
||
import numpy as np
|
||
import argparse
|
||
import matplotlib.pyplot as plt # 新增可视化依赖
|
||
|
||
class PositionalEncoding(nn.Module):
|
||
"""位置编码层,将输入坐标扩展到更高维度"""
|
||
def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10):
|
||
super(PositionalEncoding, self).__init__()
|
||
self.N_freqs = N_freqs
|
||
self.funcs = [torch.sin, torch.cos]
|
||
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs)
|
||
self.output_dim = input_dim * len(self.funcs) * N_freqs
|
||
|
||
def forward(self, x):
|
||
"""
|
||
对输入坐标进行位置编码
|
||
x: [N, input_dim] 输入坐标
|
||
returns: [N, output_dim] 编码后的特征
|
||
"""
|
||
out = []
|
||
for freq in self.freq_bands:
|
||
for func in self.funcs:
|
||
out.append(func(freq * x))
|
||
return torch.cat(out, dim=-1)
|
||
|
||
class NeRF(nn.Module):
|
||
def __init__(self, input_ch=60, output_ch=4):
|
||
super(NeRF, self).__init__()
|
||
self.encoding = PositionalEncoding(input_dim=3, N_freqs=10)
|
||
self.net = nn.Sequential(
|
||
nn.Linear(self.encoding.output_dim, 256), # 使用编码后的维度
|
||
nn.ReLU(),
|
||
nn.Linear(256, 256),
|
||
nn.ReLU(),
|
||
nn.Linear(256, output_ch)
|
||
)
|
||
|
||
def forward(self, x):
|
||
# 确保输入类型一致
|
||
if x.dtype != torch.float32:
|
||
x = x.float()
|
||
|
||
# 应用位置编码
|
||
x = self.encoding(x)
|
||
return self.net(x)
|
||
|
||
class NeRFDataset(Dataset):
|
||
def __init__(self, data_dir, split='test'):
|
||
self.data_dir = data_dir
|
||
self.split = split
|
||
|
||
# 使用根目录下的transforms_test.json和images目录
|
||
self.transforms_path = os.path.join(data_dir, 'transforms_test.json')
|
||
self.images_dir = os.path.join(data_dir, 'images')
|
||
|
||
# 检查必要文件是否存在
|
||
if not os.path.exists(self.transforms_path):
|
||
raise FileNotFoundError(
|
||
f"找不到必要文件: {self.transforms_path}\n"
|
||
"请确保:\n"
|
||
f"1. 测试数据已正确下载\n"
|
||
f"2. 数据路径配置正确(当前路径:{data_dir})\n"
|
||
"3. 包含images目录和transforms_test.json文件"
|
||
)
|
||
|
||
with open(self.transforms_path, 'r') as f:
|
||
self.transforms = json.load(f)
|
||
|
||
# 验证文件格式
|
||
if 'frames' not in self.transforms:
|
||
raise ValueError(
|
||
f"transforms_test.json格式错误,缺少必要字段'frames'\n"
|
||
"请确保:\n"
|
||
"1. 数据集文件完整\n"
|
||
"2. 使用正确的数据集版本\n"
|
||
"3. transforms_test.json包含frames数组和intrinsic参数"
|
||
)
|
||
|
||
self.image_paths = [os.path.join(self.images_dir, frame['file_path'])
|
||
for frame in self.transforms['frames']]
|
||
self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']])
|
||
|
||
def __len__(self):
|
||
return len(self.image_paths)
|
||
|
||
def __getitem__(self, idx):
|
||
# 加载位姿
|
||
pose = self.poses[idx]
|
||
|
||
# 生成随机输入坐标
|
||
num_samples = 64
|
||
x = torch.rand(num_samples, 3) # (x, y, z)
|
||
|
||
return {'coords': x, 'pose': torch.tensor(pose, dtype=torch.float32)}
|
||
|
||
if __name__ == '__main__':
|
||
# 添加参数解析
|
||
parser = argparse.ArgumentParser(description='NeRF测试脚本')
|
||
parser.add_argument('--data_path', type=str, default='../data/nerf/lego',
|
||
help='数据集路径 (default: ../data/nerf/lego)')
|
||
parser.add_argument('--batch_size', type=int, default=1,
|
||
help='批量大小 (default: 1)')
|
||
args = parser.parse_args()
|
||
|
||
# 数据集路径配置
|
||
dataset = NeRFDataset(args.data_path)
|
||
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
|
||
|
||
# 加载模型
|
||
model = NeRF()
|
||
checkpoint_path = os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth')
|
||
|
||
if os.path.exists(checkpoint_path):
|
||
model.load_state_dict(torch.load(checkpoint_path))
|
||
else:
|
||
print("警告:未找到预训练模型,使用随机初始化进行测试")
|
||
|
||
model.eval()
|
||
|
||
# 测试循环
|
||
with torch.no_grad():
|
||
for i, data in enumerate(dataloader):
|
||
coords = data['coords']
|
||
poses = data['pose']
|
||
|
||
# 前向传播
|
||
outputs = model(coords)
|
||
|
||
# 保存预测结果到results/nerf目录
|
||
os.makedirs(os.path.join(args.data_path, 'results', 'nerf'), exist_ok=True)
|
||
pred_path = os.path.join(args.data_path, 'results', 'nerf', f'pred_{i}.json')
|
||
|
||
pred_result = {
|
||
"rgb": outputs[:, :3].tolist(),
|
||
"density": outputs[:, 3].tolist()
|
||
}
|
||
|
||
with open(pred_path, 'w') as f:
|
||
json.dump(pred_result, f)
|
||
|
||
# 可视化存储到results/nerf目录
|
||
os.makedirs(os.path.join(args.data_path, 'results', 'nerf', 'visualizations'), exist_ok=True)
|
||
vis_path = os.path.join(args.data_path, 'results', 'nerf', 'visualizations', f'vis_{i}.png')
|
||
|
||
plt.figure(figsize=(8, 4))
|
||
plt.subplot(121)
|
||
plt.imshow(coords[0].reshape(8, 8, 3).cpu().numpy())
|
||
plt.title('Input Coordinates')
|
||
|
||
plt.subplot(122)
|
||
plt.imshow(outputs[:, :3].reshape(-1, 4, 3).cpu().numpy())
|
||
plt.title('Predicted RGB')
|
||
|
||
plt.savefig(vis_path)
|
||
plt.close()
|
||
|
||
print(f'Test Sample {i+1}/{len(dataset)} completed and visualization saved to {vis_path}') |