forked from open-mmlab/mmcv
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add registry and build_from_cfg (open-mmlab#195)
* add registry and build_from_cfg * add some corner cases * add some corner cases * fix the unittest for python 3.5 * minor fix
- Loading branch information
Showing
3 changed files
with
248 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import inspect | ||
from functools import partial | ||
|
||
from .misc import is_str | ||
|
||
|
||
class Registry(object): | ||
"""A registry to map strings to classes. | ||
Args: | ||
name (str): Registry name. | ||
""" | ||
|
||
def __init__(self, name): | ||
self._name = name | ||
self._module_dict = dict() | ||
|
||
def __len__(self): | ||
return len(self._module_dict) | ||
|
||
def __repr__(self): | ||
format_str = self.__class__.__name__ + '(name={}, items={})'.format( | ||
self._name, list(self._module_dict.keys())) | ||
return format_str | ||
|
||
@property | ||
def name(self): | ||
return self._name | ||
|
||
@property | ||
def module_dict(self): | ||
return self._module_dict | ||
|
||
def get(self, key): | ||
"""Get the registry record. | ||
Args: | ||
key (str): The class name in string format. | ||
Returns: | ||
class: The corresponding class. | ||
""" | ||
return self._module_dict.get(key, None) | ||
|
||
def _register_module(self, module_class, force=False): | ||
if not inspect.isclass(module_class): | ||
raise TypeError('module must be a class, but got {}'.format( | ||
type(module_class))) | ||
module_name = module_class.__name__ | ||
if not force and module_name in self._module_dict: | ||
raise KeyError('{} is already registered in {}'.format( | ||
module_name, self.name)) | ||
self._module_dict[module_name] = module_class | ||
|
||
def register_module(self, cls=None, force=False): | ||
"""Register a module. | ||
A record will be added to `self._module_dict`, whose key is the class | ||
name and value is the class itself. | ||
It can be used as a decorator or a normal function. | ||
Example: | ||
>>> backbones = Registry('backbone') | ||
>>> @backbones.register_module | ||
>>> class ResNet(object): | ||
>>> pass | ||
Example: | ||
>>> backbones = Registry('backbone') | ||
>>> class ResNet(object): | ||
>>> pass | ||
>>> backbones.register_module(ResNet) | ||
Args: | ||
module (:obj:`nn.Module`): Module to be registered. | ||
force (bool, optional): Whether to override an existing class with | ||
the same name. Default: False. | ||
""" | ||
if cls is None: | ||
return partial(self.register_module, force=force) | ||
self._register_module(cls, force=force) | ||
return cls | ||
|
||
|
||
def build_from_cfg(cfg, registry, default_args=None): | ||
"""Build a module from config dict. | ||
Args: | ||
cfg (dict): Config dict. It should at least contain the key "type". | ||
registry (:obj:`Registry`): The registry to search the type from. | ||
default_args (dict, optional): Default initialization arguments. | ||
Returns: | ||
obj: The constructed object. | ||
""" | ||
if not (isinstance(cfg, dict) and 'type' in cfg): | ||
raise TypeError('cfg must be a dict containing the key "type"') | ||
if not isinstance(registry, Registry): | ||
raise TypeError( | ||
'registry must be an mmcv.Registry object, but got {}'.format( | ||
type(registry))) | ||
if not (isinstance(default_args, dict) or default_args is None): | ||
raise TypeError( | ||
'default_args must be a dict or None, but got {}'.format( | ||
type(default_args))) | ||
|
||
args = cfg.copy() | ||
obj_type = args.pop('type') | ||
if is_str(obj_type): | ||
obj_cls = registry.get(obj_type) | ||
if obj_cls is None: | ||
raise KeyError('{} is not in the {} registry'.format( | ||
obj_type, registry.name)) | ||
elif inspect.isclass(obj_type): | ||
obj_cls = obj_type | ||
else: | ||
raise TypeError('type must be a str or valid type, but got {}'.format( | ||
type(obj_type))) | ||
|
||
if default_args is not None: | ||
for name, value in default_args.items(): | ||
args.setdefault(name, value) | ||
return obj_cls(**args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import pytest | ||
|
||
import mmcv | ||
|
||
|
||
def test_registry(): | ||
reg_name = 'cat' | ||
CATS = mmcv.Registry(reg_name) | ||
assert CATS.name == reg_name | ||
assert CATS.module_dict == {} | ||
assert len(CATS) == 0 | ||
|
||
@CATS.register_module | ||
class BritishShorthair: | ||
pass | ||
|
||
assert len(CATS) == 1 | ||
assert CATS.get('BritishShorthair') is BritishShorthair | ||
|
||
class Munchkin: | ||
pass | ||
|
||
CATS.register_module(Munchkin) | ||
assert len(CATS) == 2 | ||
assert CATS.get('Munchkin') is Munchkin | ||
|
||
with pytest.raises(KeyError): | ||
CATS.register_module(Munchkin) | ||
|
||
CATS.register_module(Munchkin, force=True) | ||
assert len(CATS) == 2 | ||
|
||
with pytest.raises(KeyError): | ||
|
||
@CATS.register_module | ||
class BritishShorthair: | ||
pass | ||
|
||
@CATS.register_module(force=True) | ||
class BritishShorthair: | ||
pass | ||
|
||
assert len(CATS) == 2 | ||
|
||
assert CATS.get('PersianCat') is None | ||
|
||
# The order of dict keys are not preserved in python 3.5 | ||
assert repr(CATS) in [ | ||
"Registry(name=cat, items=['BritishShorthair', 'Munchkin'])", | ||
"Registry(name=cat, items=['Munchkin', 'BritishShorthair'])" | ||
] | ||
|
||
# the registered module should be a class | ||
with pytest.raises(TypeError): | ||
CATS.register_module(0) | ||
|
||
|
||
def test_build_from_cfg(): | ||
BACKBONES = mmcv.Registry('backbone') | ||
|
||
@BACKBONES.register_module | ||
class ResNet: | ||
|
||
def __init__(self, depth, stages=4): | ||
self.depth = depth | ||
self.stages = stages | ||
|
||
@BACKBONES.register_module | ||
class ResNeXt: | ||
|
||
def __init__(self, depth, stages=4): | ||
self.depth = depth | ||
self.stages = stages | ||
|
||
cfg = dict(type='ResNet', depth=50) | ||
model = mmcv.build_from_cfg(cfg, BACKBONES) | ||
assert isinstance(model, ResNet) | ||
assert model.depth == 50 and model.stages == 4 | ||
|
||
cfg = dict(type='ResNet', depth=50) | ||
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args={'stages': 3}) | ||
assert isinstance(model, ResNet) | ||
assert model.depth == 50 and model.stages == 3 | ||
|
||
cfg = dict(type='ResNeXt', depth=50, stages=3) | ||
model = mmcv.build_from_cfg(cfg, BACKBONES) | ||
assert isinstance(model, ResNeXt) | ||
assert model.depth == 50 and model.stages == 3 | ||
|
||
cfg = dict(type=ResNet, depth=50) | ||
model = mmcv.build_from_cfg(cfg, BACKBONES) | ||
assert isinstance(model, ResNet) | ||
assert model.depth == 50 and model.stages == 4 | ||
|
||
# non-registered class | ||
with pytest.raises(KeyError): | ||
cfg = dict(type='VGG') | ||
model = mmcv.build_from_cfg(cfg, BACKBONES) | ||
|
||
# cfg['type'] should be a str or class | ||
with pytest.raises(TypeError): | ||
cfg = dict(type=1000) | ||
model = mmcv.build_from_cfg(cfg, BACKBONES) | ||
|
||
# cfg should contain the key "type" | ||
with pytest.raises(TypeError): | ||
cfg = dict(depth=50, stages=4) | ||
model = mmcv.build_from_cfg(cfg, BACKBONES) | ||
|
||
# incorrect registry type | ||
with pytest.raises(TypeError): | ||
dict(type='ResNet', depth=50) | ||
model = mmcv.build_from_cfg(cfg, 'BACKBONES') | ||
|
||
# incorrect default_args type | ||
with pytest.raises(TypeError): | ||
dict(type='ResNet', depth=50) | ||
model = mmcv.build_from_cfg(cfg, BACKBONES, default_args=0) |