Skip to content

Commit

Permalink
explicitly compute nll instead of using scipy
Browse files Browse the repository at this point in the history
  • Loading branch information
niki-amini-naieni committed Nov 5, 2023
1 parent fc2e9f9 commit 347d12a
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions nerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,11 +326,11 @@ def nll(truths, preds, vars):
vars = vars.flatten(end_dim=-2).cpu().numpy()
log_pdf_vals = []
for px_ind in range(vars.shape[0]):
log_pdf_vals.append(
-np.log(
multivariate_normal.pdf(truths[px_ind], mean=preds[px_ind], cov=np.diag(vars[px_ind]), allow_singular=False)
) + 1e-12
)
gt = truths[px_ind]
mu = preds[px_ind]
var = vars[px_ind]
log_pdf = np.log(np.exp(-0.5 * (gt - mu) ** 2 / var) / np.sqrt(var * 2.0 * np.pi) + 1e-12).sum()
log_pdf_vals.append(log_pdf)
return np.mean(log_pdf_vals)

self.fn = nll
Expand Down

0 comments on commit 347d12a

Please sign in to comment.