Skip to content

Commit

Permalink
implement VAE
Browse files Browse the repository at this point in the history
  • Loading branch information
y0ast authored and Joost van Amersfoort committed Dec 22, 2016
1 parent 4ba9ae6 commit 1174146
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
mnist/data
VAE/data
*.pyc
13 changes: 13 additions & 0 deletions VAE/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Basic VAE Example

This is an improved implementation of the paper [Stochastic Gradient VB and the
Variational Auto-Encoder](http://arxiv.org/abs/1312.6114) by Kingma and Welling.
It uses ReLUs and the adam optimizer, instead of sigmoids and adagrad. These changes make the network converge much faster.

We reuse the data preparation script of the MNIST experiment

```bash
pip install -r requirements.txt
python ../mnist/data.py
python main.py
```
134 changes: 134 additions & 0 deletions VAE/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import print_function
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

cuda = torch.cuda.is_available()

print('Running with CUDA: {0}'.format(cuda))


def print_header(msg):
print('===>', msg)


assert os.path.exists('data/processed/training.pt'), \
"Please run python ../mnist/data.py before starting the VAE."

# Data
print_header('Loading data')
with open('data/processed/training.pt', 'rb') as f:
training_set = torch.load(f)
with open('data/processed/test.pt', 'rb') as f:
test_set = torch.load(f)

training_data = training_set[0].view(-1, 784).div(255)
test_data = test_set[0].view(-1, 784).div(255)

del training_set
del test_set

# Model
print_header('Building model')


class VAE(nn.Container):
def __init__(self):
super().__init__()

self.fc1 = nn.Linear(784, 400)
self.relu = nn.ReLU()
self.fc21 = nn.Linear(400, 20)
self.fc22 = nn.Linear(400, 20)
self.fc3 = nn.Linear(20, 400)
self.fc4 = nn.Linear(400, 784)
self.sigmoid = nn.Sigmoid()

def encode(self, x):
h1 = self.relu(self.fc1(x))
return self.fc21(h1), self.fc22(h1)

def reparametrize(self, mu, logvar):
std = logvar.mul(0.5).exp_()
eps = Variable(torch.randn(std.size()), requires_grad=False)
return eps.mul(std).add_(mu)

def decode(self, z):
h3 = self.relu(self.fc3(z))
return self.sigmoid(self.fc4(h3))

def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparametrize(mu, logvar)
return self.decode(z), mu, logvar


model = VAE()
if cuda is True:
model.cuda()

reconstruction_function = nn.BCELoss()
reconstruction_function.size_average = False


def loss_function(recon_x, x, mu, logvar):
BCE = reconstruction_function(recon_x, x)

# Appendix B from VAE paper: 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
KLD = torch.sum(KLD_element).mul_(-0.5)

return BCE + KLD


# Training settings
BATCH_SIZE = 150
TEST_BATCH_SIZE = 1000
NUM_EPOCHS = 2

optimizer = optim.Adam(model.parameters(), lr=1e-3)


def train(epoch):
batch_data_t = torch.FloatTensor(BATCH_SIZE, 784)
if cuda:
batch_data_t = batch_data_t.cuda()
batch_data = Variable(batch_data_t, requires_grad=False)
for i in range(0, training_data.size(0), BATCH_SIZE):
optimizer.zero_grad()
batch_data.data[:] = training_data[i:i + BATCH_SIZE]
recon_batch_data, mu, logvar = model(batch_data)
loss = loss_function(recon_batch_data, batch_data, mu, logvar)
loss.backward()
loss = loss.data[0]
optimizer.step()
if i % 10 == 0:
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(
epoch,
i + BATCH_SIZE, training_data.size(0),
float(i + BATCH_SIZE) / training_data.size(0) * 100,
loss / BATCH_SIZE))


def test(epoch):
test_loss = 0
batch_data_t = torch.FloatTensor(TEST_BATCH_SIZE, 784)
if cuda:
batch_data_t = batch_data_t.cuda()
batch_data = Variable(batch_data_t, volatile=True)
for i in range(0, test_data.size(0), TEST_BATCH_SIZE):
print('Testing model: {}/{}'.format(i, test_data.size(0)), end='\r')
batch_data.data[:] = test_data[i:i + TEST_BATCH_SIZE]
recon_batch_data, mu, logvar = model(batch_data)
test_loss += loss_function(recon_batch_data, batch_data, mu, logvar)

test_loss = test_loss.data[0] / test_data.size(0)
print('TEST SET RESULTS:' + ' ' * 20)
print('Average loss: {:.4f}'.format(test_loss))


for epoch in range(1, NUM_EPOCHS + 1):
train(epoch)
test(epoch)
3 changes: 3 additions & 0 deletions VAE/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
torch
tqdm
six

0 comments on commit 1174146

Please sign in to comment.