Skip to content

Commit

Permalink
fixing wgan-gp
Browse files Browse the repository at this point in the history
  • Loading branch information
rfelixmg committed Oct 23, 2018
1 parent 0434b75 commit 31b75e4
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions models/wgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def __init__(self, hparams):

def cwgan_loss(self):
alpha = tf.random_uniform(shape=tf.shape(self.generator.output), minval=0., maxval=1.)
interpolation = alpha * self.generator.output + (1. - alpha) * self.generator.output
interpolation = alpha * self.discriminator.x + (1. - alpha) * self.generator.output

d_input = tf.concat([interpolation, self.generator.a], -1)
grad = tf.gradients(self.discriminator.forward(d_input), [interpolation])[0]
grad_norm = tf.norm(grad, axis=1, ord='euclidean')
self.grad_pen = self.lmbda * tf.reduce_mean(tf.square(grad_norm - 1))

return self.d_real - (self.d_fake + self.aux_loss) + self.grad_pen
return self.d_real - self.d_fake + self.grad_pen

__MODEL__=WGAN

0 comments on commit 31b75e4

Please sign in to comment.