Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Results #10

Open
NaxAlpha opened this issue Jan 7, 2021 · 57 comments
Open

Results #10

NaxAlpha opened this issue Jan 7, 2021 · 57 comments

Comments

@NaxAlpha
Copy link
Contributor

NaxAlpha commented Jan 7, 2021

I have trained DiscreteVEE on 128x128 FFHQ dataset. using this configration:

vae = DiscreteVAE(
    num_layers = 2,
    num_tokens = 4096,
    dim = 1024,
    hidden_dim = 256
)

Here are the results after 3 epochs (top original, bottom reconstructed):

image
image
image

@NaxAlpha
Copy link
Contributor Author

NaxAlpha commented Jan 8, 2021

Even smaller model works pretty neat:

vae = DiscreteVAE(
    num_layers = 3,
    num_tokens = 4096,
    codebook_dim = 512,
    hidden_dim = 256,
)

Here are the samples:

image
image
image

Here is the (expected) loss after ~3 epochs:
image

@mrconter1
Copy link

Are you inputting descriptions for images or just let it randomly generate an image?

@adrian-spataru
Copy link

Are you inputting descriptions for images or just let it randomly generate an image?

Not OP, but this is just the VQVAE and only images from reconstruction not sampling. So input image top and the image bottom is the output of the VAE.
The VQVAE is used for the codebook construction which will be then used by the transformer to generate image by a description

@NaxAlpha
Copy link
Contributor Author

NaxAlpha commented Jan 8, 2021

Yes these results are for VAE - took only ~30 min to an 1hr on colab pro (V100) - I am in process of training DALLE - results should be ready soon!!!

@mrconter1
Copy link

Would you mind sharing the Colab you have so far? :)

@NaxAlpha
Copy link
Contributor Author

NaxAlpha commented Jan 8, 2021

Sure! Here is the notebook so far.
Also there is an update after discussion on #12 and applying the fix here are the results I got which is not so promising as previous but still it will actually work now (hopefully).

from dalle_pytorch import DiscreteVAE

NUM_LAYERS = 2
IMAGE_SIZE = 128

BATCH_SIZE = 32
NUM_TOKENS = 8192

EMB_DIM = 256
HID_DIM = 128

vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim =  EMB_DIM,
    hidden_dim =    HID_DIM,
)

(Top is ground truth, middle one is soft decoded [via gumbel_softmax], bottom is hard decoded [via argmax] which only noise previously because of the bug)

image
image
image

@mrconter1
Copy link

mrconter1 commented Jan 8, 2021

Thanks! I'm a noob but I tried to help:
https://colab.research.google.com/drive/1KxG1iGBoKt2fLVH7uXG_vhvll2OlFkey?usp=sharing
:)

@mrconter1
Copy link

mrconter1 commented Jan 8, 2021

Okay. Here is a fully working Colab for at least VAE training. Thanks to NaxAlpha of course!

https://colab.research.google.com/drive/1KxG1iGBoKt2fLVH7uXG_vhvll2OlFkey?usp=sharing

image
After around 600 training pairs.

@NaxAlpha
Copy link
Contributor Author

NaxAlpha commented Jan 8, 2021

Here are results after a few hours of training of DALL-E:

image
image
image

Loss is still very high right now but its going down slowly

image

@mrconter1
Copy link

Which dataset are you using to train DALL-E? Don't you need text as well? Also, what are you training on? Do you have access to Google Colab Pro?

@lucidrains
Copy link
Owner

@NaxAlpha nice! I just realized, without text, this essentially becomes iGPT! (If that is what you are doing)

@VIVelev
Copy link

VIVelev commented Jan 8, 2021

@lucidrains Isn't iGPT on pixel level or close to pixel level (a.k.a. the 9-bit color palette), whereas DALL-E operates on codebook vectors level? In a sense, DALL-E works at the right level of abstraction (pixels and local features are too fine, and entire scenes are too coarse).

