Skip to content

Commit

Permalink
change to use config file for training
Browse files Browse the repository at this point in the history
  • Loading branch information
yeyun111 committed Sep 8, 2017
1 parent a2750f0 commit d14ff4c
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 67 deletions.
83 changes: 46 additions & 37 deletions random_bonus/unet_segmentation/argparser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import os
import argparse
import torch.optim as optim


def parse_param_file(filepath):
with open(filepath, 'r') as f:
kw_exprs = [x.strip() for x in f.readlines() if x.strip()]
return eval('dict({})'.format(','.join(kw_exprs)))


def parse_args():
Expand All @@ -15,54 +20,58 @@ def parse_args():
help='Directory containing training images in "images" and "segmentations" or test images')
parser.add_argument('--cpu',
help='Set to CPU mode', action='store_true')
parser.add_argument('--color_labels',
help='Colors of labels in segmentation image',
type=str, default='(0,0,0),(255,255,255)')
parser.add_argument('--image-width',
help='width of image',
type=int, default=256)
parser.add_argument('--image-height',
help='height of image',
type=int, default=256)
parser.add_argument('--output-dir',
help='Directory of output for both train/test',
type=str, default='')
parser.add_argument('--no-data-aug',
help='Disable data-augmentation', action='store_true')

# training options
parser.add_argument('--img-dir',
help='Directory under [dataroot] containing images',
type=str, default='images')
parser.add_argument('--seg-dir',
help='Directory under [dataroot] containing segmentations',
type=str, default='segmentations')
parser.add_argument('--epochs',
help='Num of training epochs',
type=int, default=20)
parser.add_argument('--batch-size',
help='Batch size',
type=int, default=4)
parser.add_argument('--lr',
help='Learning rate, for Adadelta it is the base learning rate',
type=float, default=0.0002)
parser.add_argument('--lr-policy',
help='Learning rate policy, example:"5:0.0005,10:0.0001,18:1e-5"',
# train options
parser.add_argument('--config',
help='Path to config file',
type=str, default='')
parser.add_argument('--no-batchnorm',
help='Do NOT use batch normalization', action='store_true')
parser.add_argument('--print-interval',
help='Print info after each specified iterations',
type=int, default=20)

# test options
parser.add_argument('--model',
help='Path to pre-trained model',
type=str, default='')

args = parser.parse_args()

# other params specified in config file
if args.mode == 'train':
kwargs = parse_param_file(args.config)

# default: no augmentation, with batch-norm
params = {
# general params
'color_labels': [],

# training params
'image_width': 256,
'image_height': 256,
'lr_policy': {0: 1e-4},
'momentum': 0.9,
'nesterov': True,
'batch_norm': True,
'batch_size': 4,
'epochs': 24,
'print_interval': 20,
'random_horizontal_flip': False,
'random_square_crop': False,
'random_crop': None, # example: (0.81, 0.1), use 0.81 as area ratio, & 0.1 as the hw ratio variation
'random_rotation': 0,
'img_dir': 'images',
'seg_dir': 'segmentations'
}

# update params from config
for k, v in kwargs.items():
if k in params:
params[k] = v

# set params to args
for k, v in params.items():
setattr(args, k, v)

args.dataroot = args.dataroot.rstrip(os.sep)
args.color_labels = eval('[{}]'.format(args.color_labels))
args.lr_policy = eval('{{{}}}'.format(args.lr_policy)) if args.lr_policy else {}

return args
12 changes: 12 additions & 0 deletions random_bonus/unet_segmentation/example.unetparams
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
lr_policy={0: 1e-3, 1: 1e-4, 2: 5e-5, 5: 2e-5, 10: 1e-5, 18: 2e-6}
batch_size=4
epochs=24
print_interval=20
color_labels=[(0, 0, 255), (0, 255, 0), (255, 0, 0)]
image_width=64
image_height=64
random_horizontal_flip=True
random_square_crop=True
random_rotation=3
img_dir='images'
seg_dir='profiles'
48 changes: 23 additions & 25 deletions random_bonus/unet_segmentation/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,50 +30,48 @@ def train(args):
console.setFormatter(formatter)
logging.getLogger('').addHandler(console)

logging.info('=========== Taks {} started! ==========='.format(args.output_dir))
for arg in vars(args):
logging.info('{}: {}'.format(arg, getattr(args, arg)))
logging.info('========================================')

# initialize loader
if args.no_data_aug:
train_set = utils.SegmentationImageFolder(os.sep.join([args.dataroot, 'train']),
image_folder=args.img_dir,
segmentation_folder=args.seg_dir,
labels=args.color_labels,
image_size=(args.image_width, args.image_height))
else:
train_set = utils.SegmentationImageFolder(os.sep.join([args.dataroot, 'train']),
image_folder=args.img_dir,
segmentation_folder=args.seg_dir,
labels=args.color_labels,
image_size=(args.image_width, args.image_height),
random_horizontal_flip=True,
random_rotation=1,
random_crop=(0.85, 0.1))
train_set = utils.SegmentationImageFolder(os.sep.join([args.dataroot, 'train']),
image_folder=args.img_dir,
segmentation_folder=args.seg_dir,
labels=args.color_labels,
image_size=(args.image_width, args.image_height),
random_horizontal_flip=args.random_horizontal_flip,
random_rotation=args.random_rotation,
random_crop=args.random_crop,
random_square_crop=args.random_square_crop)
val_set = utils.SegmentationImageFolder(os.sep.join([args.dataroot, 'val']),
image_folder=args.img_dir,
segmentation_folder=args.seg_dir,
labels=args.color_labels,
image_size=(args.image_width, args.image_height))
image_size=(args.image_width, args.image_height),
random_square_crop=args.random_square_crop)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.batch_size, shuffle=True)

