Skip to content

Commit

Permalink
remove some dead code, text mask is not needed when training DALL-E
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 4, 2022
1 parent 07f6f3f commit 4511d2d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 13 deletions.
9 changes: 3 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,13 @@ dalle = DALLE(

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = dalle(text, images, mask = mask, return_loss = True)
loss = dalle(text, images, return_loss = True)
loss.backward()

# do the above for a long time with a lot of data ... then

images = dalle.generate_images(text, mask = mask)
images = dalle.generate_images(text)
images.shape # (4, 3, 256, 256)
```

Expand All @@ -141,7 +140,6 @@ img_prime = torch.randn(4, 3, 256, 256)

images = dalle.generate_images(
text,
mask = mask,
img = img_prime,
num_init_img_tokens = (14 * 32) # you can set the size of the initial crop, defaults to a little less than ~1/2 of the tokens, as done in the paper
)
Expand Down Expand Up @@ -179,9 +177,8 @@ dalle = DALLE(

text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()

loss = dalle(text, images, mask = mask, return_loss = True)
loss = dalle(text, images, return_loss = True)
loss.backward()
```

Expand Down
7 changes: 1 addition & 6 deletions dalle_pytorch/dalle_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,6 @@ def generate_images(
text,
*,
clip = None,
mask = None,
filter_thres = 0.5,
temperature = 1.,
img = None,
Expand Down Expand Up @@ -494,7 +493,7 @@ def generate_images(

text, image = out[:, :text_seq_len], out[:, text_seq_len:]

logits = self(text, image, mask = mask)
logits = self(text, image)
logits = logits[:, -1, :]

filtered_logits = top_k(logits, thres = filter_thres)
Expand All @@ -503,9 +502,6 @@ def generate_images(
sample -= (num_text_tokens if is_image else 0) # offset sampled token if it is an image token, since logit space is composed of text and then image tokens
out = torch.cat((out, sample[:, None]), dim=-1)

if out.shape[1] <= text_seq_len:
mask = F.pad(mask, (0, 1), value = True)

text_seq = out[:, :text_seq_len]

img_seq = out[:, -image_seq_len:]
Expand All @@ -521,7 +517,6 @@ def forward(
self,
text,
image = None,
mask = None,
return_loss = False
):
assert text.shape[-1] == self.text_seq_len, f'the length {text.shape[-1]} of the text tokens you passed in does not have the correct length ({self.text_seq_len})'
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
name = 'dalle-pytorch',
packages = find_packages(),
include_package_data = True,
version = '1.1.8',
version = '1.2.0',
license='MIT',
description = 'DALL-E - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 4511d2d

Please sign in to comment.