From 34806663e38cc84877d8eb0840d36407b95d0781 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sat, 13 Aug 2022 10:03:40 -0700 Subject: [PATCH] make it so diffusion prior p_sample_loop returns unnormalized image embeddings --- dalle2_pytorch/dalle2_pytorch.py | 9 +++++---- dalle2_pytorch/version.py | 2 +- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/dalle2_pytorch/dalle2_pytorch.py b/dalle2_pytorch/dalle2_pytorch.py index 94f129f1..f7c3b1c8 100644 --- a/dalle2_pytorch/dalle2_pytorch.py +++ b/dalle2_pytorch/dalle2_pytorch.py @@ -1279,9 +1279,12 @@ def p_sample_loop(self, *args, timesteps = None, **kwargs): is_ddim = timesteps < self.noise_scheduler.num_timesteps if not is_ddim: - return self.p_sample_loop_ddpm(*args, **kwargs) + normalized_image_embed = self.p_sample_loop_ddpm(*args, **kwargs) + else: + normalized_image_embed = self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps) - return self.p_sample_loop_ddim(*args, **kwargs, timesteps = timesteps) + image_embed = normalized_image_embed / self.image_embed_scale + return image_embed def p_losses(self, image_embed, times, text_cond, noise = None): noise = default(noise, lambda: torch.randn_like(image_embed)) @@ -1350,8 +1353,6 @@ def sample( # retrieve original unscaled image embed - image_embeds /= self.image_embed_scale - text_embeds = text_cond['text_embed'] text_embeds = rearrange(text_embeds, '(b r) d -> b r d', r = num_samples_per_batch) diff --git a/dalle2_pytorch/version.py b/dalle2_pytorch/version.py index d07785c5..f3df7f04 100644 --- a/dalle2_pytorch/version.py +++ b/dalle2_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.6.4' +__version__ = '1.6.5'