Skip to content

Commit

Permalink
make it so diffusion prior p_sample_loop returns unnormalized image e…
Browse files Browse the repository at this point in the history
…mbeddings
  • Loading branch information
lucidrains committed Aug 13, 2022
1 parent dc816b1 commit 3480666
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
9 changes: 5 additions & 4 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,9 +1279,12 @@ def p_sample_loop(self, *args, timesteps = None, **kwargs):
is_ddim = timesteps < self.noise_scheduler.num_timesteps

if not is_ddim:
return self.p_sample_loop_ddpm(*args, **kwargs)
normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs)
else:
normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)

return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps)
image_embed = normalized_image_embed / self.image_embed_scale
return image_embed

def p_losses(self, image_embed, times, text_cond, noise = None):
noise = default(noise, lambda: torch.randn_like(image_embed))
Expand Down Expand Up @@ -1350,8 +1353,6 @@ def sample(

# retrieve original unscaled image embed

image_embeds /= self.image_embed_scale

text_embeds = text_cond['text_embed']

text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch)
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.6.4'
__version__ = '1.6.5'

0 comments on commit 3480666

Please sign in to comment.