Skip to content

Commit

Permalink
adding new metric (SSIM)
Browse files Browse the repository at this point in the history
  • Loading branch information
DomaradzkiMaciej committed Feb 14, 2023
1 parent 0793857 commit 1079526
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 14 deletions.
4 changes: 3 additions & 1 deletion main_tensoRF.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@

if opt.test:

trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt)
metrics = [PSNRMeter(), SSIMMeter(), LPIPSMeter(device=device)]
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, criterion=criterion, fp16=opt.fp16, metrics=metrics, use_checkpoint=opt.ckpt)

if opt.gui:
gui = NeRFGUI(opt, trainer)
Expand All @@ -126,6 +127,7 @@
# decay to 0.1 * init_lr at last iter step
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))

metrics = [PSNRMeter(), SSIMMeter(), LPIPSMeter(device=device)]
trainer = Trainer('ngp', opt, model, device=device, workspace=opt.workspace, optimizer=optimizer, criterion=criterion, ema_decay=None, fp16=opt.fp16, lr_scheduler=scheduler, scheduler_update_every_step=True, metrics=[PSNRMeter()], use_checkpoint=opt.ckpt, eval_interval=50)

# calc upsample target resolutions
Expand Down
18 changes: 7 additions & 11 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,7 @@

from packaging import version as pver
import lpips

import skimage
import skimage.metrics
from torchmetrics.functional import structural_similarity_index_measure

def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
Expand Down Expand Up @@ -254,18 +252,16 @@ def clear(self):
def prepare_inputs(self, *inputs):
outputs = []
for i, inp in enumerate(inputs):
if torch.is_tensor(inp):
inp = inp.detach().cpu().numpy()
inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]
inp = inp.to(self.device)
outputs.append(inp)

return outputs

def update(self, preds, truths):
preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3], range[0, 1]

pred, truth = preds[0], truths[0] # B=1
ssim = skimage.metrics.structural_similarity(pred, truth, data_range=1)

preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1]

ssim = structural_similarity_index_measure(preds, truths)

self.V += ssim
self.N += 1

Expand Down
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,4 @@ packaging
scipy
lpips
imageio
scipy
scikit-image
torchmetrics

0 comments on commit 1079526

Please sign in to comment.