Skip to content

Commit

Permalink
simplify test rendering thanks to ngp_pl
Browse files Browse the repository at this point in the history
  • Loading branch information
ashawkey committed Jul 4, 2022
1 parent 033892b commit 1333f67
Show file tree
Hide file tree
Showing 6 changed files with 34 additions and 100 deletions.
24 changes: 8 additions & 16 deletions dnerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,24 +342,15 @@ def run_cuda(self, rays_o, rays_d, time, dt_gamma=0, bg_color=None, perturb=Fals
image = torch.zeros(N, 3, dtype=dtype, device=device)

n_alive = N
alive_counter = torch.zeros([1], dtype=torch.int32, device=device)

rays_alive = torch.zeros(2, n_alive, dtype=torch.int32, device=device) # 2 is used to loop old/new
rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device)
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
rays_t = nears.clone() # [N]

step = 0
i = 0

while step < max_steps:

# count alive rays
if step == 0:
# init rays at first step.
torch.arange(n_alive, out=rays_alive[0])
rays_t[0] = nears
else:
alive_counter.zero_()
raymarching.compact_rays(n_alive, rays_alive[i % 2], rays_alive[(i + 1) % 2], rays_t[i % 2], rays_t[(i + 1) % 2], alive_counter)
n_alive = alive_counter.item() # must invoke D2H copy here
n_alive = rays_alive.shape[0]

# exit loop
if n_alive <= 0:
Expand All @@ -368,20 +359,21 @@ def run_cuda(self, rays_o, rays_d, time, dt_gamma=0, bg_color=None, perturb=Fals
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)

xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], rays_o, rays_d, self.bound, self.density_bitfield[t], self.cascade, self.grid_size, nears, fars, 128, perturb, dt_gamma, max_steps)
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield[t], self.cascade, self.grid_size, nears, fars, 128, perturb, dt_gamma, max_steps)

sigmas, rgbs, _ = self(xyzs, dirs, time)
# density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
# sigmas = density_outputs['sigma']
# rgbs = self.color(xyzs, dirs, **density_outputs)
sigmas = self.density_scale * sigmas

raymarching.composite_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], sigmas, rgbs, deltas, weights_sum, depth, image)
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)

rays_alive = rays_alive[rays_alive >= 0]

#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')

step += n_step
i += 1

image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
Expand Down
24 changes: 8 additions & 16 deletions nerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,24 +328,15 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for
image = torch.zeros(N, 3, dtype=dtype, device=device)

n_alive = N
alive_counter = torch.zeros([1], dtype=torch.int32, device=device)

rays_alive = torch.zeros(2, n_alive, dtype=torch.int32, device=device) # 2 is used to loop old/new
rays_t = torch.zeros(2, n_alive, dtype=dtype, device=device)
rays_alive = torch.arange(n_alive, dtype=torch.int32, device=device) # [N]
rays_t = nears.clone() # [N]

step = 0
i = 0

while step < max_steps:

# count alive rays
if step == 0:
# init rays at first step.
torch.arange(n_alive, out=rays_alive[0])
rays_t[0] = nears
else:
alive_counter.zero_()
raymarching.compact_rays(n_alive, rays_alive[i % 2], rays_alive[(i + 1) % 2], rays_t[i % 2], rays_t[(i + 1) % 2], alive_counter)
n_alive = alive_counter.item() # must invoke D2H copy here
n_alive = rays_alive.shape[0]

# exit loop
if n_alive <= 0:
Expand All @@ -354,20 +345,21 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for
# decide compact_steps
n_step = max(min(N // n_alive, 8), 1)

xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb, dt_gamma, max_steps)
xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive, rays_t, rays_o, rays_d, self.bound, self.density_bitfield, self.cascade, self.grid_size, nears, fars, 128, perturb, dt_gamma, max_steps)

sigmas, rgbs = self(xyzs, dirs)
# density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb.
# sigmas = density_outputs['sigma']
# rgbs = self.color(xyzs, dirs, **density_outputs)
sigmas = self.density_scale * sigmas

