Skip to content

Commit

Permalink
parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
laura-gibbs committed May 28, 2021
1 parent 464e8dc commit 6e5df01
Showing 1 changed file with 16 additions and 15 deletions.
31 changes: 16 additions & 15 deletions implementations/dcgan/dcgan.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@

parser = argparse.ArgumentParser()
parser.add_argument("--n_epochs", type=int, default=301, help="number of epochs of training")
parser.add_argument("--batch_size", type=int, default=144, help="size of the batches")
parser.add_argument("--batch_size", type=int, default=256, help="size of the batches")
parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
parser.add_argument("--img_size", type=int, default=64, help="size of each image dimension")
parser.add_argument("--channels", type=int, default=1, help="number of image channels")
parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
parser.add_argument("--sample_interval", type=int, default=2, help="interval between image sampling")
opt = parser.parse_args()
print(opt)

Expand Down Expand Up @@ -136,7 +136,7 @@ def forward(self, img):
# batch_size=opt.batch_size,
# shuffle=True,
CSDataset(
root_dir = '../../../MDT-Calculations/saved_tiles' + '/training/tiles_32',
root_dir = '../../../MDT-Calculations/saved_tiles' + '/training/mdt_tiles',
transform=transforms.Compose([
transforms.Resize(opt.img_size),
transforms.ToTensor(),
Expand Down Expand Up @@ -204,20 +204,21 @@ def forward(self, img):
% (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
)
# Add code to resize image up
batches_done = epoch * len(dataloader) + i
if batches_done % opt.sample_interval == 0:
# batches_done = epoch * len(dataloader) + i
if epoch % opt.sample_interval == 0:
gen_imgs = F.interpolate(gen_imgs, (32,32), mode='bilinear')
save_image(gen_imgs.data[:144], "images/%d.png" % epoch, nrow=12, normalize=True)
# if epoch == 300:
# print("running")
# with torch.no_grad():
# for j in range(700):
# z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
# gen_imgs = generator(z)
# gen_imgs = F.interpolate(gen_imgs, (32,32), mode='bilinear')
# for i in range(144):
# # print("running")
# save_image(gen_imgs.data[i].unsqueeze(0), "tiles_resize/%d_"% epoch + str(j) + '_' + str(i) + ".png", normalize=True)
if epoch == 300:
print(epoch, "running")
with torch.no_grad():
for j in range(3000):
z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
gen_imgs = generator(z)
gen_imgs = F.interpolate(gen_imgs, (32,32), mode='bilinear')
for k in range(opt.batch_size):
# print("running")
save_image(gen_imgs.data[k].unsqueeze(0), "mdt_tiles/%d_"% epoch + str(k) + '_' + str(j) + ".png", normalize=True)
# save_image(gen_img[k], f"gen_tiles/tile_{j*opt.batch_size+k}.png", normalize=True)

# Save tiles code from home PC
# save_image(gen_imgs.data[:64], "images/%d.png" % epoch, nrow=8, normalize=True)
Expand Down

0 comments on commit 6e5df01

Please sign in to comment.