Skip to content

Commit

Permalink
fix a dtype conversion issue for the diffusion timesteps in the diffu…
Browse files Browse the repository at this point in the history
…sion prior, thanks to @JiaHeng-DLUT
  • Loading branch information
lucidrains committed Oct 19, 2022
1 parent 5975e82 commit 41fabf2
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
10 changes: 10 additions & 0 deletions dalle2_pytorch/dalle2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ def inner(model, *args, **kwargs):
return out
return inner

def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])

def is_list_str(x):
if not isinstance(x, (list, tuple)):
return False
Expand Down Expand Up @@ -968,6 +971,8 @@ def __init__(
Rearrange('b (n d) -> b n d', n = num_text_embeds)
)

self.continuous_embedded_time = not exists(num_timesteps)

self.to_time_embeds = nn.Sequential(
nn.Embedding(num_timesteps, dim * num_time_embeds) if exists(num_timesteps) else nn.Sequential(SinusoidalPosEmb(dim), MLP(dim, dim * num_time_embeds)), # also offer a continuous version of timestep embeddings, with a 2 layer MLP
Rearrange('b (n d) -> b n d', n = num_time_embeds)
Expand Down Expand Up @@ -1095,6 +1100,9 @@ def forward(
# whether text embedding is used for conditioning depends on whether text encodings are available for attention (for classifier free guidance, even though it seems from the paper it was not used in the prior ddpm, as the objective is different)
# but let's just do it right

if self.continuous_embedded_time:
diffusion_timesteps = diffusion_timesteps.type(dtype)

time_embed = self.to_time_embeds(diffusion_timesteps)

learned_queries = repeat(self.learned_query, 'd -> b 1 d', b = batch)
Expand Down Expand Up @@ -1538,6 +1546,8 @@ def __init__(self, dim):

def forward(self, x):
dtype, device = x.dtype, x.device
assert is_float_dtype(dtype), 'input to sinusoidal pos emb must be a float type'

half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device = device, dtype = dtype) * -emb)
Expand Down
2 changes: 1 addition & 1 deletion dalle2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.10.7'
__version__ = '1.10.8'

0 comments on commit 41fabf2

Please sign in to comment.