raymarching.composite_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], sigmas, rgbs, deltas, weights_sum, depth, image)
raymarching.composite_rays(n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weights_sum, depth, image)

rays_alive = rays_alive[rays_alive >= 0]

#print(f'step = {step}, n_step = {n_step}, n_alive = {n_alive}, xyzs: {xyzs.shape}')

step += n_step
i += 1

image = image + (1 - weights_sum).unsqueeze(-1) * bg_color
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
Expand Down
26 changes: 3 additions & 23 deletions raymarching/raymarching.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,8 +345,8 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weig
Args:
n_alive: int, number of alive rays
n_step: int, how many steps we march
rays_alive: int, [N], the alive rays' IDs in N (N >= n_alive, but we only use first n_alive)
rays_t: float, [N], the alive rays' time, we only use the first n_alive.
rays_alive: int, [n_alive], the alive rays' IDs in N (N >= n_alive)
rays_t: float, [N], the alive rays' time
sigmas: float, [n_alive * n_step,]
rgbs: float, [n_alive * n_step, 3]
deltas: float, [n_alive * n_step, 2], all generated points' deltas (here we record two deltas, the first is for RGB, the second for depth).
Expand All @@ -359,24 +359,4 @@ def forward(ctx, n_alive, n_step, rays_alive, rays_t, sigmas, rgbs, deltas, weig
return tuple()


composite_rays = _composite_rays.apply


class _compact_rays(Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter):
''' compact rays, remove dead rays and reallocate alive rays, to accelerate next ray marching.
Args:
n_alive: int, number of alive rays
rays_alive_old: int, [N]
rays_t_old: float, [N], dead rays are marked by rays_t < 0
alive_counter: int, [1], used to count remained alive rays.
In-place Outputs:
rays_alive: int, [N]
rays_t: float, [N]
'''
_backend.compact_rays(n_alive, rays_alive, rays_alive_old, rays_t, rays_t_old, alive_counter)
return tuple()

compact_rays = _compact_rays.apply
composite_rays = _composite_rays.apply
1 change: 0 additions & 1 deletion raymarching/src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// infer
m.def("march_rays", &march_rays, "march rays (CUDA)");
m.def("composite_rays", &composite_rays, "composite rays (CUDA)");
m.def("compact_rays", &compact_rays, "compact rays (CUDA)");
}
56 changes: 14 additions & 42 deletions raymarching/src/raymarching.cu
Original file line number Diff line number Diff line change
Expand Up @@ -725,21 +725,21 @@ __global__ void kernel_march_rays(
if (n >= n_alive) return;

const int index = rays_alive[n]; // ray id
float t = rays_t[n]; // current ray's t


// locate
rays_o += index * 3;
rays_d += index * 3;
xyzs += n * n_step * 3;
dirs += n * n_step * 3;
deltas += n * n_step * 2;

const float ox = rays_o[0], oy = rays_o[1], oz = rays_o[2];
const float dx = rays_d[0], dy = rays_d[1], dz = rays_d[2];
const float rdx = 1 / dx, rdy = 1 / dy, rdz = 1 / dz;
const float rH = 1 / (float)H;
const float H3 = H * H * H;


float t = rays_t[index]; // current ray's t
const float near = nears[index], far = fars[index];

const float dt_min = 2 * SQRT3() / max_steps;
Expand Down Expand Up @@ -829,7 +829,7 @@ template <typename scalar_t>
__global__ void kernel_composite_rays(
const uint32_t n_alive,
const uint32_t n_step,
const int* __restrict__ rays_alive,
int* rays_alive,
scalar_t* rays_t,
const scalar_t* __restrict__ sigmas,
const scalar_t* __restrict__ rgbs,
Expand All @@ -840,16 +840,18 @@ __global__ void kernel_composite_rays(
if (n >= n_alive) return;

const int index = rays_alive[n]; // ray id
scalar_t t = rays_t[n]; // current ray's t


// locate
sigmas += n * n_step;
rgbs += n * n_step * 3;
deltas += n * n_step * 2;


rays_t += index;
weights_sum += index;
depth += index;
image += index * 3;

scalar_t t = rays_t[0]; // current ray's t

scalar_t weight_sum = weights_sum[0];
scalar_t d = depth[0];
Expand Down Expand Up @@ -896,11 +898,11 @@ __global__ void kernel_composite_rays(

//printf("[n=%d] rgb=(%f, %f, %f), d=%f\n", n, r, g, b, d);

// rays_t = -1 means ray is terminated early.
// rays_alive = -1 means ray is terminated early.
if (step < n_step) {
rays_t[n] = -1;
rays_alive[n] = -1;
} else {
rays_t[n] = t;
rays_t[0] = t;
}

weights_sum[0] = weight_sum; // this is the thing I needed!
Expand All @@ -911,40 +913,10 @@ __global__ void kernel_composite_rays(
}


void composite_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, at::Tensor rays_t, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, at::Tensor weights, at::Tensor depth, at::Tensor image) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
image.scalar_type(), "composite_rays", ([&] {
kernel_composite_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, n_step, rays_alive.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), sigmas.data_ptr<scalar_t>(), rgbs.data_ptr<scalar_t>(), deltas.data_ptr<scalar_t>(), weights.data_ptr<scalar_t>(), depth.data_ptr<scalar_t>(), image.data_ptr<scalar_t>());
}));
}