# initialize model, input channels need to be calculated by hand
model = networks.UNet([32, 64, 128, 256, 512], 3, 2)
model = networks.UNet([32, 64, 128, 256, 512], 3, len(args.color_labels), use_bn=args.batch_norm)
if not args.cpu:
model.cuda()

criterion = utils.CrossEntropyLoss2D()

# optimizer & lr policy
lr = args.lr
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)
logging.info('| Learning Rate\t| Initialized learning rate: {}'.format(lr))

# train
for epoch in range(args.epochs):
model.train()
# update lr if lr_policy is defined
# update lr according to lr policy
if epoch in args.lr_policy:
lr = args.lr_policy[epoch]
optimizer = optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)
logging.info('| Learning Rate\t| Epoch: {}\t| Change learning rate to {}'.format(epoch+1, lr))
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, nesterov=args.nesterov)
if epoch > 0:
logging.info('| Learning Rate\t| Epoch: {}\t| Change learning rate to {}'.format(epoch+1, lr))
else:
logging.info('| Learning Rate\t| Initial learning rate: {}'.format(lr))

# iterate all samples
losses = utils.AverageMeter()
Expand Down
18 changes: 13 additions & 5 deletions random_bonus/unet_segmentation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def __init__(self, root,
random_horizontal_flip=False,
random_rotation=None,
random_crop=None,
random_square_crop=False,
loader=default_loader):
super(SegmentationImageFolder, self).__init__(root, loader=loader)
pair_len = len(self.imgs) / 2
Expand All @@ -48,6 +49,7 @@ def __init__(self, root,
self.flip_lr = random_horizontal_flip
self.random_rotation = random_rotation
self.random_crop = random_crop
self.random_square_crop = random_square_crop

def __getitem__(self, index):
"""
Expand All @@ -62,11 +64,6 @@ def __getitem__(self, index):
seg = self.loader(segpath)

# manually transform to incorporate horizontal flip & one-hot coding for segmentation labels
if (self.random_rotation or self.random_crop) and self.image_size:
w, h = self.image_size
img = img.resize((w*2, h*2))
seg = seg.resize((w*2, h*2), Image.NEAREST)

if self.random_rotation:
w, h = img.size
angle = self.random_rotation % 360
Expand Down Expand Up @@ -107,6 +104,17 @@ def __getitem__(self, index):
img = img.crop((x0, y0, x0+w_crop, y0+h_crop))
seg = seg.crop((x0, y0, x0+w_crop, y0+h_crop))

if self.random_square_crop:
w, h = img.size
if w > h:
x0 = random.randint(0, w-h-1)
img = img.crop((x0, 0, x0+h, h))
seg = seg.crop((x0, 0, x0+h, h))
elif w < h:
y0 = random.randint(0, h-w-1)
img = img.crop((0, y0, w, y0+w))
seg = seg.crop((0, y0, w, y0+w))

if self.image_size:
img = img.resize(self.image_size)
seg = seg.resize(self.image_size, Image.NEAREST)
Expand Down

0 comments on commit d14ff4c

Please sign in to comment.