forked from AliaksandrSiarohin/pose-gan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cmd.py
90 lines (65 loc) · 4.88 KB
/
cmd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import argparse
def args():
"""
Define args that is used in project
"""
parser = argparse.ArgumentParser(description="Pose guided image generation usign deformable skip layers")
parser.add_argument("--output_dir", default='output/displayed_samples', help="Directory with generated sample images")
parser.add_argument("--batch_size", default=4, type=int, help='Size of the batch')
parser.add_argument("--training_ratio", default=1, type=int,
help="The training ratio is the number of discriminator updates per generator update.")
parser.add_argument("--l1_penalty_weight", default=100, type=float, help='Weight of l1 loss')
parser.add_argument('--gan_penalty_weight', default=1, type=float, help='Weight of GAN loss')
parser.add_argument('--tv_penalty_weight', default=0, type=float, help='Weight of total variation loss')
parser.add_argument('--lstruct_penalty_weight', default=0, type=float, help="Weight of lstruct")
parser.add_argument("--number_of_epochs", default=500, type=int, help="Number of training epochs")
parser.add_argument("--content_loss_layer", default='none', help='Name of content layer (vgg19)'
' e.g. block4_conv1 or none')
parser.add_argument("--checkpoints_dir", default="output/checkpoints", help="Folder with checkpoints")
parser.add_argument("--checkpoint_ratio", default=30, type=int, help="Number of epochs between consecutive checkpoints")
parser.add_argument("--generator_checkpoint", default=None, help="Previosly saved model of generator")
parser.add_argument("--discriminator_checkpoint", default=None, help="Previosly saved model of discriminator")
parser.add_argument("--nn_loss_area_size", default=1, type=int, help="Use nearest neighbour loss")
parser.add_argument("--use_validation", default=1, type=int, help="Use validation")
parser.add_argument('--dataset', default='market', choices=['market', 'fasion', 'prw', 'fasion128', 'fasion128128'],
help='Market, fasion or prw')
parser.add_argument("--display_ratio", default=1, type=int, help='Number of epochs between ploting')
parser.add_argument("--start_epoch", default=0, type=int, help='Start epoch for starting from checkpoint')
parser.add_argument("--pose_estimator", default='pose_estimator.h5',
help='Pretrained model for cao pose estimator')
parser.add_argument("--images_for_test", default=12000, type=int, help="Number of images for testing")
parser.add_argument("--use_input_pose", default=True, type=int, help='Feed to generator input pose')
parser.add_argument("--warp_skip", default='stn', choices=['none', 'full', 'mask', 'stn'],
help="Type of warping skip layers to use.")
parser.add_argument("--warp_agg", default='max', choices=['max', 'avg'],
help="Type of aggregation.")
parser.add_argument("--disc_type", default='call', choices=['call', 'sim', 'warp'],
help="Type of discriminator call - concat all, sim - siamease, sharewarp - warp.")
parser.add_argument("--use_bg", default=0, type=int, help='Use background images only for prw dataset')
parser.add_argument("--pose_rep_type", default='hm', choices=['hm', 'stickman'],
help='Representation of the pose, hm - heatmap, stickman - like in vunet (https://github.com/CompVis/vunet).')
parser.add_argument("--cache_pose_rep", default=1, type=int, help="Cache pose representation on disk.")
parser.add_argument("--generated_images_dir", default='output/generated_images',
help='Folder with generated images from training dataset')
parser.add_argument('--load_generated_images', default=0, type=int,
help='Load images from generated_images_dir or generate')
parser.add_argument('--use_dropout_test', default=0, type=int,
help='To use dropout when generate images')
args = parser.parse_args()
args.images_dir_train = 'data/' + args.dataset + '-dataset/train'
args.images_dir_test = 'data/' + args.dataset + '-dataset/test'
args.annotations_file_train = 'data/' + args.dataset + '-annotation-train.csv'
args.annotations_file_test = 'data/' + args.dataset + '-annotation-test.csv'
args.pairs_file_train = 'data/' + args.dataset + '-pairs-train.csv'
args.pairs_file_test = 'data/' + args.dataset + '-pairs-test.csv'
args.bg_images_dir_train = 'data/' + args.dataset + '-dataset/train-bg'
args.bg_images_dir_test = 'data/' + args.dataset + '-dataset/test-bg'
if args.dataset == 'fasion':
args.image_size = (256, 256)
elif args.dataset == 'fasion128128':
args.image_size = (128, 128)
else:
args.image_size = (128, 64)
args.tmp_pose_dir = 'tmp/' + args.dataset + '/'
del args.dataset
return args