diff --git a/src/sdk/pynni/nni/_graph_utils.py b/src/sdk/pynni/nni/_graph_utils.py index b17aa02994..445aaebd58 100644 --- a/src/sdk/pynni/nni/_graph_utils.py +++ b/src/sdk/pynni/nni/_graph_utils.py @@ -8,19 +8,21 @@ from collections import defaultdict import torch from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy - CLASSTYPE_KIND = 'ClassType' GETATTR_KIND = 'prim::GetAttr' _logger = logging.getLogger(__name__) + def build_module_graph(model, dummy_input): return TorchModuleGraph(model, dummy_input) + def build_graph(model, dummy_input, verbose=False): g = TorchProtoGraph(model, dummy_input, verbose) return g.graph_def, g.stepstats + def parse_traced_name(module_name): prefix = 'TracedModule[' suffix = ']' @@ -28,11 +30,13 @@ def parse_traced_name(module_name): module_name = module_name[len(prefix):-len(suffix)] return module_name + class TorchGraph: """ This class is to extract pytorch model topology graph by tracing """ - def __init__(self, model, dummy_input): + + def __init__(self, model=None, dummy_input=None, traced_model=None): """ Parameters ---------- @@ -40,25 +44,39 @@ def __init__(self, model, dummy_input): The model user wants to speed up dummy_input : pytorch tensor The dummy input for ```jit.trace```, users should put it on right device before pass in + traced_model : torch._C.torch.jit.TopLevelTracedModule + An alredy traced model, if traced_model is not None, then TorchGraph will build the graph + based on this traced model and won't trace the model again. """ assert torch.__version__ >= '1.3.1' - - self.bound_model = model - self._trace(model, dummy_input) + # check if the input is legal + if traced_model is not None: + assert isinstance(traced_model, torch.jit.TopLevelTracedModule) + self.trace = traced_model + # it's ok if the graph is already unpacked + torch._C._jit_pass_inline(self.trace.graph) + elif model is not None and dummy_input is not None: + self.bound_model = model + self._trace(model, dummy_input) + else: + raise Exception( + 'Please provide model & dummy_input or the traced_model as inputs') def _trace(self, model, dummy_input): with torch.onnx.set_training(model, False): self.trace = torch.jit.trace(model, dummy_input) torch._C._jit_pass_inline(self.trace.graph) + class TorchProtoGraph(TorchGraph): """ - Generates model graph for pytorch models in protobuf, this implementation is borrowed from pytorch v1.4.0, - and fixed following issues: + Generates model graph for pytorch models in protobuf, this implementation + is borrowed from pytorch v1.4.0, and fixed following issues: https://github.com/pytorch/pytorch/issues/33691 https://github.com/pytorch/pytorch/issues/33670 """ + def __init__(self, model, dummy_input, verbose=False): super().__init__(model, dummy_input) @@ -70,8 +88,10 @@ def __init__(self, model, dummy_input, verbose=False): list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input) if verbose: print(self.trace.graph) - self.stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0")])) - self.graph_def = GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)) + self.stepstats = RunMetadata(step_stats=StepStats( + dev_stats=[DeviceStepStats(device="/device:CPU:0")])) + self.graph_def = GraphDef( + node=list_of_nodes, versions=VersionDef(producer=22)) def parse(self, graph, trace, args=None, omit_useless_nodes=True): """This method parses an optimized PyTorch model graph and produces @@ -94,16 +114,20 @@ def parse(self, graph, trace, args=None, omit_useless_nodes=True): nodes_py.append(NodePyIO(node, 'input')) attr_to_scope = dict() - node_to_name = lambda d: str(d).split(":")[0].strip() + + def node_to_name(d): + return str(d).split(":")[0].strip() for node in graph.nodes(): if node.kind() == GETATTR_KIND: attr_name = node.s('name') node_name = node_to_name(node) parent = node.input().node() - if parent.kind() == GETATTR_KIND: # If the parent node is not the top-level "self" node + # If the parent node is not the top-level "self" node + if parent.kind() == GETATTR_KIND: parent_scope = attr_to_scope[node_to_name(parent)] attr_scope = parent_scope.split('/')[-1] - attr_to_scope[node_name] = '{}/{}.{}'.format(parent_scope, attr_scope, attr_name) + attr_to_scope[node_name] = '{}/{}.{}'.format( + parent_scope, attr_scope, attr_name) else: attr_to_scope[node_name] = '__module.{}'.format(attr_name) # We don't need classtype nodes; scope will provide this information @@ -114,7 +138,8 @@ def parse(self, graph, trace, args=None, omit_useless_nodes=True): else: nodes_py.append(NodePyOP(node)) - for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops + # Create sink nodes for output ops + for i, node in enumerate(graph.outputs()): node_py = NodePyIO(node, 'output') node_py.debugName = "output.{}".format(i + 1) node_py.inputs = [node.debugName()] @@ -136,23 +161,33 @@ def parse(self, graph, trace, args=None, omit_useless_nodes=True): node.scopeName = base_name else: module_name += '.' + alias - node.scopeName += '/' + (alias_to_name[module_name] if module_name in alias_to_name else alias) + node.scopeName += '/' + \ + (alias_to_name[module_name] + if module_name in alias_to_name else alias) nodes_py.populate_namespace_from_OP_to_IO() return nodes_py.to_proto() + class NodePyGroup(NodePy): """ This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph, there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node. """ - def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=None): + + def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None): """ Parameters: ----------- name: str node name, such as `conv1`, `backbone.classifier` + unique_name: str + A global unique name for current node. Due to some modules, + such as relu, may be reused several times, so the scopename + is not suitable as the global unique identifier, so we add a + unique_name for each node as the global unique identifier. + We should use the unique_name to traverset the module graph. node_type: str `module` or `func` op_type: str @@ -167,6 +202,7 @@ def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=Non super(NodePyGroup, self).__init__(name, []) self.node_cpps = node_cpps self.name = name + self.unique_name = unique_name self.op_type = op_type self.type = node_type self.nodes = [] @@ -178,7 +214,7 @@ def __init__(self, name, node_type, op_type, node_cpps, inputs=None, outputs=Non def add_nodes(self, node_cpps): for node_cpp in node_cpps: nodepy = NodePyOP(node_cpp) - nodepy.name = str(node_cpp).split(':')[0].strip().replace('%', '') + nodepy.name = node_cpp.scopeName() + '_' + node_cpp.kind() self.nodes.append(nodepy) def sub_node_names(self): @@ -186,7 +222,8 @@ def sub_node_names(self): def __repr__(self): return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format( - self.name, self.type, self.op_type, self.sub_node_names(), self.inputs, self.outputs, self.auxiliary + self.name, self.type, self.op_type, self.sub_node_names(), + self.inputs, self.outputs, self.auxiliary ) @@ -194,12 +231,14 @@ class TorchModuleGraph(TorchGraph): """ Generates model graph, each node is created from single or multiple jit trace nodes. """ - def __init__(self, model, dummy_input): - super().__init__(model, dummy_input) + + def __init__(self, model=None, dummy_input=None, traced_model=None): + super().__init__(model, dummy_input, traced_model) self.global_count = 0 self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph() - def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): + def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node, + module_type): """ For trace graph nodes, some nodes are not in modules, these nodes are usually generated by the functions directly called in module ```forward```. For such nodes, some of them are @@ -217,6 +256,8 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): key: input name, value: a node that uses this input output_to_node : dict key: output name, value: a node that generates this output + module_type : str + can be 'module' or 'func' Returns ------- @@ -224,11 +265,12 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): the expanded non-prim node """ # TODO: scope name could be empty - node_name = '.'.join([self._get_module_name(node.scopeName()), node.kind(), str(self.global_count)]) + node_name = '.'.join([self._get_module_name( + node.scopeName()), node.kind(), str(self.global_count)]) + unique_name = node_name _logger.debug("expand non-prim node, node name: %s", node_name) self.global_count += 1 op_type = node.kind() - node_group = [node] inputs = list() outputs = list() @@ -239,38 +281,88 @@ def _expand_non_prim_node(self, node, nodes, input_to_node, output_to_node): for _input in curr_node.inputs(): input_name = _input.debugName() if input_name in output_to_node and output_to_node[input_name] in nodes: - predecessor_node = output_to_node[input_name] - if predecessor_node.kind().startswith('prim::'): - node_group.append(predecessor_node) - node_queue.put(predecessor_node) - else: - inputs.append(input_name) + predecessor_node = output_to_node[input_name] + if predecessor_node.kind().startswith('prim::'): + node_group.append(predecessor_node) + node_queue.put(predecessor_node) + else: + inputs.append(input_name) else: inputs.append(input_name) for output in node.outputs(): outputs.append(output.debugName()) - nodepy = NodePyGroup(node_name, 'func', op_type, node_group, inputs=inputs, outputs=outputs) + nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, + node_group, inputs=inputs, outputs=outputs) return nodepy - def _build_module_node_group(self, module_name, op_type, node_cpps, input_to_node, output_to_node): - graph = self.trace.graph - inputs, outputs = [], [] - for n in node_cpps: - for i in n.inputs(): - name = i.debugName() - if not name in output_to_node and i in graph.inputs(): - inputs.append(name) - elif output_to_node[name] not in node_cpps: - inputs.append(name) - for o in n.outputs(): - name = o.debugName() - if not name in input_to_node and o in graph.outputs(): - outputs.append(name) - elif input_to_node[name] not in node_cpps: - outputs.append(name) - - return NodePyGroup(module_name, 'module', op_type, node_cpps, inputs, outputs) + def _expand_module_node(self, node, node_name, unique_name, op_type, nodes, + input_to_node, output_to_node, module_type): + """ + merge the adjacent nodes of the module. The difference between the + _expand_module_node and _expand_non_prim_node is that, the _expand_non_prim_node + only merge the prim:: nodes into the aten:: node, in contrast,the _expand_module_node + will merge all adjacent nodes into a same nodepy group. + Parameters + ---------- + node : trace graph node + The non-prim node to expand + node_name : str + specify the node_name for NodePyGroup + unique_name : str + unique_name for the NodePyGroup + op_type : str + specify the op_type for the NodePyGroup + nodes : list of trace graph node + All the trace graph nodes within the same scope as the non-prim node + input_to_node : dict + key: input name, value: a node that uses this input + output_to_node : dict + key: output name, value: a node that generates this output + module_type : str + can be 'module' or 'func' + Returns + ------- + node + the expanded non-prim node + + """ + _logger.debug("expand module node, node name: %s", node_name) + self.global_count += 1 + if not op_type: + op_type = node.kind() + node_group = [node] + inputs = list() + outputs = list() + node_queue = queue.Queue() + node_queue.put(node) + visited = {node} + while not node_queue.empty(): + curr_node = node_queue.get() + for _input in curr_node.inputs(): + input_name = _input.debugName() + if input_name in output_to_node and output_to_node[input_name] in nodes: + predecessor_node = output_to_node[input_name] + if predecessor_node not in visited: + node_group.append(predecessor_node) + node_queue.put(predecessor_node) + visited.add(predecessor_node) + else: + inputs.append(input_name) + for _output in curr_node.outputs(): + output_name = _output.debugName() + if output_name in input_to_node and input_to_node[output_name] in nodes: + successor_node = input_to_node[output_name] + if successor_node not in visited: + node_group.append(successor_node) + node_queue.put(successor_node) + visited.add(successor_node) + else: + outputs.append(output_name) + + nodepy = NodePyGroup(node_name, unique_name, module_type, op_type, + node_group, inputs=inputs, outputs=outputs) + return nodepy def _extract_shape_info(self, node): """ @@ -318,11 +410,12 @@ def is_parent(name1, name2): parts1, parts2 = name1.split('.'), name2.split('.') if len(parts1) >= len(parts2): return False - for i in range(len(parts1)): + for i, _ in enumerate(parts1): if parts2[i] != parts1[i]: return False return True - module_names = sorted([x[0] for x in self.trace.named_modules() if x[0]]) + module_names = sorted([x[0] + for x in self.trace.named_modules() if x[0]]) leaf_nodes = [] for i, name in enumerate(module_names): if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]): @@ -354,7 +447,7 @@ def _build_index(self, nodes_op): input_to_node = defaultdict(list) output_to_node = dict() for node in nodes_op: - name_to_node[node.name] = node + name_to_node[node.unique_name] = node for _input in node.inputs: input_to_node[_input].append(node) for output in node.outputs: @@ -385,9 +478,11 @@ def _build_graph(self): graph = self.trace.graph _logger.debug(graph) # build output mapping, from output debugName to its node - output_to_node = {x.debugName(): n for n in graph.nodes() for x in n.outputs()} + output_to_node = {x.debugName(): n for n in graph.nodes() + for x in n.outputs()} # build input mapping, from input debugName to its node - input_to_node = {x.debugName(): n for n in graph.nodes() for x in n.inputs()} + input_to_node = {x.debugName(): n for n in graph.nodes() + for x in n.inputs()} # build module mapping, from module name to all nodes (as list) under this module scope module_to_nodes = defaultdict(list) # the mapping of function (non-module in forward) to nodes, key is scope name @@ -403,7 +498,8 @@ def _build_graph(self): nodes_py.append(NodePyIO(node, 'input')) self.leaf_modules = self._extract_leaf_modules() - module_to_type = {name: parse_traced_name(module._name) for name, module in self.trace.named_modules()} + module_to_type = {name: parse_traced_name( + module._name) for name, module in self.trace.named_modules()} # associate module name with their trace graph nodes for node in graph.nodes(): @@ -412,14 +508,24 @@ def _build_graph(self): module_to_nodes[module_name].append(node) else: func_to_nodes[node.scopeName()].append(node) - # build node group for module for module_name, node_cpps in module_to_nodes.items(): - node_group = self._build_module_node_group( - module_name, module_to_type[module_name], node_cpps, input_to_node, output_to_node - ) - _logger.debug('node_group: %s', node_group) - nodes_py.nodes_op.append(node_group) + use_count = 0 + merged = set() + for node in node_cpps: + if node not in merged: + # modules that have same scope name may have different locations in the + # graph. Futhermore, there are also lots of prim:: nodes that in node_cpps, + # so we also need to call the expand_module_node. + unique_name = module_name + if use_count > 0: + unique_name = module_name + '.%d' % use_count + node_group = self._expand_module_node( + node, module_name, unique_name, module_to_type[module_name], + node_cpps, input_to_node, output_to_node, 'module') + nodes_py.nodes_op.append(node_group) + use_count += 1 + merged.update(node_group.node_cpps) # each scope_name may have multiple funcs, we split them and create node for each of them # build node group for torch.nn.functional @@ -431,11 +537,13 @@ def _build_graph(self): non_prim_nodes.append(node) # for each non prim node, expand it for node in non_prim_nodes: - node_group = self._expand_non_prim_node(node, nodes, input_to_node, output_to_node) + node_group = self._expand_non_prim_node( + node, nodes, input_to_node, output_to_node, 'func') nodes_py.nodes_op.append(node_group) # get shape infor for view (aten::view) func if node_group.op_type in ['aten::view', 'aten::flatten']: node_group.auxiliary = self._extract_shape_info(node) + for node in graph.outputs(): # Create sink nodes for output ops node_py = NodePyIO(node, 'output') nodes_py.append(node_py) @@ -444,14 +552,14 @@ def _build_graph(self): # build index return self._build_index(self.nodes_py.nodes_op) - def find_predecessors(self, module_name): + def find_predecessors(self, unique_name): """ Find predecessor node of the given node Parameters ---------- - module_name : str - The name of the node + unique_name : str + The unique name of the node Returns ------- @@ -459,22 +567,22 @@ def find_predecessors(self, module_name): a list of nodes who are the given node's predecessor """ predecessors = [] - for _input in self.name_to_node[module_name].inputs: + for _input in self.name_to_node[unique_name].inputs: if not _input in self.output_to_node: _logger.debug("cannot find node with %s as its output", _input) else: node_py = self.output_to_node[_input] - predecessors.append(node_py.name) + predecessors.append(node_py.unique_name) return predecessors - def find_successors(self, module_name): + def find_successors(self, unique_name): """ Find successor nodes of the given node Parameters ---------- - module_name : str - The name of the node + unique_name : str + The unique name of the node Returns ------- @@ -482,9 +590,11 @@ def find_successors(self, module_name): a list of nodes who are the given node's successor """ successors = [] - for output in self.name_to_node[module_name].outputs: - assert output in self.input_to_node, "No node with input {}".format(output) + for output in self.name_to_node[unique_name].outputs: + if output not in self.input_to_node: + # may reach the output of the whole graph + continue nodes_py = self.input_to_node[output] for node_py in nodes_py: - successors.append(node_py.name) + successors.append(node_py.unique_name) return successors diff --git a/src/sdk/pynni/tests/test_graph_utils.py b/src/sdk/pynni/tests/test_graph_utils.py index 38ff67c52c..92851bc91c 100644 --- a/src/sdk/pynni/tests/test_graph_utils.py +++ b/src/sdk/pynni/tests/test_graph_utils.py @@ -15,7 +15,7 @@ import unittest from unittest import TestCase, main -from nni._graph_utils import build_module_graph, build_graph +from nni._graph_utils import build_module_graph, build_graph, TorchModuleGraph class BackboneModel1(nn.Module): def __init__(self): @@ -153,6 +153,46 @@ def forward(self, x): torch.randn(4, 5), os.path.join(os.path.dirname(__file__), "expect", "test_graph_module3.expect") ) + + @unittest.skipIf(torch.__version__ < "1.4.0", "not supported") + def test_module_reuse(self): + class MyModule(nn.Module): + def __init__(self): + super().__init__() + self.liner1 = nn.Linear(10, 10) + self.relu = nn.ReLU(inplace=True) + self.liner2 = nn.Linear(10, 20) + self.liner3 = nn.Linear(20, 10) + + def forward(self, x): + x = self.liner1(x) + x = self.relu(x) + x = self.liner2(x) + x = self.relu(x) + x = self.liner3(x) + x = self.relu(x) + return x + + data = torch.rand(10, 10) + net = MyModule() + traced = torch.jit.trace(net, data) + modulegraph = TorchModuleGraph(traced_model=traced) + # Traverse the TorchModuleGraph, due the resue of the relu module, + # there will be three cpp_nodes corrspoding to the same module. + # During traversing the graph, there should be only one + # successor of each cpp-node (including the cpp_nodes that corresponds + # to the same relu module). + for name, nodeio in modulegraph.nodes_py.nodes_io.items(): + if nodeio.input_or_output == 'input': + # Find the first node of the whole graph + start_nodes = modulegraph.input_to_node[name] + # We have only one single path top-down + assert len(start_nodes) == 1 + node = start_nodes[0].unique_name + while modulegraph.find_successors(node): + nodes = modulegraph.find_successors(node) + assert len(nodes) == 1 + node = nodes[0] if __name__ == '__main__': main()