@lucidrains
Copy link
Owner

lucidrains commented Jan 8, 2021

@VIVelev Yup, you are correct! iGPT is pixel level, but clustered into 512 (9-bit) discrete tokens. Equivalent to a 0 layer discrete VAE with a codebook of 512

@lucidrains
Copy link
Owner

@NaxAlpha I just added a temperature flag on the DiscreteVAE class so you can control the hardness of gumbel during training! just fyi!

@NaxAlpha
Copy link
Contributor Author

NaxAlpha commented Jan 9, 2021

Awesome! Yeah I am training it unconditionally - (just 1 text token which is random xD). Here are the results after 9 more hours:

image
image
image

I feel like it is going slower than my expectation. (Might need to scale up the transformer) Here is the DALLE configuration I am using:

from dalle_pytorch import DALLE, DiscreteVAE

NUM_LAYERS = 3
IMAGE_SIZE = 128

BATCH_SIZE = 16
NUM_TOKENS = 8192

EMB_DIM = 256
HID_DIM = 128

vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim =  EMB_DIM,
    hidden_dim =    HID_DIM,
)

dalle = DALLE(
    dim = EMB_DIM,
    vae = vae,
    num_text_tokens = 1024,     # 1024 fixed latents (model should learn to ignore it)
    text_seq_len = 1,           # Acts like a latent variable
    depth = 16,
    heads = 24,
)

@lucidrains
Copy link
Owner

@NaxAlpha haha yea, they used 64 layers! perhaps this could be tried on something small scale, like cifar sized

@mrconter1
Copy link

Would it possible that using more coherent text (instead of random) also would result in more coherent images?

@NaxAlpha
Copy link
Contributor Author

NaxAlpha commented Jan 9, 2021

@lucidrains wow! temperature feature is awesome! Gradually decreasing it from 5 to 0.05 over 5 epochs and convergence is really fast as well as results look much better!!!

@mrconter1 Yes using coherent text should help but since I do not have any text for now so I am using just 1 token to make it work for now xD.

@mrconter1
Copy link

mrconter1 commented Jan 9, 2021

@NaxAlpha

I created an image + desc fetcher. You can see it here. Could it be useful?

@mrconter1
Copy link

mrconter1 commented Jan 9, 2021

I just benchmarked my scraper on Google Colab Pro. It takes around 3.46 hours/10 000 image+desc pairs. I will upload the data when I'm done.

@lucidrains
Copy link
Owner

@NaxAlpha Added reversible networks! https://github.com/lucidrains/DALLE-pytorch#scaling-depth Maybe depth will help!

@mrconter1
Copy link

mrconter1 commented Jan 9, 2021

Nevermind my scraper. Just use the COCO dataset. It has 500 000 images with descriptions for each one. Takes 10 minutes to download on Colab Pro. If anyone wants me to set up and Colab just tell me what format you want to have the data in.

@TheodoreGalanos
Copy link

hi everyone, thanks for all the amazing work and sharing results!

I have a really noobish question, hope it's okay. What do we think the scale of the image+text pairs needs to be to have something of use? I want to train it on my specific domain (architecture) and I'll probably need to create custom datasets. Any idea of what scale and above is worth to try? Also, concerning the codebook, does it need to be build on a similar dataset or variety is better?

Thanks in advance!

@NaxAlpha
Copy link
Contributor Author

@lucidrains Awesome, I have scaled the model - lets wait and see the results 😁.

The main problem right now is that VAE output is not really great. When temperature is high (>1) results look good but when temperature goes near 0.1, it becomes horrible - ideally we want temperature to be close to 0 because otherwise no matter how good the language model is decoded output would be rough.

Below are the outputs where top row is ground truth, middle row is output of VAE through gumbel softmax at different temperatures and last is output through following code:

codes = vae.get_codebook_indices(images[:k])
image = vae.decode(codes)

@ temperature = 2.9
image

@ temperature = 1.8
image

@ temperature = 0.6
image

