Skip to content

Commit

Permalink
added nll metric
Browse files Browse the repository at this point in the history
  • Loading branch information
niki-amini-naieni committed Nov 5, 2023
1 parent 9e08dc4 commit 7b571ae
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 2 deletions.
2 changes: 1 addition & 1 deletion eval_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
# decay to 0.1 * init_lr at last iter step
scheduler = lambda optimizer: optim.lr_scheduler.LambdaLR(optimizer, lambda iter: 0.1 ** min(iter / opt.iters, 1))

metrics = [PSNRMeter(), LPIPSMeter(device=device), SSIMMeter(device=device)]
metrics = [PSNRMeter(), LPIPSMeter(device=device), SSIMMeter(device=device), NLLMeter()]

ensembles = []
for model_ind in range(opt.M):
Expand Down
51 changes: 50 additions & 1 deletion nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from packaging import version as pver
import lpips
from torchmetrics.functional import structural_similarity_index_measure
from scipy.stats import multivariate_normal

def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
Expand Down Expand Up @@ -312,6 +313,50 @@ def write(self, writer, global_step, prefix=""):

def report(self):
return f'LPIPS ({self.net}) = {self.measure():.6f}'


class NLLMeter:
def __init__(self):
self.V = 0
self.N = 0

def nll(truths, preds, vars):
truths = truths.flatten(end_dim=-2)
preds = preds.flatten(end_dim=-2)
vars = vars.flatten(end_dim=-2)
covs = []
for px in vars:
covs.append(torch.diag(px))
covs = torch.tensor(covs)
return torch.log(multivariate_normal(truths, preds, covs)).mean()

self.fn = nll

def clear(self):
self.V = 0
self.N = 0

def prepare_inputs(self, *inputs):
outputs = []
for i, inp in enumerate(inputs):
inp = inp.permute(0, 3, 1, 2).contiguous() # [B, 3, H, W]
inp = inp.to(self.device)
outputs.append(inp)
return outputs

def update(self, preds, truths, vars):
v = self.fn(truths, preds, vars).item() # normalize=True: [0, 1] to [-1, 1]
self.V += v
self.N += 1

def measure(self):
return self.V / self.N

def write(self, writer, global_step, prefix=""):
writer.add_scalar(os.path.join(prefix, f"NLL"), self.measure(), global_step)

def report(self):
return f'NLL = {self.measure():.6f}'


def get_ensemble_metrics(ensemble, loader):
Expand All @@ -331,9 +376,13 @@ def get_ensemble_metrics(ensemble, loader):
preds, _, truths, _ = model.eval_step(data)
preds_ensemble.append(preds.cpu().numpy())
preds = torch.from_numpy(np.array(preds_ensemble).sum(axis=0) / M)
vars = torch.from_numpy(np.array(preds_ensemble).var(axis=0))
# Use the first ensemble member to save results.
for metric in ensemble[0].metrics:
metric.update(preds, truths)
if not isinstance(metric, NLLMeter):
metric.update(preds, truths)
else:
metric.update(preds, truths, vars)

print("Ensemble Metrics:")
for metric in ensemble[0].metrics:
Expand Down

0 comments on commit 7b571ae

Please sign in to comment.