Skip to content

Commit

Permalink
correct data split for hypernerf & llff datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jun 30, 2022
1 parent 02e4381 commit f08a0c6
Show file tree
Hide file tree
Showing 8 changed files with 142 additions and 17 deletions.
4 changes: 2 additions & 2 deletions dnerf/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,9 +206,9 @@ def __init__(self, opt, device, type='train', downscale=1, n_test=10):
self.times = []

# assume frames are already sorted by time!
for f in tqdm.tqdm(frames, desc=f'Loading {type} data:'):
for f in tqdm.tqdm(frames, desc=f'Loading {type} data'):
f_path = os.path.join(self.root_path, f['file_path'])
if self.mode == 'blender' and f_path[-4:] != '.png':
if self.mode == 'blender' and '.' not in f_path:
f_path += '.png' # so silly...

# there are non-exist paths in fox...
Expand Down
54 changes: 53 additions & 1 deletion loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

def mape_loss(pred, target, reduction='mean'):
# pred, target: [B, 1], torch tenspr
difference = (pred - target).abs()
Expand All @@ -21,4 +23,54 @@ def huber_loss(pred, target, delta=0.1, reduction='mean'):
if reduction == 'mean':
loss = loss.mean()

return loss
return loss


# ref: https://github.com/sunset1995/torch_efficient_distloss/blob/main/torch_efficient_distloss/eff_distloss.py
class EffDistLoss(torch.autograd.Function):
@staticmethod
def forward(ctx, w, m, interval):
'''
Efficient O(N) realization of distortion loss.
There are B rays each with N sampled points.
w: Float tensor in shape [B,N]. Volume rendering weights of each point.
m: Float tensor in shape [B,N]. Midpoint distance to camera of each point.
interval: Scalar or float tensor in shape [B,N]. The query interval of each point.
'''
n_rays = np.prod(w.shape[:-1])
wm = (w * m)
w_cumsum = w.cumsum(dim=-1)
wm_cumsum = wm.cumsum(dim=-1)

w_total = w_cumsum[..., [-1]]
wm_total = wm_cumsum[..., [-1]]
w_prefix = torch.cat([torch.zeros_like(w_total), w_cumsum[..., :-1]], dim=-1)
wm_prefix = torch.cat([torch.zeros_like(wm_total), wm_cumsum[..., :-1]], dim=-1)
loss_uni = (1/3) * interval * w.pow(2)
loss_bi = 2 * w * (m * w_prefix - wm_prefix)
if torch.is_tensor(interval):
ctx.save_for_backward(w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval)
ctx.interval = None
else:
ctx.save_for_backward(w, m, wm, w_prefix, w_total, wm_prefix, wm_total)
ctx.interval = interval
ctx.n_rays = n_rays
return (loss_bi.sum() + loss_uni.sum()) / n_rays

@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_back):
interval = ctx.interval
n_rays = ctx.n_rays
if interval is None:
w, m, wm, w_prefix, w_total, wm_prefix, wm_total, interval = ctx.saved_tensors
else:
w, m, wm, w_prefix, w_total, wm_prefix, wm_total = ctx.saved_tensors
grad_uni = (1/3) * interval * 2 * w
w_suffix = w_total - (w_prefix + w)
wm_suffix = wm_total - (wm_prefix + wm)
grad_bi = 2 * (m * (w_prefix - w_suffix) + (wm_suffix - wm_prefix))
grad = grad_back * (grad_bi + grad_uni) / n_rays
return grad, None, None, None

eff_distloss = EffDistLoss.apply
4 changes: 2 additions & 2 deletions main_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
parser.add_argument('--seed', type=int, default=0)

### training options
parser.add_argument('--iters', type=int, default=40000, help="training iters")
parser.add_argument('--iters', type=int, default=30000, help="training iters")
parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
parser.add_argument('--ckpt', type=str, default='latest')
parser.add_argument('--num_rays', type=int, default=4096, help="num rays sampled per image for each training step")
Expand Down Expand Up @@ -123,7 +123,7 @@

optimizer = lambda model: torch.optim.Adam(model.get_params(opt.lr), betas=(0.9, 0.99), eps=1e-15)

train_loader = NeRFDataset(opt, device=device, type='trainval').dataloader()
train_loader = NeRFDataset(opt, device=device, type='train').dataloader()

# decay to 0.1 * init_lr at last iter step
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))
Expand Down
4 changes: 2 additions & 2 deletions nerf/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ def __init__(self, opt, device, type='train', downscale=1, n_test=10):

self.poses = []
self.images = []
for f in tqdm.tqdm(frames, desc=f'Loading {type} data:'):
for f in tqdm.tqdm(frames, desc=f'Loading {type} data'):
f_path = os.path.join(self.root_path, f['file_path'])
if self.mode == 'blender' and f_path[-4:] != '.png':
if self.mode == 'blender' and '.' not in f_path:
f_path += '.png' # so silly...

