Skip to content

Commit

Permalink
added weights sum
Browse files Browse the repository at this point in the history
  • Loading branch information
niki-amini-naieni committed Nov 5, 2023
1 parent ac1644a commit b60168e
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 1 deletion.
1 change: 1 addition & 0 deletions nerf/renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,7 @@ def run_cuda(self, rays_o, rays_d, dt_gamma=0, bg_color=None, perturb=False, for
depth = torch.clamp(depth - nears, min=0) / (fars - nears)
image = image.view(*prefix, 3)
depth = depth.view(*prefix)
results['weights_sum'] = weights_sum

results['depth'] = depth
results['image'] = image
Expand Down
30 changes: 29 additions & 1 deletion nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,8 @@ def get_ensemble_metrics(ensemble, loader):
for model in ensemble:
with torch.no_grad():
with torch.cuda.amp.autocast(enabled=model.fp16):
preds, _, truths, _ = model.eval_step(data)
preds, _, truths, _, weights_sum = model.eval_step_ensemble(data)
print(weights_sum)
preds_ensemble.append(preds.cpu().numpy())
preds = torch.from_numpy(np.array(preds_ensemble).sum(axis=0) / M)
vars = torch.from_numpy(np.array(preds_ensemble).var(axis=0))
Expand Down Expand Up @@ -668,6 +669,33 @@ def eval_step(self, data):
loss = self.criterion(pred_rgb, gt_rgb).mean()

return pred_rgb, pred_depth, gt_rgb, loss

def eval_step_ensemble(self, data):

rays_o = data['rays_o'] # [B, N, 3]
rays_d = data['rays_d'] # [B, N, 3]
images = data['images'] # [B, H, W, 3/4]
B, H, W, C = images.shape

if self.opt.color_space == 'linear':
images[..., :3] = srgb_to_linear(images[..., :3])

# eval with fixed background color
bg_color = 1
if C == 4:
gt_rgb = images[..., :3] * images[..., 3:] + bg_color * (1 - images[..., 3:])
else:
gt_rgb = images

outputs = self.model.render(rays_o, rays_d, staged=True, bg_color=bg_color, perturb=False, **vars(self.opt))

pred_rgb = outputs['image'].reshape(B, H, W, 3)
pred_depth = outputs['depth'].reshape(B, H, W)
weights_sum = outputs['weights_sum'].reshape(B, H, W)

loss = self.criterion(pred_rgb, gt_rgb).mean()

return pred_rgb, pred_depth, gt_rgb, loss, weights_sum

# moved out bg_color and perturb for more flexible control...
def test_step(self, data, bg_color=None, perturb=False):
Expand Down

0 comments on commit b60168e

Please sign in to comment.