Python library for a wide variety of GANs (Generative Adversarial Networks) based on TensorFlow and Keras.
To download the GANForge model from pypi please use the following pip command in your command prompt/terminal
pip install git+https://github.com/quadeer15sh/GANForge.git
You can get started with building GANs in just a few lines of code.
import tensorflow as tf
from GANForge.dcgan import DCGAN
# train_ds: your image dataset
model = DCGAN(input_shape=(64, 64, 3), latent_dim=128)
model.compile(d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
loss_fn=tf.keras.losses.BinaryCrossentropy())
model.fit(train_ds, epochs=25)
Please feel free to explore through the notebook files on each of the GAN models available in GANForge
Note : This list is updated frequently, please come back to check if the GAN architecture you desire to use is available or not
Sr. | GAN Architecture | Status |
---|---|---|
1 | DC GAN | Available |
2 | Conditional GAN | Available |
3 | Info GAN | In Progress |
4 | SR GAN | Available |
5 | ESR GAN | In Progress |
6 | Pix2Pix GAN | In Progress |
7 | Cycle GAN | In Progress |
8 | Attention GAN | In Progress |
Custom callbacks available for usage during your training
Sr. | Callback | GAN Applicable |
---|---|---|
1 | DCGANVisualization | DC GAN |
2 | ConditionalGANVisualization | Conditional DCGAN |
import tensorflow as tf
from GANForge.dcgan import DCGAN
from GANForge.callbacks import DCGANVisualization
# train_ds: your image dataset
model = DCGAN(input_shape=(64, 64, 3), latent_dim=128)
model.compile(d_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
g_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0002),
loss_fn=tf.keras.losses.BinaryCrossentropy())
visualizer = DCGANVisualization(n_epochs=5)
model.fit(train_ds, epochs=25, callbacks=[visualizer])