# there are non-exist paths in fox...
Expand Down
19 changes: 19 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ For GPUs with lower architecture, `--tcnn` can still be used, but the speed will
We use the same data format as instant-ngp, e.g., [armadillo](https://github.com/NVlabs/instant-ngp/blob/master/data/sdf/armadillo.obj) and [fox](https://github.com/NVlabs/instant-ngp/tree/master/data/nerf/fox).
Please download and put them under `./data`.

We also support self-captured dataset and converting other formats (e.g., LLFF, Tanks&Temples, Mip-NeRF 360) to the nerf-compatible format, with details in the following code block.

<details>
<summary> Supported datasets </summary>

* [nerf_synthetic](https://drive.google.com/drive/folders/128yBriW1IG_3NJ5Rp7APSTZsJqdJdfc1)

* [Tanks&Temples](https://dl.fbaipublicfiles.com/nsvf/dataset/TanksAndTemple.zip): [[conversion script]](./scripts/tanks2nerf.py)

* [LLFF](https://drive.google.com/drive/folders/14boI-o5hGO9srnWaaogTU5_ji7wkX2S7): [[conversion script]](./scripts/llff2nerf.py)

* [Mip-NeRF 360](http://storage.googleapis.com/gresearch/refraw360/360_v2.zip): [[conversion script]](./scripts/llff2nerf.py)

* (dynamic) [D-NeRF](https://www.dropbox.com/s/0bf6fl0ye2vz3vr/data.zip?dl=0)

* (dynamic) [Hyper-NeRF](https://github.com/google/hypernerf/releases/tag/v0.1): [[conversion script]](./scripts/hyper2nerf.py)

</details>

First time running will take some time to compile the CUDA extensions.

```bash
Expand Down
11 changes: 10 additions & 1 deletion scripts/hyper2nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,15 @@ def rotmat(a, b):
'frames': train_frames,
}
transforms_val = {
'w': W,
'h': H,
'fl_x': fl,
'fl_y': fl,
'cx': cx,
'cy': cy,
'frames': val_frames[::10], # only use 1/10 frames for val
}
transforms_test = {
'w': W,
'h': H,
'fl_x': fl,
Expand All @@ -235,5 +244,5 @@ def rotmat(a, b):
output_path = os.path.join(opt.path, 'transforms_test.json')
print(f'[INFO] write to {output_path}')
with open(output_path, 'w') as f:
json.dump(transforms_val, f, indent=2)
json.dump(transforms_test, f, indent=2)

61 changes: 53 additions & 8 deletions scripts/llff2nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def visualize_poses(poses, size=0.1):
parser.add_argument('path', type=str, help="root directory to the LLFF dataset (contains images/ and pose_bounds.npy)")
parser.add_argument('--images', type=str, default='images', help="images folder (do not include full path, e.g., just use `images_4`)")
parser.add_argument('--downscale', type=float, default=1, help="image size down scale, e.g., 4")
parser.add_argument('--hold', type=int, default=8, help="hold out for validation every $ images")

opt = parser.parse_args()
print(f'[INFO] process {opt.path}')
Expand Down Expand Up @@ -99,7 +100,7 @@ def visualize_poses(poses, size=0.1):

# visualize_poses(poses)

# the following stuff are from colmap2nerf... [flower fails, the camera must be in-ward focusing...]
# the following stuff are from colmap2nerf... [flower fails, the camera must be in-ward...]
poses[:, 0:3, 1] *= -1
poses[:, 0:3, 2] *= -1
poses = poses[:, [1, 0, 2, 3], :] # swap y and z
Expand Down Expand Up @@ -134,28 +135,72 @@ def visualize_poses(poses, size=0.1):
# visualize_poses(poses)

# construct frames
frames = []
for i in range(N):
frames.append({

all_ids = np.arange(N)
test_ids = all_ids[::opt.hold]
train_ids = np.array([i for i in all_ids if i not in test_ids])

frames_train = []
frames_test = []
for i in train_ids:
frames_train.append({
'file_path': images[i],
'transform_matrix': poses[i].tolist(),
})
for i in test_ids:
frames_test.append({
'file_path': images[i],
'transform_matrix': poses[i].tolist(),
})


# construct a transforms.json
transforms = {
transforms_train = {
'w': W,
'h': H,
'fl_x': fl,
'fl_y': fl,
'cx': W // 2,
'cy': H // 2,
'aabb_scale': 2,
'frames': frames_train,
}

transforms_val = {
'w': W,
'h': H,
'fl_x': fl,
'fl_y': fl,
'cx': W // 2,
'cy': H // 2,
'aabb_scale': 2,
'frames': frames,
'frames': frames_test[::10],
}

transforms_test = {
'w': W,
'h': H,
'fl_x': fl,
'fl_y': fl,
'cx': W // 2,
'cy': H // 2,
'aabb_scale': 2,
'frames': frames_test,
}

# write
output_path = os.path.join(opt.path, 'transforms.json')
output_path = os.path.join(opt.path, 'transforms_train.json')
print(f'[INFO] write to {output_path}')
with open(output_path, 'w') as f:
json.dump(transforms_train, f, indent=2)

output_path = os.path.join(opt.path, 'transforms_test.json')
print(f'[INFO] write to {output_path}')
with open(output_path, 'w') as f:
json.dump(transforms_test, f, indent=2)

output_path = os.path.join(opt.path, 'transforms_val.json')
print(f'[INFO] write to {output_path}')
with open(output_path, 'w') as f:
json.dump(transforms, f, indent=2)
json.dump(transforms_val, f, indent=2)

2 changes: 1 addition & 1 deletion scripts/run_dnerf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_dnerf.py data/split-cookie/ --workspace trial_dnerf_cookies -O --bound 1 --scale 0.3 --dt_gamma 0
#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_dnerf.py data/split-cookie/ --workspace trial_dnerf_cookies_ncr --preload --fp16 --bound 1 --scale 0.3 --dt_gamma 0

OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=4 python main_dnerf.py data/vrig-3dprinter/ --workspace trial_dnerf_printer -O --bound 1 --scale 0.3 --dt_gamma 0
OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=4 python main_dnerf.py data/vrig-3dprinter/ --workspace trial_dnerf_printer -O --bound 2 --scale 0.33 --dt_gamma 0

0 comments on commit f08a0c6

Please sign in to comment.