Skip to content

Commit

Permalink
fix #20
Browse files Browse the repository at this point in the history
  • Loading branch information
tqch committed Aug 7, 2024
1 parent 52a56bb commit b60eb8d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions ddpm_torch/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,16 @@ def _prior_bpd(self, x_0):
def calc_all_bpd(self, denoise_fn, x_0, clip_denoised=True):
B, T = x_0.shape, self.timesteps
t = torch.empty([B, ], dtype=torch.int64)
t.fill_(T - 1)
losses = torch.zeros([B, T], dtype=torch.float32)
mses = torch.zeros([B, T], dtype=torch.float32)

for i in range(T - 1, -1, -1):
for ti in range(T - 1, -1, -1):
t.fill_(ti)
x_t = self.q_sample(x_0, t=t)
loss, pred_x_0 = self._loss_term_bpd(
denoise_fn, x_0, x_t=x_t, t=t, clip_denoised=clip_denoised, return_pred=True)
losses[:, i] = loss
mses[:, i] = flat_mean((pred_x_0 - x_0).pow(2))
losses[:, ti] = loss
mses[:, ti] = flat_mean((pred_x_0 - x_0).pow(2))

prior_bpd = self._prior_bpd(x_0)
total_bpd = torch.sum(losses, dim=1) + prior_bpd
Expand Down

0 comments on commit b60eb8d

Please sign in to comment.