AI-exp-4/docs/datasets.md

120 lines
3.6 KiB
Markdown
Raw Normal View History

# NeRF和CNN实验数据集说明
## 1. NeRF数据集
NeRF算法需要多视角图像数据推荐使用以下公开数据集
### Blender合成数据集
```bash
# 安装gdown下载工具
pip install gdown
# 下载Lego场景示例
gdown https://drive.google.com/uc?id=18JxhpWD-4ZmuEK22MTzM9c5YZZjr3RC6
# 解压数据
tar -xf lego.tar -C ../data/nerf/
```
数据目录结构应如下:
```
data/nerf/
└── lego/
├── images/ # 包含所有训练图片
│ ├── r_00.png
│ ├── r_01.png
│ └── ...
├── poses_bounds.npy # 位姿和边界信息
├── transforms.json # 训练用配置文件
├── transforms_test.json # 测试专用配置文件
└── results/ # 结果存储目录
└── nerf/ # NeRF算法结果
├── pred_0.json # 预测结果
└── visualizations/ # 可视化图像
└── vis_0.png
```
### 测试数据准备
测试时请确保包含以下内容:
1. `images/`:包含测试用图像文件
2. `transforms_test.json`:必须包含以下字段:
```json
{
"intrinsic": [[...]], # 相机内参矩阵
"frames": [ # 图像帧信息数组
{
"file_path": "test_00.png", # 图片路径
"transform_matrix": [...] # 位姿变换矩阵
},
// ...其他帧数据...
]
}
```
### 模型输入说明
NeRF模型的输入是三维坐标(x, y, z),通过位置编码扩展到更高维度。当前实现:
1. 输入维度3 (x, y, z)
2. 位置编码使用10个频率的正弦/余弦函数进行编码输出维度为60
3. 网络结构:三个全连接层(60 -> 256 -> 256 -> 4)
4. 注意事项:确保测试数据生成的坐标维度与训练时一致
### 数据集验证
下载完成后请验证transforms.json文件内容确保包含以下必要字段
```json
{
"intrinsic": [[...]], # 相机内参矩阵
"frames": [ # 图像帧信息数组
{
"file_path": "r_00.png", # 图片路径
"transform_matrix": [...] # 位姿变换矩阵
},
// ...其他帧数据...
]
}
```
### 测试与可视化
测试完成后将生成以下内容:
1. 预测结果保存在`results/nerf/`目录下的JSON文件中
2. 可视化图像保存在`results/nerf/visualizations/`目录下,包含:
- 点云可视化xy投影
- RGB颜色分布
示例可视化输出:
```
data/nerf/
└── lego/
└── results/
└── nerf/
└── visualizations/
├── vis_0.png # 第一个测试样本的可视化
└── vis_1.png # 第二个测试样本的可视化
```
### 训练与测试脚本
```bash
# 训练NeRF模型可选参数
python nerf/train_nerf.py \
--data_path ../data/nerf/lego \
--batch_size 2 \
--num_epochs 20
# 测试NeRF模型可选参数
python nerf/test_nerf.py \
--data_path ../data/nerf/lego \
--batch_size 2
```
训练完成后会在`data/nerf/lego/checkpoint/`目录下生成模型文件,预测结果会保存在`data/nerf/lego/results/nerf/`目录。
## 2. CNN数据集
CNN图像识别推荐使用CIFAR-10数据集代码中已自动下载
```python
# 在cnn.py中会自动下载到data/cnn/目录
torchvision.datasets.CIFAR10(root='../data/cnn/', train=True, download=True)
```
## 数据预处理
```bash
# 图像尺寸统一(示例)
pip install opencv-python
python -c "import cv2; import os; [cv2.resize(cv2.imread(f), (256,256)) for f in os.listdir('../data/nerf/lego/images/') if f.endswith('.png')]"