@ temperature = 0.1
image

BTW Here is the config that I am using:

from dalle_pytorch import DiscreteVAE

NUM_LAYERS = 3
IMAGE_SIZE = 128

BATCH_SIZE = 8
NUM_TOKENS = 8192

EMB_DIM = 1024
HID_DIM = 256

vae = DiscreteVAE(
    image_size = IMAGE_SIZE,
    num_layers = NUM_LAYERS,
    num_tokens = NUM_TOKENS,
    codebook_dim =  EMB_DIM,
    hidden_dim =    HID_DIM,
)

@HenryHengZJ
Copy link

Nevermind my scraper. Just use the COCO dataset. It has 500 000 images with descriptions for each one. Takes 10 minutes to download on Colab Pro. If anyone wants me to set up and Colab just tell me what format you want to have the data in.

@mrconter1 how do you feed in the text descriptions with corresponding images as the input parameter of dalle training? Would u mind to share your colab?

@lucidrains
Copy link
Owner

lucidrains commented Jan 10, 2021

@NaxAlpha thanks for sharing your results! So I have an end to end version at a different branch in the repository that could be tried, perhaps with an annealing schedule

I'll also add resnet blocks to the VAE later today, per suggestion of Aran

Keep us posted!

Edit - will also reread https://arxiv.org/abs/2012.09841 for insights

@mrconter1
Copy link

mrconter1 commented Jan 10, 2021 via email

@htoyryla
Copy link

htoyryla commented Jan 10, 2021

I've just put together quick and dirty code to train dalle. Not directly usable for anyone, I am afraid. I am using a small dataset of 2000+ landscapes, for which I automatically generated captions into a text files. This script reads the image filenames and captions from a text file, builds a vocabulary and uses it to convert text tokens into numeric.

There is not even a proper pytorch dataset, just quick code iterate through the data. So far, it appears to be learning. Loss is decreasing and the generated images are starting to rougly resemble landscapes.
dallevae-cdim256_epoch_20

https://github.com/htoyryla/DALLE-pytorch/blob/main/trainDALLE.py

PS. The vocabulary class is missing from my repo. My code uses one from this page https://www.kdnuggets.com/2019/11/create-vocabulary-nlp-tasks-python.html

@lucidrains
Copy link
Owner

@htoyryla ohh sorry, I should have read the code, the codebook size (number of unique image tokens) is 2048 https://github.com/htoyryla/DALLE-pytorch/blob/main/trainDALLE.py#L30

@htoyryla
Copy link

htoyryla commented Jan 10, 2021

I am just dabbling in this really. My main interest is in working with images. Got interested in the discrete VAE and then decided to give a try on the text dimension as well. Some new concepts here... which is good.

Evening here... so I guess in the morning I will see if the images have improved.

@lucidrains
Copy link
Owner

@htoyryla your code looks ok on first glance. when did you train this? does this include the latest change with axial positional embeddings for the images?

@htoyryla
Copy link

The VAE was trained yesterday (East European time), the DALLE just been training a couple of hours. So most probably does not include latest changes (I have not synced since yesterday anyway).

@lucidrains
Copy link
Owner

@htoyryla thank you for sharing your results! this is encouraging if there are no errors lol

@lucidrains
Copy link
Owner

lucidrains commented Jan 10, 2021

@htoyryla is there a data standard for multimodal data? I may endeavor to turn this into a light command line tool, so people without coding experience can train small DALL-E's

would you be willing to share your dataset with the landscape pictures and generated captions?

@htoyryla
Copy link

@htoyryla is there a data standard for multimodal data? I may endeavor to turn this into a light command line tool, so people without coding experience can train small DALL-E's

Sounds a great plan. But no, I don't have an idea for multimodal data.

would you be willing to share your dataset with the landscape pictures and generated captions?

Not this set, I am sorry. The images are my own and material which I use in my artistic work. Also, the captions are quite poor as I used a quite simple image captioning tutorial to generate them. It should be possible to make a better set without too much work.

