Skip to content

Commit

Permalink
perturb in torch side, expose T_threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jul 31, 2022
1 parent 37a1c8c commit 3b066b6
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 263 deletions.
2 changes: 1 addition & 1 deletion dnerf/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True

dpg.add_combo(('image', 'depth'), label='image_mode', default_value=self.mode, callback=callback_change_mode)
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)

# time slider
def callback_set_time(sender, app_data):
Expand Down
2 changes: 1 addition & 1 deletion dnerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def run_cuda(self, rays_o, rays_d, time, dt_gamma=0, bg_color=None, perturb=Fals
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)

xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield[t], self.cascade, self.grid_size, nears, fars, 128, perturb, dt_gamma, max_steps)
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield[t], self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)

sigmas, rgbs, _ = self(xyzs, dirs, time)
# density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
Expand Down
2 changes: 1 addition & 1 deletion nerf/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def callback_change_mode(sender, app_data):
self.mode = app_data
self.need_update = True

dpg.add_combo(('image', 'depth'), label='image_mode', default_value=self.mode, callback=callback_change_mode)
dpg.add_combo(('image', 'depth'), label='mode', default_value=self.mode, callback=callback_change_mode)

# bg_color picker
def callback_change_bg(sender, app_data):
Expand Down
2 changes: 1 addition & 1 deletion nerf/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
from torch.utils.data import DataLoader

from .utils import get_rays, srgb_to_linear, torch_vis_2d
from .utils import get_rays


# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50
Expand Down
10 changes: 5 additions & 5 deletions nerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def run(self, rays_o, rays_d, num_steps=128, upsample_steps=128, bg_color=None,
}


def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, **kwargs):
def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, force_all_rays=False, max_steps=1024, T_thresh=1e-4, **kwargs):
# rays_o, rays_d: [B, N, 3], assumes B == 1
# return: image: [B, N, 3], depth: [B, N]

Expand Down Expand Up @@ -301,7 +301,7 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for
depths = []
images = []
for k in range(K):
weights_sum, depth, image = raymarching.composite_rays_train(sigmas[k], rgbs[k], deltas, rays)
weights_sum, depth, image = raymarching.composite_rays_train(sigmas[k], rgbs[k], deltas, rays, T_thresh)
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
images.append(image.view(*prefix, 3))
Expand All @@ -312,7 +312,7 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for

else:

weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays)
weights_sum, depth, image = raymarching.composite_rays_train(sigmas, rgbs, deltas, rays, T_thresh)
image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
image = image.view(*prefix, 3)
Expand Down Expand Up @@ -350,15 +350,15 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)

xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb, dt_gamma, max_steps)
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb if step == 0 else False, dt_gamma, max_steps)

sigmas, rgbs = self(xyzs, dirs)
# density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
# sigmas = density_outputs['sigma']
# rgbs = self.color(xyzs, dirs, **density_outputs)
sigmas = self.density_scale * sigmas

raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh)

rays_alive = rays_alive[rays_alive >= 0]

Expand Down
31 changes: 21 additions & 10 deletions raymarching/raymarching.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,12 @@ def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, ste
if step_counter is None:
step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter

_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, perturb) # m is the actually used points number
if perturb:
noises = torch.rand(N, dtype=rays_o.dtype, device=rays_o.device)
else:
noises = torch.zeros(N, dtype=rays_o.dtype, device=rays_o.device)

_backend.march_rays_train(rays_o, rays_d, density_bitfield, bound, dt_gamma, max_steps, N, C, H, M, nears, fars, xyzs, dirs, deltas, rays, step_counter, noises) # m is the actually used points number

#print(step_counter, M)

Expand All @@ -233,7 +238,7 @@ def forward(ctx, rays_o, rays_d, bound, density_bitfield, C, H, nears, fars, ste
class _composite_rays_train(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, sigmas, rgbs, deltas, rays):
def forward(ctx, sigmas, rgbs, deltas, rays, T_thresh=1e-4):
''' composite rays' rgbs, according to the ray marching formula.
Args:
rgbs: float, [M, 3]
Expand All @@ -256,10 +261,10 @@ def forward(ctx, sigmas, rgbs, deltas, rays):
depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device)
image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device)

_backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, weights_sum, depth, image)
_backend.composite_rays_train_forward(sigmas, rgbs, deltas, rays, M, N, T_thresh, weights_sum, depth, image)

ctx.save_for_backward(sigmas, rgbs, deltas, rays, weights_sum, depth, image)
ctx.dims = [M, N]
ctx.dims = [M, N, T_thresh]

return weights_sum, depth, image

Expand All @@ -273,14 +278,14 @@ def backward(ctx, grad_weights_sum, grad_depth, grad_image):
grad_image = grad_image.contiguous()

sigmas, rgbs, deltas, rays, weights_sum, depth, image = ctx.saved_tensors
M, N = ctx.dims
M, N, T_thresh = ctx.dims

grad_sigmas = torch.zeros_like(sigmas)
grad_rgbs = torch.zeros_like(rgbs)

_backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, grad_sigmas, grad_rgbs)
_backend.composite_rays_train_backward(grad_weights_sum, grad_image, sigmas, rgbs, deltas, rays, weights_sum, image, M, N, T_thresh, grad_sigmas, grad_rgbs)

return grad_sigmas, grad_rgbs, None, None
return grad_sigmas, grad_rgbs, None, None, None


composite_rays_train = _composite_rays_train.apply
Expand Down Expand Up @@ -330,7 +335,13 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, den
dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device)
deltas = torch.zeros(M, 2, dtype=rays_o.dtype, device=rays_o.device) # 2 vals, one for rgb, one for depth

_backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, perturb)
if perturb:
# torch.manual_seed(perturb) # test_gui uses spp index as seed
noises = torch.rand(n_alive, dtype=rays_o.dtype, device=rays_o.device)
else:
noises = torch.zeros(n_alive, dtype=rays_o.dtype, device=rays_o.device)

_backend.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, dt_gamma, max_steps, C, H, density_bitfield, near, far, xyzs, dirs, deltas, noises)

return xyzs, dirs, deltas

Expand All @@ -340,7 +351,7 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, bound, den
class _composite_rays(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32) # need to cast sigmas & rgbs to float
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image):
def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image, T_thresh=1e-2):
''' composite rays' rgbs, according to the ray marching formula. (for inference)
Args:
n_alive: int, number of alive rays
Expand All @@ -355,7 +366,7 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weig
depth: float, [N,], the depth value
image: float, [N, 3], the RGB channel (after multiplying alpha!)
'''
_backend.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
_backend.composite_rays(n_alive, n_step, T_thresh, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)
return tuple()


Expand Down
205 changes: 0 additions & 205 deletions raymarching/src/pcg32.h

This file was deleted.

Loading

0 comments on commit 3b066b6

Please sign in to comment.