Skip to content

Commit

Permalink
Uniform variable names to match the slides
Browse files Browse the repository at this point in the history
  • Loading branch information
Atcold committed Nov 21, 2018
1 parent a9a8827 commit 8631437
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions 11-VAE.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
" nn.Sigmoid(),\n",
" )\n",
"\n",
" def reparameterize(self, mu, logvar):\n",
" def reparameterise(self, mu, logvar):\n",
" if self.training:\n",
" std = logvar.mul(0.5).exp_()\n",
" eps = std.data.new(std.size()).normal_()\n",
Expand All @@ -119,7 +119,7 @@
" mu_logvar = self.encoder(x.view(-1, 784)).view(-1, 2, d)\n",
" mu = mu_logvar[:, 0, :]\n",
" logvar = mu_logvar[:, 1, :]\n",
" z = self.reparameterize(mu, logvar)\n",
" z = self.reparameterise(mu, logvar)\n",
" return self.decoder(z), mu, logvar\n",
"\n",
"model = VAE().to(device)"
Expand Down Expand Up @@ -147,9 +147,9 @@
"source": [
"# Reconstruction + KL divergence losses summed over all elements and batch\n",
"\n",
"def loss_function(recon_x, x, mu, logvar):\n",
"def loss_function(x_hat, x, mu, logvar):\n",
" BCE = nn.functional.binary_cross_entropy(\n",
" recon_x, x.view(-1, 784), size_average=False\n",
" x_hat, x.view(-1, 784), reduction='sum'\n",
" )\n",
" KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
"\n",
Expand All @@ -171,11 +171,11 @@
" # Training\n",
" model.train()\n",
" train_loss = 0\n",
" for data, _ in train_loader:\n",
" data = data.to(device)\n",
" for x, _ in train_loader:\n",
" x = x.to(device)\n",
" # ===================forward=====================\n",
" recon_batch, mu, logvar = model(data)\n",
" loss = loss_function(recon_batch, data, mu, logvar)\n",
" x_hat, mu, logvar = model(x)\n",
" loss = loss_function(x_hat, x, mu, logvar)\n",
" train_loss += loss.item()\n",
" # ===================backward====================\n",
" optimizer.zero_grad()\n",
Expand All @@ -189,15 +189,15 @@
" with torch.no_grad():\n",
" model.eval()\n",
" test_loss = 0\n",
" for data, _ in test_loader:\n",
" data = data.to(device)\n",
" for x, _ in test_loader:\n",
" x = x.to(device)\n",
" # ===================forward=====================\n",
" recon_batch, mu, logvar = model(data)\n",
" test_loss += loss_function(recon_batch, data, mu, logvar).item()\n",
" x_hat, mu, logvar = model(x)\n",
" test_loss += loss_function(x_hat, x, mu, logvar).item()\n",
" # ===================log========================\n",
" test_loss /= len(test_loader.dataset)\n",
" print(f'====> Test set loss: {test_loss:.4f}')\n",
" display_images(data, recon_batch, 1, f'Epoch {epoch}')"
" display_images(x, x_hat, 1, f'Epoch {epoch}')"
]
},
{
Expand All @@ -209,7 +209,7 @@
"# Generating a few samples\n",
"\n",
"N = 16\n",
"sample = torch.randn((N, 20), requires_grad=False).to(device)\n",
"sample = torch.randn((N, d)).to(device)\n",
"sample = model.decoder(sample)\n",
"display_images(None, sample, N // 4, count=True)"
]
Expand All @@ -222,7 +222,7 @@
"source": [
"# Display last test batch\n",
"\n",
"display_images(None, data, 4, count=True)"
"display_images(None, x, 4, count=True)"
]
},
{
Expand All @@ -233,11 +233,11 @@
"source": [
"# Choose starting and ending point for the interpolation -> shows original and reconstructed\n",
"\n",
"A, B = 5, 14\n",
"A, B = 3, 8\n",
"sample = model.decoder(torch.stack((mu[A].data, mu[B].data), 0))\n",
"display_images(None, torch.stack(((\n",
" data[A].data.view(-1),\n",
" data[B].data.view(-1),\n",
" x[A].data.view(-1),\n",
" x[B].data.view(-1),\n",
" sample.data[0],\n",
" sample.data[1]\n",
")), 0))"
Expand All @@ -254,7 +254,7 @@
"N = 16\n",
"code = torch.Tensor(N, 20).to(device)\n",
"for i in range(N):\n",
" code[i] = i / N * mu[B].data + (1 - i / N) * mu[A].data\n",
" code[i] = i / (N - 1) * mu[B].data + (1 - i / (N - 1) ) * mu[A].data\n",
"code = torch.tensor(code, requires_grad=True)\n",
"sample = model.decoder(code)\n",
"display_images(None, sample, N // 4, count=True)"
Expand All @@ -263,9 +263,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "latest",
"display_name": "Python 3",
"language": "python",
"name": "latest"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand Down

0 comments on commit 8631437

Please sign in to comment.