Skip to content

ai-porter/WGAN-GP-tensorflow

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

27 Commits
 
 
 
 
 
 

Repository files navigation

WGAN-GP-tensorflow

This repository is a Tensorflow implementation of the WGAN-GP for MNIST, CIFAR-10, and ImageNet64.

  • All samples in README.md are genearted by neural network except the first image for each row.

Requirements

  • tensorflow 1.10.0
  • python 3.5.5
  • numpy 1.14.2
  • matplotlib 2.2.2
  • scipy 0.19.1
  • pillow 5.0.0
  • urlib3 1.23
  • jsonschema 2.6.0
  • requests 2.14.2
  • tqdm 4.26.0
  • six 1.11.0

Generated Images

1. Toy Dataset

Results from 2-dimensional of the 8 Gaussian Mixture Models, 25 Gaussian Mixture Models, and Swiss Roll data. Ipython Notebook.

Note: To demonstrate following experiment, we held the generator distribution Pg fixed at the real distribution plus unit-variance Gaussian noise.

  • Top: GAN discriminator
  • Middle: WGAN critic with weight clipping
  • Bottom: WGAN critic with weight penalty

Note: For the next experiment, we did not fix generator and showed generated points by the generator.

  • Top: GAN discriminator
  • Middle: WGAN critic with weight clipping
  • Bottom: WGAN critic with weight penalty

2. MNIST Dataset

3. CIFAR-10

4. IMAGENET64

3. CIFAR-10

Documentation

Download Dataset

'MNIST' and 'CIFAR10' dataset will be downloaded automatically from the code if in a specific folder there are no dataset. 'ImageNet64' dataset can be download from the Downsampled ImageNet.

Directory Hierarchy

.
│   WGAN-GP
│   ├── src
│   │   ├── imagenet (folder saved inception network weights that downloaded from the inception_score.py)
│   │   ├── cache.py
│   │   ├── cifar10.py
│   │   ├── dataset.py
│   │   ├── dataset_.py
│   │   ├── download.py
│   │   ├── inception_score.py
│   │   ├── main.py
│   │   ├── plot.py
│   │   ├── solver.py
│   │   ├── tensorflow_utils.py
│   │   ├── utils.py
│   │   └── wgan_gp.py
│   Data
│   ├── mnist
│   ├── cifar10
│   └── imagenet64

src: source codes of the WGAN-GP

Training WGAN-GP

Use main.py to train a WGAN-GP network. Example usage:

python main.py
  • gpu_index: gpu index, default: 0

  • batch_size: batch size for one feed forward, default: 64

  • dataset: dataset name from [mnist, cifar10, imagenet64], default: mnist

  • is_train: training or inference mode, default: True

  • learning_rate: initial learning rate for Adam, default: 0.001

  • num_critic: the number of iterations of the critic per generator iteration, default: 5

  • z_dim: dimension of z vector, default: 128

  • lambda_: gradient penalty lambda hyperparameter, default: 10.

  • beta1: beta1 momentum term of Adam, default: 0.5

  • beta2: beta2 momentum term of Adam, default: 0.9

  • iters: number of interations, default: 200000

  • print_freq: print frequency for loss, default: 100

  • save_freq: save frequency for model, default: 10000

  • sample_freq: sample frequency for saving image, default: 500

  • inception_freq: calculation frequence of the inception score, default: 1000

  • sample_batch: number of sampling images for check generator quality, default: 64

  • load_model: folder of save model that you wish to test, (e.g. 20181120-1558). default: None

WGAN-GP During Training

Note: From the following figures, the Y axises are tge negative critic loss for the WGAN-GP.

  1. MNIST

  1. CIFAR10

  1. IMAGENET64

Inception Score on CIFAR10 During Training

Note: Inception score was calculated every 1000 iterations.

Test WGAN-GP

Use main.py to test a WGAN-GP network. Example usage:

python main.py --is_train=false --load_model=folder/you/wish/to/test/e.g./20181120-1558

Please refer to the above arguments.

Citation

  @misc{chengbinjin2018wgan-gp,
    author = {Cheng-Bin Jin},
    title = {WGAN-GP-tensorflow},
    year = {2018},
    howpublished = {\url{https://github.com/ChengBinJin/WGAN-GP-tensorflow}},
    note = {commit xxxxxxx}
  }

Attributions/Thanks

License

Copyright (c) 2018 Cheng-Bin Jin. Contact me for commercial use (or rather any use that is not academic research) (email: [email protected]). Free for research use, as long as proper attribution is given and this copyright notice is retained.

Related Projects

About

WGAN-GP tensorflow implementation

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 84.6%
  • Jupyter Notebook 15.4%