Skip to content

Commit

Permalink
fix gridencoder D=1, add patch-based rendering
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Aug 11, 2022
1 parent 3b066b6 commit e8db8fb
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 14 deletions.
6 changes: 4 additions & 2 deletions gridencoder/src/gridencoder.cu
Original file line number Diff line number Diff line change
Expand Up @@ -365,11 +365,12 @@ void kernel_grid_wrapper(const float *inputs, const scalar_t *embeddings, const
template <typename scalar_t>
void grid_encode_forward_cuda(const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *outputs, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 1: kernel_grid_wrapper<scalar_t, 1>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
case 2: kernel_grid_wrapper<scalar_t, 2>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
case 3: kernel_grid_wrapper<scalar_t, 3>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
case 4: kernel_grid_wrapper<scalar_t, 4>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
case 5: kernel_grid_wrapper<scalar_t, 5>(inputs, embeddings, offsets, outputs, B, C, L, S, H, calc_grad_inputs, dy_dx, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
}

}
Expand Down Expand Up @@ -410,11 +411,12 @@ void kernel_grid_backward_wrapper(const scalar_t *grad, const float *inputs, con
template <typename scalar_t>
void grid_encode_backward_cuda(const scalar_t *grad, const float *inputs, const scalar_t *embeddings, const int *offsets, scalar_t *grad_embeddings, const uint32_t B, const uint32_t D, const uint32_t C, const uint32_t L, const float S, const uint32_t H, const bool calc_grad_inputs, scalar_t *dy_dx, scalar_t *grad_inputs, const uint32_t gridtype, const bool align_corners) {
switch (D) {
case 1: kernel_grid_backward_wrapper<scalar_t, 1>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
case 2: kernel_grid_backward_wrapper<scalar_t, 2>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
case 3: kernel_grid_backward_wrapper<scalar_t, 3>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
case 4: kernel_grid_backward_wrapper<scalar_t, 4>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
case 5: kernel_grid_backward_wrapper<scalar_t, 5>(grad, inputs, embeddings, offsets, grad_embeddings, B, C, L, S, H, calc_grad_inputs, dy_dx, grad_inputs, gridtype, align_corners); break;
default: throw std::runtime_error{"GridEncoding: C must be 1, 2, 4, or 8."};
default: throw std::runtime_error{"GridEncoding: D must be 1, 2, 3, 4, or 5."};
}
}

Expand Down
17 changes: 11 additions & 6 deletions main_nerf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
parser.add_argument('--upsample_steps', type=int, default=0, help="num steps up-sampled per ray (only valid when NOT using --cuda_ray)")
parser.add_argument('--update_extra_interval', type=int, default=16, help="iter interval to update extra status (only valid when using --cuda_ray)")
parser.add_argument('--max_ray_batch', type=int, default=4096, help="batch size of rays at inference to avoid OOM (only valid when NOT using --cuda_ray)")
parser.add_argument('--patch_size', type=int, default=1, help="[experimental] render patches in training, so as to apply LPIPS loss. 1 means disabled, use [64, 32, 16] to enable")

### network backbone options
parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training")
Expand Down Expand Up @@ -67,6 +68,12 @@
opt.fp16 = True
opt.cuda_ray = True
opt.preload = True

if opt.patch_size > 1:
opt.error_map = False # do not use error_map if use patch-based training
# assert opt.patch_size > 16, "patch_size should > 16 to run LPIPS loss."
assert opt.num_rays % (opt.patch_size ** 2) == 0, "patch_size ** 2 should be dividable by num_rays."


if opt.ff:
opt.fp16 = True
Expand Down Expand Up @@ -103,8 +110,7 @@

if opt.test:

metrics = [PSNRMeter(),]
# metrics.append([LPIPSMeter(device=device))
metrics = [PSNRMeter(), LPIPSMeter(device=device)]
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)

if opt.gui:
Expand All @@ -119,7 +125,7 @@

trainer.test(test_loader, write_video=True) # test and save video

#trainer.save_mesh(resolution=256, threshold=10)
trainer.save_mesh(resolution=256, threshold=10)

else:

Expand All @@ -130,8 +136,7 @@
# decay to 0.1 * init_lr at last iter step
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))

metrics = [PSNRMeter(),]
# metrics.append([LPIPSMeter(device=device))
metrics = [PSNRMeter(), LPIPSMeter(device=device)]
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=0.95, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=metrics, use_checkpoint=opt.ckpt, eval_interval=50)

