From 8288a9d35cc50591d8b0ca2509ec942f8fd47985 Mon Sep 17 00:00:00 2001 From: ashawkey Date: Sat, 19 Feb 2022 00:31:27 +0800 Subject: [PATCH] update cuda raymarching --- hashencoder/hashgrid.py | 13 +-- nerf/network.py | 89 +++++++++--------- nerf/network_ff.py | 87 +++++++++--------- nerf/network_tcnn.py | 89 +++++++++--------- nerf/utils.py | 19 ++-- raymarching/backend.py | 4 +- raymarching/raymarching.py | 50 ++++++---- raymarching/src/raymarching.cu | 162 ++++++++++++--------------------- raymarching/src/raymarching.h | 6 +- testing/test_ffmlp.py | 64 ++++++++++--- train_nerf.py | 2 +- 11 files changed, 301 insertions(+), 284 deletions(-) diff --git a/hashencoder/hashgrid.py b/hashencoder/hashgrid.py index 2a1d6baf..625683bc 100644 --- a/hashencoder/hashgrid.py +++ b/hashencoder/hashgrid.py @@ -19,7 +19,7 @@ def forward(ctx, inputs, embeddings, offsets, per_level_scale, base_resolution, inputs = inputs.contiguous() embeddings = embeddings.contiguous() - offsets = offsets.contiguous().to(inputs.device) + offsets = offsets.contiguous() B, D = inputs.shape # batch size, coord dim L = offsets.shape[0] - 1 # level @@ -95,19 +95,20 @@ def __init__(self, input_dim=3, num_levels=16, level_dim=2, per_level_scale=2, b print('[WARN] detected HashGrid level_dim % 2 != 0, which will cause very slow backward is also enabled fp16! (maybe fix later)') # allocate parameters - self.offsets = [] + offsets = [] offset = 0 self.max_params = 2 ** log2_hashmap_size for i in range(num_levels): resolution = int(np.ceil(base_resolution * per_level_scale ** i)) params_in_level = min(self.max_params, (resolution + 1) ** input_dim) # limit max number params_in_level = int(params_in_level / 8) * 8 # make divisible - self.offsets.append(offset) + offsets.append(offset) offset += params_in_level - self.offsets.append(offset) - self.offsets = torch.from_numpy(np.array(self.offsets, dtype=np.int32)) + offsets.append(offset) + offsets = torch.from_numpy(np.array(offsets, dtype=np.int32)) + self.register_buffer('offsets', offsets) - self.n_params = self.offsets[-1] * level_dim + self.n_params = offsets[-1] * level_dim # parameters self.embeddings = nn.Parameter(torch.empty(offset, level_dim)) diff --git a/nerf/network.py b/nerf/network.py index 477c151e..3d924160 100644 --- a/nerf/network.py +++ b/nerf/network.py @@ -102,7 +102,7 @@ def __init__(self, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, - density_grid_size=-1, # density grid size + cuda_raymarching=False, ): super().__init__() @@ -150,15 +150,19 @@ def __init__(self, self.color_net = nn.ModuleList(color_net) - # density grid - if density_grid_size > 0: - # buffer is like parameter but never requires_grad - density_grid = torch.zeros([density_grid_size + 1] * 3) # +1 because we save values at grid + # extra state for cuda raymarching + self.cuda_raymarching = cuda_raymarching + if cuda_raymarching: + # density grid + density_grid = torch.zeros([128 + 1] * 3) # +1 because we save values at grid self.register_buffer('density_grid', density_grid) self.mean_density = 0 self.iter_density = 0 - else: - self.density_grid = None + # step counter + step_counter = torch.zeros(64, 2, dtype=torch.int32) # 64 is hardcoded for averaging... + self.register_buffer('step_counter', step_counter) + self.mean_count = 0 + self.local_step = 0 def forward(self, x, d, bound): # x: [B, N, 3], in [-bound, bound] @@ -299,7 +303,7 @@ def run(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color): # mix background color if bg_color is None: - bg_color = 1 # let it broadcast + bg_color = 1 image = image + (1 - weights_sum).unsqueeze(-1) * bg_color return depth, image @@ -314,15 +318,24 @@ def run_cuda(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color): bg_color = torch.ones(3, dtype=rays_o.dtype, device=rays_o.device) ### generate points (forward only) - points, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, self.training) + if self.training: + counter = self.step_counter[self.local_step % 64] + counter.zero_() # set to 0 + self.local_step += 1 + force_all_rays = False + else: + counter = None + force_all_rays = True + + xyzs, dirs, deltas, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, self.training, 128, force_all_rays) ### call network inference - sigmas, rgbs = self(points[:, :3], points[:, 3:6], bound=bound) + sigmas, rgbs = self(xyzs, dirs, bound=bound) ### accumulate rays (need backward) # inputs: sigmas: [M], rgbs: [M, 3], offsets: [N+1] # outputs: depth: [N], image: [N, 3] - depth, image = raymarching.accumulate_rays(sigmas, rgbs, points, rays, bound, bg_color) + depth, image = raymarching.accumulate_rays(sigmas, rgbs, deltas, rays, bound, bg_color) depth = depth.reshape(B, N) image = image.reshape(B, N, 3) @@ -330,22 +343,18 @@ def run_cuda(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color): return depth, image - def update_density_grid(self, bound, decay=0.95, split_size=128): - # call before run_cuda, prepare a coarse density grid. + def update_extra_state(self, bound, decay=0.95): + # call before each epoch to update extra states. - if self.density_grid is None: + if not self.cuda_raymarching: return + ### update density grid resolution = self.density_grid.shape[0] - - N = split_size # chunk to avoid OOM - X = torch.linspace(-bound, bound, resolution).split(N) - Y = torch.linspace(-bound, bound, resolution).split(N) - Z = torch.linspace(-bound, bound, resolution).split(N) - - # all_pts = [] - # all_density = [] + X = torch.linspace(-bound, bound, resolution).split(128) + Y = torch.linspace(-bound, bound, resolution).split(128) + Z = torch.linspace(-bound, bound, resolution).split(128) tmp_grid = torch.zeros_like(self.density_grid) with torch.no_grad(): @@ -354,39 +363,31 @@ def update_density_grid(self, bound, decay=0.95, split_size=128): for zi, zs in enumerate(Z): lx, ly, lz = len(xs), len(ys), len(zs) xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij') - pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).to(tmp_grid.device).unsqueeze(0) # [1, N, 3] - density = self.density(pts, bound).reshape(lx, ly, lz).detach() - tmp_grid[xi * N: xi * N + lx, yi * N: yi * N + ly, zi * N: zi * N + lz] = density - - # all_pts.append(pts[0]) - # all_density.append(density.reshape(-1)) + pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3] + # manual padding for ffmlp + n = pts.shape[0] + pad_n = 128 - (n % 128) + if pad_n != 0: + pts = torch.cat([pts, torch.zeros(pad_n, 3)], dim=0) + density = self.density(pts.to(tmp_grid.device), bound)[:n].reshape(lx, ly, lz).detach() + tmp_grid[xi * 128: xi * 128 + lx, yi * 128: yi * 128 + ly, zi * 128: zi * 128 + lz] = density - # smooth by maxpooling tmp_grid = F.pad(tmp_grid, (0, 1, 0, 1, 0, 1)) tmp_grid = F.max_pool3d(tmp_grid.unsqueeze(0).unsqueeze(0), kernel_size=2, stride=1).squeeze(0).squeeze(0) # ema update - #self.density_grid = tmp_grid self.density_grid = torch.maximum(self.density_grid * decay, tmp_grid) - self.mean_density = torch.mean(self.density_grid).item() self.iter_density += 1 - # TMP: save as voxel volume (point cloud format...) - # all_pts = torch.cat(all_pts, dim=0).detach().cpu().numpy() # [N, 3] - # all_density = torch.cat(all_density, dim=0).detach().cpu().numpy() # [N] - # mask = all_density > 10 - # all_pts = all_pts[mask] - # all_density = all_density[mask] - # plot_pointcloud(all_pts, map_color(all_density)) - - #vertices, triangles = mcubes.marching_cubes(tmp_grid.detach().cpu().numpy(), 5) - #vertices = vertices / (resolution - 1.0) * 2 * bound - bound - #mesh = trimesh.Trimesh(vertices, triangles) - #mesh.export(f'./{self.iter_density}.ply') + ### update step counter + total_step = min(64, self.local_step) + if total_step > 0: + self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) + self.local_step = 0 - print(f'[density grid] iter={self.iter_density} min={self.density_grid.min().item()}, max={self.density_grid.max().item()}, mean={self.mean_density}') + print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f} | [step counter] mean={self.mean_count}') def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False, max_ray_batch=4096, bg_color=None, **kwargs): diff --git a/nerf/network_ff.py b/nerf/network_ff.py index 471ab2f1..2efde109 100644 --- a/nerf/network_ff.py +++ b/nerf/network_ff.py @@ -91,7 +91,7 @@ def __init__(self, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, - density_grid_size=-1, # density grid size + cuda_raymarching=False, ): super().__init__() @@ -121,23 +121,27 @@ def __init__(self, num_layers=self.num_layers_color, ) - # density grid - if density_grid_size > 0: - # buffer is like parameter but never requires_grad - density_grid = torch.zeros([density_grid_size + 1] * 3) # +1 because we save values at grid + # extra state for cuda raymarching + self.cuda_raymarching = cuda_raymarching + if cuda_raymarching: + # density grid + density_grid = torch.zeros([128 + 1] * 3) # +1 because we save values at grid self.register_buffer('density_grid', density_grid) self.mean_density = 0 self.iter_density = 0 - else: - self.density_grid = None + # step counter + step_counter = torch.zeros(64, 2, dtype=torch.int32) # 64 is hardcoded for averaging... + self.register_buffer('step_counter', step_counter) + self.mean_count = 0 + self.local_step = 0 def forward(self, x, d, bound): # x: [B, N, 3], in [-bound, bound] # d: [B, N, 3], nomalized in [-1, 1] prefix = x.shape[:-1] - x = x.reshape(-1, 3) - d = d.reshape(-1, 3) + x = x.view(-1, 3) + d = d.view(-1, 3) # sigma x = self.encoder(x, size=bound) @@ -156,8 +160,8 @@ def forward(self, x, d, bound): # sigmoid activation for rgb color = torch.sigmoid(h) - sigma = sigma.reshape(*prefix) - color = color.reshape(*prefix, -1) + sigma = sigma.view(*prefix) + color = color.view(*prefix, -1) return sigma, color @@ -165,7 +169,7 @@ def density(self, x, bound): # x: [B, N, 3], in [-bound, bound] prefix = x.shape[:-1] - x = x.reshape(-1, 3) + x = x.view(-1, 3) x = self.encoder(x, size=bound) h = self.sigma_net(x) @@ -173,7 +177,7 @@ def density(self, x, bound): #sigma = torch.exp(torch.clamp(h[..., 0], -15, 15)) sigma = F.relu(h[..., 0]) - sigma = sigma.reshape(*prefix) + sigma = sigma.view(*prefix) return sigma @@ -286,26 +290,24 @@ def run_cuda(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color): bg_color = torch.ones(3, dtype=rays_o.dtype, device=rays_o.device) ### generate points (forward only) - points, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, self.training) - - ### call network inference - # manual pad for ffmlp (slow, should be avoided...) - n = points.shape[0] - pad_n = 128 - (n % 128) - if pad_n > 0: - points = torch.cat([points, torch.zeros(pad_n, points.shape[1], device=points.device, dtype=points.dtype)], dim=0) - - sigmas, rgbs = self(points[:, :3], points[:, 3:6], bound=bound) + if self.training: + counter = self.step_counter[self.local_step % 64] + counter.zero_() # set to 0 + self.local_step += 1 + force_all_rays = False + else: + counter = None + force_all_rays = True - if pad_n > 0: - sigmas = sigmas[:n] - rgbs = rgbs[:n] + xyzs, dirs, deltas, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, self.training, 128, force_all_rays) + ### call network inference + sigmas, rgbs = self(xyzs, dirs, bound=bound) ### accumulate rays (need backward) # inputs: sigmas: [M], rgbs: [M, 3], offsets: [N+1] # outputs: depth: [N], image: [N, 3] - depth, image = raymarching.accumulate_rays(sigmas, rgbs, points, rays, bound, bg_color) + depth, image = raymarching.accumulate_rays(sigmas, rgbs, deltas, rays, bound, bg_color) depth = depth.reshape(B, N) image = image.reshape(B, N, 3) @@ -313,19 +315,18 @@ def run_cuda(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color): return depth, image - def update_density_grid(self, bound, decay=0.95, split_size=128): - # call before run_cuda, prepare a coarse density grid. + def update_extra_state(self, bound, decay=0.95): + # call before each epoch to update extra states. - if self.density_grid is None: + if not self.cuda_raymarching: return + ### update density grid resolution = self.density_grid.shape[0] - - N = split_size # chunk to avoid OOM - X = torch.linspace(-bound, bound, resolution).split(N) - Y = torch.linspace(-bound, bound, resolution).split(N) - Z = torch.linspace(-bound, bound, resolution).split(N) + X = torch.linspace(-bound, bound, resolution).split(128) + Y = torch.linspace(-bound, bound, resolution).split(128) + Z = torch.linspace(-bound, bound, resolution).split(128) tmp_grid = torch.zeros_like(self.density_grid) with torch.no_grad(): @@ -341,26 +342,24 @@ def update_density_grid(self, bound, decay=0.95, split_size=128): if pad_n != 0: pts = torch.cat([pts, torch.zeros(pad_n, 3)], dim=0) density = self.density(pts.to(tmp_grid.device), bound)[:n].reshape(lx, ly, lz).detach() - tmp_grid[xi * N: xi * N + lx, yi * N: yi * N + ly, zi * N: zi * N + lz] = density + tmp_grid[xi * 128: xi * 128 + lx, yi * 128: yi * 128 + ly, zi * 128: zi * 128 + lz] = density # smooth by maxpooling tmp_grid = F.pad(tmp_grid, (0, 1, 0, 1, 0, 1)) tmp_grid = F.max_pool3d(tmp_grid.unsqueeze(0).unsqueeze(0), kernel_size=2, stride=1).squeeze(0).squeeze(0) # ema update - #self.density_grid = tmp_grid self.density_grid = torch.maximum(self.density_grid * decay, tmp_grid) - self.mean_density = torch.mean(self.density_grid).item() self.iter_density += 1 - # TMP: save mesh for debug - # vertices, triangles = mcubes.marching_cubes(tmp_grid.detach().cpu().numpy(), 5) - # vertices = vertices / (resolution - 1.0) * 2 * bound - bound - # mesh = trimesh.Trimesh(vertices, triangles) - # mesh.export(f'./tmp/{self.iter_density}.ply') + ### update step counter + total_step = min(64, self.local_step) + if total_step > 0: + self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) + self.local_step = 0 - print(f'[density grid] iter={self.iter_density} min={self.density_grid.min().item()}, max={self.density_grid.max().item()}, mean={self.mean_density}') + print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f} | [step counter] mean={self.mean_count}') def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False, max_ray_batch=4096, bg_color=None, **kwargs): diff --git a/nerf/network_tcnn.py b/nerf/network_tcnn.py index 471f7434..71febf7c 100644 --- a/nerf/network_tcnn.py +++ b/nerf/network_tcnn.py @@ -90,7 +90,7 @@ def __init__(self, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, - density_grid_size=-1, # density grid size + cuda_raymarching=False, ): super().__init__() @@ -149,23 +149,27 @@ def __init__(self, }, ) - # density grid - if density_grid_size > 0: - # buffer is like parameter but never requires_grad - density_grid = torch.zeros([density_grid_size + 1] * 3) # +1 because we save values at grid +# extra state for cuda raymarching + self.cuda_raymarching = cuda_raymarching + if cuda_raymarching: + # density grid + density_grid = torch.zeros([128 + 1] * 3) # +1 because we save values at grid self.register_buffer('density_grid', density_grid) self.mean_density = 0 self.iter_density = 0 - else: - self.density_grid = None + # step counter + step_counter = torch.zeros(64, 2, dtype=torch.int32) # 64 is hardcoded for averaging... + self.register_buffer('step_counter', step_counter) + self.mean_count = 0 + self.local_step = 0 def forward(self, x, d, bound): # x: [B, N, 3], in [-bound, bound] # d: [B, N, 3], nomalized in [-1, 1] prefix = x.shape[:-1] - x = x.reshape(-1, 3) - d = d.reshape(-1, 3) + x = x.view(-1, 3) + d = d.view(-1, 3) # sigma x = (x + bound) / (2 * bound) # to [0, 1] @@ -186,8 +190,8 @@ def forward(self, x, d, bound): # sigmoid activation for rgb color = torch.sigmoid(h) - sigma = sigma.reshape(*prefix) - color = color.reshape(*prefix, -1) + sigma = sigma.view(*prefix) + color = color.view(*prefix, -1) return sigma, color @@ -195,7 +199,7 @@ def density(self, x, bound): # x: [B, N, 3], in [-bound, bound] prefix = x.shape[:-1] - x = x.reshape(-1, 3) + x = x.view(-1, 3) x = (x + bound) / (2 * bound) # to [0, 1] x = self.encoder(x) @@ -204,7 +208,7 @@ def density(self, x, bound): #sigma = torch.exp(torch.clamp(h[..., 0], -15, 15)) sigma = F.relu(h[..., 0]) - sigma = sigma.reshape(*prefix) + sigma = sigma.view(*prefix) return sigma @@ -317,46 +321,43 @@ def run_cuda(self, rays_o, rays_d, num_steps, bound, upsample_steps, bg_color): bg_color = torch.ones(3, dtype=rays_o.dtype, device=rays_o.device) ### generate points (forward only) - points, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, self.training) - - ### call network inference - # manual pad for ffmlp - n = points.shape[0] - pad_n = 128 - (n % 128) - if pad_n > 0: - points = torch.cat([points, torch.zeros(pad_n, points.shape[1], device=points.device, dtype=points.dtype)], dim=0) - - sigmas, rgbs = self(points[:, :3], points[:, 3:6], bound=bound) + if self.training: + counter = self.step_counter[self.local_step % 64] + counter.zero_() # set to 0 + self.local_step += 1 + force_all_rays = False + else: + counter = None + force_all_rays = True - if pad_n > 0: - sigmas = sigmas[:n] - rgbs = rgbs[:n] + xyzs, dirs, deltas, rays = raymarching.generate_points(rays_o, rays_d, bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, self.training, 128, force_all_rays) + ### call network inference + sigmas, rgbs = self(xyzs, dirs, bound=bound) ### accumulate rays (need backward) # inputs: sigmas: [M], rgbs: [M, 3], offsets: [N+1] # outputs: depth: [N], image: [N, 3] - depth, image = raymarching.accumulate_rays(sigmas, rgbs, points, rays, bound, bg_color) + depth, image = raymarching.accumulate_rays(sigmas, rgbs, deltas, rays, bound, bg_color) depth = depth.reshape(B, N) image = image.reshape(B, N, 3) return depth, image - - def update_density_grid(self, bound, decay=0.95, split_size=128): - # call before run_cuda, prepare a coarse density grid. + + def update_extra_state(self, bound, decay=0.95): + # call before each epoch to update extra states. - if self.density_grid is None: + if not self.cuda_raymarching: return + ### update density grid resolution = self.density_grid.shape[0] - - N = split_size # chunk to avoid OOM - X = torch.linspace(-bound, bound, resolution).split(N) - Y = torch.linspace(-bound, bound, resolution).split(N) - Z = torch.linspace(-bound, bound, resolution).split(N) + X = torch.linspace(-bound, bound, resolution).split(128) + Y = torch.linspace(-bound, bound, resolution).split(128) + Z = torch.linspace(-bound, bound, resolution).split(128) tmp_grid = torch.zeros_like(self.density_grid) with torch.no_grad(): @@ -372,26 +373,24 @@ def update_density_grid(self, bound, decay=0.95, split_size=128): if pad_n != 0: pts = torch.cat([pts, torch.zeros(pad_n, 3)], dim=0) density = self.density(pts.to(tmp_grid.device), bound)[:n].reshape(lx, ly, lz).detach() - tmp_grid[xi * N: xi * N + lx, yi * N: yi * N + ly, zi * N: zi * N + lz] = density + tmp_grid[xi * 128: xi * 128 + lx, yi * 128: yi * 128 + ly, zi * 128: zi * 128 + lz] = density # smooth by maxpooling tmp_grid = F.pad(tmp_grid, (0, 1, 0, 1, 0, 1)) tmp_grid = F.max_pool3d(tmp_grid.unsqueeze(0).unsqueeze(0), kernel_size=2, stride=1).squeeze(0).squeeze(0) # ema update - #self.density_grid = tmp_grid self.density_grid = torch.maximum(self.density_grid * decay, tmp_grid) - self.mean_density = torch.mean(self.density_grid).item() self.iter_density += 1 - # TMP: save mesh for debug - # vertices, triangles = mcubes.marching_cubes(tmp_grid.detach().cpu().numpy(), 5) - # vertices = vertices / (resolution - 1.0) * 2 * bound - bound - # mesh = trimesh.Trimesh(vertices, triangles) - # mesh.export(f'./tmp/{self.iter_density}.ply') + ### update step counter + total_step = min(64, self.local_step) + if total_step > 0: + self.mean_count = int(self.step_counter[:total_step, 0].sum().item() / total_step) + self.local_step = 0 - print(f'[density grid] iter={self.iter_density} min={self.density_grid.min().item()}, max={self.density_grid.max().item()}, mean={self.mean_density}') + print(f'[density grid] min={self.density_grid.min().item():.4f}, max={self.density_grid.max().item():.4f}, mean={self.mean_density:.4f} | [step counter] mean={self.mean_count}') def render(self, rays_o, rays_d, num_steps, bound, upsample_steps, staged=False, max_ray_batch=4096, bg_color=None, **kwargs): diff --git a/nerf/utils.py b/nerf/utils.py index 616636e5..50b50823 100644 --- a/nerf/utils.py +++ b/nerf/utils.py @@ -319,7 +319,7 @@ def test_step(self, data): B = poses.shape[0] rays_o, rays_d, _ = get_rays(poses, intrinsics, H, W, -1) outputs = self.model.render(rays_o, rays_d, staged=True, **self.conf) - pred_rgb = outputs['rgb'].reshape(B, H, W, 3) + pred_rgb = outputs['rgb'].reshape(B, H, W, -1) pred_depth = outputs['depth'].reshape(B, H, W) return pred_rgb, pred_depth @@ -389,11 +389,6 @@ def test(self, loader, save_path=None): self.log(f"==> Start Test, save results to {save_path}") - # update grid - if self.model.density_grid is not None: - with torch.cuda.amp.autocast(enabled=self.fp16): - self.model.update_density_grid(self.conf['bound']) - pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') self.model.eval() with torch.no_grad(): @@ -447,9 +442,9 @@ def train_one_epoch(self, loader): self.model.train() # update grid - if self.model.density_grid is not None: + if self.model.cuda_raymarching: with torch.cuda.amp.autocast(enabled=self.fp16): - self.model.update_density_grid(self.conf['bound']) + self.model.update_extra_state(self.conf['bound']) # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs # ref: https://pytorch.org/docs/stable/data.html @@ -614,6 +609,10 @@ def save_checkpoint(self, full=False, best=False): 'stats': self.stats, } + if self.model.cuda_raymarching: + state['mean_count'] = self.model.mean_count + state['mean_density'] = self.model.mean_density + if full: state['optimizer'] = self.optimizer.state_dict() state['lr_scheduler'] = self.lr_scheduler.state_dict() @@ -685,6 +684,10 @@ def load_checkpoint(self, checkpoint=None): self.stats = checkpoint_dict['stats'] self.epoch = checkpoint_dict['epoch'] + + if self.model.cuda_raymarching: + self.model.mean_count = checkpoint_dict['mean_count'] + self.model.mean_density = checkpoint_dict['mean_density'] if self.optimizer and 'optimizer' in checkpoint_dict: try: diff --git a/raymarching/backend.py b/raymarching/backend.py index 5b71784a..e3ce9ce2 100644 --- a/raymarching/backend.py +++ b/raymarching/backend.py @@ -4,8 +4,8 @@ _src_path = os.path.dirname(os.path.abspath(__file__)) _backend = load(name='_raymarching', - extra_cflags=['-O3'], # '-std=c++17' - extra_cuda_cflags=['-O3'], # '-arch=sm_70' + extra_cflags=['-O3', '-std=c++14'], + extra_cuda_cflags=['-O3', '-std=c++14'], sources=[os.path.join(_src_path, 'src', f) for f in [ 'raymarching.cu', 'bindings.cpp', diff --git a/raymarching/raymarching.py b/raymarching/raymarching.py index d69d73e8..6e34118c 100644 --- a/raymarching/raymarching.py +++ b/raymarching/raymarching.py @@ -17,7 +17,7 @@ class _generate_points(Function): @staticmethod @custom_fwd(cast_inputs=torch.half) - def forward(ctx, rays_o, rays_d, bound, density_grid, mean_density, iter_density, perturb=False): + def forward(ctx, rays_o, rays_d, bound, density_grid, mean_density, iter_density, step_counter=None, mean_count=-1, perturb=False, align=-1, force_all_rays=False): rays_o = rays_o.reshape(-1, 3).contiguous() rays_d = rays_d.reshape(-1, 3).contiguous() @@ -25,21 +25,37 @@ def forward(ctx, rays_o, rays_d, bound, density_grid, mean_density, iter_density N = rays_o.shape[0] # num rays H = density_grid.shape[0] # grid resolution - M = N * 1024 # max points number in total, hardcoded + M = N * 512 # init max points number in total, hardcoded + + # running average based on previous epoch (mimic `measured_batch_size_before_compaction` in instant-ngp) + # It estimate the max points number to enable faster training, but will lead to random rays to be ignored. + if not force_all_rays and mean_count > 0: + if align > 0: + mean_count += align - mean_count % align + M = mean_count - points = torch.empty(M, 7, dtype=rays_o.dtype, device=rays_o.device) + xyzs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + dirs = torch.zeros(M, 3, dtype=rays_o.dtype, device=rays_o.device) + deltas = torch.zeros(M, dtype=rays_o.dtype, device=rays_o.device) rays = torch.empty(N, 3, dtype=torch.int32, device=rays_o.device) # id, offset, num_steps - counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter - - _backend.generate_points(rays_o, rays_d, density_grid, mean_density, iter_density, bound, N, H, M, points, rays, counter, perturb) # m is the actually used points number - # inplace resize - # TODO: this cause D2H copy... should avoid... - #points.resize_(counter[0], 7) + if step_counter is None: + step_counter = torch.zeros(2, dtype=torch.int32, device=rays_o.device) # point counter, ray counter + + _backend.generate_points(rays_o, rays_d, density_grid, mean_density, iter_density, bound, N, H, M, xyzs, dirs, deltas, rays, step_counter, perturb) # m is the actually used points number + + #print(step_counter, M) - #print(f"generated points count m = {counter[0]} << {M}") + # only used at the first epoch. + if force_all_rays or mean_count <= 0: + m = step_counter[0].item() # cause copy to CPU, will slow down a bit + if align > 0: + m += align - m % align + xyzs = xyzs[:m] + dirs = dirs[:m] + deltas = deltas[:m] - return points, rays + return xyzs, dirs, deltas, rays generate_points = _generate_points.apply @@ -50,11 +66,11 @@ def forward(ctx, rays_o, rays_d, bound, density_grid, mean_density, iter_density class _accumulate_rays(Function): @staticmethod @custom_fwd(cast_inputs=torch.half) - def forward(ctx, sigmas, rgbs, points, rays, bound, bg_color): + def forward(ctx, sigmas, rgbs, deltas, rays, bound, bg_color): sigmas = sigmas.contiguous() rgbs = rgbs.contiguous() - points = points.contiguous() + deltas = deltas.contiguous() rays = rays.contiguous() M = sigmas.shape[0] @@ -63,9 +79,9 @@ def forward(ctx, sigmas, rgbs, points, rays, bound, bg_color): depth = torch.empty(N, dtype=sigmas.dtype, device=sigmas.device) image = torch.empty(N, 3, dtype=sigmas.dtype, device=sigmas.device) - _backend.accumulate_rays_forward(sigmas, rgbs, points, rays, bound, bg_color, M, N, depth, image) + _backend.accumulate_rays_forward(sigmas, rgbs, deltas, rays, bound, bg_color, M, N, depth, image) - ctx.save_for_backward(sigmas, rgbs, points, rays, image, bg_color) + ctx.save_for_backward(sigmas, rgbs, deltas, rays, image, bg_color) ctx.dims = [M, N, bound] return depth, image @@ -77,13 +93,13 @@ def backward(ctx, grad_depth, grad_image): grad_image = grad_image.contiguous() - sigmas, rgbs, points, rays, image, bg_color = ctx.saved_tensors + sigmas, rgbs, deltas, rays, image, bg_color = ctx.saved_tensors M, N, bound = ctx.dims grad_sigmas = torch.zeros_like(sigmas) grad_rgbs = torch.zeros_like(rgbs) - _backend.accumulate_rays_backward(grad_image, sigmas, rgbs, points, rays, image, bound, M, N, grad_sigmas, grad_rgbs) + _backend.accumulate_rays_backward(grad_image, sigmas, rgbs, deltas, rays, image, bound, M, N, grad_sigmas, grad_rgbs) return grad_sigmas, grad_rgbs, None, None, None, None diff --git a/raymarching/src/raymarching.cu b/raymarching/src/raymarching.cu index 6fe3dac3..181dd125 100644 --- a/raymarching/src/raymarching.cu +++ b/raymarching/src/raymarching.cu @@ -54,7 +54,7 @@ __global__ void kernel_generate_points( const int iter_density, const float bound, const uint32_t N, const uint32_t H, const uint32_t M, - scalar_t * points, + scalar_t * xyzs, scalar_t * dirs, scalar_t * deltas, int * rays, int * counter, const bool perturb @@ -63,7 +63,7 @@ __global__ void kernel_generate_points( const uint32_t n = threadIdx.x + blockIdx.x * blockDim.x; if (n >= N) return; - const uint32_t max_steps = M / N; // fixed to 1024 + const uint32_t max_steps = 512; // M / N; const float rbound = 1 / bound; const float density_thresh = fminf(DENSITY_THRESH(), mean_density); @@ -97,50 +97,6 @@ __global__ void kernel_generate_points( const float t0 = near; // + dt_small * rng.next_float(); - // if iter_density too low (thus grid is unreliable), only generate coarse points. - //if (iter_density < 50) { - // if (false) { - - // uint32_t num_steps = H - 1; - - // uint32_t point_index = atomicAdd(counter, num_steps); - // uint32_t ray_index = atomicAdd(counter + 1, 1); - - // if (point_index + num_steps > M) return; - - // points += point_index * 7; - - // // write rays - // rays[ray_index * 3] = n; - // rays[ray_index * 3 + 1] = point_index; - // rays[ray_index * 3 + 2] = num_steps; - - // float t = t0; - // float last_t = t; - // uint32_t step = 0; - - // while (t <= far && step < num_steps) { - // // current point - // const float x = ox + t * dx; - // const float y = oy + t * dy; - // const float z = oz + t * dz; - // // write step - // points[0] = x; - // points[1] = y; - // points[2] = z; - // points[3] = dx; - // points[4] = dy; - // points[5] = dz; - // step++; - // t += dt_large * rng.next_float() * 2; // random perturb - // points[6] = t - last_t; - // points += 7; - // last_t = t; - // } - // return; - // } - - // else use two passes to generate fine samples // first pass: estimation of num_steps float t = t0; uint32_t num_steps = 0; @@ -183,10 +139,11 @@ __global__ void kernel_generate_points( //printf("[n=%d] num_steps=%d\n", n, num_steps); //printf("[n=%d] num_steps=%d, pc=%d, rc=%d\n", n, num_steps, counter[0], counter[1]); + // second pass: really locate and write points & dirs uint32_t point_index = atomicAdd(counter, num_steps); uint32_t ray_index = atomicAdd(counter + 1, 1); - + //printf("[n=%d] num_steps=%d, point_index=%d, ray_index=%d\n", n, num_steps, point_index, ray_index); // write rays @@ -195,9 +152,11 @@ __global__ void kernel_generate_points( rays[ray_index * 3 + 2] = num_steps; if (num_steps == 0) return; - if (point_index + num_steps > M) return; + if (point_index + num_steps >= M) return; - points += point_index * 7; + xyzs += point_index * 3; + dirs += point_index * 3; + deltas += point_index; t = t0; dt = dt_small; @@ -223,23 +182,26 @@ __global__ void kernel_generate_points( // if occpuied, advance a small step, and write to output if (density > density_thresh) { // write step - points[0] = x; - points[1] = y; - points[2] = z; - points[3] = dx; - points[4] = dy; - points[5] = dz; + xyzs[0] = x; + xyzs[1] = y; + xyzs[2] = z; + dirs[0] = dx; + dirs[1] = dy; + dirs[2] = dz; step++; dt = fminf(dt * dt_gamma, dt_large); if (perturb) { const float p_dt = dt * rng.next_float() * 2; t += p_dt; - points[6] = p_dt; + deltas[0] = p_dt; } else { t += dt; - points[6] = dt; + deltas[0] = dt; } - points += 7; + xyzs += 3; + dirs += 3; + deltas++; + // else, skip a large step (basically skip a voxel grid) } else { // calc distance to next voxel @@ -255,13 +217,13 @@ __global__ void kernel_generate_points( // rays_o/d: [N, 3] // grid: [H, H, H] -// points: [M, 3] +// xyzs, dirs, deltas: [M, 3], [M, 3], [M] // dirs: [M, 3] // rays: [N, 3], idx, offset, num_steps template -void generate_points_cuda(const scalar_t *rays_o, const scalar_t *rays_d, const scalar_t *grid, const float mean_density, const int iter_density, const float bound, const uint32_t N, const uint32_t H, const uint32_t M, scalar_t *points, int *rays, int *counter, const bool perturb) { +void generate_points_cuda(const scalar_t *rays_o, const scalar_t *rays_d, const scalar_t *grid, const float mean_density, const int iter_density, const float bound, const uint32_t N, const uint32_t H, const uint32_t M, scalar_t *xyzs, scalar_t *dirs, scalar_t *deltas, int *rays, int *counter, const bool perturb) { static constexpr uint32_t N_THREAD = 256; - kernel_generate_points<<>>(rays_o, rays_d, grid, mean_density, iter_density, bound, N, H, M, points, rays, counter, perturb); + kernel_generate_points<<>>(rays_o, rays_d, grid, mean_density, iter_density, bound, N, H, M, xyzs, dirs, deltas, rays, counter, perturb); } @@ -269,7 +231,7 @@ template __global__ void kernel_accumulate_rays_forward( const scalar_t * __restrict__ sigmas, const scalar_t * __restrict__ rgbs, - const scalar_t * __restrict__ points, + const scalar_t * __restrict__ deltas, const scalar_t * __restrict__ bg_color, const int * __restrict__ rays, const float bound, @@ -286,9 +248,18 @@ __global__ void kernel_accumulate_rays_forward( uint32_t offset = rays[n * 3 + 1]; uint32_t num_steps = rays[n * 3 + 2]; + // empty ray, or ray that exceed max step count. + if (num_steps == 0 || offset + num_steps >= M) { + depth[index] = 0; + image[index * 3] = bg_color[0]; + image[index * 3 + 1] = bg_color[1]; + image[index * 3 + 2] = bg_color[2]; + return; + } + sigmas += offset; rgbs += offset * 3; - points += offset * 7; + deltas += offset; // accumulate uint32_t step = 0; @@ -302,7 +273,7 @@ __global__ void kernel_accumulate_rays_forward( if (T < 1e-4f) break; const scalar_t sigma = sigmas[0]; - const scalar_t delta = points[6]; + const scalar_t delta = deltas[0]; const scalar_t rr = rgbs[0], gg = rgbs[1], bb = rgbs[2]; const scalar_t alpha = 1.0f - __expf(- sigma * delta); const scalar_t weight = alpha * T; @@ -320,7 +291,7 @@ __global__ void kernel_accumulate_rays_forward( // locate sigmas++; rgbs += 3; - points += 7; + deltas++; step++; } @@ -345,15 +316,15 @@ __global__ void kernel_accumulate_rays_forward( // sigmas: [M] // rgbs: [M, 3] -// points: [M, 7] +// deltas: [M] // bg_color: [3] // rays: [N, 3], idx, offset, num_steps // depth: [N] // image: [N, 3] template -void accumulate_rays_forward_cuda(const scalar_t *sigmas, const scalar_t *rgbs, const scalar_t *points, const scalar_t *bg_color, const int *rays, const float bound, const uint32_t M, const uint32_t N, scalar_t *depth, scalar_t *image) { +void accumulate_rays_forward_cuda(const scalar_t *sigmas, const scalar_t *rgbs, const scalar_t *deltas, const scalar_t *bg_color, const int *rays, const float bound, const uint32_t M, const uint32_t N, scalar_t *depth, scalar_t *image) { static constexpr uint32_t N_THREAD = 256; - kernel_accumulate_rays_forward<<>>(sigmas, rgbs, points, bg_color, rays, bound, M, N, depth, image); + kernel_accumulate_rays_forward<<>>(sigmas, rgbs, deltas, bg_color, rays, bound, M, N, depth, image); } @@ -362,7 +333,7 @@ __global__ void kernel_accumulate_rays_backward( const scalar_t * __restrict__ grad, const scalar_t * __restrict__ sigmas, const scalar_t * __restrict__ rgbs, - const scalar_t * __restrict__ points, + const scalar_t * __restrict__ deltas, const int * __restrict__ rays, const scalar_t * __restrict__ image, const float bound, @@ -379,11 +350,13 @@ __global__ void kernel_accumulate_rays_backward( uint32_t offset = rays[n * 3 + 1]; uint32_t num_steps = rays[n * 3 + 2]; + if (num_steps == 0 || offset + num_steps >= M) return; + grad += index * 3; image += index * 3; sigmas += offset; rgbs += offset * 3; - points += offset * 7; + deltas += offset; grad_sigmas += offset; grad_rgbs += offset * 3; @@ -399,7 +372,7 @@ __global__ void kernel_accumulate_rays_backward( if (T < 1e-4f) break; const scalar_t sigma = sigmas[0]; - const scalar_t delta = points[6]; + const scalar_t delta = deltas[0]; const scalar_t rr = rgbs[0], gg = rgbs[1], bb = rgbs[2]; const scalar_t alpha = 1.0f - __expf(- sigma * delta); const scalar_t weight = alpha * T; @@ -426,7 +399,7 @@ __global__ void kernel_accumulate_rays_backward( rgbs += 3; grad_sigmas++; grad_rgbs += 3; - points += 7; + deltas++; step++; } @@ -436,58 +409,43 @@ __global__ void kernel_accumulate_rays_backward( // grad: [N, 3] // sigmas: [M] // rgbs: [M, 3] -// points: [M, 7] +// deltas: [M] // rays: [N, 3], idx, offset, num_steps // image: [N, 3] // grad_sigmas: [M] // grad_rgbs: [M, 3] template -void accumulate_rays_backward_cuda(const scalar_t *grad, const scalar_t *sigmas, const scalar_t *rgbs, const scalar_t *points, const int *rays, const scalar_t *image, const float bound, const uint32_t M, const uint32_t N, scalar_t *grad_sigmas, scalar_t *grad_rgbs) { +void accumulate_rays_backward_cuda(const scalar_t *grad, const scalar_t *sigmas, const scalar_t *rgbs, const scalar_t *deltas, const int *rays, const scalar_t *image, const float bound, const uint32_t M, const uint32_t N, scalar_t *grad_sigmas, scalar_t *grad_rgbs) { static constexpr uint32_t N_THREAD = 256; - kernel_accumulate_rays_backward<<>>(grad, sigmas, rgbs, points, rays, image, bound, M, N, grad_sigmas, grad_rgbs); + kernel_accumulate_rays_backward<<>>(grad, sigmas, rgbs, deltas, rays, image, bound, M, N, grad_sigmas, grad_rgbs); } - - -void generate_points(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float mean_density, const int iter_density, const float bound, const uint32_t N, const uint32_t H, const uint32_t M, at::Tensor points, at::Tensor rays, at::Tensor counter, const bool perturb) { +void generate_points(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float mean_density, const int iter_density, const float bound, const uint32_t N, const uint32_t H, const uint32_t M, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const bool perturb) { CHECK_CUDA(rays_o); CHECK_CUDA(rays_d); CHECK_CUDA(grid); - CHECK_CUDA(points); - CHECK_CUDA(rays); - CHECK_CUDA(counter); CHECK_CONTIGUOUS(rays_o); CHECK_CONTIGUOUS(rays_d); CHECK_CONTIGUOUS(grid); - CHECK_CONTIGUOUS(points); - CHECK_CONTIGUOUS(rays); - CHECK_CONTIGUOUS(counter); CHECK_IS_FLOATING(rays_o); CHECK_IS_FLOATING(rays_d); CHECK_IS_FLOATING(grid); - CHECK_IS_FLOATING(points); - CHECK_IS_INT(rays); - CHECK_IS_INT(counter); - AT_DISPATCH_FLOATING_TYPES_AND_HALF( rays_o.scalar_type(), "generate_points", ([&] { - generate_points_cuda(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), mean_density, iter_density, bound, N, H, M, points.data_ptr(), rays.data_ptr(), counter.data_ptr(), perturb); + generate_points_cuda(rays_o.data_ptr(), rays_d.data_ptr(), grid.data_ptr(), mean_density, iter_density, bound, N, H, M, xyzs.data_ptr(), dirs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), counter.data_ptr(), perturb); })); - - // resize in c++ - points.resize_({counter[0].item().to(), points.size(1)}); } -void accumulate_rays_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor points, at::Tensor rays, const float bound, at::Tensor bg_color, const uint32_t M, const uint32_t N, at::Tensor depth, at::Tensor image) { +void accumulate_rays_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const float bound, at::Tensor bg_color, const uint32_t M, const uint32_t N, at::Tensor depth, at::Tensor image) { CHECK_CUDA(sigmas); CHECK_CUDA(rgbs); - CHECK_CUDA(points); + CHECK_CUDA(deltas); CHECK_CUDA(rays); CHECK_CUDA(depth); CHECK_CUDA(image); @@ -495,7 +453,7 @@ void accumulate_rays_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor poin CHECK_CONTIGUOUS(sigmas); CHECK_CONTIGUOUS(rgbs); - CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(deltas); CHECK_CONTIGUOUS(rays); CHECK_CONTIGUOUS(depth); CHECK_CONTIGUOUS(image); @@ -503,7 +461,7 @@ void accumulate_rays_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor poin CHECK_IS_FLOATING(sigmas); CHECK_IS_FLOATING(rgbs); - CHECK_IS_FLOATING(points); + CHECK_IS_FLOATING(deltas); CHECK_IS_INT(rays); CHECK_IS_FLOATING(depth); CHECK_IS_FLOATING(image); @@ -511,17 +469,17 @@ void accumulate_rays_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor poin AT_DISPATCH_FLOATING_TYPES_AND_HALF( sigmas.scalar_type(), "accumulate_rays_forward", ([&] { - accumulate_rays_forward_cuda(sigmas.data_ptr(), rgbs.data_ptr(), points.data_ptr(), bg_color.data_ptr(), rays.data_ptr(), bound, M, N, depth.data_ptr(), image.data_ptr()); + accumulate_rays_forward_cuda(sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), bg_color.data_ptr(), rays.data_ptr(), bound, M, N, depth.data_ptr(), image.data_ptr()); })); } -void accumulate_rays_backward(at::Tensor grad, at::Tensor sigmas, at::Tensor rgbs, at::Tensor points, at::Tensor rays, at::Tensor image, const float bound, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { +void accumulate_rays_backward(at::Tensor grad, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor image, const float bound, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs) { CHECK_CUDA(grad); CHECK_CUDA(sigmas); CHECK_CUDA(rgbs); - CHECK_CUDA(points); + CHECK_CUDA(deltas); CHECK_CUDA(rays); CHECK_CUDA(image); CHECK_CUDA(grad_sigmas); @@ -530,7 +488,7 @@ void accumulate_rays_backward(at::Tensor grad, at::Tensor sigmas, at::Tensor rgb CHECK_CONTIGUOUS(grad); CHECK_CONTIGUOUS(sigmas); CHECK_CONTIGUOUS(rgbs); - CHECK_CONTIGUOUS(points); + CHECK_CONTIGUOUS(deltas); CHECK_CONTIGUOUS(rays); CHECK_CONTIGUOUS(image); CHECK_CONTIGUOUS(grad_sigmas); @@ -539,7 +497,7 @@ void accumulate_rays_backward(at::Tensor grad, at::Tensor sigmas, at::Tensor rgb CHECK_IS_FLOATING(grad); CHECK_IS_FLOATING(sigmas); CHECK_IS_FLOATING(rgbs); - CHECK_IS_FLOATING(points); + CHECK_IS_FLOATING(deltas); CHECK_IS_INT(rays); CHECK_IS_FLOATING(image); CHECK_IS_FLOATING(grad_sigmas); @@ -547,6 +505,6 @@ void accumulate_rays_backward(at::Tensor grad, at::Tensor sigmas, at::Tensor rgb AT_DISPATCH_FLOATING_TYPES_AND_HALF( grad.scalar_type(), "accumulate_rays_backward", ([&] { - accumulate_rays_backward_cuda(grad.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), points.data_ptr(), rays.data_ptr(), image.data_ptr(), bound, M, N, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); + accumulate_rays_backward_cuda(grad.data_ptr(), sigmas.data_ptr(), rgbs.data_ptr(), deltas.data_ptr(), rays.data_ptr(), image.data_ptr(), bound, M, N, grad_sigmas.data_ptr(), grad_rgbs.data_ptr()); })); } \ No newline at end of file diff --git a/raymarching/src/raymarching.h b/raymarching/src/raymarching.h index 8bf8bdc4..751f4c2c 100644 --- a/raymarching/src/raymarching.h +++ b/raymarching/src/raymarching.h @@ -5,8 +5,8 @@ #include // _backend.generate_points(rays_o, rays_d, density_grid, bound, N, H, M, points, offsets) -void generate_points(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float mean_density, const int iter_density, const float bound, const uint32_t N, const uint32_t H, const uint32_t M, at::Tensor points, at::Tensor rays, at::Tensor counter, const bool perturb); +void generate_points(at::Tensor rays_o, at::Tensor rays_d, at::Tensor grid, const float mean_density, const int iter_density, const float bound, const uint32_t N, const uint32_t H, const uint32_t M, at::Tensor xyzs, at::Tensor dirs, at::Tensor deltas, at::Tensor rays, at::Tensor counter, const bool perturb); // _backend.accumulate_rays_forward(sigmas, rgbs, rays, bound, M, N, depth, image) -void accumulate_rays_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor points, at::Tensor rays, const float bound, at::Tensor bg_color, const uint32_t M, const uint32_t N, at::Tensor depth, at::Tensor image); -void accumulate_rays_backward(at::Tensor grad, at::Tensor sigmas, at::Tensor rgbs, at::Tensor points, at::Tensor rays, at::Tensor image, const float bound, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs); +void accumulate_rays_forward(at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, const float bound, at::Tensor bg_color, const uint32_t M, const uint32_t N, at::Tensor depth, at::Tensor image); +void accumulate_rays_backward(at::Tensor grad, at::Tensor sigmas, at::Tensor rgbs, at::Tensor deltas, at::Tensor rays, at::Tensor image, const float bound, const uint32_t M, const uint32_t N, at::Tensor grad_sigmas, at::Tensor grad_rgbs); diff --git a/testing/test_ffmlp.py b/testing/test_ffmlp.py index 855117f0..5268411a 100644 --- a/testing/test_ffmlp.py +++ b/testing/test_ffmlp.py @@ -6,6 +6,8 @@ from ffmlp import FFMLP import math +import tinycudann as tcnn + class MLP(nn.Module): def __init__(self, input_dim, output_dim, hidden_dim, num_layers, activation=F.relu): super().__init__() @@ -95,7 +97,7 @@ def forward(self, x): # # Speed # ################################## -BATCH_SIZE = 2**21 # the least batch to lauch a full block ! +BATCH_SIZE = 2**21 INPUT_DIM = 16 OUTPUT_DIM = 16 HIDDEN_DIM = 64 @@ -103,10 +105,18 @@ def forward(self, x): net0 = FFMLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() net1 = MLP(INPUT_DIM, OUTPUT_DIM, HIDDEN_DIM, NUM_LAYERS).cuda() +net2 = tcnn.Network(n_input_dims=INPUT_DIM, n_output_dims=OUTPUT_DIM, network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": HIDDEN_DIM, + "n_hidden_layers": NUM_LAYERS, + }) x = torch.rand(BATCH_SIZE, INPUT_DIM).cuda() * 10 x1 = x.detach().clone() x2 = x.detach().clone() +x3 = x.detach().clone() starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) @@ -115,14 +125,17 @@ def forward(self, x): starter.record() y2 = net1(x2) -ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'time1 (fp32 train) = {curr_time}') +ender.record() +torch.cuda.synchronize() +curr_time = starter.elapsed_time(ender) +print(f'pytorch MLP (fp32 train) = {curr_time}') starter.record() y2.sum().backward() ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) -print(f'time1 (fp32 back) = {curr_time}') +print(f'pytorch MLP (fp32 back) = {curr_time}') #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) @@ -134,14 +147,14 @@ def forward(self, x): ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) - print(f'time0 (forward) = {curr_time}') + print(f'FFMLP (forward) = {curr_time}') starter.record() y0.sum().backward() ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) - print(f'time0 (backward) = {curr_time}') + print(f'FFMLP (backward) = {curr_time}') #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) @@ -151,16 +164,31 @@ def forward(self, x): ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) - print(f'time1 (forward) = {curr_time}') + print(f'pytorch MLP (forward) = {curr_time}') starter.record() y1.sum().backward() ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) - print(f'time1 (backward) = {curr_time}') + print(f'pytorch MLP (backward) = {curr_time}') #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + starter.record() + y3 = net2(x3) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'TCNN (forward) = {curr_time}') + + starter.record() + y3.sum().backward() + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'TCNN (backward) = {curr_time}') + #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) with torch.no_grad(): @@ -169,7 +197,7 @@ def forward(self, x): ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) - print(f'time1 (fp32 infer) = {curr_time}') + print(f'pytorch MLP (fp32 infer) = {curr_time}') with torch.cuda.amp.autocast(enabled=True): @@ -181,7 +209,7 @@ def forward(self, x): ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) - print(f'time0 (infer) = {curr_time}') + print(f'FFMLP (infer) = {curr_time}') #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) @@ -192,10 +220,22 @@ def forward(self, x): ender.record() torch.cuda.synchronize() curr_time = starter.elapsed_time(ender) - print(f'time1 (infer) = {curr_time}') + print(f'pytorch MLP (infer) = {curr_time}') + + #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) + + #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: + + starter.record() + y2 = net2(x) + ender.record() + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + print(f'TCNN (infer) = {curr_time}') #print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1)) -print(y0) -print(y1) + +# print(y0) +# print(y1) \ No newline at end of file diff --git a/train_nerf.py b/train_nerf.py index 65c188f7..cfe13b5c 100644 --- a/train_nerf.py +++ b/train_nerf.py @@ -49,7 +49,7 @@ model = NeRFNetwork( encoding="hashgrid", encoding_dir="sphere_harmonics", num_layers=2, hidden_dim=64, geo_feat_dim=15, num_layers_color=3, hidden_dim_color=64, - density_grid_size=128 if opt.cuda_raymarching else -1, + cuda_raymarching=opt.cuda_raymarching, ) #model = NeRFNetwork(encoding="frequency", encoding_dir="frequency", num_layers=4, hidden_dim=256, geo_feat_dim=256, num_layers_color=4, hidden_dim_color=128)