AI-exp-4/nerf/train_nerf.py

153 lines
5.2 KiB
Python
Raw Permalink Normal View History

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import json
import numpy as np
import argparse
class NeRFDataset(Dataset):
def __init__(self, data_dir, split='train'):
self.data_dir = data_dir
self.split = split
# 检查必要文件是否存在
transforms_path = os.path.join(data_dir, 'transforms_train.json')
if not os.path.exists(transforms_path):
raise FileNotFoundError(
f"找不到必要文件: {transforms_path}\n"
"请确保:\n"
"1. 数据集已正确下载\n"
"2. 数据路径配置正确(当前路径:{data_dir}\n"
"3. 包含images目录和poses_bounds.npy文件"
)
with open(transforms_path, 'r') as f:
self.transforms = json.load(f)
# 验证文件格式
if 'frames' not in self.transforms:
raise ValueError(
f"transforms.json格式错误缺少必要字段'frames'\n"
"请确保:\n"
"1. 数据集文件完整\n"
"2. 使用正确的数据集版本\n"
"3. transforms.json包含frames数组和intrinsic参数"
)
self.image_paths = [os.path.join(data_dir, 'images', frame['file_path'])
for frame in self.transforms['frames']]
self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图像和位姿
image_path = self.image_paths[idx]
pose = self.poses[idx]
# 简单的数据预处理
# 生成随机坐标替代真实图像输入NeRF典型做法
height, width = 400, 400 # 降低分辨率以减少计算量
x = torch.rand(height * width, 3) # (x, y, z) 坐标
# 生成随机颜色作为目标值
y = torch.rand(height * width, 4) # (r, g, b, a) 目标值
return {'coords': x, 'targets': y}
class PositionalEncoding(nn.Module):
"""位置编码层,将输入坐标扩展到更高维度"""
def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10):
super(PositionalEncoding, self).__init__()
self.N_freqs = N_freqs
self.funcs = [torch.sin, torch.cos]
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs)
self.output_dim = input_dim * len(self.funcs) * N_freqs
def forward(self, x):
"""
对输入坐标进行位置编码
x: [N, input_dim] 输入坐标
returns: [N, output_dim] 编码后的特征
"""
out = []
for freq in self.freq_bands:
for func in self.funcs:
out.append(func(freq * x))
return torch.cat(out, dim=-1)
class NeRF(nn.Module):
def __init__(self, input_ch=60, output_ch=4):
super(NeRF, self).__init__()
self.encoding = PositionalEncoding(input_dim=3, N_freqs=10)
self.net = nn.Sequential(
nn.Linear(self.encoding.output_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, output_ch)
)
def forward(self, x):
# 确保输入类型一致
if x.dtype != torch.float32:
x = x.float()
# 应用位置编码
x = self.encoding(x)
return self.net(x)
if __name__ == '__main__':
# 添加参数解析
parser = argparse.ArgumentParser(description='NeRF训练脚本')
parser.add_argument('--data_path', type=str, default='../data/nerf/lego',
help='数据集路径 (default: ../data/nerf/lego)')
parser.add_argument('--batch_size', type=int, default=1,
help='批量大小 (default: 1)')
parser.add_argument('--num_epochs', type=int, default=10,
help='训练轮数 (default: 10)')
args = parser.parse_args()
# 数据集路径配置
try:
dataset = NeRFDataset(args.data_path)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
except Exception as e:
print(f"数据加载失败: {str(e)}")
print("建议解决方案:")
print("1. 验证数据集路径是否正确")
print("2. 确保数据集已正确下载并解压")
print("3. 检查transforms.json文件格式")
exit(1)
# 初始化模型
model = NeRF()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
# 训练循环
for epoch in range(args.num_epochs):
for i, data in enumerate(dataloader):
coords = data['coords']
targets = data['targets']
# 前向传播
outputs = model(coords)
loss = torch.mean((outputs - targets) ** 2)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{args.num_epochs}], Step [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}')
# 保存模型
os.makedirs(os.path.join(args.data_path, 'checkpoint'), exist_ok=True)
torch.save(model.state_dict(), os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth'))