if opt.gui:
Expand All @@ -152,4 +157,4 @@

trainer.test(test_loader, write_video=True) # test and save video

#trainer.save_mesh(resolution=256, threshold=10)
trainer.save_mesh(resolution=256, threshold=10)
4 changes: 2 additions & 2 deletions nerf/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,8 +299,8 @@ def collate(self, index):

error_map = None if self.error_map is None else self.error_map[index]

rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, error_map)
rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, error_map, self.opt.patch_size)

results = {
'H': self.H,
'W': self.W,
Expand Down
46 changes: 42 additions & 4 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def srgb_to_linear(x):


@torch.cuda.amp.autocast(enabled=False)
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
def get_rays(poses, intrinsics, H, W, N=-1, error_map=None, patch_size=1):
''' get rays
Args:
poses: [B, 4, 4], cam2world
Expand All @@ -66,7 +66,7 @@ def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
B = poses.shape[0]
fx, fy, cx, cy = intrinsics

i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device))
i, j = custom_meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device)) # float
i = i.t().reshape([1, H*W]).expand([B, H*W]) + 0.5
j = j.t().reshape([1, H*W]).expand([B, H*W]) + 0.5

Expand All @@ -75,7 +75,27 @@ def get_rays(poses, intrinsics, H, W, N=-1, error_map=None):
if N > 0:
N = min(N, H*W)

if error_map is None:
# if use patch-based sampling, ignore error_map
if patch_size > 1:

# random sample left-top cores.
# NOTE: this impl will lead to less sampling on the image corner pixels... but I don't have other ideas.
num_patch = N // (patch_size ** 2)
inds_x = torch.randint(0, H - patch_size, size=[num_patch], device=device)
inds_y = torch.randint(0, W - patch_size, size=[num_patch], device=device)
inds = torch.stack([inds_x, inds_y], dim=-1) # [np, 2]

# create meshgrid for each patch
pi, pj = custom_meshgrid(torch.arange(patch_size, device=device), torch.arange(patch_size, device=device))
offsets = torch.stack([pi.reshape(-1), pj.reshape(-1)], dim=-1) # [p^2, 2]

inds = inds.unsqueeze(1) + offsets.unsqueeze(0) # [np, p^2, 2]
inds = inds.view(-1, 2) # [N, 2]
inds = inds[:, 0] * W + inds[:, 1] # [N], flatten

inds = inds.expand([B, N])

elif error_map is None:
inds = torch.randint(0, H*W, size=[N], device=device) # may duplicate
inds = inds.expand([B, N])
else:
Expand Down Expand Up @@ -311,6 +331,11 @@ def __init__(self,
criterion.to(self.device)
self.criterion = criterion

# optionally use LPIPS loss for patch-based training
if self.opt.patch_size > 1:
import lpips
self.criterion_lpips = lpips.LPIPS(net='alex').to(self.device)

if optimizer is None:
self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam
else:
Expand Down Expand Up @@ -443,12 +468,25 @@ def train_step(self, data):
else:
gt_rgb = images

outputs = self.model.render(rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False, **vars(self.opt))
# outputs = self.model.render(rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=False if self.opt.patch_size == 1 else True, **vars(self.opt))
outputs = self.model.render(rays_o, rays_d, staged=False, bg_color=bg_color, perturb=True, force_all_rays=True, **vars(self.opt))

pred_rgb = outputs['image']

# MSE loss
loss = self.criterion(pred_rgb, gt_rgb).mean(-1) # [B, N, 3] --> [B, N]

# patch-based rendering
if self.opt.patch_size > 1:
gt_rgb = gt_rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous()
pred_rgb = pred_rgb.view(-1, self.opt.patch_size, self.opt.patch_size, 3).permute(0, 3, 1, 2).contiguous()

# torch_vis_2d(gt_rgb[0])
# torch_vis_2d(pred_rgb[0])

# LPIPS loss [not useful...]
loss = loss + 1e-3 * self.criterion_lpips(pred_rgb, gt_rgb)

# special case for CCNeRF's rank-residual training
if len(loss.shape) == 3: # [K, B, N]
loss = loss.mean(0)
Expand Down

0 comments on commit e8db8fb

Please sign in to comment.