Skip to content

Commit

Permalink
Add identity loss to avoid color inversion.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 254755026
  • Loading branch information
yashk2810 authored and copybara-github committed Jun 24, 2019
1 parent 6276843 commit ed6dc74
Showing 1 changed file with 47 additions and 20 deletions.
67 changes: 47 additions & 20 deletions site/en/r2/tutorials/generative/cyclegan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -395,9 +395,9 @@
"id": "cGrL73uCd-_M"
},
"source": [
"Import the generator and the discriminator from the Pix2Pix tutorial by installing the [tensorflow_examples](https://github.com/tensorflow/examples) package.\n",
"Import the generator and the discriminator used in [Pix2Pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py) via the installed [tensorflow_examples](https://github.com/tensorflow/examples) package.\n",
"\n",
"The model architecture used in this tutorial is very similar to what was used in [pix2pix](https://www.tensorflow.org/beta/tutorials/generative/pix2pix). Some of the differences are:\n",
"The model architecture used in this tutorial is very similar to what was used in [pix2pix](https://github.com/tensorflow/examples/blob/master/tensorflow_examples/models/pix2pix/pix2pix.py). Some of the differences are:\n",
"* Cyclegan uses [instance normalization](https://arxiv.org/abs/1607.08022) instead of [batch normalization](https://arxiv.org/abs/1502.03167).\n",
"* The [CycleGAN paper](https://arxiv.org/abs/1703.10593) uses a modified `resnet` based generator. This tutorial is using a modified `unet` generator for simplicity.\n",
"\n",
Expand Down Expand Up @@ -444,22 +444,16 @@
"plt.figure(figsize=(8, 8))\n",
"contrast = 8\n",
"\n",
"plt.subplot(221)\n",
"plt.title('Horse')\n",
"plt.imshow(sample_horse[0] * 0.5 + 0.5)\n",
"\n",
"plt.subplot(222)\n",
"plt.title('To Zebra')\n",
"plt.imshow(to_zebra[0] * 0.5 * contrast + 0.5)\n",
"\n",
"plt.subplot(223)\n",
"plt.title('Zebra')\n",
"plt.imshow(sample_zebra[0] * 0.5 + 0.5)\n",
"\n",
"plt.subplot(224)\n",
"plt.title('To Horse')\n",
"plt.imshow(to_horse[0] * 0.5 * contrast + 0.5)\n",
"imgs = [sample_horse, to_zebra, sample_zebra, to_horse]\n",
"title = ['Horse', 'To Zebra', 'Zebra', 'To Horse']\n",
"\n",
"for i in range(len(imgs)):\n",
" plt.subplot(2, 2, i+1)\n",
" plt.title(title[i])\n",
" if i % 2 == 0:\n",
" plt.imshow(imgs[i][0] * 0.5 + 0.5)\n",
" else:\n",
" plt.imshow(imgs[i][0] * 0.5 * contrast + 0.5)\n",
"plt.show()"
]
},
Expand Down Expand Up @@ -606,6 +600,33 @@
" return LAMBDA * loss1"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "U-tJL-fX0Mq7"
},
"source": [
"As shown above, generator $G$ is responsible for translating image $X$ to image $Y$. Identity loss says that, if you fed image $Y$ to generator $G$, it should yield the real image $Y$ or something close to image $Y$.\n",
"\n",
"$$Identity\\ loss = |G(Y) - Y| + |F(X) - X|$$"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "05ywEH680Aud"
},
"outputs": [],
"source": [
"def identity_loss(real_image, same_image):\n",
" loss = tf.reduce_mean(tf.abs(real_image - same_image))\n",
" return LAMBDA * 0.5 * loss"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -754,13 +775,19 @@
" # once to calculate the gradients.\n",
" with tf.GradientTape(persistent=True) as gen_tape, tf.GradientTape(\n",
" persistent=True) as disc_tape:\n",
"\n",
" # Generator G translates X -\u003e Y\n",
" # Generator F translates Y -\u003e X.\n",
" \n",
" fake_y = generator_g(real_x, training=True)\n",
" cycled_x = generator_f(fake_y, training=True)\n",
"\n",
" fake_x = generator_f(real_y, training=True)\n",
" cycled_y = generator_g(fake_x, training=True)\n",
"\n",
" # same_x and same_y are used for identity loss.\n",
" same_x = generator_f(real_x, training=True)\n",
" same_y = generator_g(real_y, training=True)\n",
"\n",
" disc_real_x = discriminator_x(real_x, training=True)\n",
" disc_real_y = discriminator_y(real_y, training=True)\n",
"\n",
Expand All @@ -772,8 +799,8 @@
" gen_f_loss = generator_loss(disc_fake_x)\n",
" \n",
" # Total generator loss = adversarial loss + cycle loss\n",
" total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x)\n",
" total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y)\n",
" total_gen_g_loss = gen_g_loss + calc_cycle_loss(real_x, cycled_x) + identity_loss(real_x, same_x)\n",
" total_gen_f_loss = gen_f_loss + calc_cycle_loss(real_y, cycled_y) + identity_loss(real_y, same_y)\n",
"\n",
" disc_x_loss = discriminator_loss(disc_real_x, disc_fake_x)\n",
" disc_y_loss = discriminator_loss(disc_real_y, disc_fake_y)\n",
Expand Down

0 comments on commit ed6dc74

Please sign in to comment.