From dd4378cc9d1a490fbe6b81e54316291ca6eb5412 Mon Sep 17 00:00:00 2001 From: ashawkey Date: Sun, 27 Mar 2022 20:02:12 +0800 Subject: [PATCH] major update --- .gitignore | 4 +- ffmlp/ffmlp.py | 16 +- main_nerf.py | 4 +- main_tensoRF.py | 4 +- main_tensorf.py | 6 +- nerf/network.py | 43 +- nerf/network_ff.py | 22 +- nerf/network_tcnn.py | 54 +- nerf/renderer.py | 135 ++-- nerf/utils.py | 20 +- raymarching/raymarching.py | 2 +- raymarching/src/raymarching.cu | 7 +- readme.md | 43 +- scripts/run_gui_nerf.sh | 6 +- scripts/run_gui_tensoRF.sh | 7 + scripts/run_nerf.sh | 16 +- scripts/run_tensoRF.sh | 10 +- scripts/run_tensorf.sh | 3 - tensoRF/network.py | 16 +- tensoRF/utils.py | 23 +- tensorf/network.py | 1100 -------------------------------- tensorf/provider.py | 194 ------ tensorf/utils.py | 923 --------------------------- testing/test_ffmlp.py | 2 +- 24 files changed, 293 insertions(+), 2367 deletions(-) create mode 100755 scripts/run_gui_tensoRF.sh delete mode 100644 scripts/run_tensorf.sh delete mode 100644 tensorf/network.py delete mode 100644 tensorf/provider.py delete mode 100644 tensorf/utils.py diff --git a/.gitignore b/.gitignore index 0117a8a8..139d0d1b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ tmp* data/ trial*/ volsdf/ -**volsdf* \ No newline at end of file +**volsdf* +tensorf/ +**tensorf* \ No newline at end of file diff --git a/ffmlp/ffmlp.py b/ffmlp/ffmlp.py index 93968c52..3d2177cd 100644 --- a/ffmlp/ffmlp.py +++ b/ffmlp/ffmlp.py @@ -17,8 +17,6 @@ def forward(ctx, inputs, weights, input_dim, output_dim, hidden_dim, num_layers, B = inputs.shape[0] - assert B >= 128 and B % 128 == 0, f"ffmlp batch size must be 128 * m (m > 0), but got {B}." - inputs = inputs.contiguous() weights = weights.contiguous() @@ -148,12 +146,20 @@ def forward(self, inputs): # return: [B, outupt_dim] #print('inputs', inputs.shape, inputs.dtype, inputs.min().item(), inputs.max().item(), inputs.requires_grad) - + + B, C = inputs.shape + #assert B >= 128 and B % 128 == 0, f"ffmlp batch size must be 128 * m (m > 0), but got {B}." + + # pad input + pad = 128 - (B % 128) + if pad > 0: + inputs = torch.cat([inputs, torch.zeros(pad, C, dtype=inputs.dtype, device=inputs.device)], dim=0) + outputs = ffmlp_forward(inputs, self.weights, self.input_dim, self.padded_output_dim, self.hidden_dim, self.num_layers, self.activation, self.output_activation, not self.training, inputs.requires_grad) # unpad output - if self.padded_output_dim != self.output_dim: - outputs = outputs[:, :self.output_dim] + if B != outputs.shape[0] or self.padded_output_dim != self.output_dim: + outputs = outputs[:B, :self.output_dim] #print('outputs', outputs.shape, outputs.dtype, outputs.min().item(), outputs.max().item()) diff --git a/main_nerf.py b/main_nerf.py index 4a8812e5..26d98463 100644 --- a/main_nerf.py +++ b/main_nerf.py @@ -19,8 +19,8 @@ parser.add_argument('--num_rays', type=int, default=4096) parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") # (only valid when not using --cuda_ray) - parser.add_argument('--num_steps', type=int, default=128) - parser.add_argument('--upsample_steps', type=int, default=128) + parser.add_argument('--num_steps', type=int, default=512) + parser.add_argument('--upsample_steps', type=int, default=0) parser.add_argument('--max_ray_batch', type=int, default=4096) ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") diff --git a/main_tensoRF.py b/main_tensoRF.py index 2234f8b3..4a46d036 100644 --- a/main_tensoRF.py +++ b/main_tensoRF.py @@ -20,8 +20,8 @@ parser.add_argument('--cuda_ray', action='store_true', help="use CUDA raymarching instead of pytorch") parser.add_argument('--l1_reg_weight', type=float, default=4e-5) # (only valid when not using --cuda_ray) - parser.add_argument('--num_steps', type=int, default=128) - parser.add_argument('--upsample_steps', type=int, default=128) + parser.add_argument('--num_steps', type=int, default=512) + parser.add_argument('--upsample_steps', type=int, default=0) parser.add_argument('--max_ray_batch', type=int, default=4096) ### network backbone options parser.add_argument('--fp16', action='store_true', help="use amp mixed precision training") diff --git a/main_tensorf.py b/main_tensorf.py index 614f0131..c374415e 100644 --- a/main_tensorf.py +++ b/main_tensorf.py @@ -32,7 +32,7 @@ parser.add_argument('--N_voxel_init', type=int, default=128**3) parser.add_argument('--N_voxel_final', type=int, default=300**3) parser.add_argument("--upsamp_list", type=int, action="append", default=[2000,3000,4000,5500,7000]) - parser.add_argument("--update_AlphaMask_list", type=int, action="append", default=[2000,4000]) + parser.add_argument("--update_AlphaMask_list", type=int, action="append", default=[]) # [2000,4000] parser.add_argument('--lindisp', default=False, action="store_true", help='use disparity depth sampling') parser.add_argument("--perturb", type=float, default=1., help='set to 0. for no jitter, 1. for jitter') parser.add_argument("--accumulate_decay", type=float, default=0.998) @@ -71,7 +71,7 @@ aabb = (torch.tensor([[-1.0, -1.0, -1.0], [1.0, 1.0, 1.0]]) * opt.bound).to(device) reso_cur = N_to_reso(opt.N_voxel_init, aabb) - nSamples = min(opt.nSamples, cal_n_samples(reso_cur, opt.step_ratio)) + nSamples = 512 # min(opt.nSamples, cal_n_samples(reso_cur, opt.step_ratio)) near_far = [2.0, 6.0] # fixed for blender N_voxel_list = (torch.round(torch.exp(torch.linspace(np.log(opt.N_voxel_init), np.log(opt.N_voxel_final), len(opt.upsamp_list)+1))).long()).tolist()[1:] @@ -109,7 +109,7 @@ scheduler = lambda optimizer: optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100, 200], gamma=0.33) - trainer = Trainer('tensorf', vars(opt), model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, metrics=[PSNRMeter()], use_checkpoint='latest', eval_interval=50) + trainer = Trainer('tensorf', vars(opt), model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, metrics=[PSNRMeter()], use_checkpoint='scratch', eval_interval=50) # attach extra things trainer.aabb = aabb diff --git a/nerf/network.py b/nerf/network.py index f4d8d8b5..13a0fe71 100644 --- a/nerf/network.py +++ b/nerf/network.py @@ -66,8 +66,8 @@ def __init__(self, def forward(self, x, d): - # x: [B, N, 3], in [-bound, bound] - # d: [B, N, 3], nomalized in [-1, 1] + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] # sigma x = self.encoder(x, bound=self.bound) @@ -96,7 +96,7 @@ def forward(self, x, d): return sigma, color def density(self, x): - # x: [B, N, 3], in [-bound, bound] + # x: [N, 3], in [-bound, bound] x = self.encoder(x, bound=self.bound) h = x @@ -106,5 +106,40 @@ def density(self, x): h = F.relu(h, inplace=True) sigma = F.relu(h[..., 0]) + geo_feat = h[..., 1:] + + return { + 'sigma': sigma, + 'geo_feat': geo_feat, + } + + # allow masked inference + def color(self, x, d, mask=None, geo_feat=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + geo_feat = geo_feat[mask] + + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + for l in range(self.num_layers_color): + h = self.color_net[l](h) + if l != self.num_layers_color - 1: + h = F.relu(h, inplace=True) + + # sigmoid activation for rgb + h = torch.sigmoid(h) + + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 + else: + rgbs = h - return sigma \ No newline at end of file + return rgbs \ No newline at end of file diff --git a/nerf/network_ff.py b/nerf/network_ff.py index 8d251713..db8db9c8 100644 --- a/nerf/network_ff.py +++ b/nerf/network_ff.py @@ -93,24 +93,42 @@ def color(self, x, d, mask=None, geo_feat=None, **kwargs): # x: [N, 3] in [-bound, bound] # mask: [N,], bool, indicates where we actually needs to compute rgb. + #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) + #starter.record() + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs x = x[mask] d = d[mask] geo_feat = geo_feat[mask] + #print(x.shape, rgbs.shape) + + #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'mask = {curr_time}') + #starter.record() + d = self.encoder_dir(d) p = torch.zeros_like(geo_feat[..., :1]) # manual input padding h = torch.cat([d, geo_feat, p], dim=-1) + h = self.color_net(h) # sigmoid activation for rgb h = torch.sigmoid(h) + #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'call = {curr_time}') + #starter.record() + if mask is not None: - rgbs = torch.zeros(np.prod(prefix), 3, dtype=h.dtype, device=h.device) # [N, 3] - rgbs[mask] = h + rgbs[mask] = h.to(rgbs.dtype) else: rgbs = h + #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'unmask = {curr_time}') + #starter.record() + return rgbs diff --git a/nerf/network_tcnn.py b/nerf/network_tcnn.py index 981d0084..584b57a2 100644 --- a/nerf/network_tcnn.py +++ b/nerf/network_tcnn.py @@ -81,12 +81,9 @@ def __init__(self, def forward(self, x, d): - # x: [B, N, 3], in [-bound, bound] - # d: [B, N, 3], nomalized in [-1, 1] + # x: [N, 3], in [-bound, bound] + # d: [N, 3], nomalized in [-1, 1] - prefix = x.shape[:-1] - x = x.view(-1, 3) - d = d.view(-1, 3) # sigma x = (x + self.bound) / (2 * self.bound) # to [0, 1] @@ -106,17 +103,11 @@ def forward(self, x, d): # sigmoid activation for rgb color = torch.sigmoid(h) - - sigma = sigma.view(*prefix) - color = color.view(*prefix, -1) return sigma, color def density(self, x): - # x: [B, N, 3], in [-bound, bound] - - prefix = x.shape[:-1] - x = x.view(-1, 3) + # x: [N, 3], in [-bound, bound] x = (x + self.bound) / (2 * self.bound) # to [0, 1] x = self.encoder(x) @@ -124,7 +115,42 @@ def density(self, x): #sigma = torch.exp(torch.clamp(h[..., 0], -15, 15)) sigma = F.relu(h[..., 0]) + geo_feat = h[..., 1:] + + return { + 'sigma': sigma, + 'geo_feat': geo_feat, + } + + # allow masked inference + def color(self, x, d, mask=None, geo_feat=None, **kwargs): + # x: [N, 3] in [-bound, bound] + # mask: [N,], bool, indicates where we actually needs to compute rgb. + + x = (x + self.bound) / (2 * self.bound) # to [0, 1] + + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs + x = x[mask] + d = d[mask] + geo_feat = geo_feat[mask] + + # color + d = (d + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] + d = self.encoder_dir(d) + + h = torch.cat([d, geo_feat], dim=-1) + h = self.color_net(h) + + # sigmoid activation for rgb + h = torch.sigmoid(h) - sigma = sigma.view(*prefix) + if mask is not None: + rgbs[mask] = h.to(rgbs.dtype) # fp16 --> fp32 + else: + rgbs = h - return sigma \ No newline at end of file + return rgbs \ No newline at end of file diff --git a/nerf/renderer.py b/nerf/renderer.py index 92383647..41b17f19 100644 --- a/nerf/renderer.py +++ b/nerf/renderer.py @@ -97,7 +97,7 @@ def __init__(self, self.cuda_ray = cuda_ray if cuda_ray: # density grid - density_grid = torch.ones([128] * 3) + density_grid = torch.zeros([128] * 3) self.register_buffer('density_grid', density_grid) self.mean_density = 0 self.iter_density = 0 @@ -130,7 +130,11 @@ def run(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb): # bg_color: [3] in range [0, 1] # return: image: [B, N, 3], depth: [B, N] - B, N = rays_o.shape[:2] + prefix = rays_o.shape[:-1] + rays_o = rays_o.contiguous().view(-1, 3) + rays_d = rays_d.contiguous().view(-1, 3) + + N = rays_o.shape[0] # N = B * N, in fact device = rays_o.device # sample steps @@ -138,84 +142,92 @@ def run(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb): #print(f'near = {near.min().item()} ~ {near.max().item()}, far = {far.min().item()} ~ {far.max().item()}') - z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0).unsqueeze(0) # [1, 1, T] - z_vals = z_vals.expand((B, N, num_steps)) # [B, N, T] - z_vals = near + (far - near) * z_vals # [B, N, T], in [near, far] + z_vals = torch.linspace(0.0, 1.0, num_steps, device=device).unsqueeze(0) # [1, T] + z_vals = z_vals.expand((N, num_steps)) # [N, T] + z_vals = near + (far - near) * z_vals # [N, T], in [near, far] # perturb z_vals sample_dist = (far - near) / num_steps if perturb: z_vals = z_vals + (torch.rand(z_vals.shape, device=device) - 0.5) * sample_dist - #z_vals = z_vals.clamp(near, far) # avoid out of bounds pts. + #z_vals = z_vals.clamp(near, far) # avoid out of bounds xyzs. - # generate pts - pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [B, N, 1, 3] * [B, N, T, 3] -> [B, N, T, 3] - pts = pts.clamp(-self.bound, self.bound) # must be strictly inside the bounds, else lead to nan in hashgrid encoder! + # generate xyzs + xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * z_vals.unsqueeze(-1) # [N, 1, 3] * [N, T, 1] -> [N, T, 3] + xyzs = xyzs.clamp(-self.bound, self.bound) # must be strictly inside the bounds, else lead to nan in hashgrid encoder! - #print(f'pts {pts.shape} {pts.min().item()} ~ {pts.max().item()}') + # print('[xyzs]', xyzs.shape, xyzs.dtype, xyzs.min().item(), xyzs.max().item()) - #plot_pointcloud(pts.reshape(-1, 3).detach().cpu().numpy()) + #plot_pointcloud(xyzs.reshape(-1, 3).detach().cpu().numpy()) # query SDF and RGB - dirs = rays_d.unsqueeze(-2).expand_as(pts) - - sigmas, rgbs = self(pts.reshape(-1, 3), dirs.reshape(-1, 3)) + density_outputs = self.density(xyzs.reshape(-1, 3)) - rgbs = rgbs.reshape(B, N, num_steps, 3) # [B, N, T, 3] - sigmas = sigmas.reshape(B, N, num_steps) # [B, N, T] + #sigmas = density_outputs['sigma'].view(N, num_steps) # [N, T] + for k, v in density_outputs.items(): + density_outputs[k] = v.view(N, num_steps, -1) # upsample z_vals (nerf-like) if upsample_steps > 0: with torch.no_grad(): - deltas = z_vals[:, :, 1:] - z_vals[:, :, :-1] # [B, N, T-1] - deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[:, :, :1])], dim=-1) + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) - alphas = 1 - torch.exp(-deltas * self.density_scale * sigmas) # [B, N, T] - alphas_shifted = torch.cat([torch.ones_like(alphas[:, :, :1]), 1 - alphas + 1e-15], dim=-1) # [B, N, T+1] - weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[:, :, :-1] # [B, N, T] + alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T] # sample new z_vals - z_vals_mid = (z_vals[:, :, :-1] + 0.5 * deltas[:, :, :-1]) # [B, N, T-1] - new_z_vals = sample_pdf(z_vals_mid.reshape(B*N, -1), weights.reshape(B*N, -1)[:, 1:-1], upsample_steps, det=not self.training).detach() # [BN, t] - new_z_vals = new_z_vals.reshape(B, N, upsample_steps) + z_vals_mid = (z_vals[..., :-1] + 0.5 * deltas[..., :-1]) # [N, T-1] + new_z_vals = sample_pdf(z_vals_mid, weights[:, 1:-1], upsample_steps, det=not self.training).detach() # [N, t] - new_pts = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [B, N, 1, 3] * [B, N, t, 3] -> [B, N, t, 3] - new_pts = new_pts.clamp(-self.bound, self.bound) + new_xyzs = rays_o.unsqueeze(-2) + rays_d.unsqueeze(-2) * new_z_vals.unsqueeze(-1) # [N, 1, 3] * [N, t, 1] -> [N, t, 3] + new_xyzs = new_xyzs.clamp(-self.bound, self.bound) # only forward new points to save computation - new_dirs = rays_d.unsqueeze(-2).expand_as(new_pts) - new_sigmas, new_rgbs = self(new_pts.reshape(-1, 3), new_dirs.reshape(-1, 3)) - new_rgbs = new_rgbs.reshape(B, N, upsample_steps, 3) # [B, N, t, 3] - new_sigmas = new_sigmas.reshape(B, N, upsample_steps) # [B, N, t] + new_density_outputs = self.density(new_xyzs.reshape(-1, 3)) + #new_sigmas = new_density_outputs['sigma'].view(N, upsample_steps) # [N, t] + for k, v in new_density_outputs.items(): + new_density_outputs[k] = v.view(N, upsample_steps, -1) # re-order - z_vals = torch.cat([z_vals, new_z_vals], dim=-1) # [B, N, T+t] - z_vals, z_index = torch.sort(z_vals, dim=-1) + z_vals = torch.cat([z_vals, new_z_vals], dim=1) # [N, T+t] + z_vals, z_index = torch.sort(z_vals, dim=1) + + xyzs = torch.cat([xyzs, new_xyzs], dim=1) # [N, T+t, 3] + xyzs = torch.gather(xyzs, dim=1, index=z_index.unsqueeze(-1).expand_as(xyzs)) + + for k in density_outputs: + tmp_output = torch.cat([density_outputs[k], new_density_outputs[k]], dim=1) + density_outputs[k] = torch.gather(tmp_output, dim=1, index=z_index.unsqueeze(-1).expand_as(tmp_output)) + + deltas = z_vals[..., 1:] - z_vals[..., :-1] # [N, T+t-1] + deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[..., :1])], dim=-1) + alphas = 1 - torch.exp(-deltas * self.density_scale * density_outputs['sigma'].squeeze(-1)) # [N, T+t] + alphas_shifted = torch.cat([torch.ones_like(alphas[..., :1]), 1 - alphas + 1e-15], dim=-1) # [N, T+t+1] + weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[..., :-1] # [N, T+t] - sigmas = torch.cat([sigmas, new_sigmas], dim=-1) # [B, N, T+t] - sigmas = torch.gather(sigmas, dim=-1, index=z_index) + mask = weights > 1e-4 # hard coded - rgbs = torch.cat([rgbs, new_rgbs], dim=-2) # [B, N, T+t, 3] - rgbs = torch.gather(rgbs, dim=-2, index=z_index.unsqueeze(-1).expand_as(rgbs)) + dirs = rays_d.view(-1, 1, 3).expand_as(xyzs) + for k, v in density_outputs.items(): + density_outputs[k] = v.view(-1, v.shape[-1]) - ### render core - deltas = z_vals[:, :, 1:] - z_vals[:, :, :-1] # [B, N, T-1] - deltas = torch.cat([deltas, sample_dist * torch.ones_like(deltas[:, :, :1])], dim=-1) + rgbs = self.color(xyzs.reshape(-1, 3), dirs.reshape(-1, 3), mask=mask.reshape(-1), **density_outputs) + rgbs = rgbs.view(N, -1, 3) # [N, T+t, 3] - alphas = 1 - torch.exp(-deltas * self.density_scale * sigmas) # [B, N, T] - alphas_shifted = torch.cat([torch.ones_like(alphas[:, :, :1]), 1 - alphas + 1e-15], dim=-1) # [B, N, T+1] - weights = alphas * torch.cumprod(alphas_shifted, dim=-1)[:, :, :-1] # [B, N, T] + #print(xyzs.shape, 'valid_rgb:', mask.sum().item()) # calculate weight_sum (mask) - weights_sum = weights.sum(dim=-1) # [B, N] + weights_sum = weights.sum(dim=-1) # [N] # calculate depth ori_z_vals = ((z_vals - near) / (far - near)).clamp(0, 1) depth = torch.sum(weights * ori_z_vals, dim=-1) # calculate color - image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [B, N, 3], in [0, 1] + image = torch.sum(weights.unsqueeze(-1) * rgbs, dim=-2) # [N, 3], in [0, 1] # mix background color if bg_color is None: @@ -223,6 +235,9 @@ def run(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb): image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + image = image.view(*prefix, 3) + depth = depth.view(*prefix) + return depth, image @@ -246,16 +261,19 @@ def run_cuda(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb) counter.zero_() # set to 0 self.local_step += 1 - xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, perturb, 128, False) - - deltas = self.density_scale * deltas + xyzs, dirs, deltas, rays = raymarching.march_rays_train(rays_o, rays_d, self.bound, self.density_grid, self.mean_density, self.iter_density, counter, self.mean_count, perturb, 128, True) density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. sigmas = density_outputs['sigma'] + sigmas = self.density_scale * sigmas weights = raymarching.composite_weights_train(sigmas, deltas, rays) # [M,] - mask = weights > 1e-4 # hard coded + + # masked rgb cannot accelerate cuda_ray training, disabled! (mask ratio is only ~50%, cannot beat the mask/unmask overhead.) + mask = None # weights > 1e-4 rgbs = self.color(xyzs, dirs, mask=mask, **density_outputs) + #print(f'valid RGB query ratio: {mask.sum().item() / mask.shape[0]} (total = {mask.sum().item()})') + weights_sum, image = raymarching.composite_rays_train(weights, rgbs, rays, self.bound) depth = None # currently training do not requires depth @@ -285,7 +303,7 @@ def run_cuda(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb) step = 0 i = 0 - while step < 1024: # max step + while step < 1024: # hard coded max step # count alive rays if step == 0: @@ -306,11 +324,11 @@ def run_cuda(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb) xyzs, dirs, deltas = raymarching.march_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], rays_o, rays_d, self.bound, self.density_grid, self.mean_density, near, far, 128, perturb) - deltas = self.density_scale * deltas - #sigmas, rgbs = self(xyzs, dirs) density_outputs = self.density(xyzs) # [M,], use a dict since it may include extra things, like geo_feat for rgb. sigmas = density_outputs['sigma'] + sigmas = self.density_scale * sigmas + # no need for weights mask, since we already terminated those rays. rgbs = self.color(xyzs, dirs, **density_outputs) raymarching.composite_rays(n_alive, n_step, rays_alive[i % 2], rays_t[i % 2], sigmas, rgbs, deltas, weights_sum, depth, image) @@ -320,12 +338,12 @@ def run_cuda(self, rays_o, rays_d, num_steps, upsample_steps, bg_color, perturb) step += n_step i += 1 - image = image.view(*prefix, 3) image = image + (1 - weights_sum).unsqueeze(-1) * bg_color + image = image.view(*prefix, 3) if depth is not None: - depth = depth.view(*prefix) depth = torch.clamp(depth - near, min=0) / (far - near) + depth = depth.view(*prefix) return depth, image @@ -353,16 +371,11 @@ def update_extra_state(self, decay=0.95): lx, ly, lz = len(xs), len(ys), len(zs) # construct points xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij') - pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3] + xyzs = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1) # [N, 3] # add noise in [-hgs, hgs] - pts += (torch.rand_like(pts) * 2 - 1) * half_grid_size - # manual padding for ffmlp - n = pts.shape[0] - pad_n = 128 - (n % 128) - if pad_n != 0: - pts = torch.cat([pts, torch.zeros(pad_n, 3)], dim=0) + xyzs += (torch.rand_like(xyzs) * 2 - 1) * half_grid_size # query density - sigmas = self.density(pts.to(tmp_grid.device))[:n].reshape(lx, ly, lz)['sigma'].detach() + sigmas = self.density(xyzs.to(tmp_grid.device))['sigma'].reshape(lx, ly, lz).detach() # change density to alpha in [0, 1] alphas = 1 - torch.exp(-self.density_scale * sigmas) # [B, N, T], fake deltas to 1 (it doesn't really matter) tmp_grid[xi * 128: xi * 128 + lx, yi * 128: yi * 128 + ly, zi * 128: zi * 128 + lz] = alphas diff --git a/nerf/utils.py b/nerf/utils.py index 452e11d8..a0fd03da 100644 --- a/nerf/utils.py +++ b/nerf/utils.py @@ -459,10 +459,6 @@ def train_gui(self, train_loader, step=16): self.model.train() - # update grid - if self.model.cuda_ray: - with torch.cuda.amp.autocast(enabled=self.fp16): - self.model.update_extra_state() total_loss = torch.tensor([0], dtype=torch.float32, device=self.device) @@ -476,6 +472,11 @@ def train_gui(self, train_loader, step=16): except StopIteration: loader = iter(train_loader) data = next(loader) + + # update grid + if self.model.cuda_ray and self.global_step % 100 == 0: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() self.global_step += 1 @@ -578,9 +579,9 @@ def train_one_epoch(self, loader): self.model.train() # update grid - # if self.model.cuda_ray: - # with torch.cuda.amp.autocast(enabled=self.fp16): - # self.model.update_extra_state() + if self.model.cuda_ray: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs # ref: https://pytorch.org/docs/stable/data.html @@ -593,11 +594,6 @@ def train_one_epoch(self, loader): self.local_step = 0 for data in loader: - - # # update grid - # if self.global_step % 16 == 0 and self.model.cuda_ray: - # with torch.cuda.amp.autocast(enabled=self.fp16): - # self.model.update_extra_state() self.local_step += 1 self.global_step += 1 diff --git a/raymarching/raymarching.py b/raymarching/raymarching.py index d7dbbbd6..0e0396ef 100644 --- a/raymarching/raymarching.py +++ b/raymarching/raymarching.py @@ -48,7 +48,7 @@ def forward(ctx, rays_o, rays_d, bound, density_grid, mean_density, iter_density _backend.march_rays_train(rays_o, rays_d, density_grid, mean_density, iter_density, bound, N, H, M, xyzs, dirs, deltas, rays, step_counter, perturb) # m is the actually used points number - ##print(step_counter, M) + #print(step_counter, M) # only used at the first (few) epochs. if force_all_rays or mean_count <= 0: diff --git a/raymarching/src/raymarching.cu b/raymarching/src/raymarching.cu index 8916c181..bbad1b15 100644 --- a/raymarching/src/raymarching.cu +++ b/raymarching/src/raymarching.cu @@ -23,7 +23,7 @@ inline constexpr __device__ float SQRT3() { return 1.73205080757f; } inline constexpr __device__ int MAX_STEPS() { return 1024; } inline constexpr __device__ float MIN_STEPSIZE() { return 2 * SQRT3() / MAX_STEPS(); } // still need to mul bound to get dt_min inline constexpr __device__ float MIN_NEAR() { return 0.05f; } -inline constexpr __device__ float DT_GAMMA() { return 1.0f / 256.0f; } +inline constexpr __device__ float DT_GAMMA() { return 0.0f / 256.0f; } // util functions template @@ -98,6 +98,7 @@ __global__ void kernel_march_rays_train( const float far = fminf(far_x, fminf(far_y, far_z)); const float dt_min = MIN_STEPSIZE() * bound; + //const float dt_min = (far - near) / MAX_STEPS(); const float dt_max = 2 * bound / (H - 1); const float dt_gamma = bound > 1 ? DT_GAMMA() : 0.0f; @@ -149,9 +150,7 @@ __global__ void kernel_march_rays_train( } } - //printf("[n=%d] num_steps=%d\n", n, num_steps); - //printf("[n=%d] num_steps=%d, pc=%d, rc=%d\n", n, num_steps, counter[0], counter[1]); - + //printf("[n=%d] num_steps=%d, near=%f, far=%f, dt=%f, max_steps=%f\n", n, num_steps, near, far, dt_min, (far - near) / dt_min); // second pass: really locate and write points & dirs uint32_t point_index = atomicAdd(counter, num_steps); diff --git a/readme.md b/readme.md index 59844d49..18925482 100644 --- a/readme.md +++ b/readme.md @@ -44,7 +44,6 @@ Later development will be focused on reproducing the NeRF inference speed. # Install ```bash - git clone --recursive https://github.com/ashawkey/torch-ngp.git cd torch-ngp @@ -67,6 +66,7 @@ Please download and put them under `./data`. First time running will take some time to compile the CUDA extensions. ```bash +### HashNeRF # train with different backbones (with slower pytorch ray marching) # for the colmap dataset, the default dataset setting `--mode colmap --bound 2 --scale 0.33` is used. python main_nerf.py data/fox --workspace trial_nerf # fp32 mode @@ -93,10 +93,37 @@ python main_nerf.py data/fox --workspace trial_nerf --fp16 --ff --cuda_ray --gui # --scale adjusts the camera locaction to make sure it falls inside the above bounding box. python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf --fp16 --ff --cuda_ray --mode blender --bound 1.5 --scale 1.0 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf --fp16 --ff --cuda_ray --mode blender --bound 1.5 --scale 1.0 --gui + +### SDF +python main_sdf.py data/armadillo.obj --workspace trial_sdf +python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 +python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 --ff +python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 --tcnn + +python main_sdf.py data/armadillo.obj --workspace trial_sdf --fp16 --ff --test + +### TensoRF +# almost the same as HashNeRF, just replace the main script. +python main_tensoRF.py data/fox --workspace trial_tensoRF --fp16 --ff --cuda_ray +python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF --fp16 --ff --cuda_ray --mode blender --bound 1.5 --scale 1.0 + ``` check the `scripts` directory for more provided examples. +# Performance Reference +Tested with the default settings on the Lego test dataset. Here the speed refers to the `iterations per second` on a TITAN RTX. +| Model | PSNR | Train Speed | Test Speed | +| - | - | - | - | +| HashNeRF (`fp16`) | 32.22 | 24 | 0.56 | +| HashNeRF (`fp16 + ff`) | 32.81 | 24 | 0.79 | +| HashNeRF (`fp16 + tcnn`) | 32.72 | 20 | 0.37 | +| HashNeRF (`fp16 + cuda_ray`) | 32.54 | 65 | 6.4 | +| HashNeRF (`fp16 + cuda_ray + ff`) | 33.24 | 72 | 6.9 | +| HashNeRF (`fp16 + cuda_ray + tcnn`) | 33.11 | 60 | 5.8 | +| TensoRF (`fp16`) | 33.79 | 18 | 0.53 | +| TensoRF (`fp16 + cuda_ray`) | 34.05 | 13 | 0.43 | + # Difference from the original implementation * Instead of assuming the scene is bounded in the unit box `[0, 1]` and centered at `(0.5, 0.5, 0.5)`, this repo assumes **the scene is bounded in box `[-bound, bound]`, and centered at `(0, 0, 0)`**. Therefore, the functionality of `aabb_scale` is replaced by `bound` here. * For the hashgrid encoder, this repo only implement the linear interpolation mode. @@ -104,6 +131,7 @@ check the `scripts` directory for more provided examples. * For the blender dataest, the default mode in instant-ngp is to load all data (train/val/test) for training. Instead, we only use the specified split to train in CMD mode for easy evaluation. However, for GUI mode, we follow instant-ngp and use all data to train (check `type='all'` for `NeRFDataset`). # Update Logs +* 3.27: major update. basically improve performance, and support tensoRF model. * 3.22: reverted from pre-generating rays as it takes too much CPU memory, still the PSNR for Lego can reach ~33 now. * 3.14: fixed the precision related issue for `fp16` mode, and it renders much better quality. Added PSNR metric for NeRF. * known issue: PSNR is worse, for Lego test dataset is only ~30. @@ -157,4 +185,17 @@ check the `scripts` directory for more provided examples. year = {2020}, } ``` + +* The official TensoRF [implementation](https://github.com/apchenstu/TensoRF): + ``` + @misc{TensoRF, + title={TensoRF: Tensorial Radiance Fields}, + author={Anpei Chen and Zexiang Xu and Andreas Geiger and and Jingyi Yu and Hao Su}, + year={2022}, + eprint={2203.09517}, + archivePrefix={arXiv}, + primaryClass={cs.CV} + } + ``` + * The NeRF GUI is developed with [DearPyGui](https://github.com/hoffstadt/DearPyGui). diff --git a/scripts/run_gui_nerf.sh b/scripts/run_gui_nerf.sh index 9783acd2..541a4eee 100755 --- a/scripts/run_gui_nerf.sh +++ b/scripts/run_gui_nerf.sh @@ -1,7 +1,7 @@ #! /bin/bash -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_gui --cuda_ray --gui -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_gui_lego --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf --cuda_ray --gui +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff2 --fp16 --ff --cuda_ray --gui -OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff2 --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui \ No newline at end of file +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff2_gui --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui \ No newline at end of file diff --git a/scripts/run_gui_tensoRF.sh b/scripts/run_gui_tensoRF.sh new file mode 100755 index 00000000..70cf1a64 --- /dev/null +++ b/scripts/run_gui_tensoRF.sh @@ -0,0 +1,7 @@ +#! /bin/bash + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/fox --workspace trial_tensoRF --cuda_ray --gui +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui + +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/fox --workspace trial_tensoRF_ff2 --fp16 --ff --cuda_ray --gui +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego_ff2 --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender --gui \ No newline at end of file diff --git a/scripts/run_nerf.sh b/scripts/run_nerf.sh index 5b602774..772d48a5 100755 --- a/scripts/run_nerf.sh +++ b/scripts/run_nerf.sh @@ -1,17 +1,17 @@ #! /bin/bash -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf --fp16 -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff --fp16 --ff -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_tcnn --fp16 --tcnn +# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf --fp16 +# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff --fp16 --ff +# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_tcnn --fp16 --tcnn -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf2 --fp16 --cuda_ray -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff2 --fp16 --ff --cuda_ray -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_tcnn2 --fp16 --tcnn --cuda_ray +# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf2 --fp16 --cuda_ray +# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_ff2 --fp16 --ff --cuda_ray +# OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/fox --workspace trial_nerf_tcnn2 --fp16 --tcnn --cuda_ray #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego --fp16 --bound 1.5 --scale 1.0 --mode blender #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff --fp16 --ff --bound 1.5 --scale 1.0 --mode blender #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_tcnn --fp16 --tcnn --bound 1.5 --scale 1.0 --mode blender -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego2 --fp16 --cuda_ray --bound 1.5 --scale 1.0 --mode blender -OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff2 --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego2 --fp16 --cuda_ray --bound 1.5 --scale 1.0 --mode blender --test +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_ff2 --fp16 --ff --cuda_ray --bound 1.5 --scale 1.0 --mode blender #OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_nerf.py data/nerf_synthetic/lego --workspace trial_nerf_lego_tcnn2 --fp16 --tcnn --cuda_ray --bound 1.5 --scale 1.0 --mode blender \ No newline at end of file diff --git a/scripts/run_tensoRF.sh b/scripts/run_tensoRF.sh index 180d6ff3..e8f0e4c0 100755 --- a/scripts/run_tensoRF.sh +++ b/scripts/run_tensoRF.sh @@ -1,9 +1,7 @@ #! /bin/bash -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/fox --workspace trial_tensoRF --fp16 +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_tensoRF.py data/fox --workspace trial_tensoRF --fp16 +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_tensoRF.py data/fox --workspace trial_tensoRF2 --fp16 --cuda_ray -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/fox --workspace trial_tensoRF2 --fp16 --cuda_ray - -#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego --fp16 --bound 1.5 --scale 1.0 --mode blender - -OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=1 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego2 --fp16 --cuda_ray --bound 1.5 --scale 1.0 --mode blender \ No newline at end of file +#OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego --fp16 --bound 1.5 --scale 1.0 --mode blender +OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=2 python main_tensoRF.py data/nerf_synthetic/lego --workspace trial_tensoRF_lego2 --fp16 --cuda_ray --bound 1.5 --scale 1.0 --mode blender \ No newline at end of file diff --git a/scripts/run_tensorf.sh b/scripts/run_tensorf.sh deleted file mode 100644 index 4a3b3726..00000000 --- a/scripts/run_tensorf.sh +++ /dev/null @@ -1,3 +0,0 @@ -#! /bin/bash - -OMP_NUM_THREADS=8 CUDA_VISIBLE_DEVICES=3 python main_tensorf.py data/nerf_synthetic/lego --workspace trial_tensorf \ No newline at end of file diff --git a/tensoRF/network.py b/tensoRF/network.py index 8437b49b..e445527f 100644 --- a/tensoRF/network.py +++ b/tensoRF/network.py @@ -164,15 +164,22 @@ def color(self, x, d, mask=None, **kwargs): # x: [N, 3] in [-bound, bound] # mask: [N,], bool, indicates where we actually needs to compute rgb. + # normalize to [-1, 1] + x = x / self.bound + if mask is not None: + rgbs = torch.zeros(mask.shape[0], 3, dtype=x.dtype, device=x.device) # [N, 3] + # in case of empty mask + if not mask.any(): + return rgbs x = x[mask] d = d[mask] color_feat = self.get_color_feat(x) - enc_color_feat = self.encoder(color_feat) - enc_d = self.encoder_dir(d) + color_feat = self.encoder(color_feat) + d = self.encoder_dir(d) - h = torch.cat([enc_color_feat, enc_d], dim=-1) + h = torch.cat([color_feat, d], dim=-1) for l in range(self.num_layers): h = self.color_net[l](h) if l != self.num_layers - 1: @@ -182,8 +189,7 @@ def color(self, x, d, mask=None, **kwargs): h = torch.sigmoid(h) if mask is not None: - rgbs = torch.zeros(mask.shape[0], 3, dtype=h.dtype, device=h.device) # [N, 3] - rgbs[mask] = h + rgbs[mask] = h.to(rgbs.dtype) else: rgbs = h diff --git a/tensoRF/utils.py b/tensoRF/utils.py index b68a7e7d..f89adc89 100644 --- a/tensoRF/utils.py +++ b/tensoRF/utils.py @@ -214,12 +214,14 @@ def __init__(self, self.criterion = criterion if optimizer is None: + self.optimizer_fn = None self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam else: self.optimizer_fn = optimizer self.optimizer = optimizer(self.model) if lr_scheduler is None: + self.lr_scheduler_fn = None self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler else: self.lr_scheduler_fn = lr_scheduler @@ -305,9 +307,9 @@ def train_step(self, data): images = torch.gather(images.reshape(B, -1, C), 1, torch.stack(C*[inds], -1)) # [B, N, 3/4] # train with random background color if using alpha mixing - bg_color = torch.ones(3, device=self.device) # [3], fixed white background + #bg_color = torch.ones(3, device=self.device) # [3], fixed white background #bg_color = torch.rand(3, device=self.device) # [3], frame-wise random. - #bg_color = torch.rand_like(images[..., :3]) # [N, 3], pixel-wise random. + bg_color = torch.rand_like(images[..., :3]) # [N, 3], pixel-wise random. if C == 4: gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) @@ -583,9 +585,9 @@ def train_one_epoch(self, loader): self.model.train() # update grid - # if self.model.cuda_ray: - # with torch.cuda.amp.autocast(enabled=self.fp16): - # self.model.update_extra_state() + if self.model.cuda_ray: + with torch.cuda.amp.autocast(enabled=self.fp16): + self.model.update_extra_state() # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs # ref: https://pytorch.org/docs/stable/data.html @@ -598,11 +600,6 @@ def train_one_epoch(self, loader): self.local_step = 0 for data in loader: - - # # update grid - # if self.global_step % 16 == 0 and self.model.cuda_ray: - # with torch.cuda.amp.autocast(enabled=self.fp16): - # self.model.update_extra_state() self.local_step += 1 self.global_step += 1 @@ -847,8 +844,10 @@ def load_checkpoint(self, checkpoint=None): # return self.model.upsample_model(checkpoint_dict['resolution']) - self.optimizer = self.optimizer_fn(self.model) - self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) + if self.optimizer_fn is not None: + self.optimizer = self.optimizer_fn(self.model) + if self.lr_scheduler_fn is not None: + self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) missing_keys, unexpected_keys = self.model.load_state_dict(checkpoint_dict['model'], strict=False) self.log("[INFO] loaded model.") diff --git a/tensorf/network.py b/tensorf/network.py deleted file mode 100644 index 151c5add..00000000 --- a/tensorf/network.py +++ /dev/null @@ -1,1100 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import time - - -################## sh function ################## -C0 = 0.28209479177387814 -C1 = 0.4886025119029199 -C2 = [ - 1.0925484305920792, - -1.0925484305920792, - 0.31539156525252005, - -1.0925484305920792, - 0.5462742152960396 -] -C3 = [ - -0.5900435899266435, - 2.890611442640554, - -0.4570457994644658, - 0.3731763325901154, - -0.4570457994644658, - 1.445305721320277, - -0.5900435899266435 -] -C4 = [ - 2.5033429417967046, - -1.7701307697799304, - 0.9461746957575601, - -0.6690465435572892, - 0.10578554691520431, - -0.6690465435572892, - 0.47308734787878004, - -1.7701307697799304, - 0.6258357354491761, -] - -def eval_sh(deg, sh, dirs): - """ - Evaluate spherical harmonics at unit directions - using hardcoded SH polynomials. - Works with torch/np/jnp. - ... Can be 0 or more batch dimensions. - :param deg: int SH max degree. Currently, 0-4 supported - :param sh: torch.Tensor SH coeffs (..., C, (max degree + 1) ** 2) - :param dirs: torch.Tensor unit directions (..., 3) - :return: (..., C) - """ - assert deg <= 4 and deg >= 0 - assert (deg + 1) ** 2 == sh.shape[-1] - C = sh.shape[-2] - - result = C0 * sh[..., 0] - if deg > 0: - x, y, z = dirs[..., 0:1], dirs[..., 1:2], dirs[..., 2:3] - result = (result - - C1 * y * sh[..., 1] + - C1 * z * sh[..., 2] - - C1 * x * sh[..., 3]) - if deg > 1: - xx, yy, zz = x * x, y * y, z * z - xy, yz, xz = x * y, y * z, x * z - result = (result + - C2[0] * xy * sh[..., 4] + - C2[1] * yz * sh[..., 5] + - C2[2] * (2.0 * zz - xx - yy) * sh[..., 6] + - C2[3] * xz * sh[..., 7] + - C2[4] * (xx - yy) * sh[..., 8]) - - if deg > 2: - result = (result + - C3[0] * y * (3 * xx - yy) * sh[..., 9] + - C3[1] * xy * z * sh[..., 10] + - C3[2] * y * (4 * zz - xx - yy)* sh[..., 11] + - C3[3] * z * (2 * zz - 3 * xx - 3 * yy) * sh[..., 12] + - C3[4] * x * (4 * zz - xx - yy) * sh[..., 13] + - C3[5] * z * (xx - yy) * sh[..., 14] + - C3[6] * x * (xx - 3 * yy) * sh[..., 15]) - if deg > 3: - result = (result + C4[0] * xy * (xx - yy) * sh[..., 16] + - C4[1] * yz * (3 * xx - yy) * sh[..., 17] + - C4[2] * xy * (7 * zz - 1) * sh[..., 18] + - C4[3] * yz * (7 * zz - 3) * sh[..., 19] + - C4[4] * (zz * (35 * zz - 30) + 3) * sh[..., 20] + - C4[5] * xz * (7 * zz - 3) * sh[..., 21] + - C4[6] * (xx - yy) * (7 * zz - 1) * sh[..., 22] + - C4[7] * xz * (xx - 3 * yy) * sh[..., 23] + - C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)) * sh[..., 24]) - return result - -def eval_sh_bases(deg, dirs): - """ - Evaluate spherical harmonics bases at unit directions, - without taking linear combination. - At each point, the final result may the be - obtained through simple multiplication. - :param deg: int SH max degree. Currently, 0-4 supported - :param dirs: torch.Tensor (..., 3) unit directions - :return: torch.Tensor (..., (deg+1) ** 2) - """ - assert deg <= 4 and deg >= 0 - result = torch.empty((*dirs.shape[:-1], (deg + 1) ** 2), dtype=dirs.dtype, device=dirs.device) - result[..., 0] = C0 - if deg > 0: - x, y, z = dirs.unbind(-1) - result[..., 1] = -C1 * y; - result[..., 2] = C1 * z; - result[..., 3] = -C1 * x; - if deg > 1: - xx, yy, zz = x * x, y * y, z * z - xy, yz, xz = x * y, y * z, x * z - result[..., 4] = C2[0] * xy; - result[..., 5] = C2[1] * yz; - result[..., 6] = C2[2] * (2.0 * zz - xx - yy); - result[..., 7] = C2[3] * xz; - result[..., 8] = C2[4] * (xx - yy); - - if deg > 2: - result[..., 9] = C3[0] * y * (3 * xx - yy); - result[..., 10] = C3[1] * xy * z; - result[..., 11] = C3[2] * y * (4 * zz - xx - yy); - result[..., 12] = C3[3] * z * (2 * zz - 3 * xx - 3 * yy); - result[..., 13] = C3[4] * x * (4 * zz - xx - yy); - result[..., 14] = C3[5] * z * (xx - yy); - result[..., 15] = C3[6] * x * (xx - 3 * yy); - - if deg > 3: - result[..., 16] = C4[0] * xy * (xx - yy); - result[..., 17] = C4[1] * yz * (3 * xx - yy); - result[..., 18] = C4[2] * xy * (7 * zz - 1); - result[..., 19] = C4[3] * yz * (7 * zz - 3); - result[..., 20] = C4[4] * (zz * (35 * zz - 30) + 3); - result[..., 21] = C4[5] * xz * (7 * zz - 3); - result[..., 22] = C4[6] * (xx - yy) * (7 * zz - 1); - result[..., 23] = C4[7] * xz * (xx - 3 * yy); - result[..., 24] = C4[8] * (xx * (xx - 3 * yy) - yy * (3 * xx - yy)); - return result - - - -def positional_encoding(positions, freqs): - - freq_bands = (2**torch.arange(freqs).float()).to(positions.device) # (F,) - pts = (positions[..., None] * freq_bands).reshape( - positions.shape[:-1] + (freqs * positions.shape[-1], )) # (..., DF) - pts = torch.cat([torch.sin(pts), torch.cos(pts)], dim=-1) - return pts - -def raw2alpha(sigma, dist): - # sigma, dist [N_rays, N_samples] - alpha = 1. - torch.exp(-sigma*dist) - - T = torch.cumprod(torch.cat([torch.ones(alpha.shape[0], 1).to(alpha.device), 1. - alpha + 1e-10], -1), -1) - - weights = alpha * T[:, :-1] # [N_rays, N_samples] - return alpha, weights, T[:,-1:] - - -def SHRender(xyz_sampled, viewdirs, features): - sh_mult = eval_sh_bases(2, viewdirs)[:, None] - rgb_sh = features.view(-1, 3, sh_mult.shape[-1]) - rgb = torch.relu(torch.sum(sh_mult * rgb_sh, dim=-1) + 0.5) - return rgb - - -def RGBRender(xyz_sampled, viewdirs, features): - - rgb = features - return rgb - -class AlphaGridMask(torch.nn.Module): - def __init__(self, device, aabb, alpha_volume): - super(AlphaGridMask, self).__init__() - self.device = device - - self.aabb=aabb.to(self.device) - self.aabbSize = self.aabb[1] - self.aabb[0] - self.invgridSize = 1.0/self.aabbSize * 2 - self.alpha_volume = alpha_volume.view(1,1,*alpha_volume.shape[-3:]) - self.gridSize = torch.LongTensor([alpha_volume.shape[-1],alpha_volume.shape[-2],alpha_volume.shape[-3]]).to(self.device) - - def sample_alpha(self, xyz_sampled): - xyz_sampled = self.normalize_coord(xyz_sampled) - alpha_vals = F.grid_sample(self.alpha_volume, xyz_sampled.view(1,-1,1,1,3), align_corners=True).view(-1) - - return alpha_vals - - def normalize_coord(self, xyz_sampled): - return (xyz_sampled-self.aabb[0]) * self.invgridSize - 1 - - -class MLPRender_Fea(torch.nn.Module): - def __init__(self,inChanel, viewpe=6, feape=6, featureC=128): - super(MLPRender_Fea, self).__init__() - - self.in_mlpC = 2*viewpe*3 + 2*feape*inChanel + 3 + inChanel - self.viewpe = viewpe - self.feape = feape - layer1 = torch.nn.Linear(self.in_mlpC, featureC) - layer2 = torch.nn.Linear(featureC, featureC) - layer3 = torch.nn.Linear(featureC,3) - - self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) - torch.nn.init.constant_(self.mlp[-1].bias, 0) - - def forward(self, pts, viewdirs, features): - indata = [features, viewdirs] - if self.feape > 0: - indata += [positional_encoding(features, self.feape)] - if self.viewpe > 0: - indata += [positional_encoding(viewdirs, self.viewpe)] - mlp_in = torch.cat(indata, dim=-1) - rgb = self.mlp(mlp_in) - rgb = torch.sigmoid(rgb) - - return rgb - -class MLPRender_PE(torch.nn.Module): - def __init__(self,inChanel, viewpe=6, pospe=6, featureC=128): - super(MLPRender_PE, self).__init__() - - self.in_mlpC = (3+2*viewpe*3)+ (3+2*pospe*3) + inChanel # - self.viewpe = viewpe - self.pospe = pospe - layer1 = torch.nn.Linear(self.in_mlpC, featureC) - layer2 = torch.nn.Linear(featureC, featureC) - layer3 = torch.nn.Linear(featureC,3) - - self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) - torch.nn.init.constant_(self.mlp[-1].bias, 0) - - def forward(self, pts, viewdirs, features): - indata = [features, viewdirs] - if self.pospe > 0: - indata += [positional_encoding(pts, self.pospe)] - if self.viewpe > 0: - indata += [positional_encoding(viewdirs, self.viewpe)] - mlp_in = torch.cat(indata, dim=-1) - rgb = self.mlp(mlp_in) - rgb = torch.sigmoid(rgb) - - return rgb - -class MLPRender(torch.nn.Module): - def __init__(self,inChanel, viewpe=6, featureC=128): - super(MLPRender, self).__init__() - - self.in_mlpC = (3+2*viewpe*3) + inChanel - self.viewpe = viewpe - - layer1 = torch.nn.Linear(self.in_mlpC, featureC) - layer2 = torch.nn.Linear(featureC, featureC) - layer3 = torch.nn.Linear(featureC,3) - - self.mlp = torch.nn.Sequential(layer1, torch.nn.ReLU(inplace=True), layer2, torch.nn.ReLU(inplace=True), layer3) - torch.nn.init.constant_(self.mlp[-1].bias, 0) - - def forward(self, pts, viewdirs, features): - indata = [features, viewdirs] - if self.viewpe > 0: - indata += [positional_encoding(viewdirs, self.viewpe)] - mlp_in = torch.cat(indata, dim=-1) - rgb = self.mlp(mlp_in) - rgb = torch.sigmoid(rgb) - - return rgb - - -def near_far_from_bound(rays_o, rays_d, bound, type='cube'): - # rays: [B, N, 3], [B, N, 3] - # bound: int, radius for ball or half-edge-length for cube - # return near [B, N, 1], far [B, N, 1] - - radius = rays_o.norm(dim=-1, keepdim=True) - - if type == 'sphere': - near = radius - bound # [B, N, 1] - far = radius + bound - - elif type == 'cube': - tmin = (-bound - rays_o) / (rays_d + 1e-15) # [B, N, 3] - tmax = (bound - rays_o) / (rays_d + 1e-15) - near = torch.where(tmin < tmax, tmin, tmax).max(dim=-1, keepdim=True)[0] - far = torch.where(tmin > tmax, tmin, tmax).min(dim=-1, keepdim=True)[0] - # if far < near, means no intersection, set both near and far to inf (1e9 here) - mask = far < near - near[mask] = 1e9 - far[mask] = 1e9 - # restrict near to a minimal value - near = torch.clamp(near, min=0.05) - - return near, far - - -class TensorBase(torch.nn.Module): - def __init__(self, aabb, gridSize, device, density_n_comp = 8, appearance_n_comp = 24, app_dim = 27, - shadingMode = 'MLP_PE', alphaMask = None, near_far=[2.0,6.0], - density_shift = -10, alphaMask_thres=0.08, distance_scale=25, rayMarch_weight_thres=0.0001, - pos_pe = 6, view_pe = 6, fea_pe = 6, featureC=128, step_ratio=2.0, - fea2denseAct = 'softplus', - cuda_ray = False, - ): - super(TensorBase, self).__init__() - - self.density_n_comp = density_n_comp - self.app_n_comp = appearance_n_comp - self.app_dim = app_dim - self.aabb = aabb - self.alphaMask = alphaMask - self.device=device - - self.density_shift = density_shift - self.alphaMask_thres = alphaMask_thres - self.distance_scale = distance_scale - self.rayMarch_weight_thres = rayMarch_weight_thres - self.fea2denseAct = fea2denseAct - - self.near_far = near_far - self.step_ratio = step_ratio - - - self.update_stepSize(gridSize) - - self.matMode = [[0,1], [0,2], [1,2]] - self.vecMode = [2, 1, 0] - self.comp_w = [1,1,1] - - - self.init_svd_volume(gridSize[0], device) - - self.shadingMode, self.pos_pe, self.view_pe, self.fea_pe, self.featureC = shadingMode, pos_pe, view_pe, fea_pe, featureC - self.init_render_func(shadingMode, pos_pe, view_pe, fea_pe, featureC, device) - - self.cuda_ray = cuda_ray - if cuda_ray: - # density grid - density_grid = torch.zeros([128] * 3) - self.register_buffer('density_grid', density_grid) - self.mean_density = 0 - self.iter_density = 0 - # step counter - step_counter = torch.zeros(16, 2, dtype=torch.int32) # 16 is hardcoded for averaging... - self.register_buffer('step_counter', step_counter) - self.mean_count = 0 - self.local_step = 0 - - def reset_extra_state(self): - if not self.cuda_ray: - return - # density grid - self.density_grid.zero_() - self.mean_density = 0 - self.iter_density = 0 - # step counter - self.step_counter.zero_() - self.mean_count = 0 - self.local_step = 0 - - def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device): - if shadingMode == 'MLP_PE': - self.renderModule = MLPRender_PE(self.app_dim, view_pe, pos_pe, featureC).to(device) - elif shadingMode == 'MLP_Fea': - self.renderModule = MLPRender_Fea(self.app_dim, view_pe, fea_pe, featureC).to(device) - elif shadingMode == 'MLP': - self.renderModule = MLPRender(self.app_dim, view_pe, featureC).to(device) - elif shadingMode == 'SH': - self.renderModule = SHRender - elif shadingMode == 'RGB': - assert self.app_dim == 3 - self.renderModule = RGBRender - else: - print("Unrecognized shading module") - exit() - print("pos_pe", pos_pe, "view_pe", view_pe, "fea_pe", fea_pe) - print(self.renderModule) - - def update_stepSize(self, gridSize): - print("aabb", self.aabb.view(-1)) - print("grid size", gridSize) - self.aabbSize = self.aabb[1] - self.aabb[0] - self.invaabbSize = 2.0/self.aabbSize - self.gridSize= torch.LongTensor(gridSize).to(self.device) - self.units=self.aabbSize / (self.gridSize-1) - self.stepSize=torch.mean(self.units)*self.step_ratio - self.aabbDiag = torch.sqrt(torch.sum(torch.square(self.aabbSize))) - self.nSamples=int((self.aabbDiag / self.stepSize).item()) + 1 - print("sampling step size: ", self.stepSize) - print("sampling number: ", self.nSamples) - - def init_svd_volume(self, res, device): - pass - - def compute_features(self, xyz_sampled): - pass - - def compute_densityfeature(self, xyz_sampled): - pass - - def compute_appfeature(self, xyz_sampled): - pass - - def normalize_coord(self, xyz_sampled): - return (xyz_sampled-self.aabb[0]) * self.invaabbSize - 1 - - def get_optparam_groups(self, lr_init_spatial = 0.02, lr_init_network = 0.001): - pass - - def get_kwargs(self): - return { - 'aabb': self.aabb, - 'gridSize':self.gridSize.tolist(), - 'density_n_comp': self.density_n_comp, - 'appearance_n_comp': self.app_n_comp, - 'app_dim': self.app_dim, - - 'density_shift': self.density_shift, - 'alphaMask_thres': self.alphaMask_thres, - 'distance_scale': self.distance_scale, - 'rayMarch_weight_thres': self.rayMarch_weight_thres, - 'fea2denseAct': self.fea2denseAct, - - 'near_far': self.near_far, - 'step_ratio': self.step_ratio, - - 'shadingMode': self.shadingMode, - 'pos_pe': self.pos_pe, - 'view_pe': self.view_pe, - 'fea_pe': self.fea_pe, - 'featureC': self.featureC - } - - def get_state_dict(self): - kwargs = self.get_kwargs() - ckpt = {'kwargs': kwargs, 'state_dict': self.state_dict()} - if self.alphaMask is not None: - alpha_volume = self.alphaMask.alpha_volume.bool().cpu().numpy() - ckpt.update({'alphaMask.shape':alpha_volume.shape}) - ckpt.update({'alphaMask.mask':np.packbits(alpha_volume.reshape(-1))}) - ckpt.update({'alphaMask.aabb': self.alphaMask.aabb.cpu()}) - return ckpt - - def load(self, ckpt): - if 'alphaMask.aabb' in ckpt.keys(): - length = np.prod(ckpt['alphaMask.shape']) - alpha_volume = torch.from_numpy(np.unpackbits(ckpt['alphaMask.mask'])[:length].reshape(ckpt['alphaMask.shape'])) - self.alphaMask = AlphaGridMask(self.device, ckpt['alphaMask.aabb'].to(self.device), alpha_volume.float().to(self.device)) - self.load_state_dict(ckpt['state_dict']) - - - def sample_ray_ndc(self, rays_o, rays_d, is_train=True, N_samples=-1): - N_samples = N_samples if N_samples > 0 else self.nSamples - near, far = self.near_far - interpx = torch.linspace(near, far, N_samples).unsqueeze(0).to(rays_o) - if is_train: - interpx += torch.rand_like(interpx).to(rays_o) * ((far - near) / N_samples) - - rays_pts = rays_o[..., None, :] + rays_d[..., None, :] * interpx[..., None] - mask_outbbox = ((self.aabb[0] > rays_pts) | (rays_pts > self.aabb[1])).any(dim=-1) - return rays_pts, interpx, ~mask_outbbox - - # this is the place to insert cuda ray. - def sample_ray(self, rays_o, rays_d, is_train=True, N_samples=-1): - N_samples = N_samples if N_samples>0 else self.nSamples - stepsize = self.stepSize - near, far = self.near_far - vec = torch.where(rays_d==0, torch.full_like(rays_d, 1e-6), rays_d) - rate_a = (self.aabb[1] - rays_o) / vec - rate_b = (self.aabb[0] - rays_o) / vec - t_min = torch.minimum(rate_a, rate_b).amax(-1).clamp(min=near, max=far) - - rng = torch.arange(N_samples)[None].float() - if is_train: - rng = rng.repeat(rays_d.shape[-2],1) - rng += torch.rand_like(rng[:,[0]]) - step = stepsize * rng.to(rays_o.device) - interpx = (t_min[...,None] + step) - - rays_pts = rays_o[...,None,:] + rays_d[...,None,:] * interpx[...,None] - mask_outbbox = ((self.aabb[0]>rays_pts) | (rays_pts>self.aabb[1])).any(dim=-1) - - return rays_pts, interpx, ~mask_outbbox - - - def shrink(self, new_aabb, voxel_size): - pass - - @torch.no_grad() - def updateAlphaMask(self, gridSize=(200,200,200)): - - total_voxels = gridSize[0] * gridSize[1] * gridSize[2] - - samples = torch.stack(torch.meshgrid( - torch.linspace(0, 1, gridSize[0]), - torch.linspace(0, 1, gridSize[1]), - torch.linspace(0, 1, gridSize[2]), - ), -1).to(self.device) - dense_xyz = self.aabb[0] * (1-samples) + self.aabb[1] * samples - - dense_xyz = dense_xyz.transpose(0,2).contiguous() - alpha = torch.zeros_like(dense_xyz[...,0]) - for i in range(gridSize[2]): - alpha[i] = self.compute_alpha(dense_xyz[i].view(-1,3), self.distance_scale*self.aabbDiag).view((gridSize[1], gridSize[0])) - alpha = alpha.clamp(0,1)[None,None] - - - ks = 3 - alpha = F.max_pool3d(alpha, kernel_size=ks, padding=ks // 2, stride=1).view(gridSize[::-1]) - alpha[alpha>=self.alphaMask_thres] = 1 - alpha[alpha0.5] - - xyz_min = valid_xyz.amin(0) - xyz_max = valid_xyz.amax(0) - - new_aabb = torch.stack((xyz_min, xyz_max)) - - total = torch.sum(alpha) - print(f"bbox: {xyz_min, xyz_max} alpha rest %%%f"%(total/total_voxels*100)) - return new_aabb - - @torch.no_grad() - def filtering_rays(self, all_rays, all_rgbs, N_samples=256, chunk=10240*5, bbox_only=False): - print('========> filtering rays ...') - tt = time.time() - - N = torch.tensor(all_rays.shape[:-1]).prod() - - mask_filtered = [] - idx_chunks = torch.split(torch.arange(N), chunk) - for idx_chunk in idx_chunks: - rays_chunk = all_rays[idx_chunk].to(self.device) - - rays_o, rays_d = rays_chunk[..., :3], rays_chunk[..., 3:6] - if bbox_only: - vec = torch.where(rays_d == 0, torch.full_like(rays_d, 1e-6), rays_d) - rate_a = (self.aabb[1] - rays_o) / vec - rate_b = (self.aabb[0] - rays_o) / vec - t_min = torch.minimum(rate_a, rate_b).amax(-1)#.clamp(min=near, max=far) - t_max = torch.maximum(rate_a, rate_b).amin(-1)#.clamp(min=near, max=far) - mask_inbbox = t_max > t_min - - else: - xyz_sampled, _,_ = self.sample_ray(rays_o, rays_d, N_samples=N_samples, is_train=False) - mask_inbbox= (self.alphaMask.sample_alpha(xyz_sampled).view(xyz_sampled.shape[:-1]) > 0).any(-1) - - mask_filtered.append(mask_inbbox.cpu()) - - mask_filtered = torch.cat(mask_filtered).view(all_rgbs.shape[:-1]) - - print(f'Ray filtering done! takes {time.time()-tt} s. ray mask ratio: {torch.sum(mask_filtered) / N}') - return all_rays[mask_filtered], all_rgbs[mask_filtered] - - - def feature2density(self, density_features): - if self.fea2denseAct == "softplus": - return F.softplus(density_features+self.density_shift) - elif self.fea2denseAct == "relu": - return F.relu(density_features) - - - def compute_alpha(self, xyz_locs, length=1): - - if self.alphaMask is not None: - alphas = self.alphaMask.sample_alpha(xyz_locs) - alpha_mask = alphas > 0 - else: - alpha_mask = torch.ones_like(xyz_locs[:,0], dtype=bool) - - - sigma = torch.zeros(xyz_locs.shape[:-1], device=xyz_locs.device) - - if alpha_mask.any(): - xyz_sampled = self.normalize_coord(xyz_locs[alpha_mask]) - sigma_feature = self.compute_densityfeature(xyz_sampled) - validsigma = self.feature2density(sigma_feature) - sigma[alpha_mask] = validsigma - - - alpha = 1 - torch.exp(-sigma*length).view(xyz_locs.shape[:-1]) - - return alpha - - - def forward(self, rays_chunk, white_bg=True, is_train=False, ndc_ray=False, N_samples=-1): - - #starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) - #starter.record() - - # sample points - viewdirs = rays_chunk[:, 3:6] - if ndc_ray: - xyz_sampled, z_vals, ray_valid = self.sample_ray_ndc(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) - dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) - rays_norm = torch.norm(viewdirs, dim=-1, keepdim=True) - dists = dists * rays_norm - viewdirs = viewdirs / rays_norm - else: - xyz_sampled, z_vals, ray_valid = self.sample_ray(rays_chunk[:, :3], viewdirs, is_train=is_train,N_samples=N_samples) - dists = torch.cat((z_vals[:, 1:] - z_vals[:, :-1], torch.zeros_like(z_vals[:, :1])), dim=-1) - viewdirs = viewdirs.view(-1, 1, 3).expand(xyz_sampled.shape) - - - if self.alphaMask is not None: - alphas = self.alphaMask.sample_alpha(xyz_sampled[ray_valid]) - alpha_mask = alphas > 0 - ray_invalid = ~ray_valid - ray_invalid[ray_valid] |= (~alpha_mask) - ray_valid = ~ray_invalid - - #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'sample ray = {curr_time}') - #starter.record() - - sigma = torch.zeros(xyz_sampled.shape[:-1], device=xyz_sampled.device) - rgb = torch.zeros((*xyz_sampled.shape[:2], 3), device=xyz_sampled.device) - - if ray_valid.any(): - xyz_sampled = self.normalize_coord(xyz_sampled) - sigma_feature = self.compute_densityfeature(xyz_sampled[ray_valid]) - - validsigma = self.feature2density(sigma_feature) - sigma[ray_valid] = validsigma - - #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'sigma = {curr_time}') - #starter.record() - - alpha, weight, bg_weight = raw2alpha(sigma, dists * self.distance_scale) - - app_mask = weight > self.rayMarch_weight_thres - - print(app_mask.sum(), xyz_sampled.shape) - - if app_mask.any(): - app_features = self.compute_appfeature(xyz_sampled[app_mask]) - valid_rgbs = self.renderModule(xyz_sampled[app_mask], viewdirs[app_mask], app_features) - rgb[app_mask] = valid_rgbs - - acc_map = torch.sum(weight, -1) - rgb_map = torch.sum(weight[..., None] * rgb, -2) - - if white_bg or (is_train and torch.rand((1,))<0.5): - rgb_map = rgb_map + (1. - acc_map[..., None]) - - - rgb_map = rgb_map.clamp(0,1) - - with torch.no_grad(): - depth_map = torch.sum(weight * z_vals, -1) - depth_map = depth_map + (1. - acc_map) * rays_chunk[..., -1] - - #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'comp rays = {curr_time}') - - return rgb_map, depth_map # rgb, sigma, alpha, weight, bg_weight - - - -class TensorVM(TensorBase): - def __init__(self, aabb, gridSize, device, **kargs): - super(TensorVM, self).__init__(aabb, gridSize, device, **kargs) - - - def init_svd_volume(self, res, device): - self.plane_coef = torch.nn.Parameter( - 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, res), device=device)) - self.line_coef = torch.nn.Parameter( - 0.1 * torch.randn((3, self.app_n_comp + self.density_n_comp, res, 1), device=device)) - self.basis_mat = torch.nn.Linear(self.app_n_comp * 3, self.app_dim, bias=False, device=device) - - - def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): - grad_vars = [{'params': self.line_coef, 'lr': lr_init_spatialxyz}, {'params': self.plane_coef, 'lr': lr_init_spatialxyz}, - {'params': self.basis_mat.parameters(), 'lr':lr_init_network}] - if isinstance(self.renderModule, torch.nn.Module): - grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}] - return grad_vars - - def compute_features(self, xyz_sampled): - - coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach() - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach() - - plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, align_corners=True).view( - -1, *xyz_sampled.shape[:1]) - line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view( - -1, *xyz_sampled.shape[:1]) - - sigma_feature = torch.sum(plane_feats * line_feats, dim=0) - - - plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(3 * self.app_n_comp, -1) - line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(3 * self.app_n_comp, -1) - - - app_features = self.basis_mat((plane_feats * line_feats).T) - - return sigma_feature, app_features - - def compute_densityfeature(self, xyz_sampled): - - print(xyz_sampled.shape) - - coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - plane_feats = F.grid_sample(self.plane_coef[:, -self.density_n_comp:], coordinate_plane, align_corners=True).view( - -1, *xyz_sampled.shape[:1]) - line_feats = F.grid_sample(self.line_coef[:, -self.density_n_comp:], coordinate_line, align_corners=True).view( - -1, *xyz_sampled.shape[:1]) - - sigma_feature = torch.sum(plane_feats * line_feats, dim=0) - - - return sigma_feature - - def compute_appfeature(self, xyz_sampled): - coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - plane_feats = F.grid_sample(self.plane_coef[:, :self.app_n_comp], coordinate_plane, align_corners=True).view(3 * self.app_n_comp, -1) - line_feats = F.grid_sample(self.line_coef[:, :self.app_n_comp], coordinate_line, align_corners=True).view(3 * self.app_n_comp, -1) - - - app_features = self.basis_mat((plane_feats * line_feats).T) - - - return app_features - - - def vectorDiffs(self, vector_comps): - total = 0 - - for idx in range(len(vector_comps)): - # print(self.line_coef.shape, vector_comps[idx].shape) - n_comp, n_size = vector_comps[idx].shape[:-1] - - dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2)) - # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape) - non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1] - # print(vector_comps[idx].shape, vector_comps[idx].view(n_comp,n_size).transpose(-1,-2).shape, dotp.shape,non_diagonal.shape) - total = total + torch.mean(torch.abs(non_diagonal)) - return total - - def vector_comp_diffs(self): - - return self.vectorDiffs(self.line_coef[:,-self.density_n_comp:]) + self.vectorDiffs(self.line_coef[:,:self.app_n_comp]) - - - @torch.no_grad() - def up_sampling_VM(self, plane_coef, line_coef, res_target): - - for i in range(len(self.vecMode)): - vec_id = self.vecMode[i] - mat_id_0, mat_id_1 = self.matMode[i] - - plane_coef[i] = torch.nn.Parameter( - F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear', - align_corners=True)) - line_coef[i] = torch.nn.Parameter( - F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) - - # plane_coef[0] = torch.nn.Parameter( - # F.interpolate(plane_coef[0].data, size=(res_target[1], res_target[0]), mode='bilinear', - # align_corners=True)) - # line_coef[0] = torch.nn.Parameter( - # F.interpolate(line_coef[0].data, size=(res_target[2], 1), mode='bilinear', align_corners=True)) - # plane_coef[1] = torch.nn.Parameter( - # F.interpolate(plane_coef[1].data, size=(res_target[2], res_target[0]), mode='bilinear', - # align_corners=True)) - # line_coef[1] = torch.nn.Parameter( - # F.interpolate(line_coef[1].data, size=(res_target[1], 1), mode='bilinear', align_corners=True)) - # plane_coef[2] = torch.nn.Parameter( - # F.interpolate(plane_coef[2].data, size=(res_target[2], res_target[1]), mode='bilinear', - # align_corners=True)) - # line_coef[2] = torch.nn.Parameter( - # F.interpolate(line_coef[2].data, size=(res_target[0], 1), mode='bilinear', align_corners=True)) - - return plane_coef, line_coef - - @torch.no_grad() - def upsample_volume_grid(self, res_target): - # self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target) - # self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target) - - scale = res_target[0]/self.line_coef.shape[2] #assuming xyz have the same scale - plane_coef = F.interpolate(self.plane_coef.detach().data, scale_factor=scale, mode='bilinear',align_corners=True) - line_coef = F.interpolate(self.line_coef.detach().data, size=(res_target[0],1), mode='bilinear',align_corners=True) - self.plane_coef, self.line_coef = torch.nn.Parameter(plane_coef), torch.nn.Parameter(line_coef) - self.compute_stepSize(res_target) - print(f'upsamping to {res_target}') - - -class TensorVMSplit(TensorBase): - def __init__(self, aabb, gridSize, device, **kargs): - super(TensorVMSplit, self).__init__(aabb, gridSize, device, **kargs) - - - def init_svd_volume(self, res, device): - self.density_plane, self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.1, device) - self.app_plane, self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.1, device) - self.basis_mat = torch.nn.Linear(sum(self.app_n_comp), self.app_dim, bias=False).to(device) - - - def init_one_svd(self, n_component, gridSize, scale, device): - plane_coef, line_coef = [], [] - for i in range(len(self.vecMode)): - vec_id = self.vecMode[i] - mat_id_0, mat_id_1 = self.matMode[i] - plane_coef.append(torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[mat_id_1], gridSize[mat_id_0])))) # - line_coef.append(torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1)))) - - return torch.nn.ParameterList(plane_coef).to(device), torch.nn.ParameterList(line_coef).to(device) - - - - def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): - grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, {'params': self.density_plane, 'lr': lr_init_spatialxyz}, - {'params': self.app_line, 'lr': lr_init_spatialxyz}, {'params': self.app_plane, 'lr': lr_init_spatialxyz}, - {'params': self.basis_mat.parameters(), 'lr':lr_init_network}] - if isinstance(self.renderModule, torch.nn.Module): - grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}] - return grad_vars - - - def vectorDiffs(self, vector_comps): - total = 0 - - for idx in range(len(vector_comps)): - n_comp, n_size = vector_comps[idx].shape[1:-1] - - dotp = torch.matmul(vector_comps[idx].view(n_comp,n_size), vector_comps[idx].view(n_comp,n_size).transpose(-1,-2)) - non_diagonal = dotp.view(-1)[1:].view(n_comp-1, n_comp+1)[...,:-1] - total = total + torch.mean(torch.abs(non_diagonal)) - return total - - def vector_comp_diffs(self): - return self.vectorDiffs(self.density_line) + self.vectorDiffs(self.app_line) - - def density_L1(self): - total = 0 - for idx in range(len(self.density_plane)): - total = total + torch.mean(torch.abs(self.density_plane[idx])) + torch.mean(torch.abs(self.density_line[idx]))# + torch.mean(torch.abs(self.app_plane[idx])) + torch.mean(torch.abs(self.density_plane[idx])) - return total - - def TV_loss_density(self, reg): - total = 0 - for idx in range(len(self.density_plane)): - total = total + reg(self.density_plane[idx]) * 1e-2 + reg(self.density_line[idx]) * 1e-3 - return total - - def TV_loss_app(self, reg): - total = 0 - for idx in range(len(self.app_plane)): - total = total + reg(self.app_plane[idx]) * 1e-2 + reg(self.app_line[idx]) * 1e-3 - return total - - def compute_densityfeature(self, xyz_sampled): - - #print(xyz_sampled.shape) - - # starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) - # starter.record() - - # plane + line basis - coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - sigma_feature = torch.zeros((xyz_sampled.shape[0],), device=xyz_sampled.device) - for idx_plane in range(len(self.density_plane)): - plane_coef_point = F.grid_sample(self.density_plane[idx_plane], coordinate_plane[[idx_plane]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = F.grid_sample(self.density_line[idx_plane], coordinate_line[[idx_plane]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - sigma_feature = sigma_feature + torch.sum(plane_coef_point * line_coef_point, dim=0) - - #ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'density = {curr_time}') - - return sigma_feature - - - def compute_appfeature(self, xyz_sampled): - - # plane + line basis - coordinate_plane = torch.stack((xyz_sampled[..., self.matMode[0]], xyz_sampled[..., self.matMode[1]], xyz_sampled[..., self.matMode[2]])).detach().view(3, -1, 1, 2) - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - plane_coef_point,line_coef_point = [],[] - for idx_plane in range(len(self.app_plane)): - plane_coef_point.append(F.grid_sample(self.app_plane[idx_plane], coordinate_plane[[idx_plane]], - align_corners=True).view(-1, *xyz_sampled.shape[:1])) - line_coef_point.append(F.grid_sample(self.app_line[idx_plane], coordinate_line[[idx_plane]], - align_corners=True).view(-1, *xyz_sampled.shape[:1])) - plane_coef_point, line_coef_point = torch.cat(plane_coef_point), torch.cat(line_coef_point) - - - return self.basis_mat((plane_coef_point * line_coef_point).T) - - - - @torch.no_grad() - def up_sampling_VM(self, plane_coef, line_coef, res_target): - - for i in range(len(self.vecMode)): - vec_id = self.vecMode[i] - mat_id_0, mat_id_1 = self.matMode[i] - plane_coef[i] = torch.nn.Parameter( - F.interpolate(plane_coef[i].data, size=(res_target[mat_id_1], res_target[mat_id_0]), mode='bilinear', - align_corners=True)) - line_coef[i] = torch.nn.Parameter( - F.interpolate(line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) - - - return plane_coef, line_coef - - @torch.no_grad() - def upsample_volume_grid(self, res_target): - self.app_plane, self.app_line = self.up_sampling_VM(self.app_plane, self.app_line, res_target) - self.density_plane, self.density_line = self.up_sampling_VM(self.density_plane, self.density_line, res_target) - - self.update_stepSize(res_target) - print(f'upsamping to {res_target}') - - @torch.no_grad() - def shrink(self, new_aabb): - print("====> shrinking ...") - xyz_min, xyz_max = new_aabb - t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units - # print(new_aabb, self.aabb) - # print(t_l, b_r,self.alphaMask.alpha_volume.shape) - t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1 - b_r = torch.stack([b_r, self.gridSize]).amin(0) - - for i in range(len(self.vecMode)): - mode0 = self.vecMode[i] - self.density_line[i] = torch.nn.Parameter( - self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:] - ) - self.app_line[i] = torch.nn.Parameter( - self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:] - ) - mode0, mode1 = self.matMode[i] - self.density_plane[i] = torch.nn.Parameter( - self.density_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]] - ) - self.app_plane[i] = torch.nn.Parameter( - self.app_plane[i].data[...,t_l[mode1]:b_r[mode1],t_l[mode0]:b_r[mode0]] - ) - - - if not torch.all(self.alphaMask.gridSize == self.gridSize): - t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1) - correct_aabb = torch.zeros_like(new_aabb) - correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1] - correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1] - print("aabb", new_aabb, "\ncorrect aabb", correct_aabb) - new_aabb = correct_aabb - - newSize = b_r - t_l - self.aabb = new_aabb - self.update_stepSize((newSize[0], newSize[1], newSize[2])) - - - - -class TensorCP(TensorBase): - def __init__(self, aabb, gridSize, device, **kargs): - super(TensorCP, self).__init__(aabb, gridSize, device, **kargs) - - - def init_svd_volume(self, res, device): - self.density_line = self.init_one_svd(self.density_n_comp, self.gridSize, 0.2, device) - self.app_line = self.init_one_svd(self.app_n_comp, self.gridSize, 0.2, device) - self.basis_mat = torch.nn.Linear(self.app_n_comp[0], self.app_dim, bias=False).to(device) - - - def init_one_svd(self, n_component, gridSize, scale, device): - line_coef = [] - for i in range(len(self.vecMode)): - vec_id = self.vecMode[i] - line_coef.append( - torch.nn.Parameter(scale * torch.randn((1, n_component[i], gridSize[vec_id], 1)))) - return torch.nn.ParameterList(line_coef).to(device) - - - def get_optparam_groups(self, lr_init_spatialxyz = 0.02, lr_init_network = 0.001): - grad_vars = [{'params': self.density_line, 'lr': lr_init_spatialxyz}, - {'params': self.app_line, 'lr': lr_init_spatialxyz}, - {'params': self.basis_mat.parameters(), 'lr':lr_init_network}] - if isinstance(self.renderModule, torch.nn.Module): - grad_vars += [{'params':self.renderModule.parameters(), 'lr':lr_init_network}] - return grad_vars - - def compute_densityfeature(self, xyz_sampled): - - coordinate_line = torch.stack((xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - - line_coef_point = F.grid_sample(self.density_line[0], coordinate_line[[0]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = line_coef_point * F.grid_sample(self.density_line[1], coordinate_line[[1]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = line_coef_point * F.grid_sample(self.density_line[2], coordinate_line[[2]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - sigma_feature = torch.sum(line_coef_point, dim=0) - - - return sigma_feature - - def compute_appfeature(self, xyz_sampled): - - coordinate_line = torch.stack( - (xyz_sampled[..., self.vecMode[0]], xyz_sampled[..., self.vecMode[1]], xyz_sampled[..., self.vecMode[2]])) - coordinate_line = torch.stack((torch.zeros_like(coordinate_line), coordinate_line), dim=-1).detach().view(3, -1, 1, 2) - - - line_coef_point = F.grid_sample(self.app_line[0], coordinate_line[[0]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = line_coef_point * F.grid_sample(self.app_line[1], coordinate_line[[1]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - line_coef_point = line_coef_point * F.grid_sample(self.app_line[2], coordinate_line[[2]], - align_corners=True).view(-1, *xyz_sampled.shape[:1]) - - return self.basis_mat(line_coef_point.T) - - - @torch.no_grad() - def up_sampling_Vector(self, density_line_coef, app_line_coef, res_target): - - for i in range(len(self.vecMode)): - vec_id = self.vecMode[i] - density_line_coef[i] = torch.nn.Parameter( - F.interpolate(density_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) - app_line_coef[i] = torch.nn.Parameter( - F.interpolate(app_line_coef[i].data, size=(res_target[vec_id], 1), mode='bilinear', align_corners=True)) - - return density_line_coef, app_line_coef - - @torch.no_grad() - def upsample_volume_grid(self, res_target): - self.density_line, self.app_line = self.up_sampling_Vector(self.density_line, self.app_line, res_target) - - self.update_stepSize(res_target) - print(f'upsamping to {res_target}') - - @torch.no_grad() - def shrink(self, new_aabb): - print("====> shrinking ...") - xyz_min, xyz_max = new_aabb - t_l, b_r = (xyz_min - self.aabb[0]) / self.units, (xyz_max - self.aabb[0]) / self.units - - t_l, b_r = torch.round(torch.round(t_l)).long(), torch.round(b_r).long() + 1 - b_r = torch.stack([b_r, self.gridSize]).amin(0) - - - for i in range(len(self.vecMode)): - mode0 = self.vecMode[i] - self.density_line[i] = torch.nn.Parameter( - self.density_line[i].data[...,t_l[mode0]:b_r[mode0],:] - ) - self.app_line[i] = torch.nn.Parameter( - self.app_line[i].data[...,t_l[mode0]:b_r[mode0],:] - ) - - if not torch.all(self.alphaMask.gridSize == self.gridSize): - t_l_r, b_r_r = t_l / (self.gridSize-1), (b_r-1) / (self.gridSize-1) - correct_aabb = torch.zeros_like(new_aabb) - correct_aabb[0] = (1-t_l_r)*self.aabb[0] + t_l_r*self.aabb[1] - correct_aabb[1] = (1-b_r_r)*self.aabb[0] + b_r_r*self.aabb[1] - print("aabb", new_aabb, "\ncorrect aabb", correct_aabb) - new_aabb = correct_aabb - - newSize = b_r - t_l - self.aabb = new_aabb - self.update_stepSize((newSize[0], newSize[1], newSize[2])) - - def density_L1(self): - total = 0 - for idx in range(len(self.density_line)): - total = total + torch.mean(torch.abs(self.density_line[idx])) - return total - - def TV_loss_density(self, reg): - total = 0 - for idx in range(len(self.density_line)): - total = total + reg(self.density_line[idx]) * 1e-3 - return total - - def TV_loss_app(self, reg): - total = 0 - for idx in range(len(self.app_line)): - total = total + reg(self.app_line[idx]) * 1e-3 - return total \ No newline at end of file diff --git a/tensorf/provider.py b/tensorf/provider.py deleted file mode 100644 index 26d9f429..00000000 --- a/tensorf/provider.py +++ /dev/null @@ -1,194 +0,0 @@ -import os -import glob -import numpy as np - -import cv2 - -import torch -from torch.utils.data import Dataset - -from scipy.spatial.transform import Slerp, Rotation - -# NeRF dataset -import json - - -# ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 -def nerf_matrix_to_ngp(pose, scale=0.33): - # for the fox dataset, 0.33 scales camera radius to ~ 2 - new_pose = np.array([ - [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale], - [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale], - [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale], - [0, 0, 0, 1], - ], dtype=np.float32) - return new_pose - - -class NeRFDataset(Dataset): - def __init__(self, path, type='train', mode='colmap', preload=False, downscale=1, scale=0.33, n_test=10): - super().__init__() - # path: the json file path. - - self.root_path = path - self.type = type # train, val, test - self.mode = mode # colmap, blender, llff - self.downscale = downscale - self.preload = preload # preload data into GPU - - # camera radius scale to make sure camera are inside the bounding box. - self.scale = scale - - # load nerf-compatible format data. - if mode == 'colmap': - with open(os.path.join(path, 'transforms.json'), 'r') as f: - transform = json.load(f) - elif mode == 'blender': - # load all splits (train/valid/test), this is what instant-ngp in fact does... - if type == 'all': - transform_paths = glob.glob(os.path.join(path, '*.json')) - transform = None - for transform_path in transform_paths: - with open(transform_path, 'r') as f: - tmp_transform = json.load(f) - if transform is None: - transform = tmp_transform - else: - transform['frames'].extend(tmp_transform['frames']) - # only load one specified split - else: - with open(os.path.join(path, f'transforms_{type}.json'), 'r') as f: - transform = json.load(f) - - else: - raise NotImplementedError(f'unknown dataset mode: {mode}') - - # load image size - if 'h' in transform and 'w' in transform: - self.H = int(transform['h']) // downscale - self.W = int(transform['w']) // downscale - else: - # we have to actually read an image to get H and W later. - self.H = self.W = None - - # read images - frames = transform["frames"] - frames = sorted(frames, key=lambda d: d['file_path']) - - # for colmap, manually interpolate a test set. - if mode == 'colmap' and type == 'test': - - # choose two random poses, and interpolate between. - f0, f1 = np.random.choice(frames, 2, replace=False) - pose0 = nerf_matrix_to_ngp(np.array(f0['transform_matrix'], dtype=np.float32), scale=self.scale) # [4, 4] - pose1 = nerf_matrix_to_ngp(np.array(f1['transform_matrix'], dtype=np.float32), scale=self.scale) # [4, 4] - rots = Rotation.from_matrix(np.stack([pose0[:3, :3], pose1[:3, :3]])) - slerp = Slerp([0, 1], rots) - - self.poses = [] - self.images = None - for i in range(n_test + 1): - ratio = np.sin(((i / n_test) - 0.5) * np.pi) * 0.5 + 0.5 - pose = np.eye(4, dtype=np.float32) - pose[:3, :3] = slerp(ratio).as_matrix() - pose[:3, 3] = (1 - ratio) * pose0[:3, 3] + ratio * pose1[:3, 3] - self.poses.append(pose) - - else: - # for colmap, manually split a valid set (the first frame). - if mode == 'colmap': - if type == 'train': - frames = frames[1:] - elif type == 'val': - frames = frames[:1] - # else 'all': use all frames - - self.poses = [] - self.images = [] - for f in frames: - f_path = os.path.join(self.root_path, f['file_path']) - if mode == 'blender': - f_path += '.png' # so silly... - - # there are non-exist paths in fox... - if not os.path.exists(f_path): - continue - - pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] - pose = nerf_matrix_to_ngp(pose, scale=self.scale) - - image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] - if self.H is None or self.W is None: - self.H = image.shape[0] // downscale - self.W = image.shape[1] // downscale - - # add support for the alpha channel as a mask. - if image.shape[-1] == 3: - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - else: - image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA) - - image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA) - image = image.astype(np.float32) / 255 # [H, W, 3/4] - - self.poses.append(pose) - self.images.append(image) - - self.poses = np.stack(self.poses, axis=0) - if self.images is not None: - self.images = np.stack(self.images, axis=0) - - if preload: - self.poses = torch.from_numpy(self.poses).cuda() - if self.images is not None: - self.images = torch.from_numpy(self.images).cuda() - - # load intrinsics - - if 'fl_x' in transform or 'fl_y' in transform: - fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale - fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale - elif 'camera_angle_x' in transform or 'camera_angle_y' in transform: - # blender, assert in radians. already downscaled since we use H/W - fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None - fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None - if fl_x is None: fl_x = fl_y - if fl_y is None: fl_y = fl_x - else: - raise RuntimeError('cannot read focal!') - - cx = (transform['cx'] / downscale) if 'cx' in transform else (self.H / 2) - cy = (transform['cy'] / downscale) if 'cy' in transform else (self.W / 2) - - self.intrinsic = np.eye(3, dtype=np.float32) - self.intrinsic[0, 0] = fl_x - self.intrinsic[1, 1] = fl_y - self.intrinsic[0, 2] = cx - self.intrinsic[1, 2] = cy - - if preload: - self.intrinsic = torch.from_numpy(self.intrinsic).cuda() - - - def __len__(self): - return len(self.poses) - - def __getitem__(self, index): - - results = { - 'pose': self.poses[index], - 'intrinsic': self.intrinsic, - 'index': index, - } - - if self.type == 'test': - # only string can bypass the default collate, so we don't need to call item: https://github.com/pytorch/pytorch/blob/67a275c29338a6c6cc405bf143e63d53abe600bf/torch/utils/data/_utils/collate.py#L84 - results['H'] = str(self.H) - results['W'] = str(self.W) - # blender has test gt, so we also load it - if self.mode == 'blender': - results['image'] = self.images[index] - else: - results['image'] = self.images[index] - - return results \ No newline at end of file diff --git a/tensorf/utils.py b/tensorf/utils.py deleted file mode 100644 index 4e6809f6..00000000 --- a/tensorf/utils.py +++ /dev/null @@ -1,923 +0,0 @@ -import os -import glob -import tqdm -import random -import warnings -import tensorboardX - -import numpy as np -import pandas as pd - -import time -from datetime import datetime - -import cv2 -import matplotlib.pyplot as plt - -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -import torch.distributed as dist -from torch.utils.data import Dataset, DataLoader - -import trimesh -import mcubes -from rich.console import Console -from torch_ema import ExponentialMovingAverage - -def seed_everything(seed): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - #torch.backends.cudnn.deterministic = True - #torch.backends.cudnn.benchmark = True - - -def lift(x, y, z, intrinsics): - # x, y, z: [B, N] - # intrinsics: [B, 3, 3] - - fx = intrinsics[..., 0, 0].unsqueeze(-1) - fy = intrinsics[..., 1, 1].unsqueeze(-1) - cx = intrinsics[..., 0, 2].unsqueeze(-1) - cy = intrinsics[..., 1, 2].unsqueeze(-1) - sk = intrinsics[..., 0, 1].unsqueeze(-1) - - x_lift = (x - cx + cy * sk / fy - sk * y / fy) / fx * z - y_lift = (y - cy) / fy * z - - # homogeneous - return torch.stack((x_lift, y_lift, z), dim=-1) - -# Never cast get_rays! fp16 rays degenerates results seriously! -@torch.cuda.amp.autocast(enabled=False) -def get_rays(c2w, intrinsics, H, W, N_rays=-1): - # c2w: [B, 4, 4] - # intrinsics: [B, 3, 3] - # return: rays_o, rays_d: [B, N_rays, 3] - # return: select_inds: [B, N_rays] - - device = c2w.device - rays_o = c2w[..., :3, 3] # [B, 3] - prefix = c2w.shape[:-2] - - i, j = torch.meshgrid(torch.linspace(0, W-1, W, device=device), torch.linspace(0, H-1, H, device=device), indexing='ij') # for torch < 1.10, should remove indexing='ij' - i = i.t().reshape([*[1]*len(prefix), H*W]).expand([*prefix, H*W]) + 0.5 - j = j.t().reshape([*[1]*len(prefix), H*W]).expand([*prefix, H*W]) + 0.5 - - if N_rays > 0: - N_rays = min(N_rays, H*W) - select_inds = torch.randint(0, H*W, size=[N_rays], device=device) - select_inds = select_inds.expand([*prefix, N_rays]) - i = torch.gather(i, -1, select_inds) - j = torch.gather(j, -1, select_inds) - else: - select_inds = torch.arange(H*W, device=device).expand([*prefix, H*W]) - - directions = lift(i, j, torch.ones_like(i), intrinsics=intrinsics) - directions = directions / torch.norm(directions, dim=-1, keepdim=True) - - rays_d = directions @ c2w[:, :3, :3].transpose(-1, -2) # (B, N_rays, 3) - rays_o = rays_o[..., None, :].expand_as(rays_d) - - return rays_o, rays_d, select_inds - - -def extract_fields(bound_min, bound_max, resolution, query_func): - N = 64 - X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) - Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) - Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) - - u = np.zeros([resolution, resolution, resolution], dtype=np.float32) - with torch.no_grad(): - for xi, xs in enumerate(X): - for yi, ys in enumerate(Y): - for zi, zs in enumerate(Z): - xx, yy, zz = torch.meshgrid(xs, ys, zs, indexing='ij') # for torch < 1.10, should remove indexing='ij' - pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).unsqueeze(0) # [1, N, 3] - val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() # [1, N, 1] --> [x, y, z] - u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val - return u - - -def extract_geometry(bound_min, bound_max, resolution, threshold, query_func): - #print('threshold: {}'.format(threshold)) - u = extract_fields(bound_min, bound_max, resolution, query_func) - - #print(u.shape, u.max(), u.min(), np.percentile(u, 50)) - - vertices, triangles = mcubes.marching_cubes(u, threshold) - - b_max_np = bound_max.detach().cpu().numpy() - b_min_np = bound_min.detach().cpu().numpy() - - vertices = vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :] + b_min_np[None, :] - return vertices, triangles - - -class PSNRMeter: - def __init__(self): - self.V = 0 - self.N = 0 - - def clear(self): - self.V = 0 - self.N = 0 - - def prepare_inputs(self, *inputs): - outputs = [] - for i, inp in enumerate(inputs): - if torch.is_tensor(inp): - inp = inp.detach().cpu().numpy() - outputs.append(inp) - - return outputs - - def update(self, preds, truths): - preds, truths = self.prepare_inputs(preds, truths) # [B, N, 3] or [B, H, W, 3], range[0, 1] - - # simplified since max_pixel_value is 1 here. - psnr = -10 * np.log10(np.mean(np.power(preds - truths, 2))) - - self.V += psnr - self.N += 1 - - def measure(self): - return self.V / self.N - - def write(self, writer, global_step, prefix=""): - writer.add_scalar(os.path.join(prefix, "PSNR"), self.measure(), global_step) - - def report(self): - return f'PSNR = {self.measure():.6f}' - -def N_to_reso(n_voxels, bbox): - xyz_min, xyz_max = bbox - voxel_size = ((xyz_max - xyz_min).prod() / n_voxels).pow(1 / 3) - return ((xyz_max - xyz_min) / voxel_size).long().tolist() - -def cal_n_samples(reso, step_ratio=0.5): - return int(np.linalg.norm(reso)/step_ratio) - -def renderer(rays, tensorf, chunk=4096, N_samples=-1, ndc_ray=False, white_bg=True, is_train=False, device='cuda'): - rgbs, alphas, depth_maps, weights, uncertainties = [], [], [], [], [] - N_rays_all = rays.shape[0] - - for chunk_idx in range(N_rays_all // chunk + int(N_rays_all % chunk > 0)): - rays_chunk = rays[chunk_idx * chunk:(chunk_idx + 1) * chunk].to(device) - rgb_map, depth_map = tensorf(rays_chunk, is_train=is_train, white_bg=white_bg, ndc_ray=ndc_ray, N_samples=N_samples) - rgbs.append(rgb_map) - depth_maps.append(depth_map) - - return torch.cat(rgbs), torch.cat(depth_maps) - -class Trainer(object): - def __init__(self, - name, # name of this experiment - conf, # extra conf - model, # network - criterion=None, # loss function, if None, assume inline implementation in train_step - optimizer=None, # optimizer - ema_decay=None, # if use EMA, set the decay - lr_scheduler=None, # scheduler - metrics=[], # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric. - local_rank=0, # which GPU am I - world_size=1, # total num of GPUs - device=None, # device to use, usually setting to None is OK. (auto choose device) - mute=False, # whether to mute all print - fp16=False, # amp optimize level - eval_interval=1, # eval once every $ epoch - max_keep_ckpt=2, # max num of saved ckpts in disk - workspace='workspace', # workspace to save logs & ckpts - best_mode='min', # the smaller/larger result, the better - use_loss_as_metric=True, # use loss as the first metric - report_metric_at_train=True, # also report metrics at training - use_checkpoint="latest", # which ckpt to use at init time - use_tensorboardX=True, # whether to use tensorboard for logging - scheduler_update_every_step=False, # whether to call scheduler.step() after every train step - ): - - self.name = name - self.conf = conf - self.mute = mute - self.metrics = metrics - self.local_rank = local_rank - self.world_size = world_size - self.workspace = workspace - self.ema_decay = ema_decay - self.fp16 = fp16 - self.best_mode = best_mode - self.use_loss_as_metric = use_loss_as_metric - self.report_metric_at_train = report_metric_at_train - self.max_keep_ckpt = max_keep_ckpt - self.eval_interval = eval_interval - self.use_checkpoint = use_checkpoint - self.use_tensorboardX = use_tensorboardX - self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S") - self.scheduler_update_every_step = scheduler_update_every_step - self.device = device if device is not None else torch.device(f'cuda:{local_rank}' if torch.cuda.is_available() else 'cpu') - self.console = Console() - - model.to(self.device) - if self.world_size > 1: - model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank]) - self.model = model - - if isinstance(criterion, nn.Module): - criterion.to(self.device) - self.criterion = criterion - - if optimizer is None: - self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, weight_decay=5e-4) # naive adam - else: - self.optimizer_fn = optimizer - self.optimizer = optimizer(self.model) - - if lr_scheduler is None: - self.lr_scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, lr_lambda=lambda epoch: 1) # fake scheduler - else: - self.lr_scheduler_fn = lr_scheduler - self.lr_scheduler = lr_scheduler(self.optimizer) - - if ema_decay is not None: - self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay) - else: - self.ema = None - - self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16) - - # variable init - self.epoch = 1 - self.global_step = 0 - self.local_step = 0 - self.stats = { - "loss": [], - "valid_loss": [], - "results": [], # metrics[0], or valid_loss - "checkpoints": [], # record path of saved ckpt, to automatically remove old ckpt - "best_result": None, - } - - # auto fix - if len(metrics) == 0 or self.use_loss_as_metric: - self.best_mode = 'min' - - # workspace prepare - self.log_ptr = None - if self.workspace is not None: - os.makedirs(self.workspace, exist_ok=True) - self.log_path = os.path.join(workspace, f"log_{self.name}.txt") - self.log_ptr = open(self.log_path, "a+") - - self.ckpt_path = os.path.join(self.workspace, 'checkpoints') - self.best_path = f"{self.ckpt_path}/{self.name}.pth.tar" - os.makedirs(self.ckpt_path, exist_ok=True) - - self.log(f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}') - self.log(f'[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}') - - if self.workspace is not None: - if self.use_checkpoint == "scratch": - self.log("[INFO] Training from scratch ...") - elif self.use_checkpoint == "latest": - self.log("[INFO] Loading latest checkpoint ...") - self.load_checkpoint() - elif self.use_checkpoint == "best": - if os.path.exists(self.best_path): - self.log("[INFO] Loading best checkpoint ...") - self.load_checkpoint(self.best_path) - else: - self.log(f"[INFO] {self.best_path} not found, loading latest ...") - self.load_checkpoint() - else: # path to ckpt - self.log(f"[INFO] Loading {self.use_checkpoint} ...") - self.load_checkpoint(self.use_checkpoint) - - def __del__(self): - if self.log_ptr: - self.log_ptr.close() - - def log(self, *args, **kwargs): - if self.local_rank == 0: - if not self.mute: - #print(*args) - self.console.print(*args, **kwargs) - if self.log_ptr: - print(*args, file=self.log_ptr) - self.log_ptr.flush() # write immediately to file - - ### ------------------------------ - - def train_step(self, data): - images = data["image"] # [B, H, W, 3/4] - poses = data["pose"] # [B, 4, 4] - intrinsics = data["intrinsic"] # [B, 3, 3] - - # sample rays - B, H, W, C = images.shape - rays_o, rays_d, inds = get_rays(poses, intrinsics, H, W, self.conf['num_rays']) - images = torch.gather(images.reshape(B, -1, C), 1, torch.stack(C*[inds], -1)) # [B, N, 3/4] - - # train with random background color if using alpha mixing - bg_color = torch.ones(3, device=self.device) # [3], fixed white background - #bg_color = torch.rand(3, device=self.device) # [3], frame-wise random. - if C == 4: - gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) - else: - gt_rgb = images - - rays = torch.cat([rays_o, rays_d], dim=-1).reshape(-1, 6) - pred_rgb, pred_depth = renderer(rays, self.model, chunk=self.conf['num_rays'], N_samples=self.nSamples, white_bg=True, ndc_ray=False, device=self.device, is_train=True) - - pred_rgb = pred_rgb.reshape(B, -1, 3) - loss = self.criterion(pred_rgb, gt_rgb) - - # l1 reg - loss += self.model.density_L1() * self.L1_reg_weight - - return pred_rgb, gt_rgb, loss - - def eval_step(self, data): - images = data["image"] # [B, H, W, 3/4] - poses = data["pose"] # [B, 4, 4] - intrinsics = data["intrinsic"] # [B, 3, 3] - - # sample rays - B, H, W, C = images.shape - rays_o, rays_d, _ = get_rays(poses, intrinsics, H, W, -1) - - bg_color = torch.ones(3, device=self.device) # [3] - # eval with fixed background color - if C == 4: - gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:]) - else: - gt_rgb = images - - rays = torch.cat([rays_o, rays_d], dim=-1).reshape(-1, 6) - pred_rgb, pred_depth = renderer(rays, self.model, chunk=self.conf['num_rays'], N_samples=self.nSamples, white_bg=True, ndc_ray=False, device=self.device, is_train=False) - - pred_rgb = pred_rgb.reshape(B, H, W, -1) - pred_depth = pred_depth.reshape(B, H, W) - - # normalize depth - pred_depth = (pred_depth - self.near_far[0]) / (self.near_far[1] - self.near_far[0] + 1e-8) - - loss = self.criterion(pred_rgb, gt_rgb) - - return pred_rgb, pred_depth, gt_rgb, loss - - # moved out bg_color and perturb for more flexible control... - def test_step(self, data, bg_color=None, perturb=False): - - poses = data["pose"] # [B, 4, 4] - intrinsics = data["intrinsic"] # [B, 3, 3] - H, W = int(data['H'][0]), int(data['W'][0]) # get the target size... - - B = poses.shape[0] - - rays_o, rays_d, _ = get_rays(poses, intrinsics, H, W, -1) - - if bg_color is not None: - bg_color = bg_color.to(self.device) - - rays = torch.cat([rays_o, rays_d], dim=-1).reshape(-1, 6) - pred_rgb, pred_depth = renderer(rays, self.model, chunk=self.conf['num_rays'], N_samples=self.nSamples, white_bg=True, ndc_ray=False, device=self.device, is_train=False) - - pred_rgb = pred_rgb.reshape(B, H, W, -1) - pred_depth = pred_depth.reshape(B, H, W) - - # normalize depth - pred_depth = (pred_depth - self.near_far[0]) / (self.near_far[1] - self.near_far[0] + 1e-8) - - return pred_rgb, pred_depth - - - def save_mesh(self, save_path=None, resolution=256, threshold=10): - - if save_path is None: - save_path = os.path.join(self.workspace, 'meshes', f'{self.name}_{self.epoch}.ply') - - self.log(f"==> Saving mesh to {save_path}") - - os.makedirs(os.path.dirname(save_path), exist_ok=True) - - def query_func(pts): - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=self.fp16): - sdfs = self.model.density(pts.to(self.device)) - return sdfs - - bounds_min = torch.FloatTensor([-self.model.bound] * 3) - bounds_max = torch.FloatTensor([self.model.bound] * 3) - - vertices, triangles = extract_geometry(bounds_min, bounds_max, resolution=resolution, threshold=threshold, query_func=query_func) - - mesh = trimesh.Trimesh(vertices, triangles, process=False) # important, process=True leads to seg fault... - mesh.export(save_path) - - self.log(f"==> Finished saving mesh.") - - ### ------------------------------ - - def train(self, train_loader, valid_loader, max_epochs): - if self.use_tensorboardX and self.local_rank == 0: - self.writer = tensorboardX.SummaryWriter(os.path.join(self.workspace, "run", self.name)) - - for epoch in range(self.epoch, max_epochs + 1): - self.epoch = epoch - - self.train_one_epoch(train_loader) - - if self.workspace is not None and self.local_rank == 0: - self.save_checkpoint(full=True, best=False) - - if self.epoch % self.eval_interval == 0: - self.evaluate_one_epoch(valid_loader) - self.save_checkpoint(full=False, best=True) - - if self.use_tensorboardX and self.local_rank == 0: - self.writer.close() - - def evaluate(self, loader): - self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX - self.evaluate_one_epoch(loader) - self.use_tensorboardX = use_tensorboardX - - def test(self, loader, save_path=None): - - if save_path is None: - save_path = os.path.join(self.workspace, 'results') - - os.makedirs(save_path, exist_ok=True) - - self.log(f"==> Start Test, save results to {save_path}") - - pbar = tqdm.tqdm(total=len(loader), bar_format='{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') - self.model.eval() - with torch.no_grad(): - - for i, data in enumerate(loader): - - data = self.prepare_data(data) - - with torch.cuda.amp.autocast(enabled=self.fp16): - preds, preds_depth = self.test_step(data) - - path = os.path.join(save_path, f'{i:04d}.png') - path_depth = os.path.join(save_path, f'{i:04d}_depth.png') - - #self.log(f"[INFO] saving test image to {path}") - - cv2.imwrite(path, cv2.cvtColor((preds[0].detach().cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) - cv2.imwrite(path_depth, (preds_depth[0].detach().cpu().numpy() * 255).astype(np.uint8)) - - pbar.update(1) - - self.log(f"==> Finished Test.") - - # [GUI] just train for 16 steps, without any other overhead that may slow down rendering. - def train_gui(self, train_loader, step=16): - - self.model.train() - - # update grid - if self.model.cuda_ray: - with torch.cuda.amp.autocast(enabled=self.fp16): - self.model.update_extra_state() - - total_loss = torch.tensor([0], dtype=torch.float32, device=self.device) - - loader = iter(train_loader) - - for _ in range(step): - - # mimic an infinite loop dataloader (in case the total dataset is smaller than step) - try: - data = next(loader) - except StopIteration: - loader = iter(train_loader) - data = next(loader) - - self.global_step += 1 - - data = self.prepare_data(data) - - self.optimizer.zero_grad() - - with torch.cuda.amp.autocast(enabled=self.fp16): - preds, truths, loss = self.train_step(data) - - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() - - if self.scheduler_update_every_step: - self.lr_scheduler.step() - - total_loss += loss.detach() - - if self.ema is not None: - self.ema.update() - - average_loss = total_loss.item() / step - - if not self.scheduler_update_every_step: - if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - self.lr_scheduler.step(average_loss) - else: - self.lr_scheduler.step() - - outputs = { - 'loss': average_loss, - 'lr': self.optimizer.param_groups[0]['lr'], - } - - return outputs - - - # [GUI] test on a single image - def test_gui(self, pose, intrinsics, W, H, bg_color=None, spp=1): - - data = { - 'pose': pose[None, :], - 'intrinsic': intrinsics[None, :], - 'H': [str(H)], - 'W': [str(W)], - } - - data = self.prepare_data(data) - - self.model.eval() - - if self.ema is not None: - self.ema.store() - self.ema.copy_to() - - with torch.no_grad(): - with torch.cuda.amp.autocast(enabled=self.fp16): - # here spp is used as perturb random seed! - preds, preds_depth = self.test_step(data, bg_color=bg_color, perturb=spp) - - if self.ema is not None: - self.ema.restore() - - outputs = { - 'image': preds[0].detach().cpu().numpy(), - 'depth': preds_depth[0].detach().cpu().numpy(), - } - - return outputs - - def prepare_data(self, data): - if isinstance(data, list): - for i, v in enumerate(data): - if isinstance(v, np.ndarray): - data[i] = torch.from_numpy(v).to(self.device, non_blocking=True) - if torch.is_tensor(v): - data[i] = v.to(self.device, non_blocking=True) - elif isinstance(data, dict): - for k, v in data.items(): - if isinstance(v, np.ndarray): - data[k] = torch.from_numpy(v).to(self.device, non_blocking=True) - if torch.is_tensor(v): - data[k] = v.to(self.device, non_blocking=True) - elif isinstance(data, np.ndarray): - data = torch.from_numpy(data).to(self.device, non_blocking=True) - else: # is_tensor, or other similar objects that has `to` - data = data.to(self.device, non_blocking=True) - - return data - - def train_one_epoch(self, loader): - self.log(f"==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ...") - - total_loss = 0 - if self.local_rank == 0 and self.report_metric_at_train: - for metric in self.metrics: - metric.clear() - - self.model.train() - - # distributedSampler: must call set_epoch() to shuffle indices across multiple epochs - # ref: https://pytorch.org/docs/stable/data.html - if self.world_size > 1: - loader.sampler.set_epoch(self.epoch) - - if self.local_rank == 0: - pbar = tqdm.tqdm(total=len(loader) * loader.batch_size, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') - - self.local_step = 0 - - for data in loader: - - self.local_step += 1 - self.global_step += 1 - - data = self.prepare_data(data) - - self.optimizer.zero_grad() - - with torch.cuda.amp.autocast(enabled=self.fp16): - preds, truths, loss = self.train_step(data) - - self.scaler.scale(loss).backward() - self.scaler.step(self.optimizer) - self.scaler.update() - - if self.scheduler_update_every_step: - self.lr_scheduler.step() - - loss_val = loss.item() - total_loss += loss_val - - if self.local_rank == 0: - if self.report_metric_at_train: - for metric in self.metrics: - metric.update(preds, truths) - - if self.use_tensorboardX: - self.writer.add_scalar("train/loss", loss_val, self.global_step) - self.writer.add_scalar("train/lr", self.optimizer.param_groups[0]['lr'], self.global_step) - - if self.scheduler_update_every_step: - pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}") - else: - pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") - pbar.update(1) - - # tensoRF upsampling - if self.global_step in self.conf['update_AlphaMask_list']: - - self.log(f"[INFO] update alphamask at step {self.global_step}") - - if self.reso_cur[0] * self.reso_cur[1] * self.reso_cur[2] < 256**3:# update volume resolution - reso_mask = self.reso_cur - - new_aabb = self.model.updateAlphaMask(tuple(reso_mask)) - - # if self.global_step == self.conf['update_AlphaMask_list'][0]: - # self.model.shrink(new_aabb) # will update self.model.aabb - # # tensorVM.alphaMask = None - # self.L1_reg_weight = self.conf['L1_weight_rest'] - # #self.log(f"[INFO] new aabb {self.model.aabb}") - - # if not self.conf.ndc_ray and self.global_step == self.conf['update_AlphaMask_list'][1]: - # # filter rays outside the bbox - # allrays,allrgbs = self.model.filtering_rays(allrays,allrgbs) - # trainingSampler = SimpleSampler(allrgbs.shape[0], self.conf.batch_size) - - - if self.global_step in self.conf['upsamp_list']: - - self.log(f"[INFO] upsample at step {self.global_step}") - - n_voxels = self.N_voxel_list.pop(0) - self.reso_cur = N_to_reso(n_voxels, self.model.aabb) - self.nSamples = min(self.conf['nSamples'], cal_n_samples(self.reso_cur, self.conf['step_ratio'])) - self.model.upsample_volume_grid(self.reso_cur) - - #self.log(f"[INFO] reso {self.reso_cur}, nsamples {self.nSamples}") - - # if self.conf['lr_upsample_reset']: - # print("reset lr to initial") - # lr_scale = 1 #0.1 ** (self.global_step / self.conf.n_iters) - # else: - # lr_scale = self.conf.lr_decay_target_ratio ** (iteration / self.conf.n_iters) - - # re-init optimizer since model params are changed! - # also reset LR. - self.optimizer = self.optimizer_fn(self.model) - self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) - - if self.ema is not None: - self.ema.update() - - average_loss = total_loss / self.local_step - self.stats["loss"].append(average_loss) - - if self.local_rank == 0: - pbar.close() - if self.report_metric_at_train: - for metric in self.metrics: - self.log(metric.report(), style="red") - if self.use_tensorboardX: - metric.write(self.writer, self.epoch, prefix="train") - metric.clear() - - if not self.scheduler_update_every_step: - if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): - self.lr_scheduler.step(average_loss) - else: - self.lr_scheduler.step() - - self.log(f"==> Finished Epoch {self.epoch}.") - - - def evaluate_one_epoch(self, loader): - self.log(f"++> Evaluate at epoch {self.epoch} ...") - - total_loss = 0 - if self.local_rank == 0: - for metric in self.metrics: - metric.clear() - - self.model.eval() - - if self.ema is not None: - self.ema.store() - self.ema.copy_to() - - if self.local_rank == 0: - pbar = tqdm.tqdm(total=len(loader), bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') - - with torch.no_grad(): - self.local_step = 0 - - for data in loader: - self.local_step += 1 - - data = self.prepare_data(data) - - with torch.cuda.amp.autocast(enabled=self.fp16): - preds, preds_depth, truths, loss = self.eval_step(data) - - - # all_gather/reduce the statistics (NCCL only support all_*) - if self.world_size > 1: - dist.all_reduce(loss, op=dist.ReduceOp.SUM) - loss = loss / self.world_size - - preds_list = [torch.zeros_like(preds).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] - dist.all_gather(preds_list, preds) - preds = torch.cat(preds_list, dim=0) - - preds_depth_list = [torch.zeros_like(preds_depth).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] - dist.all_gather(preds_depth_list, preds_depth) - preds_depth = torch.cat(preds_depth_list, dim=0) - - truths_list = [torch.zeros_like(truths).to(self.device) for _ in range(self.world_size)] # [[B, ...], [B, ...], ...] - dist.all_gather(truths_list, truths) - truths = torch.cat(truths_list, dim=0) - - loss_val = loss.item() - total_loss += loss_val - - # only rank = 0 will perform evaluation. - if self.local_rank == 0: - - for metric in self.metrics: - metric.update(preds, truths) - - # save image - save_path = os.path.join(self.workspace, 'validation', f'{self.name}_{self.epoch:04d}_{self.local_step:04d}.png') - save_path_depth = os.path.join(self.workspace, 'validation', f'{self.name}_{self.epoch:04d}_{self.local_step:04d}_depth.png') - #save_path_gt = os.path.join(self.workspace, 'validation', f'{self.name}_{self.epoch:04d}_{self.local_step:04d}_gt.png') - - #self.log(f"==> Saving validation image to {save_path}") - os.makedirs(os.path.dirname(save_path), exist_ok=True) - cv2.imwrite(save_path, cv2.cvtColor((preds[0].detach().cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) - cv2.imwrite(save_path_depth, (preds_depth[0].detach().cpu().numpy() * 255).astype(np.uint8)) - #cv2.imwrite(save_path_gt, cv2.cvtColor((truths[0].detach().cpu().numpy() * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) - - pbar.set_description(f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})") - pbar.update(1) - - - average_loss = total_loss / self.local_step - self.stats["valid_loss"].append(average_loss) - - if self.local_rank == 0: - pbar.close() - if not self.use_loss_as_metric and len(self.metrics) > 0: - result = self.metrics[0].measure() - self.stats["results"].append(result if self.best_mode == 'min' else - result) # if max mode, use -result - else: - self.stats["results"].append(average_loss) # if no metric, choose best by min loss - - for metric in self.metrics: - self.log(metric.report(), style="blue") - if self.use_tensorboardX: - metric.write(self.writer, self.epoch, prefix="evaluate") - metric.clear() - - if self.ema is not None: - self.ema.restore() - - self.log(f"++> Evaluate epoch {self.epoch} Finished.") - - def save_checkpoint(self, full=False, best=False): - - state = { - 'epoch': self.epoch, - 'stats': self.stats, - } - - if full: - state['optimizer'] = self.optimizer.state_dict() - state['lr_scheduler'] = self.lr_scheduler.state_dict() - state['scaler'] = self.scaler.state_dict() - if self.ema is not None: - state['ema'] = self.ema.state_dict() - - if not best: - - state['model'] = self.model.get_state_dict() - - file_path = f"{self.ckpt_path}/{self.name}_ep{self.epoch:04d}.pth.tar" - - self.stats["checkpoints"].append(file_path) - - if len(self.stats["checkpoints"]) > self.max_keep_ckpt: - old_ckpt = self.stats["checkpoints"].pop(0) - if os.path.exists(old_ckpt): - os.remove(old_ckpt) - - torch.save(state, file_path) - - else: - if len(self.stats["results"]) > 0: - if self.stats["best_result"] is None or self.stats["results"][-1] < self.stats["best_result"]: - self.log(f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}") - self.stats["best_result"] = self.stats["results"][-1] - - # save ema results - if self.ema is not None: - self.ema.store() - self.ema.copy_to() - - state['model'] = self.model.get_state_dict() - - if self.ema is not None: - self.ema.restore() - - torch.save(state, self.best_path) - else: - self.log(f"[WARN] no evaluated results found, skip saving best checkpoint.") - - def load_checkpoint(self, checkpoint=None): - - if checkpoint is None: - checkpoint_list = sorted(glob.glob(f'{self.ckpt_path}/{self.name}_ep*.pth.tar')) - if checkpoint_list: - checkpoint = checkpoint_list[-1] - self.log(f"[INFO] Latest checkpoint is {checkpoint}") - else: - self.log("[WARN] No checkpoint found, model randomly initialized.") - return - - checkpoint_dict = torch.load(checkpoint, map_location=self.device) - - if 'model' not in checkpoint_dict: - - # need to re-create the model !!! and re-create optimizer & scheduler... - kwargs = checkpoint_dict['kwargs'] - kwargs.update({'device': self.device}) - self.model = self.model.__class__(**kwargs) - self.optimizer = self.optimizer_fn(self.model) - self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) - - self.model.load(checkpoint_dict) - self.log("[INFO] loaded model.") - return - - kwargs = checkpoint_dict['model']['kwargs'] - kwargs.update({'device': self.device}) - self.model = self.model.__class__(**kwargs) - self.optimizer = self.optimizer_fn(self.model) - self.lr_scheduler = self.lr_scheduler_fn(self.optimizer) - - self.model.load(checkpoint_dict['model']) - self.log("[INFO] loaded model.") - - if self.ema is not None and 'ema' in checkpoint_dict: - self.ema.load_state_dict(checkpoint_dict['ema']) - - self.stats = checkpoint_dict['stats'] - self.epoch = checkpoint_dict['epoch'] - - if self.optimizer and 'optimizer' in checkpoint_dict: - try: - self.optimizer.load_state_dict(checkpoint_dict['optimizer']) - self.log("[INFO] loaded optimizer.") - except: - self.log("[WARN] Failed to load optimizer, use default.") - - # strange bug: keyerror 'lr_lambdas' - if self.lr_scheduler and 'lr_scheduler' in checkpoint_dict: - try: - self.lr_scheduler.load_state_dict(checkpoint_dict['lr_scheduler']) - self.log("[INFO] loaded scheduler.") - except: - self.log("[WARN] Failed to load scheduler, use default.") - - if 'scaler' in checkpoint_dict: - try: - self.scaler.load_state_dict(checkpoint_dict['scaler']) - self.log("[INFO] loaded scaler.") - except: - self.log("[WARN] Failed to load scaler, use default.") \ No newline at end of file diff --git a/testing/test_ffmlp.py b/testing/test_ffmlp.py index 5f4b0429..8b8add23 100644 --- a/testing/test_ffmlp.py +++ b/testing/test_ffmlp.py @@ -118,11 +118,11 @@ def forward(self, x): x2 = x.detach().clone() x3 = x.detach().clone() -starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) #with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU,torch.profiler.ProfilerActivity.CUDA,]) as p: +starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True) starter.record() y2 = net1(x2) ender.record(); torch.cuda.synchronize(); curr_time = starter.elapsed_time(ender); print(f'pytorch MLP (fp32 train) = {curr_time}')