AI-exp-4/nerf/train_nerf.py
fly6516 32b5c54ce7 test_nerf.py:
- 新增数据路径解析支持相对路径- 重构代码,提高可读性和维护性
- 添加数据加载和模型推理的错误处理
- 优化数据集类,增强数据路径的灵活性和健壮性- 更新数据加载方式,支持更灵活的路径配置
- 改进结果可视化,添加Gamma校正

train_nerf.py:
- 新增数据路径解析支持相对路径
- 重构代码,提高可读性和维护性
- 优化数据集类,增强数据路径的灵活性和健壮性
- 更新数据加载方式,支持更灵活的路径配置
- 修正参数解析,统一代码风格
2025-05-31 22:44:14 +08:00

153 lines
5.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import os
import json
import numpy as np
import argparse
class NeRFDataset(Dataset):
def __init__(self, data_dir, split='train'):
self.data_dir = data_dir
self.split = split
# 检查必要文件是否存在
transforms_path = os.path.join(data_dir, 'transforms_train.json')
if not os.path.exists(transforms_path):
raise FileNotFoundError(
f"找不到必要文件: {transforms_path}\n"
"请确保:\n"
"1. 数据集已正确下载\n"
"2. 数据路径配置正确(当前路径:{data_dir}\n"
"3. 包含images目录和poses_bounds.npy文件"
)
with open(transforms_path, 'r') as f:
self.transforms = json.load(f)
# 验证文件格式
if 'frames' not in self.transforms:
raise ValueError(
f"transforms.json格式错误缺少必要字段'frames'\n"
"请确保:\n"
"1. 数据集文件完整\n"
"2. 使用正确的数据集版本\n"
"3. transforms.json包含frames数组和intrinsic参数"
)
self.image_paths = [os.path.join(data_dir, 'images', frame['file_path'])
for frame in self.transforms['frames']]
self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']])
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
# 加载图像和位姿
image_path = self.image_paths[idx]
pose = self.poses[idx]
# 简单的数据预处理
# 生成随机坐标替代真实图像输入NeRF典型做法
height, width = 400, 400 # 降低分辨率以减少计算量
x = torch.rand(height * width, 3) # (x, y, z) 坐标
# 生成随机颜色作为目标值
y = torch.rand(height * width, 4) # (r, g, b, a) 目标值
return {'coords': x, 'targets': y}
class PositionalEncoding(nn.Module):
"""位置编码层,将输入坐标扩展到更高维度"""
def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10):
super(PositionalEncoding, self).__init__()
self.N_freqs = N_freqs
self.funcs = [torch.sin, torch.cos]
self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs)
self.output_dim = input_dim * len(self.funcs) * N_freqs
def forward(self, x):
"""
对输入坐标进行位置编码
x: [N, input_dim] 输入坐标
returns: [N, output_dim] 编码后的特征
"""
out = []
for freq in self.freq_bands:
for func in self.funcs:
out.append(func(freq * x))
return torch.cat(out, dim=-1)
class NeRF(nn.Module):
def __init__(self, input_ch=60, output_ch=4):
super(NeRF, self).__init__()
self.encoding = PositionalEncoding(input_dim=3, N_freqs=10)
self.net = nn.Sequential(
nn.Linear(self.encoding.output_dim, 256),
nn.ReLU(),
nn.Linear(256, 256),
nn.ReLU(),
nn.Linear(256, output_ch)
)
def forward(self, x):
# 确保输入类型一致
if x.dtype != torch.float32:
x = x.float()
# 应用位置编码
x = self.encoding(x)
return self.net(x)
if __name__ == '__main__':
# 添加参数解析
parser = argparse.ArgumentParser(description='NeRF训练脚本')
parser.add_argument('--data_path', type=str, default='../data/nerf/lego',
help='数据集路径 (default: ../data/nerf/lego)')
parser.add_argument('--batch_size', type=int, default=1,
help='批量大小 (default: 1)')
parser.add_argument('--num_epochs', type=int, default=10,
help='训练轮数 (default: 10)')
args = parser.parse_args()
# 数据集路径配置
try:
dataset = NeRFDataset(args.data_path)
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True)
except Exception as e:
print(f"数据加载失败: {str(e)}")
print("建议解决方案:")
print("1. 验证数据集路径是否正确")
print("2. 确保数据集已正确下载并解压")
print("3. 检查transforms.json文件格式")
exit(1)
# 初始化模型
model = NeRF()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
# 训练循环
for epoch in range(args.num_epochs):
for i, data in enumerate(dataloader):
coords = data['coords']
targets = data['targets']
# 前向传播
outputs = model(coords)
loss = torch.mean((outputs - targets) ** 2)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f'Epoch [{epoch + 1}/{args.num_epochs}], Step [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}')
# 保存模型
os.makedirs(os.path.join(args.data_path, 'checkpoint'), exist_ok=True)
torch.save(model.state_dict(), os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth'))