Skip to content

Commit

Permalink
Merge pull request ashawkey#143 from DomaradzkiMaciej/main
Browse files Browse the repository at this point in the history
Adding new metric (SSIM)
  • Loading branch information
ashawkey authored Feb 15, 2023
2 parents 35a8ae5 + eb719b0 commit ccc7030
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
39 changes: 39 additions & 0 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

from packaging import version as pver
import lpips
from torchmetrics.functional import structural_similarity_index_measure

def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
Expand Down Expand Up @@ -238,6 +239,44 @@ def write(self, writer, global_step, prefix=""):
def report(self):
return f'PSNR = {self.measure():.6f}'


class SSIMMeter:
def __init__(self, device=None):
self.V = 0
self.N = 0

self.device = device if device is not None else torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def clear(self):
self.V = 0
self.N = 0

def prepare_inputs(self, *inputs):
outputs = []
for i, inp in enumerate(inputs):
inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]
inp = inp.to(self.device)
outputs.append(inp)
return outputs

def update(self, preds, truths):
preds, truths = self.prepare_inputs(preds, truths) # [B, H, W, 3] --> [B, 3, H, W], range in [0, 1]

ssim = structural_similarity_index_measure(preds, truths)

self.V += ssim
self.N += 1

def measure(self):
return self.V / self.N

def write(self, writer, global_step, prefix=""):
writer.add_scalar(os.path.join(prefix, "SSIM"), self.measure(), global_step)

def report(self):
return f'SSIM = {self.measure():.6f}'


class LPIPSMeter:
def __init__(self, net='alex', device=None):
self.V = 0
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@ dearpygui
packaging
scipy
lpips
imageio
imageio
torchmetrics

0 comments on commit ccc7030

Please sign in to comment.