import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader import os import json import numpy as np import argparse class NeRFDataset(Dataset): def __init__(self, data_dir, split='train'): self.data_dir = data_dir self.split = split # 检查必要文件是否存在 transforms_path = os.path.join(data_dir, 'transforms_train.json') if not os.path.exists(transforms_path): raise FileNotFoundError( f"找不到必要文件: {transforms_path}\n" "请确保:\n" "1. 数据集已正确下载\n" "2. 数据路径配置正确(当前路径:{data_dir})\n" "3. 包含images目录和poses_bounds.npy文件" ) with open(transforms_path, 'r') as f: self.transforms = json.load(f) # 验证文件格式 if 'frames' not in self.transforms: raise ValueError( f"transforms.json格式错误,缺少必要字段'frames'\n" "请确保:\n" "1. 数据集文件完整\n" "2. 使用正确的数据集版本\n" "3. transforms.json包含frames数组和intrinsic参数" ) self.image_paths = [os.path.join(data_dir, 'images', frame['file_path']) for frame in self.transforms['frames']] self.poses = np.array([frame['transform_matrix'] for frame in self.transforms['frames']]) def __len__(self): return len(self.image_paths) def __getitem__(self, idx): # 加载图像和位姿 image_path = self.image_paths[idx] pose = self.poses[idx] # 简单的数据预处理 # 生成随机坐标替代真实图像输入(NeRF典型做法) height, width = 400, 400 # 降低分辨率以减少计算量 x = torch.rand(height * width, 3) # (x, y, z) 坐标 # 生成随机颜色作为目标值 y = torch.rand(height * width, 4) # (r, g, b, a) 目标值 return {'coords': x, 'targets': y} class PositionalEncoding(nn.Module): """位置编码层,将输入坐标扩展到更高维度""" def __init__(self, input_dim=3, max_freq_log2=9, N_freqs=10): super(PositionalEncoding, self).__init__() self.N_freqs = N_freqs self.funcs = [torch.sin, torch.cos] self.freq_bands = 2 ** torch.linspace(0, max_freq_log2, steps=N_freqs) self.output_dim = input_dim * len(self.funcs) * N_freqs def forward(self, x): """ 对输入坐标进行位置编码 x: [N, input_dim] 输入坐标 returns: [N, output_dim] 编码后的特征 """ out = [] for freq in self.freq_bands: for func in self.funcs: out.append(func(freq * x)) return torch.cat(out, dim=-1) class NeRF(nn.Module): def __init__(self, input_ch=60, output_ch=4): super(NeRF, self).__init__() self.encoding = PositionalEncoding(input_dim=3, N_freqs=10) self.net = nn.Sequential( nn.Linear(self.encoding.output_dim, 256), nn.ReLU(), nn.Linear(256, 256), nn.ReLU(), nn.Linear(256, output_ch) ) def forward(self, x): # 确保输入类型一致 if x.dtype != torch.float32: x = x.float() # 应用位置编码 x = self.encoding(x) return self.net(x) if __name__ == '__main__': # 添加参数解析 parser = argparse.ArgumentParser(description='NeRF训练脚本') parser.add_argument('--data_path', type=str, default='../data/nerf/lego', help='数据集路径 (default: ../data/nerf/lego)') parser.add_argument('--batch_size', type=int, default=1, help='批量大小 (default: 1)') parser.add_argument('--num_epochs', type=int, default=10, help='训练轮数 (default: 10)') args = parser.parse_args() # 数据集路径配置 try: dataset = NeRFDataset(args.data_path) dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True) except Exception as e: print(f"数据加载失败: {str(e)}") print("建议解决方案:") print("1. 验证数据集路径是否正确") print("2. 确保数据集已正确下载并解压") print("3. 检查transforms.json文件格式") exit(1) # 初始化模型 model = NeRF() optimizer = torch.optim.Adam(model.parameters(), lr=5e-4) # 训练循环 for epoch in range(args.num_epochs): for i, data in enumerate(dataloader): coords = data['coords'] targets = data['targets'] # 前向传播 outputs = model(coords) loss = torch.mean((outputs - targets) ** 2) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() print(f'Epoch [{epoch+1}/{args.num_epochs}], Step [{i+1}/{len(dataloader)}], Loss: {loss.item():.4f}') # 保存模型 os.makedirs(os.path.join(args.data_path, 'checkpoint'), exist_ok=True) torch.save(model.state_dict(), os.path.join(args.data_path, 'checkpoint', 'nerf_model.pth'))