Skip to content

Commit

Permalink
Add registry and build_from_cfg (open-mmlab#195)
Browse files Browse the repository at this point in the history
* add registry and build_from_cfg

* add some corner cases

* add some corner cases

* fix the unittest for python 3.5

* minor fix
  • Loading branch information
hellock authored Feb 24, 2020
1 parent 0863073 commit 34197c5
Show file tree
Hide file tree
Showing 3 changed files with 248 additions and 5 deletions.
12 changes: 7 additions & 5 deletions mmcv/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
mkdir_or_exist, scandir, symlink)
from .progressbar import (ProgressBar, track_iter_progress,
track_parallel_progress, track_progress)
from .registry import Registry, build_from_cfg
from .timer import Timer, TimerError, check_time

__all__ = [
'ConfigDict', 'Config', 'is_str', 'iter_cast', 'list_cast', 'tuple_cast',
'is_seq_of', 'is_list_of', 'is_tuple_of', 'slice_list', 'concat_list',
'check_prerequisites', 'requires_package', 'requires_executable',
'is_filepath', 'fopen', 'check_file_exist', 'mkdir_or_exist', 'symlink',
'scandir', 'FileNotFoundError', 'ProgressBar', 'track_progress',
'ConfigDict', 'Config', 'Registry', 'build_from_cfg', 'is_str',
'iter_cast', 'list_cast', 'tuple_cast', 'is_seq_of', 'is_list_of',
'is_tuple_of', 'slice_list', 'concat_list', 'check_prerequisites',
'requires_package', 'requires_executable', 'is_filepath', 'fopen',
'check_file_exist', 'mkdir_or_exist', 'symlink', 'scandir',
'FileNotFoundError', 'ProgressBar', 'track_progress',
'track_iter_progress', 'track_parallel_progress', 'Timer', 'TimerError',
'check_time'
]
123 changes: 123 additions & 0 deletions mmcv/utils/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import inspect
from functools import partial

from .misc import is_str


class Registry(object):
"""A registry to map strings to classes.
Args:
name (str): Registry name.
"""

def __init__(self, name):
self._name = name
self._module_dict = dict()

def __len__(self):
return len(self._module_dict)

def __repr__(self):
format_str = self.__class__.__name__ + '(name={}, items={})'.format(
self._name, list(self._module_dict.keys()))
return format_str

@property
def name(self):
return self._name

@property
def module_dict(self):
return self._module_dict

def get(self, key):
"""Get the registry record.
Args:
key (str): The class name in string format.
Returns:
class: The corresponding class.
"""
return self._module_dict.get(key, None)

def _register_module(self, module_class, force=False):
if not inspect.isclass(module_class):
raise TypeError('module must be a class, but got {}'.format(
type(module_class)))
module_name = module_class.__name__
if not force and module_name in self._module_dict:
raise KeyError('{} is already registered in {}'.format(
module_name, self.name))
self._module_dict[module_name] = module_class

def register_module(self, cls=None, force=False):
"""Register a module.
A record will be added to `self._module_dict`, whose key is the class
name and value is the class itself.
It can be used as a decorator or a normal function.
Example:
>>> backbones = Registry('backbone')
>>> @backbones.register_module
>>> class ResNet(object):
>>> pass
Example:
>>> backbones = Registry('backbone')
>>> class ResNet(object):
>>> pass
>>> backbones.register_module(ResNet)
Args:
module (:obj:`nn.Module`): Module to be registered.
force (bool, optional): Whether to override an existing class with
the same name. Default: False.
"""
if cls is None:
return partial(self.register_module, force=force)
self._register_module(cls, force=force)
return cls


def build_from_cfg(cfg, registry, default_args=None):
"""Build a module from config dict.
Args:
cfg (dict): Config dict. It should at least contain the key "type".
registry (:obj:`Registry`): The registry to search the type from.
default_args (dict, optional): Default initialization arguments.
Returns:
obj: The constructed object.
"""
if not (isinstance(cfg, dict) and 'type' in cfg):
raise TypeError('cfg must be a dict containing the key "type"')
if not isinstance(registry, Registry):
raise TypeError(
'registry must be an mmcv.Registry object, but got {}'.format(
type(registry)))
if not (isinstance(default_args, dict) or default_args is None):
raise TypeError(
'default_args must be a dict or None, but got {}'.format(
type(default_args)))

args = cfg.copy()
obj_type = args.pop('type')
if is_str(obj_type):
obj_cls = registry.get(obj_type)
if obj_cls is None:
raise KeyError('{} is not in the {} registry'.format(
obj_type, registry.name))
elif inspect.isclass(obj_type):
obj_cls = obj_type
else:
raise TypeError('type must be a str or valid type, but got {}'.format(
type(obj_type)))

if default_args is not None:
for name, value in default_args.items():
args.setdefault(name, value)
return obj_cls(**args)
118 changes: 118 additions & 0 deletions tests/test_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import pytest

import mmcv


def test_registry():
reg_name = 'cat'
CATS = mmcv.Registry(reg_name)
assert CATS.name == reg_name
assert CATS.module_dict == {}
assert len(CATS) == 0

@CATS.register_module
class BritishShorthair:
pass

assert len(CATS) == 1
assert CATS.get('BritishShorthair') is BritishShorthair

class Munchkin:
pass

CATS.register_module(Munchkin)
assert len(CATS) == 2
assert CATS.get('Munchkin') is Munchkin

with pytest.raises(KeyError):
CATS.register_module(Munchkin)

CATS.register_module(Munchkin, force=True)
assert len(CATS) == 2

with pytest.raises(KeyError):

@CATS.register_module
class BritishShorthair:
pass

@CATS.register_module(force=True)
class BritishShorthair:
pass

assert len(CATS) == 2

assert CATS.get('PersianCat') is None

# The order of dict keys are not preserved in python 3.5
assert repr(CATS) in [
"Registry(name=cat, items=['BritishShorthair', 'Munchkin'])",
"Registry(name=cat, items=['Munchkin', 'BritishShorthair'])"
]

# the registered module should be a class
with pytest.raises(TypeError):
CATS.register_module(0)


def test_build_from_cfg():
BACKBONES = mmcv.Registry('backbone')

@BACKBONES.register_module
class ResNet:

def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages

@BACKBONES.register_module
class ResNeXt:

def __init__(self, depth, stages=4):
self.depth = depth
self.stages = stages

cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4

cfg = dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3})
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 3

cfg = dict(type='ResNeXt', depth=50, stages=3)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNeXt)
assert model.depth == 50 and model.stages == 3

cfg = dict(type=ResNet, depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES)
assert isinstance(model, ResNet)
assert model.depth == 50 and model.stages == 4

# non-registered class
with pytest.raises(KeyError):
cfg = dict(type='VGG')
model = mmcv.build_from_cfg(cfg, BACKBONES)

# cfg['type'] should be a str or class
with pytest.raises(TypeError):
cfg = dict(type=1000)
model = mmcv.build_from_cfg(cfg, BACKBONES)

# cfg should contain the key "type"
with pytest.raises(TypeError):
cfg = dict(depth=50, stages=4)
model = mmcv.build_from_cfg(cfg, BACKBONES)

# incorrect registry type
with pytest.raises(TypeError):
dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, 'BACKBONES')

# incorrect default_args type
with pytest.raises(TypeError):
dict(type='ResNet', depth=50)
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0)

0 comments on commit 34197c5

Please sign in to comment.