148 lines
5.4 KiB
Python
148 lines
5.4 KiB
Python
|
import torch
|
|||
|
import torch.nn as nn
|
|||
|
from torch.utils.data import Dataset, DataLoader
|
|||
|
import os
|
|||
|
import json
|
|||
|
import numpy as np
|
|||
|
import argparse
|
|||
|
|
|||
|
class NeRFDataset(Dataset):
|
|||
|
def __init__(self, data_dir, split='train'):
|
|||
|
self.data_dir = data_dir
|
|||
|
self.split = split
|
|||
|
|
|||
|
# 检查必要文件是否存在
|
|||
|
transforms_path = os.path.join(data_dir, 'transforms_train.json')
|
|||
|
if not os.path.exists(transforms_path):
|
|||
|
raise FileNotFoundError(
|
|||
|
f"找不到必要文件: {transforms_path}\n"
|
|||
|
"请确保:\n"
|
|||
|
"1. 数据集已正确下载\n"
|
|||
|
"2. 数据路径配置正确(当前路径:{data_dir})\n"
|
|||
|
"3. 包含images目录和poses_bounds.npy文件"
|
|||
|
)
|
|||
|
|
|||
|
with open(transforms_path, 'r') as f:
|
|||
|
self.transforms = json.load(f)
|
|||
|
|
|||
|
# 验证文件格式
|
|||
|
if 'frames' not in self.transforms:
|
|||
|
raise ValueError(
|
|||
|
f"transforms.json格式错误,缺少必要字段'frames'\n"
|
|||
|
"请确保:\n"
|
|||
|
"1. 数据集文件完整\n"
|
|||
|
"2. 使用正确的数据集版本\n"
|
|||
|
"3. transforms.json包含frames数组和intrinsic参数"
|
|||
|
)
|
|||
|
|
|||
|
self.image_paths = [os.path.join(data_dir, 'images', frame['file_path'])
|
|||
|
for frame in self.transforms['frames']]
|
|||
|
self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']])
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.image_paths)
|
|||
|
|
|||
|
def __getitem__(self, idx):
|
|||
|
# 加载图像和位姿
|
|||
|
image_path = self.image_paths[idx]
|
|||
|
pose = self.poses[idx]
|
|||
|
|
|||
|
# 简单的数据预处理
|
|||
|
# 生成随机坐标替代真实图像输入(NeRF典型做法)
|
|||
|
height, width = 400, 400 # 降低分辨率以减少计算量
|
|||
|
x = torch.rand(height * width, 3) # (x, y, z) 坐标
|
|||
|
|
|||
|
# 生成随机颜色作为目标值
|
|||
|
y = torch.rand(height * width, 4) # (r, g, b, a) 目标值
|
|||
|
|
|||
|
return {'coords': x, 'targets': y}
|
|||
|
|
|||
|
class PositionalEncoding(nn.Module):
|
|||
|
"""位置编码层,将输入坐标扩展到更高维度"""
|
|||
|
def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10):
|
|||
|
super(PositionalEncoding, self).__init__()
|
|||
|
self.N_freqs = N_freqs
|
|||
|
self.funcs = [torch.sin, torch.cos]
|
|||
|
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs)
|
|||
|
self.output_dim = input_dim * len(self.funcs) * N_freqs
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
"""
|
|||
|
对输入坐标进行位置编码
|
|||
|
x: [N, input_dim] 输入坐标
|
|||
|
returns: [N, output_dim] 编码后的特征
|
|||
|
"""
|
|||
|
out = []
|
|||
|
for freq in self.freq_bands:
|
|||
|
for func in self.funcs:
|
|||
|
out.append(func(freq * x))
|
|||
|
return torch.cat(out, dim=-1)
|
|||
|
|
|||
|
class NeRF(nn.Module):
|
|||
|
def __init__(self, input_ch=60, output_ch=4):
|
|||
|
super(NeRF, self).__init__()
|
|||
|
self.encoding = PositionalEncoding(input_dim=3, N_freqs=10)
|
|||
|
self.net = nn.Sequential(
|
|||
|
nn.Linear(self.encoding.output_dim, 256),
|
|||
|
nn.ReLU(),
|
|||
|
nn.Linear(256, 256),
|
|||
|
nn.ReLU(),
|
|||
|
nn.Linear(256, output_ch)
|
|||
|
)
|
|||
|
|
|||
|
def forward(self, x):
|
|||
|
# 确保输入类型一致
|
|||
|
if x.dtype != torch.float32:
|
|||
|
x = x.float()
|
|||
|
|
|||
|
# 应用位置编码
|
|||
|
x = self.encoding(x)
|
|||
|
return self.net(x)
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
# 添加参数解析
|
|||
|
parser = argparse.ArgumentParser(description='NeRF训练脚本')
|
|||
|
parser.add_argument('--data_path', type=str, default='../data/nerf/lego',
|
|||
|
help='数据集路径 (default: ../data/nerf/lego)')
|
|||
|
parser.add_argument('--batch_size', type=int, default=1,
|
|||
|
help='批量大小 (default: 1)')
|
|||
|
parser.add_argument('--num_epochs', type=int, default=10,
|
|||
|
help='训练轮数 (default: 10)')
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
# 数据集路径配置
|
|||
|
try:
|
|||
|
dataset = NeRFDataset(args.data_path)
|
|||
|
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
|
|||
|
except Exception as e:
|
|||
|
print(f"数据加载失败: {str(e)}")
|
|||
|
print("建议解决方案:")
|
|||
|
print("1. 验证数据集路径是否正确")
|
|||
|
print("2. 确保数据集已正确下载并解压")
|
|||
|
print("3. 检查transforms.json文件格式")
|
|||
|
exit(1)
|
|||
|
|
|||
|
# 初始化模型
|
|||
|
model = NeRF()
|
|||
|
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
|
|||
|
|
|||
|
# 训练循环
|
|||
|
for epoch in range(args.num_epochs):
|
|||
|
for i, data in enumerate(dataloader):
|
|||
|
coords = data['coords']
|
|||
|
targets = data['targets']
|
|||
|
|
|||
|
# 前向传播
|
|||
|
outputs = model(coords)
|
|||
|
loss = torch.mean((outputs - targets) ** 2)
|
|||
|
|
|||
|
# 反向传播
|
|||
|
optimizer.zero_grad()
|
|||
|
loss.backward()
|
|||
|
optimizer.step()
|
|||
|
|
|||
|
print(f'Epoch [{epoch+1}/{args.num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}')
|
|||
|
|
|||
|
# 保存模型
|
|||
|
os.makedirs(os.path.join(args.data_path, 'checkpoint'), exist_ok=True)
|
|||
|
torch.save(model.state_dict(), os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth'))
|