Skip to content

Commit

Permalink
Added the ability to and a config to disable certain data augmentatio…
Browse files Browse the repository at this point in the history
…n steps.
  • Loading branch information
dbolya committed Mar 3, 2019
1 parent 1d67771 commit b76da35
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
18 changes: 18 additions & 0 deletions data/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,16 @@ def print(self):
'mask_proto_double_loss': False,
'mask_proto_double_loss_alpha': 1,

# SSD data augmentation parameters
# Randomize hue, vibrance, etc.
'augment_photometric_distort': True,
# Have a chance to scale down the image and pad (to emulate smaller detections)
'augment_expand': True,
# Potentialy sample a random crop from the image and put it in a random place
'augment_random_sample_crop': True,
# Mirror the image with a probability of 1/2
'augment_random_mirror': True,

# If using batchnorm anywhere in the backbone, freeze the batchnorm layer during training.
# Note: any additional batch norm layers after the backbone will not be frozen.
'freeze_bn': False,
Expand Down Expand Up @@ -1055,6 +1065,14 @@ def print(self):
'dataset': coco2017_dataset,
})

yrm35_noaug_config = yrm35_moredata_config.copy({
'name': 'yrm35_noaug',

'augment_expand': False,
'augment_random_mirror': False,
'augment_random_sample_crop': False,
})

yrm35_resnet50_config = yrm35_moredata_config.copy({
'name': 'yrm35_resnet50',

Expand Down
13 changes: 9 additions & 4 deletions utils/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,19 +627,24 @@ def __call__(self, img):
# Return value is in channel order [n, c, h, w] and RGB
return img

def do_nothing(img=None, masks=None, boxes=None, labels=None):
return img, masks, boxes, labels


def enable_if(condition, obj):
return obj if condition else do_nothing

class SSDAugmentation(object):
""" Transform to be used when training. """

def __init__(self, mean=MEANS, std=STD):
self.augment = Compose([
ConvertFromInts(),
ToAbsoluteCoords(),
PhotometricDistort(),
Expand(mean),
RandomSampleCrop(),
RandomMirror(),
enable_if(cfg.augment_photometric_distort, PhotometricDistort()),
enable_if(cfg.augment_expand, Expand(mean)),
enable_if(cfg.augment_random_sample_crop, RandomSampleCrop()),
enable_if(cfg.augment_random_mirror, RandomMirror()),
Resize(),
Pad(cfg.max_size, cfg.max_size, mean),
ToPercentCoords(),
Expand Down

0 comments on commit b76da35

Please sign in to comment.