Skip to content

Commit

Permalink
separated multichannel ll calculation into better conditioned sum and…
Browse files Browse the repository at this point in the history
… reduced epsilon
  • Loading branch information
niki-amini-naieni committed Nov 5, 2023
1 parent c78d85b commit 0534cfd
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from torchmetrics.functional import structural_similarity_index_measure
from scipy.stats import multivariate_normal

EPS = 1e-8

def custom_meshgrid(*args):
# ref: https://pytorch.org/docs/stable/generated/torch.meshgrid.html?highlight=meshgrid#torch.meshgrid
if pver.parse(torch.__version__) < pver.parse('1.10'):
Expand Down Expand Up @@ -320,7 +322,6 @@ def __init__(self):
self.N = 0

def nll(truths, preds, vars, epistems):
eps = 1e-12
truths = truths.flatten(end_dim=-2).cpu().numpy()
preds = preds.flatten(end_dim=-2).cpu().numpy()
vars = vars.flatten(end_dim=-2).cpu().numpy()
Expand All @@ -331,8 +332,8 @@ def nll(truths, preds, vars, epistems):
mu = preds[px_ind]
var = np.mean(vars[px_ind], axis=-1)
epistem = epistems[px_ind]
var = var + epistem + eps
log_pdf = np.log((np.exp(-0.5 * (gt - mu) ** 2 / var) / np.sqrt(var * 2.0 * np.pi)).prod() + eps)
var = var + epistem + EPS
log_pdf = np.log((np.exp(-0.5 * (gt[0] - mu[0]) ** 2 / var) / np.sqrt(var * 2.0 * np.pi)) + EPS) + np.log((np.exp(-0.5 * (gt[1] - mu[1]) ** 2 / var) / np.sqrt(var * 2.0 * np.pi)) + EPS) + np.log((np.exp(-0.5 * (gt[2] - mu[2]) ** 2 / var) / np.sqrt(var * 2.0 * np.pi)) + EPS)
log_pdf_vals.append(-log_pdf)
return np.mean(log_pdf_vals)

Expand Down Expand Up @@ -370,16 +371,15 @@ def __init__(self):
self.N = 0

def nll(truths, preds, vars):
eps = 1e-5
truths = truths.flatten(end_dim=-2).cpu().numpy()
preds = preds.flatten(end_dim=-2).cpu().numpy()
vars = vars.flatten(end_dim=-2).cpu().numpy()
log_pdf_vals = []
for px_ind in range(vars.shape[0]):
gt = truths[px_ind]
mu = preds[px_ind]
var = vars[px_ind] + eps
log_pdf = np.log((np.exp(-0.5 * (gt - mu) ** 2 / var) / np.sqrt(var * 2.0 * np.pi)).prod() + eps)
var = vars[px_ind] + EPS
log_pdf = np.log((np.exp(-0.5 * (gt[0] - mu[0]) ** 2 / var) / np.sqrt(var * 2.0 * np.pi)) + EPS) + np.log((np.exp(-0.5 * (gt[1] - mu[1]) ** 2 / var) / np.sqrt(var * 2.0 * np.pi)) + EPS) + np.log((np.exp(-0.5 * (gt[2] - mu[2]) ** 2 / var) / np.sqrt(var * 2.0 * np.pi)) + EPS)
log_pdf_vals.append(-log_pdf)
return np.mean(log_pdf_vals)

Expand Down

0 comments on commit 0534cfd

Please sign in to comment.