-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathconfigLGAN.py
76 lines (62 loc) · 3.31 KB
/
configLGAN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
from utils import ensure_dirs
import argparse
import json
import shutil
class ConfigLGAN(object):
def __init__(self):
self.set_configuration()
# parse command line arguments
parser, args = self.parse()
# set as attributes
print("----Experiment Configuration-----")
for k, v in args.__dict__.items():
print("{0:20}".format(k), v)
self.__setattr__(k, v)
# experiment paths
self.data_root = os.path.join(args.proj_dir, args.exp_name, "results/all_zs_ckpt{}.h5".format(args.ae_ckpt))
self.exp_dir = os.path.join(args.proj_dir, args.exp_name, "lgan_{}".format(args.ae_ckpt))
self.log_dir = os.path.join(self.exp_dir, 'log')
self.model_dir = os.path.join(self.exp_dir, 'model')
if (not args.test) and args.cont is not True and os.path.exists(self.exp_dir):
response = input('Experiment log/model already exists, overwrite? (y/n) ')
if response != 'y':
exit()
shutil.rmtree(self.exp_dir)
ensure_dirs([self.log_dir, self.model_dir])
# GPU usage
if args.gpu_ids is not None:
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_ids)
# save this configuration
if not args.test:
with open('{}/config.txt'.format(self.exp_dir), 'w') as f:
json.dump(self.__dict__, f, indent=2)
def set_configuration(self):
# network configuration
self.n_dim = 64
self.h_dim = 512
self.z_dim = 256
# WGAN-gp configuration
self.beta1 = 0.5
self.critic_iters = 5
self.gp_lambda = 10
def parse(self):
"""initiaize argument parser. Define default hyperparameters and collect from command-line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('--proj_dir', type=str, default="proj_log",
help="path to project folder where models and logs will be saved")
parser.add_argument('--exp_name', type=str, required=True, help="name of this experiment")
parser.add_argument('--ae_ckpt', type=str, required=True, help="ckpt for autoencoder")
parser.add_argument('--continue', dest='cont', action='store_true', help="continue training from checkpoint")
parser.add_argument('--ckpt', type=str, default='latest', required=False, help="desired checkpoint to restore")
parser.add_argument('--test', action='store_true', help="test mode")
parser.add_argument('--n_samples', type=int, default=100, help="number of samples to generate when testing")
parser.add_argument('-g', '--gpu_ids', type=str, default="0",
help="gpu to use, e.g. 0 0,1,2. CPU not supported.")
parser.add_argument('--batch_size', type=int, default=256, help="batch size")
parser.add_argument('--num_workers', type=int, default=8, help="number of workers for data loading")
parser.add_argument('--n_iters', type=int, default=200000, help="total number of iterations to train")
parser.add_argument('--save_frequency', type=int, default=100000, help="save models every x iterations")
parser.add_argument('--lr', type=float, default=2e-4, help="initial learning rate")
args = parser.parse_args()
return parser, args