diff --git a/dcgan/main.py b/dcgan/main.py index 8847017453..a65874cb11 100644 --- a/dcgan/main.py +++ b/dcgan/main.py @@ -198,9 +198,6 @@ def forward(self, input): input, label = input.cuda(), label.cuda() noise, fixed_noise = noise.cuda(), fixed_noise.cuda() -input = Variable(input) -label = Variable(label) -noise = Variable(noise) fixed_noise = Variable(fixed_noise) # setup optimizer @@ -216,21 +213,25 @@ def forward(self, input): netD.zero_grad() real_cpu, _ = data batch_size = real_cpu.size(0) - input.data.resize_(real_cpu.size()).copy_(real_cpu) - label.data.resize_(batch_size).fill_(real_label) - - output = netD(input) - errD_real = criterion(output, label) + if opt.cuda: + real_cpu = real_cpu.cuda() + input.resize_as_(real_cpu).copy_(real_cpu) + label.resize_(batch_size).fill_(real_label) + inputv = Variable(input) + labelv = Variable(label) + + output = netD(inputv) + errD_real = criterion(output, labelv) errD_real.backward() D_x = output.data.mean() # train with fake - noise.data.resize_(batch_size, nz, 1, 1) - noise.data.normal_(0, 1) - fake = netG(noise) - label.data.fill_(fake_label) + noise.resize_(batch_size, nz, 1, 1).normal_(0, 1) + noisev = Variable(noise) + fake = netG(noisev) + labelv = Variable(label.fill_(fake_label)) output = netD(fake.detach()) - errD_fake = criterion(output, label) + errD_fake = criterion(output, labelv) errD_fake.backward() D_G_z1 = output.data.mean() errD = errD_real + errD_fake @@ -240,9 +241,9 @@ def forward(self, input): # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() - label.data.fill_(real_label) # fake labels are real for generator cost + labelv = Variable(label.fill_(real_label)) # fake labels are real for generator cost output = netD(fake) - errG = criterion(output, label) + errG = criterion(output, labelv) errG.backward() D_G_z2 = output.data.mean() optimizerG.step()