forked from open-mmlab/mmdetection
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add infinite sampler * add docstring * sup dp * rename it to DistributedInfiniteGroupBatchSampler * fix default value * support shuffle is false * resolve comments * add two 90k config * fix dp case * avoid bc breaking * fix doc Co-authored-by: zhangshilong <[email protected]>
- Loading branch information
Showing
6 changed files
with
257 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
_base_ = 'faster_rcnn_r50_caffe_fpn_1x_coco.py' | ||
|
||
# learning policy | ||
lr_config = dict( | ||
policy='step', | ||
warmup='linear', | ||
warmup_iters=500, | ||
warmup_ratio=0.001, | ||
step=[60000, 80000]) | ||
|
||
# Runner type | ||
runner = dict(_delete_=True, type='IterBasedRunner', max_iters=90000) | ||
|
||
checkpoint_config = dict(interval=10000) | ||
evaluation = dict(interval=10000, metric='bbox') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
_base_ = 'retinanet_r50_fpn_1x_coco.py' | ||
|
||
# learning policy | ||
lr_config = dict( | ||
policy='step', | ||
warmup='linear', | ||
warmup_iters=500, | ||
warmup_ratio=0.001, | ||
step=[60000, 80000]) | ||
|
||
# Runner type | ||
runner = dict(_delete_=True, type='IterBasedRunner', max_iters=90000) | ||
|
||
checkpoint_config = dict(interval=10000) | ||
evaluation = dict(interval=10000, metric='bbox') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,9 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .distributed_sampler import DistributedSampler | ||
from .group_sampler import DistributedGroupSampler, GroupSampler | ||
from .infinite_sampler import InfiniteBatchSampler, InfiniteGroupBatchSampler | ||
|
||
__all__ = ['DistributedSampler', 'DistributedGroupSampler', 'GroupSampler'] | ||
__all__ = [ | ||
'DistributedSampler', 'DistributedGroupSampler', 'GroupSampler', | ||
'InfiniteGroupBatchSampler', 'InfiniteBatchSampler' | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
import itertools | ||
|
||
import numpy as np | ||
import torch | ||
from mmcv.runner import get_dist_info | ||
from torch.utils.data.sampler import Sampler | ||
|
||
|
||
class InfiniteGroupBatchSampler(Sampler): | ||
"""Similar to `BatchSampler` warping a `GroupSampler. It is designed for | ||
iteration-based runners like `IterBasedRunner` and yields a mini-batch | ||
indices each time, all indices in a batch should be in the same group. | ||
The implementation logic is referred to | ||
https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py | ||
Args: | ||
dataset (object): The dataset. | ||
batch_size (int): When model is :obj:`DistributedDataParallel`, | ||
it is the number of training samples on each GPU. | ||
When model is :obj:`DataParallel`, it is | ||
`num_gpus * samples_per_gpu`. | ||
Default : 1. | ||
world_size (int, optional): Number of processes participating in | ||
distributed training. Default: None. | ||
rank (int, optional): Rank of current process. Default: None. | ||
seed (int): Random seed. Default: 0. | ||
shuffle (bool): Whether shuffle the indices of a dummy `epoch`, it | ||
should be noted that `shuffle` can not guarantee that you can | ||
generate sequential indices because it need to ensure | ||
that all indices in a batch is in a group. Default: True. | ||
""" # noqa: W605 | ||
|
||
def __init__(self, | ||
dataset, | ||
batch_size=1, | ||
world_size=None, | ||
rank=None, | ||
seed=0, | ||
shuffle=True): | ||
_rank, _world_size = get_dist_info() | ||
if world_size is None: | ||
world_size = _world_size | ||
if rank is None: | ||
rank = _rank | ||
self.rank = rank | ||
self.world_size = world_size | ||
self.dataset = dataset | ||
self.batch_size = batch_size | ||
self.seed = seed if seed is not None else 0 | ||
self.shuffle = shuffle | ||
|
||
assert hasattr(self.dataset, 'flag') | ||
self.flag = self.dataset.flag | ||
self.group_sizes = np.bincount(self.flag) | ||
# buffer used to save indices of each group | ||
self.buffer_per_group = {k: [] for k in range(len(self.group_sizes))} | ||
|
||
self.size = len(dataset) | ||
self.indices = self._indices_of_rank() | ||
|
||
def _infinite_indices(self): | ||
"""Infinitely yield a sequence of indices.""" | ||
g = torch.Generator() | ||
g.manual_seed(self.seed) | ||
while True: | ||
if self.shuffle: | ||
yield from torch.randperm(self.size, generator=g).tolist() | ||
|
||
else: | ||
yield from torch.arange(self.size).tolist() | ||
|
||
def _indices_of_rank(self): | ||
"""Slice the infinite indices by rank.""" | ||
yield from itertools.islice(self._infinite_indices(), self.rank, None, | ||
self.world_size) | ||
|
||
def __iter__(self): | ||
# once batch size is reached, yield the indices | ||
for idx in self.indices: | ||
flag = self.flag[idx] | ||
group_buffer = self.buffer_per_group[flag] | ||
group_buffer.append(idx) | ||
if len(group_buffer) == self.batch_size: | ||
yield group_buffer[:] | ||
del group_buffer[:] | ||
|
||
def __len__(self): | ||
"""Length of base dataset.""" | ||
return self.size | ||
|
||
def set_epoch(self, epoch): | ||
"""Not supported in `IterationBased` runner.""" | ||
raise NotImplementedError | ||
|
||
|
||
class InfiniteBatchSampler(Sampler): | ||
"""Similar to `BatchSampler` warping a `DistributedSampler. It is designed | ||
iteration-based runners like `IterBasedRunner` and yields a mini-batch | ||
indices each time. | ||
The implementation logic is referred to | ||
https://github.com/facebookresearch/detectron2/blob/main/detectron2/data/samplers/grouped_batch_sampler.py | ||
Args: | ||
dataset (object): The dataset. | ||
batch_size (int): When model is :obj:`DistributedDataParallel`, | ||
it is the number of training samples on each GPU, | ||
When model is :obj:`DataParallel`, it is | ||
`num_gpus * samples_per_gpu`. | ||
Default : 1. | ||
world_size (int, optional): Number of processes participating in | ||
distributed training. Default: None. | ||
rank (int, optional): Rank of current process. Default: None. | ||
seed (int): Random seed. Default: 0. | ||
shuffle (bool): Whether shuffle the dataset or not. Default: True. | ||
""" # noqa: W605 | ||
|
||
def __init__(self, | ||
dataset, | ||
batch_size=1, | ||
world_size=None, | ||
rank=None, | ||
seed=0, | ||
shuffle=True): | ||
_rank, _world_size = get_dist_info() | ||
if world_size is None: | ||
world_size = _world_size | ||
if rank is None: | ||
rank = _rank | ||
self.rank = rank | ||
self.world_size = world_size | ||
self.dataset = dataset | ||
self.batch_size = batch_size | ||
self.seed = seed if seed is not None else 0 | ||
self.shuffle = shuffle | ||
self.size = len(dataset) | ||
self.indices = self._indices_of_rank() | ||
|
||
def _infinite_indices(self): | ||
"""Infinitely yield a sequence of indices.""" | ||
g = torch.Generator() | ||
g.manual_seed(self.seed) | ||
while True: | ||
if self.shuffle: | ||
yield from torch.randperm(self.size, generator=g).tolist() | ||
|
||
else: | ||
yield from torch.arange(self.size).tolist() | ||
|
||
def _indices_of_rank(self): | ||
"""Slice the infinite indices by rank.""" | ||
yield from itertools.islice(self._infinite_indices(), self.rank, None, | ||
self.world_size) | ||
|
||
def __iter__(self): | ||
# once batch size is reached, yield the indices | ||
batch_buffer = [] | ||
for idx in self.indices: | ||
batch_buffer.append(idx) | ||
if len(batch_buffer) == self.batch_size: | ||
yield batch_buffer | ||
batch_buffer = [] | ||
|
||
def __len__(self): | ||
"""Length of base dataset.""" | ||
return self.size | ||
|
||
def set_epoch(self, epoch): | ||
"""Not supported in `IterationBased` runner.""" | ||
raise NotImplementedError |