Skip to content

Commit

Permalink
update: image scaling function at 'res_block'
Browse files Browse the repository at this point in the history
  • Loading branch information
kozistr committed Jan 10, 2019
1 parent 664b775 commit ca912d9
Showing 1 changed file with 14 additions and 12 deletions.
26 changes: 14 additions & 12 deletions BigGAN/biggan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,32 +98,35 @@ def __init__(self, s, batch_size=64, height=128, width=128, channel=3, n_classes
# Placeholders
self.x = tf.placeholder(tf.float32,
shape=[None, self.height, self.width, self.channel],
name="x-image") # (64, 64, 64, 3)
name="x-image") # (bs, 128, 128, 3)
self.y = tf.placeholder(tf.float32,
shape=[None, self.n_classes],
name="y-label") # (64, n_classes)
name="y-label") # (bs, n_classes)
self.z = tf.placeholder(tf.float32, shape=[None, self.z_dim], name="z-noise") # (-1, 128)

self.build_sagan() # build SAGAN model
self.build_sagan() # build BigGAN model

@staticmethod
def res_block(x, c, z, f, scale_type, name):
def res_block(x, f, scale_type, name):
with tf.variable_scope("res_block_up-%s" % name):
assert scale_type in ["up", "down"]
scale_up = False if scale_type == "down" else True

ssc = x
cz = tf.concat([c, z], axis=-1)

x = t.batch_norm(tf.concat([x, cz], axis=-1), name="bn-1")
x = t.batch_norm(x, name="bn-1")
x = tf.nn.relu(x)
x = t.conv2d_alt(x, f, s=2, sn=True, name="conv2d-1") if not scale_up \
else t.deconv2d_alt(x, f, s=2, sn=True, name="deconv2d-1")
x = t.conv2d_alt(x, f, sn=True, name="conv2d-1")

x = t.batch_norm(tf.concat([x, cz], axis=-1), name="bn-2")
x = t.batch_norm(x, name="bn-2")
x = tf.nn.relu(x)
x = t.conv2d_alt(x, f, s=2, sn=True, name="conv2d-2") if not scale_up \
else t.deconv2d_alt(x, f, s=2, sn=True, name="deconv2d-2")

if not scale_up:
x = t.conv2d_alt(x, f, sn=True, name="conv2d-2")
x = tf.layers.average_pooling2d(x, pool_size=(2, 2))
else:
x = t.deconv2d_alt(x, f, sn=True, name="up-sampling")

return x + ssc

@staticmethod
Expand Down Expand Up @@ -229,7 +232,6 @@ def generator(self, z, c=None, reuse=None):
res = x
for i in range(4):
res = self.res_block(res,
c=None, z=z[i],
f=(16 // (2 ** i)) * self.channel,
scale_type="up",
name="res%d" % (i + 1))
Expand Down

0 comments on commit ca912d9

Please sign in to comment.