Skip to content

Commit

Permalink
Split the batching logic from build_detection_train_loader
Browse files Browse the repository at this point in the history
Summary: to be reused in customized dataloader

Reviewed By: rbgirshick

Differential Revision: D21694836

fbshipit-source-id: caf1a48e6d259e753b09628bb631eaa1bc20c7ab
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed May 22, 2020
1 parent bf11a9b commit de09842
Showing 1 changed file with 58 additions and 39 deletions.
97 changes: 58 additions & 39 deletions detectron2/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,56 @@ def get_detection_dataset_dicts(
return dataset_dicts


def build_batch_data_loader(
dataset, sampler, total_batch_size, *, aspect_ratio_grouping=False, num_workers=0
):
"""
Build a batched dataloader for training.
Args:
dataset (torch.utils.data.Dataset): map-style PyTorch dataset. Can be indexed.
sampler (torch.utils.data.sampler.Sampler): a sampler that produces indices
total_batch_size (int): total batch size across GPUs.
aspect_ratio_grouping (bool): whether to group images with similar
aspect ratio for efficiency. When enabled, it requires each
element in dataset be a dict with keys "width" and "height".
num_workers (int): number of parallel data loading workers
Returns:
iterable[list]. Length of each list is the batch size of the current
GPU. Each element in the list comes from the dataset.
"""
world_size = get_world_size()
assert (
total_batch_size > 0 and total_batch_size % world_size == 0
), "Total batch size ({}) must be divisible by the number of gpus ({}).".format(
total_batch_size, world_size
)

batch_size = total_batch_size // world_size
if aspect_ratio_grouping:
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
num_workers=num_workers,
batch_sampler=None,
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
worker_init_fn=worker_init_reset_seed,
) # yield individual mapped dict
return AspectRatioGroupedDataset(data_loader, batch_size)
else:
batch_sampler = torch.utils.data.sampler.BatchSampler(
sampler, batch_size, drop_last=True
) # drop_last so the batch always have the same size
return torch.utils.data.DataLoader(
dataset,
num_workers=num_workers,
batch_sampler=batch_sampler,
collate_fn=trivial_batch_collator,
worker_init_fn=worker_init_reset_seed,
)


def build_detection_train_loader(cfg, mapper=None):
"""
A data loader is created by the following steps:
Expand All @@ -274,20 +324,6 @@ def build_detection_train_loader(cfg, mapper=None):
Returns:
an infinite iterator of training data
"""
num_workers = get_world_size()
images_per_batch = cfg.SOLVER.IMS_PER_BATCH
assert (
images_per_batch % num_workers == 0
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format(
images_per_batch, num_workers
)
assert (
images_per_batch >= num_workers
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format(
images_per_batch, num_workers
)
images_per_worker = images_per_batch // num_workers

dataset_dicts = get_detection_dataset_dicts(
cfg.DATASETS.TRAIN,
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS,
Expand All @@ -305,6 +341,7 @@ def build_detection_train_loader(cfg, mapper=None):
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN
logger = logging.getLogger(__name__)
logger.info("Using training sampler {}".format(sampler_name))
# TODO avoid if-else?
if sampler_name == "TrainingSampler":
sampler = samplers.TrainingSampler(len(dataset))
elif sampler_name == "RepeatFactorTrainingSampler":
Expand All @@ -313,31 +350,13 @@ def build_detection_train_loader(cfg, mapper=None):
)
else:
raise ValueError("Unknown training sampler: {}".format(sampler_name))

if cfg.DATALOADER.ASPECT_RATIO_GROUPING:
data_loader = torch.utils.data.DataLoader(
dataset,
sampler=sampler,
num_workers=cfg.DATALOADER.NUM_WORKERS,
batch_sampler=None,
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements
worker_init_fn=worker_init_reset_seed,
) # yield individual mapped dict
data_loader = AspectRatioGroupedDataset(data_loader, images_per_worker)
else:
batch_sampler = torch.utils.data.sampler.BatchSampler(
sampler, images_per_worker, drop_last=True
)
# drop_last so the batch always have the same size
data_loader = torch.utils.data.DataLoader(
dataset,
num_workers=cfg.DATALOADER.NUM_WORKERS,
batch_sampler=batch_sampler,
collate_fn=trivial_batch_collator,
worker_init_fn=worker_init_reset_seed,
)

return data_loader
return build_batch_data_loader(
dataset,
sampler,
cfg.SOLVER.IMS_PER_BATCH,
aspect_ratio_grouping=cfg.DATALOADER.ASPECT_RATIO_GROUPING,
num_workers=cfg.DATALOADER.NUM_WORKERS,
)


def build_detection_test_loader(cfg, dataset_name, mapper=None):
Expand Down

0 comments on commit de09842

Please sign in to comment.