Skip to content

Commit

Permalink
Add support for CUDA
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin authored and soumith committed Oct 1, 2017
1 parent ab7cb38 commit 5f24730
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions vae/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test(epoch):
recon_batch, mu, logvar = model(data)
test_loss += loss_function(recon_batch, data, mu, logvar).data[0]
if i == 0:
save_image(recon_batch.data.view(args.batch_size, 1, 28, 28),
save_image(recon_batch.data.cpu().view(args.batch_size, 1, 28, 28),
'reconstruction_' + str(epoch) + '.png')

test_loss /= len(test_loader.dataset)
Expand All @@ -141,5 +141,8 @@ def test(epoch):
for epoch in range(1, args.epochs + 1):
train(epoch)
test(epoch)
sample = model.decode(Variable(torch.randn(64, 20)))
sample = Variable(torch.randn(64, 20))
if args.cuda:
sample = sample.cuda()
sample = model.decode(sample).cpu()
save_image(sample.data.view(64, 1, 28, 28), 'sample_' + str(epoch) + '.png')

0 comments on commit 5f24730

Please sign in to comment.