-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathmain.py
115 lines (99 loc) · 3.59 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse, os
from GAN import GAN
from CGAN import CGAN
from WGAN import WGAN
from VAE import VAE
from LSGAN import LSGAN
from CVAE import CVAE
from WGAN_GP import WGAN_GP
from EBGAN import EBGAN
from infoGAN import infoGAN
from ACGAN import ACGAN
from SAGAN import SAGAN
"""parsing and configuration"""
def parse_args():
desc = "Pytorch implementation of GAN collections"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--gan_type', type=str, default='EBGAN',
choices=['GAN', 'CGAN','VAE', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'WGAN_GP',
'DRAGAN', 'LSGAN', 'CVAE', 'SAGAN'],
help='The type of GAN')#, required=True)
parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'celebA'],
help='The name of dataset')
parser.add_argument('--epoch', type=int, default=25, help='The number of epochs to run')
parser.add_argument('--batch_size', type=int, default=64, help='The size of batch')
parser.add_argument('--save_dir', type=str, default='models',
help='Directory name to save the model')
parser.add_argument('--result_dir', type=str, default='results',
help='Directory name to save the generated images')
parser.add_argument('--log_dir', type=str, default='logs',
help='Directory name to save training logs')
parser.add_argument('--lrG', type=float, default=0.0002)
parser.add_argument('--lrD', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5)
parser.add_argument('--beta2', type=float, default=0.999)
parser.add_argument('--gpu_mode', type=bool, default=True)
return check_args(parser.parse_args())
"""checking arguments"""
def check_args(args):
# --save_dir
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
# --result_dir
if not os.path.exists(args.result_dir):
os.makedirs(args.result_dir)
# --result_dir
if not os.path.exists(args.log_dir):
os.makedirs(args.log_dir)
# --epoch
try:
assert args.epoch >= 1
except:
print('number of epochs must be larger than or equal to one')
# --batch_size
try:
assert args.batch_size >= 1
except:
print('batch size must be larger than or equal to one')
return args
"""main"""
def main():
# parse arguments
args = parse_args()
if args is None:
exit()
# declare instance for GAN
if args.gan_type == 'GAN':
gan = GAN(args)
elif args.gan_type == 'CGAN':
gan = CGAN(args)
elif args.gan_type == 'WGAN':
gan = WGAN(args)
elif args.gan_type == 'VAE':
gan = VAE(args)
elif args.gan_type == 'LSGAN':
gan = LSGAN(args)
elif args.gan_type == 'CVAE':
gan = CVAE(args)
elif args.gan_type == 'WGAN_GP':
gan = WGAN_GP(args)
elif args.gan_type == 'LSGAN':
gan = LSGAN(args)
elif args.gan_type == 'EBGAN':
gan = EBGAN(args)
elif args.gan_type == 'infoGAN':
gan = infoGAN(args)
elif args.gan_type == 'ACGAN':
gan = ACGAN(args)
elif args.gan_type == 'SAGAN':
gan = SAGAN(args)
else:
raise Exception("[!] There is no option for " + args.gan_type)
# launch the graph in a session
gan.train()
print(" [*] Training finished!")
# visualize learned generator
gan.visualize_results(args.epoch)
print(" [*] Testing finished!")
if __name__ == '__main__':
main()