@htoyryla
Copy link

htoyryla commented Jan 11, 2021

I have added to my repo a script to generate images from text input with a trained VAE and DALLE https://github.com/htoyryla/DALLE-pytorch/blob/main/genDALLE.py

Using my models trained with landscapes, results at least look like landscapes. The semantics then do not really match, mainly because the generated captions are not really descriptive of the content.

Images generated during training Dalle:

dallevae-cdim256_epoch_227

Image generated from text input:

gendallevae-cdim256_epoch_220-1610364710

@lucidrains
Copy link
Owner

lucidrains commented Jan 11, 2021

@htoyryla 🥇

@CDitzel
Copy link

CDitzel commented Jan 11, 2021

Im a little confused. Is the VAE not supposed to be trained individually first and then to be used in DALLE later on in a pretrained fashion?

Using Resnetblocks in the VAE is surely expedient, but why not use UNets than to begin with?

Also, Phil. apparently there is sth. called 'discussions' now on github for the very purpose of discussing stuff which is not inherently related to a code issue. Maybe this can be helpful to foster the debate?

@htoyryla
Copy link

Im a little confused. Is the VAE not supposed to be trained individually first and then to be used in DALLE later on in a pretrained fashion?

If this is related to my scripts, that is exactly what I am doing. There is on my fork a script to train a VAE, then train a DALLE and finally to use DALLE to generate images from text. Have not yet tried CLIP. I made the scripts purely for my own use but at the very least they can work as examples.

@lucidrains
Copy link
Owner

lucidrains commented Jan 11, 2021

@CDitzel the resnet block is just to add a little more depth to the VAE. A unet have skip connections that would make a codebook unlearnable

Ok I started the discussion! Thanks for the suggestion!

@lucidrains lucidrains assigned lucidrains and unassigned lucidrains Jan 11, 2021
@mrconter1
Copy link

Is it possible that this could be of use to you @htoyryla? I've created a script that generates easy to use data from the COCO dataset for DALL-E.

@NaxAlpha
Copy link
Contributor Author

@lucidrains (after res block update 😁)

image

BTW I also tested gradual increase of filters i.e. [64, 128, 256] like style gan but it did not improve results as well as instance norm and dropout but it did not help much!

@lucidrains
Copy link
Owner

@NaxAlpha Oh no! Thanks for letting me know

I rolled it back to what it was, and also made it so the resnet blocks happen at the lowest resolution feature map (more akin to a working version of VQVAE that I know of) if you are willing to try it again

@CDitzel
Copy link

CDitzel commented Jan 11, 2021

this is sth. I dont particularly like about the field of deep learning. Obvious improvements according to temporary research are capable of ruining decent results...

@lucidrains
Copy link
Owner

I have added to my repo a script to generate images from text input with a trained VAE and DALLE https://github.com/htoyryla/DALLE-pytorch/blob/main/genDALLE.py

Using my models trained with landscapes, results at least look like landscapes. The semantics then do not really match, mainly because the generated captions are not really descriptive of the content.

Images generated during training Dalle:

dallevae-cdim256_epoch_227

Image generated from text input:

gendallevae-cdim256_epoch_220-1610364710

is it ok if I share these results in the readme?

@htoyryla
Copy link

is it ok if I share these results in the readme?

Yes, of course. At least until there is something better :)

@lucidrains
Copy link
Owner

@htoyryla Thank you Hannu!

@NaxAlpha
Copy link
Contributor Author

Thanks @lucidrains I have restarted training. BTW this is the code I am using for training VAE - Its very rough for now but hopefully we can integrate it once trainer class is ready!

from torch.cuda import amp
import torch.nn.functional as F
from torchvision.utils import make_grid
from torch.nn.utils import clip_grad_norm_


def loss_fn(x, y):
    return F.mse_loss(x, y) + F.smooth_l1_loss(x, y)

