Skip to content

Commit

Permalink
Support Speedup for Slim Pruner. (microsoft#4008)
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-ningxin authored Aug 3, 2021
1 parent d8e5685 commit 0aea0a5
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 11 deletions.
43 changes: 37 additions & 6 deletions nni/compression/pytorch/utils/mask_conflict.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,10 @@ def __init__(self, masks, model=None, dummy_input=None, traced=None):
super(ChannelMaskConflict, self).__init__(
masks, model, dummy_input, traced)
self.conv_prune_dim = detect_mask_prune_dim(masks, model)
self.channel_prune_type = detect_channel_prune_type(masks, model)
_logger.info('Dectected conv prune dim" %d', self.conv_prune_dim)

def fix_mask(self):
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
"""
Fix the mask conflict before the mask inference for the layers that
has shape dependencies. This function should be called before the
Expand All @@ -200,7 +196,8 @@ def fix_mask(self):
"""
if self.conv_prune_dim == 0:
channel_depen = ChannelDependency(
self.model, self.dummy_input, self.traced)
self.model, self.dummy_input, self.traced, self.channel_prune_type)

else:
channel_depen = InputChannelDependency(
self.model, self.dummy_input, self.traced)
Expand Down Expand Up @@ -307,10 +304,44 @@ def fix_mask(self):

return self.masks

def detect_channel_prune_type(masks, model):
"""
User can prune a channel through two ways: 1) prune
the corresponding filter of the conv layer(all the
filter related pruner), 2) prune the BN layers that
followed after a conv(Slim pruner). This function find
the pruning type of the masks.
Parameters
----------
masks: dict
A dict object that stores the masks.
model: nn.Module
Model object which the mask can be applied on.
Returns:
-------
prune_type: str
Could be Filter or Batchnorm
"""
prune_type = 'Filter'
all_batch_norm = True
for layer_name in masks:
_, m = get_module_by_name(model, layer_name)
if m is None or (not isinstance(m, torch.nn.BatchNorm2d)):
all_batch_norm = False
break
if all_batch_norm:
# if all masks are for batchnorm layers, then the prune_type is BatchNorm
# Note, actually we currently do not support pruning both Conv and BatchNorm
# at the same time.
prune_type = 'Batchnorm'
return prune_type

def detect_mask_prune_dim(masks, model):
"""
Detect how the masks of convolutional layers are pruned.
Parameters
----------
masks: dict
Expand Down
23 changes: 18 additions & 5 deletions nni/compression/pytorch/utils/shape_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def reshape_break_channel_dependency(op_node):


class ChannelDependency(Dependency):
def __init__(self, model=None, dummy_input=None, traced_model=None):
def __init__(self, model=None, dummy_input=None, traced_model=None, prune_type='Filter'):
"""
This model analyze the channel dependencies between the conv
layers in a model.
Expand All @@ -98,7 +98,18 @@ def __init__(self, model=None, dummy_input=None, traced_model=None):
traced_model : torch._C.Graph
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
prune_type: str
This parameter indicates the channel pruning type: 1) `Filter`
prune the filter of the convolution layer to prune the corresponding
channels 2) `Batchnorm`: prune the channel in the batchnorm layer
"""
self.prune_type = prune_type
self.target_types = []
if self.prune_type == 'Filter':
self.target_types.extend(['Conv2d', 'Linear', 'ConvTranspose2d'])
elif self.prune_type == 'Batchnorm':
self.target_types.append('BatchNorm2d')

super(ChannelDependency, self).__init__(
model, dummy_input, traced_model)

Expand All @@ -114,12 +125,13 @@ def _get_parent_layers(self, node):
parent_layers: list
nearest father conv/linear layers for the target worknode.
"""

parent_layers = []
queue = []
queue.append(node)
while queue:
curnode = queue.pop(0)
if curnode.op_type == 'Conv2d' or curnode.op_type == 'Linear' or curnode.op_type == 'ConvTranspose2d':
if curnode.op_type in self.target_types:
# find the first met conv
parent_layers.append(curnode.name)
continue
Expand All @@ -130,6 +142,7 @@ def _get_parent_layers(self, node):
parents = [self.graph.name_to_node[name] for name in parents]
for parent in parents:
queue.append(parent)

return parent_layers

def build_dependency(self):
Expand Down Expand Up @@ -193,7 +206,7 @@ def export(self, filepath):
csv_w = csv.writer(csvf, delimiter=',')
csv_w.writerow(header)
for node in self.graph.nodes_py.nodes_op:
if node.op_type != 'Conv2d' or node in visited:
if node.op_type not in self.target_types or node in visited:
continue
setid += 1
row = ['Set %d' % setid]
Expand All @@ -220,7 +233,7 @@ def dependency_sets(self):
d_sets = []
visited = set()
for node in self.graph.nodes_py.nodes_op:
if (node.op_type != 'Conv2d' and node.op_type != 'Linear') or node in visited:
if node.op_type not in self.target_types or node in visited:
continue
tmp_set = set()
if node.name not in self.dependency:
Expand Down

0 comments on commit 0aea0a5

Please sign in to comment.