Skip to content

Commit

Permalink
Dcgan benchmark
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 234666763
  • Loading branch information
yashk2810 authored and lamberta committed Feb 19, 2019
1 parent 3a23533 commit 73bcf46
Show file tree
Hide file tree
Showing 2 changed files with 272 additions and 0 deletions.
202 changes: 202 additions & 0 deletions test_models/DCGAN/dcgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
"""DCGAN.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from absl import app
from absl import flags
import tensorflow as tf # TF2
import tensorflow_datasets as tfds
assert tf.__version__.startswith('2')

FLAGS = flags.FLAGS

flags.DEFINE_integer('buffer_size', 10000, 'Shuffle buffer size')
flags.DEFINE_integer('batch_size', 64, 'Batch Size')
flags.DEFINE_integer('epochs', 1, 'Number of epochs')
flags.DEFINE_boolean('enable_function', True, 'Enable Function?')

AUTOTUNE = tf.data.experimental.AUTOTUNE


def scale(image, label):
image = tf.cast(image, tf.float32)
image = (image - 127.5) / 127.5

return image, label


def create_dataset(buffer_size, batch_size):
dataset, _ = tfds.load('mnist', as_supervised=True, with_info=True)
train_dataset, _ = dataset['train'], dataset['test']
train_dataset = train_dataset.map(scale, num_parallel_calls=AUTOTUNE)
train_dataset = train_dataset.shuffle(buffer_size).batch(batch_size)

return train_dataset


def make_generator_model():
"""Generator.
Returns:
Keras Sequential model
"""
model = tf.keras.Sequential([
tf.keras.layers.Dense(7*7*256, use_bias=False),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Reshape((7, 7, 256)),
tf.keras.layers.Conv2DTranspose(128, 5, strides=(1, 1),
padding='same', use_bias=False),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2),
padding='same', use_bias=False),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Conv2DTranspose(1, 5, strides=(2, 2),
padding='same', use_bias=False,
activation='tanh')
])

return model


def make_discriminator_model():
"""Discriminator.
Returns:
Keras Sequential model
"""
model = tf.keras.Sequential([
tf.keras.layers.Conv2D(64, 5, strides=(2, 2), padding='same'),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Conv2D(128, 5, strides=(2, 2), padding='same'),
tf.keras.layers.LeakyReLU(),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(1)
])

return model


def get_checkpoint_prefix():
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')

return checkpoint_prefix


class Dcgan(object):
"""Dcgan class.
Args:
epochs: Number of epochs.
enable_function: If true, train step is decorated with tf.function.
batch_size: Batch size.
"""

def __init__(self, epochs, enable_function, batch_size):
self.epochs = epochs
self.enable_function = enable_function
self.batch_size = batch_size
self.noise_dim = 100
self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
self.generator_optimizer = tf.keras.optimizers.Adam(1e-4)
self.discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
self.generator = make_generator_model()
self.discriminator = make_discriminator_model()
self.checkpoint = tf.train.Checkpoint(
generator_optimizer=self.generator_optimizer,
discriminator_optimizer=self.discriminator_optimizer,
generator=self.generator,
discriminator=self.discriminator)

def generator_loss(self, generated_output):
return self.loss_object(tf.ones_like(generated_output), generated_output)

def discriminator_loss(self, real_output, generated_output):
real_loss = self.loss_object(tf.ones_like(real_output), real_output)
generated_loss = self.loss_object(
tf.zeros_like(generated_output), generated_output)

total_loss = real_loss + generated_loss

return total_loss

def train_step(self, image):
"""One train step over the generator and discriminator model.
Args:
image: Input image.
Returns:
generator loss, discriminator loss.
"""
noise = tf.random.normal([self.batch_size, self.noise_dim])

with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = self.generator(noise, training=True)

real_output = self.discriminator(image, training=True)
generated_output = self.discriminator(generated_images, training=True)

gen_loss = self.generator_loss(generated_output)
disc_loss = self.discriminator_loss(real_output, generated_output)

gradients_of_generator = gen_tape.gradient(
gen_loss, self.generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(
disc_loss, self.discriminator.trainable_variables)

self.generator_optimizer.apply_gradients(zip(
gradients_of_generator, self.generator.trainable_variables))
self.discriminator_optimizer.apply_gradients(zip(
gradients_of_discriminator, self.discriminator.trainable_variables))

return gen_loss, disc_loss

def train(self, dataset, checkpoint_pr):
"""Train the GAN for x number of epochs.
Args:
dataset: train dataset.
checkpoint_pr: prefix in which the checkpoints are stored.
"""
if self.enable_function:
self.train_step = tf.function(self.train_step)

for epoch in range(self.epochs):
for image, _ in dataset:
gen_loss, disc_loss = self.train_step(image)

# saving (checkpoint) the model every 15 epochs
if (epoch + 1) % 15 == 0:
self.checkpoint.save(file_prefix=checkpoint_pr)

template = 'Epoch {}, Generator loss {}, Discriminator Loss {}'
print (template.format(epoch, gen_loss, disc_loss))


def run_main(argv):
del argv
kwargs = {'epochs': FLAGS.epochs, 'enable_function': FLAGS.enable_function,
'buffer_size': FLAGS.buffer_size, 'batch_size': FLAGS.batch_size}
main(**kwargs)


def main(epochs, enable_function, buffer_size, batch_size):
train_dataset = create_dataset(buffer_size, batch_size)
checkpoint_pr = get_checkpoint_prefix()

dcgan_obj = Dcgan(epochs, enable_function, batch_size)
print ('Training ...')
dcgan_obj.train(train_dataset, checkpoint_pr)

if __name__ == '__main__':
app.run(run_main)
70 changes: 70 additions & 0 deletions test_models/DCGAN/dcgan_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""DCGAN tests."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import time
from absl import flags
import tensorflow as tf # TF2
from tensorflow_examples.test_models.DCGAN import dcgan

FLAGS = flags.FLAGS


class DcganTest(tf.test.TestCase):

def test_one_epoch_with_function(self):
epochs = 1
batch_size = 1
enable_function = True

input_image = tf.random.uniform((28, 28, 1))
label = tf.zeros((1,))
train_dataset = tf.data.Dataset.from_tensors(
(input_image, label)).batch(batch_size)
checkpoint_pr = dcgan.get_checkpoint_prefix()

dcgan_obj = dcgan.Dcgan(epochs, enable_function, batch_size)
dcgan_obj.train(train_dataset, checkpoint_pr)

def test_one_epoch_without_function(self):
epochs = 1
batch_size = 1
enable_function = False

input_image = tf.random.uniform((28, 28, 1))
label = tf.zeros((1,))
train_dataset = tf.data.Dataset.from_tensors(
(input_image, label)).batch(batch_size)
checkpoint_pr = dcgan.get_checkpoint_prefix()

dcgan_obj = dcgan.Dcgan(epochs, enable_function, batch_size)
dcgan_obj.train(train_dataset, checkpoint_pr)


class DCGANBenchmark(tf.test.Benchmark):

def __init__(self, output_dir=None):
self.output_dir = output_dir

def benchmark_with_function(self):
kwargs = {"epochs": 1, "enable_function": True,
"buffer_size": 10000, "batch_size": 64}
self._run_and_report_benchmark(**kwargs)

def benchmark_without_function(self):
kwargs = {"epochs": 1, "enable_function": False,
"buffer_size": 10000, "batch_size": 64}
self._run_and_report_benchmark(**kwargs)

def _run_and_report_benchmark(self, **kwargs):
start_time_sec = time.time()
dcgan.main(**kwargs)
wall_time_sec = time.time() - start_time_sec

self.report_benchmark(wall_time=wall_time_sec)

if __name__ == "__main__":
assert tf.__version__.startswith('2')
tf.test.main()

0 comments on commit 73bcf46

Please sign in to comment.