forked from tensorflow/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PiperOrigin-RevId: 234666763
- Loading branch information
Showing
2 changed files
with
272 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
"""DCGAN. | ||
""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import os | ||
from absl import app | ||
from absl import flags | ||
import tensorflow as tf # TF2 | ||
import tensorflow_datasets as tfds | ||
assert tf.__version__.startswith('2') | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
flags.DEFINE_integer('buffer_size', 10000, 'Shuffle buffer size') | ||
flags.DEFINE_integer('batch_size', 64, 'Batch Size') | ||
flags.DEFINE_integer('epochs', 1, 'Number of epochs') | ||
flags.DEFINE_boolean('enable_function', True, 'Enable Function?') | ||
|
||
AUTOTUNE = tf.data.experimental.AUTOTUNE | ||
|
||
|
||
def scale(image, label): | ||
image = tf.cast(image, tf.float32) | ||
image = (image - 127.5) / 127.5 | ||
|
||
return image, label | ||
|
||
|
||
def create_dataset(buffer_size, batch_size): | ||
dataset, _ = tfds.load('mnist', as_supervised=True, with_info=True) | ||
train_dataset, _ = dataset['train'], dataset['test'] | ||
train_dataset = train_dataset.map(scale, num_parallel_calls=AUTOTUNE) | ||
train_dataset = train_dataset.shuffle(buffer_size).batch(batch_size) | ||
|
||
return train_dataset | ||
|
||
|
||
def make_generator_model(): | ||
"""Generator. | ||
Returns: | ||
Keras Sequential model | ||
""" | ||
model = tf.keras.Sequential([ | ||
tf.keras.layers.Dense(7*7*256, use_bias=False), | ||
tf.keras.layers.BatchNormalization(), | ||
tf.keras.layers.LeakyReLU(), | ||
tf.keras.layers.Reshape((7, 7, 256)), | ||
tf.keras.layers.Conv2DTranspose(128, 5, strides=(1, 1), | ||
padding='same', use_bias=False), | ||
tf.keras.layers.BatchNormalization(), | ||
tf.keras.layers.LeakyReLU(), | ||
tf.keras.layers.Conv2DTranspose(64, 5, strides=(2, 2), | ||
padding='same', use_bias=False), | ||
tf.keras.layers.BatchNormalization(), | ||
tf.keras.layers.LeakyReLU(), | ||
tf.keras.layers.Conv2DTranspose(1, 5, strides=(2, 2), | ||
padding='same', use_bias=False, | ||
activation='tanh') | ||
]) | ||
|
||
return model | ||
|
||
|
||
def make_discriminator_model(): | ||
"""Discriminator. | ||
Returns: | ||
Keras Sequential model | ||
""" | ||
model = tf.keras.Sequential([ | ||
tf.keras.layers.Conv2D(64, 5, strides=(2, 2), padding='same'), | ||
tf.keras.layers.LeakyReLU(), | ||
tf.keras.layers.Dropout(0.3), | ||
tf.keras.layers.Conv2D(128, 5, strides=(2, 2), padding='same'), | ||
tf.keras.layers.LeakyReLU(), | ||
tf.keras.layers.Dropout(0.3), | ||
tf.keras.layers.Flatten(), | ||
tf.keras.layers.Dense(1) | ||
]) | ||
|
||
return model | ||
|
||
|
||
def get_checkpoint_prefix(): | ||
checkpoint_dir = './training_checkpoints' | ||
checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') | ||
|
||
return checkpoint_prefix | ||
|
||
|
||
class Dcgan(object): | ||
"""Dcgan class. | ||
Args: | ||
epochs: Number of epochs. | ||
enable_function: If true, train step is decorated with tf.function. | ||
batch_size: Batch size. | ||
""" | ||
|
||
def __init__(self, epochs, enable_function, batch_size): | ||
self.epochs = epochs | ||
self.enable_function = enable_function | ||
self.batch_size = batch_size | ||
self.noise_dim = 100 | ||
self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True) | ||
self.generator_optimizer = tf.keras.optimizers.Adam(1e-4) | ||
self.discriminator_optimizer = tf.keras.optimizers.Adam(1e-4) | ||
self.generator = make_generator_model() | ||
self.discriminator = make_discriminator_model() | ||
self.checkpoint = tf.train.Checkpoint( | ||
generator_optimizer=self.generator_optimizer, | ||
discriminator_optimizer=self.discriminator_optimizer, | ||
generator=self.generator, | ||
discriminator=self.discriminator) | ||
|
||
def generator_loss(self, generated_output): | ||
return self.loss_object(tf.ones_like(generated_output), generated_output) | ||
|
||
def discriminator_loss(self, real_output, generated_output): | ||
real_loss = self.loss_object(tf.ones_like(real_output), real_output) | ||
generated_loss = self.loss_object( | ||
tf.zeros_like(generated_output), generated_output) | ||
|
||
total_loss = real_loss + generated_loss | ||
|
||
return total_loss | ||
|
||
def train_step(self, image): | ||
"""One train step over the generator and discriminator model. | ||
Args: | ||
image: Input image. | ||
Returns: | ||
generator loss, discriminator loss. | ||
""" | ||
noise = tf.random.normal([self.batch_size, self.noise_dim]) | ||
|
||
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: | ||
generated_images = self.generator(noise, training=True) | ||
|
||
real_output = self.discriminator(image, training=True) | ||
generated_output = self.discriminator(generated_images, training=True) | ||
|
||
gen_loss = self.generator_loss(generated_output) | ||
disc_loss = self.discriminator_loss(real_output, generated_output) | ||
|
||
gradients_of_generator = gen_tape.gradient( | ||
gen_loss, self.generator.trainable_variables) | ||
gradients_of_discriminator = disc_tape.gradient( | ||
disc_loss, self.discriminator.trainable_variables) | ||
|
||
self.generator_optimizer.apply_gradients(zip( | ||
gradients_of_generator, self.generator.trainable_variables)) | ||
self.discriminator_optimizer.apply_gradients(zip( | ||
gradients_of_discriminator, self.discriminator.trainable_variables)) | ||
|
||
return gen_loss, disc_loss | ||
|
||
def train(self, dataset, checkpoint_pr): | ||
"""Train the GAN for x number of epochs. | ||
Args: | ||
dataset: train dataset. | ||
checkpoint_pr: prefix in which the checkpoints are stored. | ||
""" | ||
if self.enable_function: | ||
self.train_step = tf.function(self.train_step) | ||
|
||
for epoch in range(self.epochs): | ||
for image, _ in dataset: | ||
gen_loss, disc_loss = self.train_step(image) | ||
|
||
# saving (checkpoint) the model every 15 epochs | ||
if (epoch + 1) % 15 == 0: | ||
self.checkpoint.save(file_prefix=checkpoint_pr) | ||
|
||
template = 'Epoch {}, Generator loss {}, Discriminator Loss {}' | ||
print (template.format(epoch, gen_loss, disc_loss)) | ||
|
||
|
||
def run_main(argv): | ||
del argv | ||
kwargs = {'epochs': FLAGS.epochs, 'enable_function': FLAGS.enable_function, | ||
'buffer_size': FLAGS.buffer_size, 'batch_size': FLAGS.batch_size} | ||
main(**kwargs) | ||
|
||
|
||
def main(epochs, enable_function, buffer_size, batch_size): | ||
train_dataset = create_dataset(buffer_size, batch_size) | ||
checkpoint_pr = get_checkpoint_prefix() | ||
|
||
dcgan_obj = Dcgan(epochs, enable_function, batch_size) | ||
print ('Training ...') | ||
dcgan_obj.train(train_dataset, checkpoint_pr) | ||
|
||
if __name__ == '__main__': | ||
app.run(run_main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
"""DCGAN tests.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import time | ||
from absl import flags | ||
import tensorflow as tf # TF2 | ||
from tensorflow_examples.test_models.DCGAN import dcgan | ||
|
||
FLAGS = flags.FLAGS | ||
|
||
|
||
class DcganTest(tf.test.TestCase): | ||
|
||
def test_one_epoch_with_function(self): | ||
epochs = 1 | ||
batch_size = 1 | ||
enable_function = True | ||
|
||
input_image = tf.random.uniform((28, 28, 1)) | ||
label = tf.zeros((1,)) | ||
train_dataset = tf.data.Dataset.from_tensors( | ||
(input_image, label)).batch(batch_size) | ||
checkpoint_pr = dcgan.get_checkpoint_prefix() | ||
|
||
dcgan_obj = dcgan.Dcgan(epochs, enable_function, batch_size) | ||
dcgan_obj.train(train_dataset, checkpoint_pr) | ||
|
||
def test_one_epoch_without_function(self): | ||
epochs = 1 | ||
batch_size = 1 | ||
enable_function = False | ||
|
||
input_image = tf.random.uniform((28, 28, 1)) | ||
label = tf.zeros((1,)) | ||
train_dataset = tf.data.Dataset.from_tensors( | ||
(input_image, label)).batch(batch_size) | ||
checkpoint_pr = dcgan.get_checkpoint_prefix() | ||
|
||
dcgan_obj = dcgan.Dcgan(epochs, enable_function, batch_size) | ||
dcgan_obj.train(train_dataset, checkpoint_pr) | ||
|
||
|
||
class DCGANBenchmark(tf.test.Benchmark): | ||
|
||
def __init__(self, output_dir=None): | ||
self.output_dir = output_dir | ||
|
||
def benchmark_with_function(self): | ||
kwargs = {"epochs": 1, "enable_function": True, | ||
"buffer_size": 10000, "batch_size": 64} | ||
self._run_and_report_benchmark(**kwargs) | ||
|
||
def benchmark_without_function(self): | ||
kwargs = {"epochs": 1, "enable_function": False, | ||
"buffer_size": 10000, "batch_size": 64} | ||
self._run_and_report_benchmark(**kwargs) | ||
|
||
def _run_and_report_benchmark(self, **kwargs): | ||
start_time_sec = time.time() | ||
dcgan.main(**kwargs) | ||
wall_time_sec = time.time() - start_time_sec | ||
|
||
self.report_benchmark(wall_time=wall_time_sec) | ||
|
||
if __name__ == "__main__": | ||
assert tf.__version__.startswith('2') | ||
tf.test.main() |