Skip to content

Commit

Permalink
[quant][graphmode][fx][refactor] Move patterns to separate files (#43…
Browse files Browse the repository at this point in the history
…891)

Summary: Pull Request resolved: pytorch/pytorch#43891

Test Plan: Imported from OSS

Reviewed By: vkuzo

Differential Revision: D23429759

fbshipit-source-id: f19add96beb7c8bac323ad78f74588ca1393040c
  • Loading branch information
jerryzh168 authored and facebook-github-bot committed Sep 1, 2020
1 parent 8d53df3 commit d15b9d9
Show file tree
Hide file tree
Showing 5 changed files with 691 additions and 668 deletions.
110 changes: 1 addition & 109 deletions torch/quantization/fx/fuse.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import torch

from torch.fx import (
GraphModule,
)
Expand All @@ -9,120 +7,14 @@
map_arg,
)

from ..fuse_modules import OP_LIST_TO_FUSER_METHOD

from .pattern_utils import (
matches,
register_fusion_pattern,
get_fusion_patterns,
)

from .utils import _parent_name
from .fusion_patterns import * # noqa: F401

import copy

# Fusion Patterns
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
class ConvBNReLUFusion():
def __init__(self, quantizer, node):
super().__init__()
self.relu_node = None
self.bn_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_module'
if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]:
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
node = node.args[0]
assert node.op == 'call_module'
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]

def fuse(self, quantizer, load_arg):
op_list = []
if self.relu_node is not None:
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
op_list.append(relu)
relu.training = self.conv.training
if self.bn_node is not None:
op_list.append(self.bn)
op_list.append(self.conv)
else:
assert self.bn_node is not None
op_list.append(self.bn)
op_list.append(self.conv)

# the modules are added in order of relu - bn - conv
# so we need to correct it
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list))

# TODO: do we need to make sure bn is only used once?
if self.bn_node is not None:
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
# relu may be used multiple times, so we don't set relu to identity
return quantizer.fused_graph.node_copy(self.conv_node, load_arg)

@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm1d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm1d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d))
class ModuleReLUFusion():
def __init__(self, quantizer, node):
super().__init__()
self.relu_node = node
node = node.args[0]
assert node.op == 'call_module'
self.module_node = node
self.module = quantizer.modules[self.module_node.target]

def fuse(self, quantizer, load_arg):
op_list = []
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
relu.training = self.module.training
op_list.append(relu)
op_list.append(self.module)

op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
module_parent_name, module_name = _parent_name(self.module_node.target)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
return quantizer.fused_graph.node_copy(self.module_node, load_arg)

class Fuser:
def fuse(self, model, inplace=False):
input_root = model.root
Expand Down
111 changes: 111 additions & 0 deletions torch/quantization/fx/fusion_patterns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import torch
from .pattern_utils import (
register_fusion_pattern,
)
from .utils import _parent_name
from ..fuse_modules import OP_LIST_TO_FUSER_METHOD

# ---------------------
# Fusion Patterns
# ---------------------

@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv1d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv3d))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
class ConvBNReLUFusion():
def __init__(self, quantizer, node):
super().__init__()
self.relu_node = None
self.bn_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_module'
if type(quantizer.modules[node.target]) in [torch.nn.BatchNorm1d, torch.nn.BatchNorm2d, torch.nn.BatchNorm3d]:
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
node = node.args[0]
assert node.op == 'call_module'
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]

def fuse(self, quantizer, load_arg):
op_list = []
if self.relu_node is not None:
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
op_list.append(relu)
relu.training = self.conv.training
if self.bn_node is not None:
op_list.append(self.bn)
op_list.append(self.conv)
else:
assert self.bn_node is not None
op_list.append(self.bn)
op_list.append(self.conv)

# the modules are added in order of relu - bn - conv
# so we need to correct it
op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[conv_parent_name], conv_name, fuser_method(*op_list))

# TODO: do we need to make sure bn is only used once?
if self.bn_node is not None:
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
# relu may be used multiple times, so we don't set relu to identity
return quantizer.fused_graph.node_copy(self.conv_node, load_arg)

@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm1d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm1d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.BatchNorm3d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.BatchNorm3d))
class ModuleReLUFusion():
def __init__(self, quantizer, node):
super().__init__()
self.relu_node = node
node = node.args[0]
assert node.op == 'call_module'
self.module_node = node
self.module = quantizer.modules[self.module_node.target]

def fuse(self, quantizer, load_arg):
op_list = []
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
relu.training = self.module.training
op_list.append(relu)
op_list.append(self.module)

op_list.reverse()
op_type_list = tuple(type(m) for m in op_list)
module_parent_name, module_name = _parent_name(self.module_node.target)
fuser_method = OP_LIST_TO_FUSER_METHOD.get(op_type_list, None)
if fuser_method is None:
raise NotImplementedError("Cannot fuse modules: {}".format(types))
setattr(quantizer.modules[module_parent_name], module_name, fuser_method(*op_list))
return quantizer.fused_graph.node_copy(self.module_node, load_arg)
10 changes: 5 additions & 5 deletions torch/quantization/fx/pattern_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import torch
import sys
from collections import OrderedDict
Expand All @@ -14,25 +13,27 @@ def insert(fn):
def get_fusion_patterns():
return FUSION_PATTERNS

# pattern for both static quantization and qat
QUANTIZATION_PATTERNS = OrderedDict()
# Register pattern for both static quantization and qat
def register_quant_pattern(pattern):
def insert(fn):
QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert

# Get patterns for both static quantization and qat
def get_quant_patterns():
return QUANTIZATION_PATTERNS

# pattern for dynamic quantization
DYNAMIC_QUANTIZATION_PATTERNS = OrderedDict()
def register_dynamic_pattern(pattern):
# Register pattern for dynamic quantization
def register_dynamic_quant_pattern(pattern):
def insert(fn):
DYNAMIC_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert

# Get patterns for dynamic quantization
def get_dynamic_quant_patterns():
return DYNAMIC_QUANTIZATION_PATTERNS

Expand All @@ -47,7 +48,6 @@ def get_dynamic_quant_patterns():
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.


def matches(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern
"""
Expand Down
Loading

0 comments on commit d15b9d9

Please sign in to comment.