- 新增数据路径解析支持相对路径- 重构代码,提高可读性和维护性 - 添加数据加载和模型推理的错误处理 - 优化数据集类,增强数据路径的灵活性和健壮性- 更新数据加载方式,支持更灵活的路径配置 - 改进结果可视化,添加Gamma校正 train_nerf.py: - 新增数据路径解析支持相对路径 - 重构代码,提高可读性和维护性 - 优化数据集类,增强数据路径的灵活性和健壮性 - 更新数据加载方式,支持更灵活的路径配置 - 修正参数解析,统一代码风格
170 lines
5.9 KiB
Python
170 lines
5.9 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from torch.utils.data import Dataset, DataLoader
|
||
import os
|
||
import json
|
||
import numpy as np
|
||
import argparse
|
||
import matplotlib.pyplot as plt # 新增可视化依赖
|
||
|
||
|
||
class PositionalEncoding(nn.Module):
|
||
"""位置编码层,将输入坐标扩展到更高维度"""
|
||
|
||
def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10):
|
||
super(PositionalEncoding, self).__init__()
|
||
self.N_freqs = N_freqs
|
||
self.funcs = [torch.sin, torch.cos]
|
||
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs)
|
||
self.output_dim = input_dim * len(self.funcs) * N_freqs
|
||
|
||
def forward(self, x):
|
||
"""
|
||
对输入坐标进行位置编码
|
||
x: [N, input_dim] 输入坐标
|
||
returns: [N, output_dim] 编码后的特征
|
||
"""
|
||
out = []
|
||
for freq in self.freq_bands:
|
||
for func in self.funcs:
|
||
out.append(func(freq * x))
|
||
return torch.cat(out, dim=-1)
|
||
|
||
|
||
class NeRF(nn.Module):
|
||
def __init__(self, input_ch=60, output_ch=4):
|
||
super(NeRF, self).__init__()
|
||
self.encoding = PositionalEncoding(input_dim=3, N_freqs=10)
|
||
self.net = nn.Sequential(
|
||
nn.Linear(self.encoding.output_dim, 256), # 使用编码后的维度
|
||
nn.ReLU(),
|
||
nn.Linear(256, 256),
|
||
nn.ReLU(),
|
||
nn.Linear(256, output_ch)
|
||
)
|
||
|
||
def forward(self, x):
|
||
# 确保输入类型一致
|
||
if x.dtype != torch.float32:
|
||
x = x.float()
|
||
|
||
# 应用位置编码
|
||
x = self.encoding(x)
|
||
return self.net(x)
|
||
|
||
|
||
class NeRFDataset(Dataset):
|
||
def __init__(self, data_dir, split='test'):
|
||
# 解析相对路径
|
||
if args.data_path.startswith('../'):
|
||
# 如果使用相对路径,基于当前脚本位置解析
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
self.data_dir = os.path.normpath(os.path.join(script_dir, args.data_path))
|
||
else:
|
||
# 使用绝对路径或基于项目根目录的路径
|
||
self.data_dir = os.path.abspath(args.data_path)
|
||
|
||
self.split = split
|
||
|
||
# 使用根目录下的transforms_test.json和images目录
|
||
self.transforms_path = os.path.join(self.data_dir, 'transforms_test.json')
|
||
self.images_dir = os.path.join(self.data_dir, 'images')
|
||
|
||
# 检查必要文件是否存在
|
||
if not os.path.exists(self.transforms_path):
|
||
raise FileNotFoundError(
|
||
f"找不到必要文件: {self.transforms_path}\n"
|
||
"请确保:\n"
|
||
f"1. 测试数据已正确下载\n"
|
||
f"2. 数据路径配置正确(当前路径:{self.data_dir})\n"
|
||
"3. 包含images目录和transforms_test.json文件"
|
||
)
|
||
|
||
with open(self.transforms_path, 'r') as f:
|
||
self.transforms = json.load(f)
|
||
|
||
# 验证文件格式
|
||
if 'frames' not in self.transforms:
|
||
raise ValueError(
|
||
f"transforms_test.json格式错误,缺少必要字段'frames'\n"
|
||
"请确保:\n"
|
||
"1. 数据集文件完整\n"
|
||
"2. 使用正确的数据集版本\n"
|
||
"3. transforms_test.json包含frames数组和intrinsic参数"
|
||
)
|
||
|
||
self.image_paths = [os.path.join(self.images_dir, frame['file_path'])
|
||
for frame in self.transforms['frames']]
|
||
self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']])
|
||
|
||
def __len__(self):
|
||
return len(self.image_paths)
|
||
|
||
def __getitem__(self, idx):
|
||
# 加载位姿
|
||
pose = self.poses[idx]
|
||
|
||
# 生成随机输入坐标
|
||
num_samples = 64
|
||
x = torch.rand(num_samples, 3) # (x, y, z)
|
||
|
||
return {'coords': x, 'pose': torch.tensor(pose, dtype=torch.float32)}
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 添加参数解析
|
||
parser = argparse.ArgumentParser(description='NeRF测试脚本')
|
||
parser.add_argument('--data_path', type=str, default='../data/nerf/lego',
|
||
help='数据集路径 (default: ../data/nerf/lego)')
|
||
parser.add_argument('--batch_size', type=int, default=1,
|
||
help='批量大小 (default: 1)')
|
||
args = parser.parse_args()
|
||
|
||
# 数据集路径配置
|
||
dataset = NeRFDataset(args.data_path)
|
||
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False)
|
||
|
||
# 加载模型
|
||
model = NeRF()
|
||
checkpoint_path = os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth')
|
||
|
||
if os.path.exists(checkpoint_path):
|
||
model.load_state_dict(torch.load(checkpoint_path))
|
||
else:
|
||
print("警告:未找到预训练模型,使用随机初始化进行测试")
|
||
|
||
model.eval()
|
||
|
||
# 测试循环
|
||
with torch.no_grad():
|
||
for i, data in enumerate(dataloader):
|
||
coords = data['coords']
|
||
poses = data['pose']
|
||
|
||
# 前向传播
|
||
outputs = model(coords)
|
||
|
||
# 可视化存储到results/nerf目录
|
||
os.makedirs(os.path.join(args.data_path, 'results', 'nerf', 'visualizations'), exist_ok=True)
|
||
vis_path = os.path.join(args.data_path, 'results', 'nerf', 'visualizations', f'vis_{i}.png')
|
||
|
||
# 应用sigmoid将输出限制在[0,1]范围
|
||
rgb_output = torch.sigmoid(outputs[:, :3])
|
||
|
||
# 添加Gamma校正(通常为2.2)
|
||
gamma = 2.2
|
||
rgb_output = torch.pow(rgb_output, 1/gamma)
|
||
|
||
plt.figure(figsize=(8, 4))
|
||
plt.subplot(121)
|
||
plt.imshow(coords[0].reshape(8, 8, 3).cpu().numpy())
|
||
plt.title('Input Coordinates')
|
||
|
||
plt.subplot(122)
|
||
plt.imshow(rgb_output.reshape(-1, 4, 3).cpu().numpy())
|
||
plt.title('Predicted RGB')
|
||
|
||
plt.savefig(vis_path)
|
||
plt.close()
|
||
|
||
print(f'Test Sample {i + 1}/{len(dataset)} completed and visualization saved to {vis_path}') |