Skip to content

Commit

Permalink
Make pix2pix flexible so that cyclegan can use it.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 253589693
  • Loading branch information
yashk2810 authored and copybara-github committed Jun 17, 2019
1 parent d6cfa25 commit f28b124
Showing 1 changed file with 85 additions and 35 deletions.
120 changes: 85 additions & 35 deletions tensorflow_examples/models/pix2pix/pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,41 @@ def create_dataset(path_to_train_images, path_to_test_images, buffer_size,
return train_dataset, test_dataset


def downsample(filters, size, apply_batchnorm=True):
class InstanceNormalization(tf.keras.layers.Layer):
"""Instance Normalization Layer (https://arxiv.org/abs/1607.08022)."""

def __init__(self, epsilon=1e-5):
super(InstanceNormalization, self).__init__()
self.epsilon = epsilon

def build(self, input_shape):
self.scale = self.add_weight(
shape=input_shape[-1:],
initializer=tf.random_normal_initializer(0., 0.02),
trainable=True)

self.offset = self.add_weight(
shape=input_shape[-1:],
initializer='zeros',
trainable=True)

def call(self, x):
mean, variance = tf.nn.moments(x, axes=[1, 2], keepdims=True)
inv = tf.math.rsqrt(variance + self.epsilon)
normalized = (x - mean) * inv
return self.scale * normalized + self.offset


def downsample(filters, size, norm_type='batchnorm', apply_norm=True):
"""Downsamples an input.
Conv2D => Batchnorm => LeakyRelu
Args:
filters: number of filters
size: filter size
apply_batchnorm: If True, adds the batchnorm layer
norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
apply_norm: If True, adds the batchnorm layer
Returns:
Downsample Sequential Model
Expand All @@ -178,27 +204,32 @@ def downsample(filters, size, apply_batchnorm=True):
tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
kernel_initializer=initializer, use_bias=False))

if apply_batchnorm:
result.add(tf.keras.layers.BatchNormalization())
if apply_norm:
if norm_type.lower() == 'batchnorm':
result.add(tf.keras.layers.BatchNormalization())
elif norm_type.lower() == 'instancenorm':
result.add(InstanceNormalization())

result.add(tf.keras.layers.LeakyReLU())

return result


def upsample(filters, size, apply_dropout=False):
def upsample(filters, size, norm_type='batchnorm', apply_dropout=False):
"""Upsamples an input.
Conv2DTranspose => Batchnorm => Dropout => Relu
Args:
filters: number of filters
size: filter size
norm_type: Normalization type; either 'batchnorm' or 'instancenorm'.
apply_dropout: If True, adds the dropout layer
Returns:
Upsample Sequential Model
"""

initializer = tf.random_normal_initializer(0., 0.02)

result = tf.keras.Sequential()
Expand All @@ -208,7 +239,10 @@ def upsample(filters, size, apply_dropout=False):
kernel_initializer=initializer,
use_bias=False))

result.add(tf.keras.layers.BatchNormalization())
if norm_type.lower() == 'batchnorm':
result.add(tf.keras.layers.BatchNormalization())
elif norm_type.lower() == 'instancenorm':
result.add(InstanceNormalization())

if apply_dropout:
result.add(tf.keras.layers.Dropout(0.5))
Expand All @@ -218,33 +252,36 @@ def upsample(filters, size, apply_dropout=False):
return result


def generator_model(output_channels):
"""Modified u-net generator model.
def unet_generator(output_channels, norm_type='batchnorm'):
"""Modified u-net generator model (https://arxiv.org/abs/1611.07004).
Args:
output_channels: Output channels
norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
Returns:
Generator model
"""

down_stack = [
downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
downsample(128, 4), # (bs, 64, 64, 128)
downsample(256, 4), # (bs, 32, 32, 256)
downsample(512, 4), # (bs, 16, 16, 512)
downsample(512, 4), # (bs, 8, 8, 512)
downsample(512, 4), # (bs, 4, 4, 512)
downsample(512, 4), # (bs, 2, 2, 512)
downsample(512, 4), # (bs, 1, 1, 512)
downsample(64, 4, norm_type, apply_norm=False), # (bs, 128, 128, 64)
downsample(128, 4, norm_type), # (bs, 64, 64, 128)
downsample(256, 4, norm_type), # (bs, 32, 32, 256)
downsample(512, 4, norm_type), # (bs, 16, 16, 512)
downsample(512, 4, norm_type), # (bs, 8, 8, 512)
downsample(512, 4, norm_type), # (bs, 4, 4, 512)
downsample(512, 4, norm_type), # (bs, 2, 2, 512)
downsample(512, 4, norm_type), # (bs, 1, 1, 512)
]

up_stack = [
upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
upsample(512, 4), # (bs, 16, 16, 1024)
upsample(256, 4), # (bs, 32, 32, 512)
upsample(128, 4), # (bs, 64, 64, 256)
upsample(64, 4), # (bs, 128, 128, 128)
upsample(512, 4, norm_type, apply_dropout=True), # (bs, 2, 2, 1024)
upsample(512, 4, norm_type, apply_dropout=True), # (bs, 4, 4, 1024)
upsample(512, 4, norm_type, apply_dropout=True), # (bs, 8, 8, 1024)
upsample(512, 4, norm_type), # (bs, 16, 16, 1024)
upsample(256, 4, norm_type), # (bs, 32, 32, 512)
upsample(128, 4, norm_type), # (bs, 64, 64, 256)
upsample(64, 4, norm_type), # (bs, 128, 128, 128)
]

initializer = tf.random_normal_initializer(0., 0.02)
Expand Down Expand Up @@ -276,39 +313,52 @@ def generator_model(output_channels):
return tf.keras.Model(inputs=inputs, outputs=x)


def discriminator_model():
"""PatchGan discriminator model.
def discriminator(norm_type='batchnorm', target=True):
"""PatchGan discriminator model (https://arxiv.org/abs/1611.07004).
Args:
norm_type: Type of normalization. Either 'batchnorm' or 'instancenorm'.
target: Bool, indicating whether target image is an input or not.
Returns:
Discriminator model
"""

initializer = tf.random_normal_initializer(0., 0.02)

inp = tf.keras.layers.Input(shape=[None, None, 3], name='input_image')
tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')
x = inp

x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
if target:
tar = tf.keras.layers.Input(shape=[None, None, 3], name='target_image')
x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)

down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
down1 = downsample(64, 4, norm_type, False)(x) # (bs, 128, 128, 64)
down2 = downsample(128, 4, norm_type)(down1) # (bs, 64, 64, 128)
down3 = downsample(256, 4, norm_type)(down2) # (bs, 32, 32, 256)

zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
conv = tf.keras.layers.Conv2D(
512, 4, strides=1, kernel_initializer=initializer,
use_bias=False)(zero_pad1) # (bs, 31, 31, 512)

batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
if norm_type.lower() == 'batchnorm':
norm1 = tf.keras.layers.BatchNormalization()(conv)
elif norm_type.lower() == 'instancenorm':
norm1 = InstanceNormalization()(conv)

leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)

last = tf.keras.layers.Conv2D(
1, 4, strides=1,
kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)

return tf.keras.Model(inputs=[inp, tar], outputs=last)
if target:
return tf.keras.Model(inputs=[inp, tar], outputs=last)
else:
return tf.keras.Model(inputs=inp, outputs=last)


def get_checkpoint_prefix():
Expand All @@ -335,8 +385,8 @@ def __init__(self, epochs, enable_function):
self.loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
self.generator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
self.discriminator_optimizer = tf.keras.optimizers.Adam(2e-4, beta_1=0.5)
self.generator = generator_model(output_channels=3)
self.discriminator = discriminator_model()
self.generator = unet_generator(output_channels=3)
self.discriminator = discriminator()
self.checkpoint = tf.train.Checkpoint(
generator_optimizer=self.generator_optimizer,
discriminator_optimizer=self.discriminator_optimizer,
Expand Down

0 comments on commit f28b124

Please sign in to comment.