vae.temperature = 5.

k = 4
dk = 0.7 ** (1/len(dl))
print('Scale Factor:', dk)

running_loss = 1.
running_error = 1.
for epoch in range(10):
    for i, (images, _) in enumerate(dl):
        images = images.to(DEVICE)

        recons = vae(images)
        loss = loss_fn(images, recons)

        opt.zero_grad()    
        loss.backward()
        # clip_grad_norm_(vae.parameters(), 1)
        opt.step()

        if i % 200 == 0:
            with torch.no_grad():
                codes = vae.get_codebook_indices(images)
                imgx = vae.decode(codes)
                error = loss_fn(images, imgx)
                running_error = 0.9 * running_error + 0.1 * error.item()

            grid = torch.cat([images[:k], recons[:k], imgx[:k]])
            grid = make_grid(grid, nrow=k, normalize=True, range=(-1, 1))
            imag = VTF.to_pil_image(grid)
            display(imag)
            torch.save(vae.state_dict(), 'vae.pt')
            wandb.log({
                "Sample Images": wandb.Image(imag), 
                'Running Error': running_error
            }, step=STEP_ID)

        running_loss = 0.9*running_loss + 0.1*loss.item()
        if i % 10 == 0:
            print(
                epoch,
                i, 
                round(running_loss, 3),
                round(running_error, 3),
                round(vae.temperature, 3),
            )
            wandb.log({
                "Running Loss": running_loss, 
                'Temperature': vae.temperature
            }, step=STEP_ID)


        vae.temperature *= dk
        STEP_ID += 1

    print('Current Temperature:', vae.temperature)

torch.save(vae.state_dict(), 'vaex.pt')

I also tested 16-bit training using pytorch amp module but i think gumbel_softmax and following steps cause it to diverge so might need to add with autocast(enabled=False): will send a PR for this soon!!!

@lucidrains
Copy link
Owner

@NaxAlpha Thanks for sharing the code! I'll wait a bit longer to hear back on what other's experience during training is

Sid and Ben are doing a run at the moment with temperature annealing from 1 -> 0.05 over 25k steps on the mesh tensorflow version

@htoyryla
Copy link

htoyryla commented Jan 12, 2021

My trainVAE now includes @NaxAlpha 's temperature scheduling and using both mse and l1 loss.

Things are developing faster than tests have time to complete :)

I think I'll add command line arguments to my scripts next.

@Apokar
Copy link

Apokar commented Jan 13, 2021

@NaxAlpha Hi Nax, I can't access your colab page due to some network reasons. Could you please upload a copy to your forked repo? It would be very helpful If you had the time to do this. I'm wondering your progress right now.

Anyway, thank you. XD

@Dakini
Copy link

Dakini commented Jan 24, 2021

I've had a go using @htoyryla code for the last few days on a subsection of danbooru portraits and danbooru 2019.
For the discrete VAE I got these results:
vae_epoch_497
batch_32300_vae_epoch_2

When I trained the DALLE implementation with the portrait dataset, it went okish with 2000 images:
test_dalle_epoch_500

Think I might try it with a new learning rate scheduler and training with smaller images first.

@CaicaiJason
Copy link

I've had a go using @htoyryla code for the last few days on a subsection of danbooru portraits and danbooru 2019.
For the discrete VAE I got these results:
vae_epoch_497
batch_32300_vae_epoch_2

When I trained the DALLE implementation with the portrait dataset, it went okish with 2000 images:
test_dalle_epoch_500

Think I might try it with a new learning rate scheduler and training with smaller images first.

hey,Where can I get the animation data set?

@Dakini
Copy link

Dakini commented Jan 25, 2021

@CaicaiJason Its the Danbooru2019 but you can get it here, albeit the full dataset is 3.4 TB, but the portrait is like 16.GB
https://www.gwern.net/Danbooru2020

@CaicaiJason
Copy link

https://www.gwern.net/Danbooru2020

@Dakini thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests