Skip to content

Commit

Permalink
[quant][fx][graphmode][api] Change API for custom module (pytorch#45920)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#45920

See docs for new way of defining custom modules

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D24145856

fbshipit-source-id: 488673fba503e39e8e303ed5a776fe36899ea4e3
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Oct 13, 2020
1 parent e6d30c8 commit 7f6a1b2
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
25 changes: 5 additions & 20 deletions test/quantization/test_quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
prepare_fx,
convert_fx,
prepare_qat_fx,
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
)

from torch.quantization import (
Expand Down Expand Up @@ -627,7 +625,6 @@ def test_save_observer_state_dict(self):
self.assertEqual(quant(x), quant_2(x))

@skipIfNoFBGEMM
@unittest.skip("Fix in next PR, will need to change API")
def test_custom_module_class(self):
class CustomModule(torch.nn.Module):
def __init__(self):
Expand Down Expand Up @@ -715,26 +712,14 @@ def forward(self, x):
original_ref_m.conv2.weight = torch.nn.Parameter(original_m.custom.conv.weight.detach())
original_ref_m.conv2.bias = torch.nn.Parameter(original_m.custom.conv.bias.detach())

from torch.fx.symbolic_trace import Tracer

# define a custom tracer to not trace through the custom module

class CustomTracer(Tracer):
def is_leaf_module(self, m, module_qualified_name):
return (m.__module__.startswith('torch.nn') and
not isinstance(m, torch.nn.Sequential)) or \
isinstance(m, CustomModule)

# TODO: add other quant types after mixed mode support
for quant_type in [QuantType.STATIC]:
# register observed and quantized custom module classes
register_observed_custom_module_mapping(CustomModule, ObservedCustomModule)
register_quantized_custom_module_mapping(CustomModule, QuantizedCustomModule)

m = torch.fx.GraphModule(original_m, CustomTracer().trace(original_m))
qconfig_dict = {'': default_qconfig}
qconfig_dict = {
'': default_qconfig,
'custom_module_class':
[(CustomModule, ObservedCustomModule, QuantizedCustomModule)]}
# check prepared model
m = prepare_fx(m, qconfig_dict)
m = prepare_fx(original_m, qconfig_dict)
# calibration
m(data)
# all activation observers are inserted in the top level module
Expand Down
29 changes: 26 additions & 3 deletions torch/quantization/quantize_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
from torch.fx import GraphModule # type: ignore
from torch.fx import symbolic_trace # type: ignore
from torch.fx.symbolic_trace import Tracer # type: ignore
from .custom_module_class_mappings import (
register_observed_custom_module_mapping,
register_quantized_custom_module_mapping,
)
from .fx import Fuser # noqa: F401
from .fx import Quantizer # noqa: F401
from .fx.utils import graph_pretty_str # noqa: F401
Expand All @@ -13,6 +17,11 @@ def _check_is_graph_module(model):
'Got type:' + str(type(model)) + ' Please make ' +
'sure to follow the tutorials.')

def _register_custom_module_class(custom_module_config):
for custom, observed, quantized in custom_module_config:
register_observed_custom_module_mapping(custom, observed)
register_quantized_custom_module_mapping(custom, quantized)

def _fuse_fx(graph_module, inplace=False):
r""" Internal helper function to fuse modules in preparation for quantization
Expand All @@ -24,14 +33,16 @@ def _fuse_fx(graph_module, inplace=False):
return fuser.fuse(graph_module, inplace)

class CustomTracer(Tracer):
def __init__(self, standalone_modules):
def __init__(self, standalone_modules, custom_module_classes):
super().__init__()
self.standalone_modules = standalone_modules
self.custom_module_classes = custom_module_classes

def is_leaf_module(self, m, module_qualified_name):
return (m.__module__.startswith('torch.nn') and
not isinstance(m, torch.nn.Sequential)) or \
module_qualified_name in self.standalone_modules
module_qualified_name in self.standalone_modules or \
type(m) in self.custom_module_classes


def _prepare_fx(model, qconfig_dict, inplace, is_standalone_module=False):
Expand All @@ -51,8 +62,14 @@ def _prepare_fx(model, qconfig_dict, inplace, is_standalone_module=False):
graph_module = symbolic_trace(model)
else:
standalone_modules = qconfig_dict.get('standalone_module_name', [])
custom_module_config = qconfig_dict.get('custom_module_class', [])
custom_module_classes = [config[0] for config in custom_module_config]
# TODO: currently we are registering classes globally,
# we want to make custom module class mapping local
_register_custom_module_class(custom_module_config)
# skipping tracing standalone modules when tracing top level module
graph_module = GraphModule(model, CustomTracer(standalone_modules).trace(model))
tracer = CustomTracer(standalone_modules, custom_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
graph_module = _fuse_fx(graph_module, inplace)
quantizer = Quantizer()
return quantizer.prepare(graph_module, qconfig_dict, inplace=True, is_standalone_module=is_standalone_module)
Expand Down Expand Up @@ -132,6 +149,12 @@ def prepare_fx(model, qconfig_dict, inplace=False):
# These modules are symbolically traced and quantized as one unit
"standalone_module_name": [
"submodule.standalone"
],
# optional: specify the custom module class and provide the corresponding
# observed and quantized custom module classes
"custom_module_class": [
(CustomModuleClass, ObservedCustomModuleClass, QuantizedCustomModuleClass)
]
}
`inplace`: flag for carry out model transformations in-place,
Expand Down

0 comments on commit 7f6a1b2

Please sign in to comment.