template <typename scalar_t>
__global__ void kernel_compact_rays(
const uint32_t n_alive,
int* rays_alive,
const int* __restrict__ rays_alive_old,
scalar_t* rays_t,
const scalar_t* __restrict__ rays_t_old,
int* alive_counter
) {
const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x;
if (n >= n_alive) return;

// rays_t_old[n] < 0 means ray died in last composite kernel.
if (rays_t_old[n] >= 0) {
const int index = atomicAdd(alive_counter, 1);
rays_alive[index] = rays_alive_old[n];
rays_t[index] = rays_t_old[n];
}
}


void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, const at::Tensor rays_alive_old, at::Tensor rays_t, const at::Tensor rays_t_old, at::Tensor alive_counter) {
static constexpr uint32_t N_THREAD = 128;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
rays_t.scalar_type(), "compact_rays", ([&] {
kernel_compact_rays<<<div_round_up(n_alive, N_THREAD), N_THREAD>>>(n_alive, rays_alive.data_ptr<int>(), rays_alive_old.data_ptr<int>(), rays_t.data_ptr<scalar_t>(), rays_t_old.data_ptr<scalar_t>(), alive_counter.data_ptr<int>());
}));
}
3 changes: 1 addition & 2 deletions raymarching/src/raymarching.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,4 @@ void composite_rays_train_forward(const at::Tensor sigmas, const at::Tensor rgbs
void composite_rays_train_backward(const at::Tensor grad_weights_sum, const at::Tensor grad_image, const at::Tensor sigmas, const at::Tensor rgbs, const at::Tensor deltas, const at::Tensor rays, const at::Tensor weights_sum, const at::Tensor image, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs);

void march_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, const at::Tensor rays_t, const at::Tensor rays_o, const at::Tensor rays_d, const float bound, const float dt_gamma, const uint32_t max_steps, const uint32_t C, const uint32_t H, const at::Tensor grid, const at::Tensor nears, const at::Tensor fars, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, const uint32_t perturb);
void composite_rays(const uint32_t n_alive, const uint32_t n_step, const at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);
void compact_rays(const uint32_t n_alive, at::Tensor rays_alive, const at::Tensor rays_alive_old, at::Tensor rays_t, const at::Tensor rays_t_old, at::Tensor alive_counter);
void composite_rays(const uint32_t n_alive, const uint32_t n_step, at::Tensor rays_alive, at::Tensor rays_t, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor weights_sum, at::Tensor depth, at::Tensor image);

0 comments on commit 1333f67

Please sign in to comment.