Skip to content

Commit

Permalink
Explicitly assign a batchsampler that doesn't cast sampler output to int
Browse files Browse the repository at this point in the history
  • Loading branch information
roytseng-tw committed May 14, 2018
1 parent cd26cc1 commit 235885b
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 7 deletions.
53 changes: 51 additions & 2 deletions lib/roi_data/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@

import torch
import torch.utils.data as data
import torch.utils.data.sampler as sampler
import torch.utils.data.sampler as torch_sampler
from torch.utils.data.dataloader import default_collate
from torch._six import int_classes as _int_classes

from core.config import cfg
from roi_data.minibatch import get_minibatch
Expand Down Expand Up @@ -143,7 +144,7 @@ def cal_minibatch_ratio(ratio_list):
return ratio_list_minibatch


class MinibatchSampler(sampler.Sampler):
class MinibatchSampler(torch_sampler.Sampler):
def __init__(self, ratio_list, ratio_index):
self.ratio_list = ratio_list
self.ratio_index = ratio_index
Expand Down Expand Up @@ -178,6 +179,54 @@ def __len__(self):
return self.num_data


class BatchSampler(torch_sampler.BatchSampler):
r"""Wraps another sampler to yield a mini-batch of indices.
Args:
sampler (Sampler): Base sampler.
batch_size (int): Size of mini-batch.
drop_last (bool): If ``True``, the sampler will drop the last batch if
its size would be less than ``batch_size``
Example:
>>> list(BatchSampler(range(10), batch_size=3, drop_last=False))
[[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
>>> list(BatchSampler(range(10), batch_size=3, drop_last=True))
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
"""

def __init__(self, sampler, batch_size, drop_last):
if not isinstance(sampler, torch_sampler.Sampler):
raise ValueError("sampler should be an instance of "
"torch.utils.data.Sampler, but got sampler={}"
.format(sampler))
if not isinstance(batch_size, _int_classes) or isinstance(batch_size, bool) or \
batch_size <= 0:
raise ValueError("batch_size should be a positive integeral value, "
"but got batch_size={}".format(batch_size))
if not isinstance(drop_last, bool):
raise ValueError("drop_last should be a boolean value, but got "
"drop_last={}".format(drop_last))
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last

def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx) # Difference: batch.append(int(idx))
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch

def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size



def collate_minibatch(list_of_blobs):
"""Stack samples seperately and return a list of minibatches
A batch contains NUM_GPUS minibatches and image size in different minibatch may be different.
Expand Down
12 changes: 7 additions & 5 deletions tools/train_net_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import utils.misc as misc_utils
from core.config import cfg, cfg_from_file, cfg_from_list, assert_and_infer_cfg
from datasets.roidb import combined_roidb_for_training
from roi_data.loader import RoiDataLoader, MinibatchSampler, collate_minibatch
from roi_data.loader import RoiDataLoader, MinibatchSampler, BatchSampler, collate_minibatch
from modeling.model_builder import Generalized_RCNN
from utils.detectron_weight_helper import load_detectron_weight
from utils.logging import setup_logging
Expand Down Expand Up @@ -236,16 +236,18 @@ def main():
# Effective training sample size for one epoch
train_size = roidb_size // args.batch_size * args.batch_size

sampler = MinibatchSampler(ratio_list, ratio_index)
batchSampler = BatchSampler(
sampler=MinibatchSampler(ratio_list, ratio_index),
batch_size=args.batch_size,
drop_last=True
)
dataset = RoiDataLoader(
roidb,
cfg.MODEL.NUM_CLASSES,
training=True)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
drop_last=True,
sampler=sampler,
batch_sampler=batchSampler,
num_workers=cfg.DATA_LOADER.NUM_THREADS,
collate_fn=collate_minibatch)
dataiterator = iter(dataloader)
Expand Down

0 comments on commit 235885b

Please sign in to comment.