diff --git a/tensorboardX/caffe2_graph.py b/tensorboardX/caffe2_graph.py index 63539e13..341d1f7c 100644 --- a/tensorboardX/caffe2_graph.py +++ b/tensorboardX/caffe2_graph.py @@ -1,43 +1,37 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from __future__ import unicode_literals import copy +import logging import os import re import six from builtins import bytes - from caffe2.proto import caffe2_pb2 from caffe2.python import core, workspace from .proto.graph_pb2 import GraphDef from .proto.node_def_pb2 import NodeDef from .proto.tensor_shape_pb2 import TensorShapeProto -from .proto.versions_pb2 import VersionDef - - -def _propagate_device_option(net): - '''Propagate the device options from net to operators.''' - if not net.HasField("device_option"): - return - for op in net.op: - if not op.HasField("device_option"): - op.device_option.CopyFrom(net.device_option) -def _get_blob_names(ops): - '''Get all the operators.''' - names = set() - for op in ops: - names.update(op.input) - names.update(op.output) - return {name: name for name in names} +def _make_unique_name(seen, name, min_version=0): + ''' + Make the name unique by appending a unique number to the name. Used for SSA. + Args: + seen (set): Set of names that have already been used (with respect to + some context). + name (string): The name to make unique + min_version (number): Starting index. Is incremented continually until + it can make the resulting name unique relative to 'seen'. -def _make_unique_name(seen, name, min_version=0): - '''Make the name unique.''' + Returns: + x (string): A version of name that is not in seen. + ''' assert name is not None i = min_version x = '%s_%d' % (name, i) if i else name @@ -48,54 +42,22 @@ def _make_unique_name(seen, name, min_version=0): return x -def _remap_keys(m, f): - '''Remap keys for names.''' - m2 = {f(key): value for key, value in six.iteritems(m)} - m.clear() - m.update(m2) - - -def _rename_all(track_blob_names, ops, f): - '''Rename all the names in the operators.''' - seen = set() - renamed = {} - - def g(name): - ''' Collision-free version.''' - if name is None: - return None - if name in renamed: - return renamed[name] - new_name = _make_unique_name(seen, f(name)) - renamed[name] = new_name - return new_name - - for op in ops: - inputs = list(op.input) - outputs = list(op.output) - del op.input[:] - del op.output[:] - op.input.extend(g(name) for name in inputs) - op.output.extend(g(name) for name in outputs) - - if track_blob_names: - _remap_keys(track_blob_names, g) - - seen.clear() - renamed.clear() - for op in ops: - op.name = g(op.name) - - -def _replace_colons(track_blob_names, ops): - '''`:i` has a special meaning in Tensorflow.''' - def f(name): - return name.replace(':', '$') - _rename_all(track_blob_names, ops, f) - - -def _formalize_for_tensorflow(track_blob_names, ops): - '''Convert some of the common names in caffe2 to tensorflow.''' +def _rename_tensorflow_style(shapes, blob_name_tracker, ops): + ''' + Convert some of the common names in Caffe2 to tensorflow. + NOTE: The common names in both Caffe2 and Tensorflow are currently + hardcoded, if either side changes at some point, then this code should + change as well. + + Args: + shapes: Dictionary mapping blob names to their shapes/dimensions. + blob_name_tracker: Dictionary of all unique blob names (with respect to + some context). + ops: List of Caffe2 operators + + Returns: + None. The _rename_all() call modifies blob_name_tracker and ops in-place. + ''' WEIGHT = re.compile(r"(_w)$") WEIGHT_ = re.compile(r"(_w_)") BN = re.compile(r"(_bn)$") @@ -116,32 +78,45 @@ def f(name): inter_name = SUM_.sub('/sum_', SUM.sub('/sum', inter_name)) new_name = BRANCH.sub('/branch', inter_name) return new_name - _rename_all(track_blob_names, ops, f) + _rename_all(shapes, blob_name_tracker, ops, f) -def _convert_to_ssa(track_blob_names, ops): +def _convert_to_ssa(shapes, blob_name_tracker, ops): ''' Convert an operator graph to SSA (i.e. out-of-place). - I.e. blobs will be renamed so that each blob is produced only once. + i.e. blobs will be renamed so that each blob is produced only once. + + Args: + shapes: Dictionary mapping blob names to their shapes/dimensions. + blob_name_tracker: Dictionary of all unique blob names (with respect to + some context). + ops: List of Caffe2 operators + + Returns: + None. Modifies blob_name_tracker and ops in-place. ''' ir = core.IR(ops) seen = set() versioned = {} - new_track_blob_names = {} + new_shapes = {} + new_blob_name_tracker = {} def ssa_name(name, versions): assert name in versions version = versions[name] if (name, version) in versioned: return versioned[(name, version)] - # Always setting new_name = `{name}_{version}` would work, but we also try + # Always setting name2 = `{name}_{version}` would work, but we also try # to avoid a trailing `_0`, so we have to be careful not to introduce # name collisions, such as (foo_1, 0) = foo_1 = (foo, 1). # Note: operator names (if any) will be handled later. new_name = _make_unique_name(seen, name, min_version=version) versioned[(name, version)] = new_name - if track_blob_names and name in track_blob_names: - new_track_blob_names[new_name] = track_blob_names[name] + # Transfer shape. + if name in shapes: + new_shapes[new_name] = shapes[name] + if blob_name_tracker and name in blob_name_tracker: + new_blob_name_tracker[new_name] = blob_name_tracker[name] return new_name for (op, ssa) in zip(ops, ir.ssa): @@ -153,80 +128,304 @@ def ssa_name(name, versions): op.input.extend(ssa_name(name, ssa.in_versions) for name in inputs) op.output.extend(ssa_name(name, ssa.out_versions) for name in outputs) - if track_blob_names: - track_blob_names.clear() - track_blob_names.update(new_track_blob_names) + shapes.clear() + shapes.update(new_shapes) + if blob_name_tracker: + blob_name_tracker.clear() + blob_name_tracker.update(new_blob_name_tracker) + + +def _get_blob_names(ops): + ''' + Get all the operator input and output blobs and perform dedup on their names. + Args: + ops: List of Caffe2 operators to extract inputs and outputs from -def _add_gradient_scope(track_blob_names, ops): - '''Separate out gradient and momentum for names.''' + Returns: + set containing distinct inputs and outputs from 'ops' + ''' + names = set() + for op in ops: + names.update(op.input) + names.update(op.output) + return {name: name for name in names} + + +def _remap_keys(old_dict, rename_fn): + ''' + Rename keys of 'old_dict' according to 'rename_fn'. + + Args: + old_dict: Dictionary (i.e. containing blob_name -> blob_name + relationships.) + remap_fn: Function string -> string for renaming. + + Returns: + None. Modifies old_dict in-place. + ''' + new_dict = {rename_fn(key): value for key, + value in six.iteritems(old_dict)} + old_dict.clear() + old_dict.update(new_dict) + + +def _rename_all(shapes, blob_name_tracker, ops, rename_fn): + ''' + Rename all the names in the operators. + + Args: + shapes: Dictionary mapping blob names to their shapes/dimensions. + blob_name_tracker: Dictionary of all unique blob names (with respect to + some context). + ops: List of Caffe2 operators + rename_fn: Function string -> string that specifies how to rename + + Returns: + None. Modifies shapes, blob_name_tracker and ops in-place using the + specified 'rename_fn'. + ''' + seen = set() + renamed = {} + + def g(name): + """ Collision-free version of f. + """ + if name is None: + return None + if name in renamed: + return renamed[name] + new_name = _make_unique_name(seen, rename_fn(name)) + renamed[name] = new_name + return new_name + + for op in ops: + inputs = list(op.input) + outputs = list(op.output) + del op.input[:] + del op.output[:] + op.input.extend(g(name) for name in inputs) + op.output.extend(g(name) for name in outputs) + + _remap_keys(shapes, g) + if blob_name_tracker: + _remap_keys(blob_name_tracker, g) + # Rename all operator names (if any) independently so that the + # unique-fication happens only once in _fill_missing_operator_names(). + seen.clear() + renamed.clear() + for op in ops: + op.name = g(op.name) + + +def _add_gradient_scope(shapes, blob_name_tracker, ops): + """ + For all operators or blobs with name containing "_grad", add a + "GRADIENTS/" scope. + Note: breaks graph execution since the blob -> gradient mapping is + hardcoded. + + Args: + shapes: Dictionary mapping blob names to their shapes/dimensions. + blob_name_tracker: Dictionary of all unique blob names (with respect to + some context). + ops: List of Caffe2 operators + + Returns: + None. Modifies shapes, blob_name_tracker and ops in-place by renaming. + """ def f(name): - new_name = name if '_grad' in name: - new_name = 'Gradients/{}'.format(new_name.replace('_grad', '')) - if '_momentum' in name: - new_name = 'Momentum/{}'.format(new_name.replace('_momentum', '')) - return new_name - _rename_all(track_blob_names, ops, f) + return 'GRADIENTS/{}'.format(name) + else: + return name + _rename_all(shapes, blob_name_tracker, ops, f) + + +def _replace_colons(shapes, blob_name_tracker, ops, repl): + ''' + `:i` has a special meaning in Tensorflow. This function replaces all colons + with $ to avoid any possible conflicts. + + Args: + shapes: Dictionary mapping blob names to their shapes/dimensions. + blob_name_tracker: Dictionary of all unique blob names (with respect to + some context). + ops: List of Caffe2 operators + repl: String representing the text to replace ':' with. Usually this is + '$'. + + Returns: + None. Modifies blob_name_tracker in-place. + + ''' + def f(name): + return name.replace(':', repl) + _rename_all(shapes, blob_name_tracker, ops, f) + + +def _fill_missing_operator_names(ops): + ''' + Give missing operators a name. + We expect C2 operators to be generally unnamed. This gives them a scope + (inferred from their outputs) and a name after their type. Duplicates will + be postfixed by an index. + + Args: + ops: List of Caffe2 operators to assign names to. + + Returns: + None: Modifies 'ops' in-place. + ''' + seen = set() + for op in ops: + # Make sure operator names don't collide with blobs. + seen.update(op.input) + seen.update(op.output) + for op in ops: + if op.name: + name = op.name + elif op.output or op.input: + name_list = [os.path.dirname(name) + for name in op.output or op.input] + scope = os.path.commonprefix(name_list) + name = os.path.join(scope, op.type) + else: + name = op.type + assert(name) + op.name = _make_unique_name(seen, name) def _tf_device(device_option): - '''Handle the devices.''' + ''' + Handle the devices. + + Args: + device_option (caffe2_pb2.DeviceOption): DeviceOption protobuf, + associated to an operator, that contains information such as + device_type (optional), cuda_gpu_id (optional), node_name (optional, + tells which node the operator should execute on). See caffe2.proto + in caffe2/proto for the full list. + + Returns: + Formatted string representing device information contained in + device_option. + ''' if not device_option.HasField("device_type"): return "" if device_option.device_type == caffe2_pb2.CPU: return "/cpu:*" if device_option.device_type == caffe2_pb2.CUDA: return "/gpu:{}".format(device_option.cuda_gpu_id) - raise Exception("Un-handled device", device_option) + raise Exception("Unhandled device", device_option) -def _add_tf_shape(m, ints): - '''Add shapes of the node.''' - sh = TensorShapeProto() +def _add_tf_shape(attr_dict, ints): + ''' + Converts a list of ints to a TensorShapeProto representing the dimensions of + a blob/object. + + Args: + attr_dict: Dictionary to update (usually attributes of a Node) + ints: List of integers representing dimensions of some object. + + Returns: + None. Modifies attr_dict in-place. + ''' + shape_proto = TensorShapeProto() for i in ints: dim = TensorShapeProto.Dim() dim.size = i - sh.dim.extend([dim]) - m['_output_shapes'].list.shape.extend([sh]) + shape_proto.dim.extend([dim]) + attr_dict['_output_shapes'].list.shape.extend([shape_proto]) -def _set_tf_attr(m, arg): - '''Add some other attributes.''' +def _set_tf_attr(attr_dict, arg): + ''' + Add attributes to a node. Key is the arg.name, and values can be shape, + floats, strings, ints or an empty list. + + Args: + attr_dict: Dictionary to update (usually attributes of a Node) + arg: Object with name and data fields. + + Returns: + None. Modifies attr_dict in-place. + ''' k = arg.name if k == 'shape' and arg.ints: - _add_tf_shape(m, arg.ints) + _add_tf_shape(attr_dict, arg.ints) return - # float + # Float if arg.HasField("f"): - m[k].f = arg.f + attr_dict[k].f = arg.f return - # integer + # Integer if arg.HasField("i"): - m[k].i = arg.i + attr_dict[k].i = arg.i return - # string + # String if arg.HasField("s"): - m[k].s = ( - arg.s if isinstance(arg.s, bytes) else str(arg.s).encode('utf-8')) + attr_dict[k].s = ( + arg.s if isinstance(arg.s, bytes) else str(arg.s).encode('utf-8') + ) return if arg.floats: - m[k].list.f.extend(arg.floats) + attr_dict[k].list.f.extend(arg.floats) return if arg.ints: - m[k].list.i.extend(arg.ints) + attr_dict[k].list.i.extend(arg.ints) return if arg.strings: - m[k].list.s.extend( + attr_dict[k].list.s.extend( s if isinstance(s, bytes) else str(s).encode('utf-8') - for s in arg.strings) + for s in arg.strings + ) return # The value is an empty list. - m[k].list.s.extend([]) + attr_dict[k].list.s.extend([]) -def _operator_to_node(op, inter_blobs, seen): - '''Convert the operators to nodes.''' +def _operator_to_node(shapes, op): + ''' + Converts an operator to a node in a TF graph. + + Args: + shapes: Dictionary mapping blob names to their shapes/dimensions. + op: The Caffe2 operator to convert to a TF graph node. + + Returns: + n: The TF graph node created from op. + ''' + assert op.name, op + n = NodeDef() + n.name = op.name + n.input.extend(op.input) + n.op = op.type + n.device = _tf_device(op.device_option) + if shapes: + # Add shapes in order. + for output in op.output: + if output not in shapes: + break + _add_tf_shape(n.attr, shapes[output]) + for arg in op.arg: + _set_tf_attr(n.attr, arg) + return n + + +def _operator_to_node_simp(op, inter_blobs, seen): + ''' + Convert the operators to nodes. + + Args: + op: Caffe2 operator to convert to node + inter_blobs: Set of intermediate blobs + seen: Names that have already been used and are not unique + + Returns: + nodes: Nodes representing 'op' and the outputs of 'op' + ''' assert op nodes = [] outputs = [o for o in op.output if o not in inter_blobs] @@ -247,8 +446,8 @@ def _operator_to_node(op, inter_blobs, seen): if op.name: name = op.name else: - l = [name for name in outputs] - scope = os.path.commonprefix(l) + name_list = [name for name in outputs] + scope = os.path.commonprefix(name_list) name = os.path.join(scope, op.type) assert(name) op.name = _make_unique_name(seen, name) @@ -276,69 +475,104 @@ def _operator_to_node(op, inter_blobs, seen): return nodes -def _input_blob_to_node(name): - '''Input blobs to node.''' +def _blob_to_node(producing_ops, shapes, name): + ''' + Converts a blob (operator input or output) to a node in a TF graph. + + Args: + producing_ops: Dictionary of blob name to list of + (producing_op, blob_index within producing_op.output) mapping. + shapes: Dictionary mapping blob names to their shapes/dimensions. + name: String representing the name of this blob. + + Returns: + n: The TF graph node created from this blob. + ''' assert name n = NodeDef() n.name = name - n.op = 'Placeholder' + # Get all ops that have the blob corresponding to 'name' as one of their + # outputs. See _operators_to_graph_def. + produced_by = producing_ops.get(name, []) + if len(produced_by) > 0: + n.op = 'Blob' + else: + # This blob is not produced but is instead a TF Placeholder where a + # value is passed in. + n.op = 'Placeholder' + n.input.extend('%s:%d' % (p_op.name, i) for p_op, i in produced_by) + if produced_by: + device = produced_by[0][0].device_option + if (all(producer[0].device_option == device for producer in produced_by)): + n.device = _tf_device(device) + if shapes and name in shapes: + _add_tf_shape(n.attr, shapes[name]) return n -def _clear_debug_info(ops): - '''Remove the debug information, they are copious.''' +def _clear_debug_info(ops, perform_clear): + ''' + Removes debug information from operators, they are copious. + + Args: + ops: List of Caffe2 operators + perform_clear: Boolean passed from _operators_to_graph_def specifying + whether to remove the debug information. This boolean is passed into + this function to reduce the complexity of _operators_to_graph_def. + + Returns: + None. Modifies the list of Caffe2 operators in-place and removes the + 'debug_info' field. + + ''' + if not perform_clear: + return + for op in ops: if op.HasField('debug_info'): op.ClearField('debug_info') -def _get_gpu_zero(track_blob_names, ops): - '''Just display the nodes that involve GPU zero as the output.''' - def f(op): - output = str(op.output[0]) - if output.startswith('gpu_0/'): - return True - return False - new_ops = [op for op in ops if f(op)] +def _check_if_forward(blob): + ''' + Blobs with names containing '_m' or 'grad' are part of the backward pass. + This function references facebookresearch/Detectron/detectron/utils/net.py. - # Remove scope. - GPU0 = re.compile(r"^(gpu_0/)") - - def g(name): - new_name = GPU0.sub('', name) - return new_name - _rename_all(track_blob_names, new_ops, g) - return new_ops + Args: + blob: The blob to inspect + Returns: + Boolean representing whether this blob is part of the forward pass + ''' + # + return (blob.find('__m') < 0 or blob.find('grad') < 0) -def _remove_unwanted(ops): - '''Remove some unwanted operators.''' - def f(blob): - flag = True - flag &= blob.find('__m') < 0 - flag &= not blob.startswith('_gpu') - return flag - new_ops = [] - for op in ops: - inputs = list(op.input) - outputs = list(op.output) - del op.input[:] - del op.output[:] - new_inputs = [i for i in inputs if f(i)] - new_outputs = [o for o in outputs if f(o)] +def _check_if_cpu(blob): + ''' + Check if the blob's name starts with '_gpu'. - # Only add the op if output is not empty - if new_outputs: - op.input.extend(new_inputs) - op.output.extend(new_outputs) - new_ops.append(op) + Args: + blob: The blob to inspect - return new_ops + Returns: + Boolean representing whether this blob is associated with a gpu + ''' + return not blob.startswith('_gpu') def _compute_in_out(ops): - # Find the input and output nodes. + ''' + Find the input, intermediate and output nodes of a set of operators. + + Args: + ops: List of Caffe2 operators to look through + + Returns: + input_blobs: The input nodes of the set of operators + inter_blobs: The intermediate nodes of the set of operators + output_blobs: The output nodes of the set of operators + ''' in_blobs = set() out_blobs = set() @@ -350,61 +584,238 @@ def _compute_in_out(ops): input_blobs = list(in_blobs.difference(out_blobs)) output_blobs = list(out_blobs.difference(in_blobs)) - inter_blobs = {b: 1 for b in output_blobs if b.startswith('_')} + inter_blobs = {b for b in output_blobs if b.startswith('_')} output_blobs = [b for b in output_blobs if b not in inter_blobs] return input_blobs, inter_blobs, output_blobs -def _operators_to_graph_def(ops, - clear_debug_info=True, - single_gpu=False, - remove_unwanted=True, - custom_rename=None): - '''Main function to convert set of operators to a graph.''' - - track_blob_names = {} - track_blob_names.update(_get_blob_names(ops)) - if clear_debug_info: - _clear_debug_info(ops) - - if single_gpu: - ops = _get_gpu_zero(track_blob_names, ops) - if remove_unwanted: - ops = _remove_unwanted(ops) - _convert_to_ssa(track_blob_names, ops) - _formalize_for_tensorflow(track_blob_names, ops) - _replace_colons(track_blob_names, ops) - if custom_rename: - _rename_all(track_blob_names, ops, custom_rename) - _add_gradient_scope(track_blob_names, ops) +def _filter_ops(ops, filter_fn, perform_filter): + ''' + Filter unwanted operators based on criteria in 'filter_fn'. + + Args: + ops: List of Caffe2 operators to filter + filter_fn: Criteria function for whether inputs/outputs in an operator + should be filtered. + perform_filter: Boolean passed from _operators_to_graph_def specifying + whether to filter operators + + Returns: + new_ops: Subset of ops containing a subset of their inputs and outputs. + ''' + if not perform_filter: + return ops + + new_ops = [] + for op in ops: + inputs = list(op.input) + outputs = list(op.output) + del op.input[:] + del op.output[:] + new_inputs = [i for i in inputs if filter_fn(i)] + new_outputs = [o for o in outputs if filter_fn(o)] + # Only add the op if output is not empty + if new_outputs: + op.input.extend(new_inputs) + op.output.extend(new_outputs) + new_ops.append(op) + + return new_ops + + +def _operators_to_graph_def( + shapes, + ops, + colon_replacement='$', + with_ssa=True, + with_gradient_scope=True, + blob_name_tracker=None, + show_simplified=False, + custom_rename=None +): + ''' + Main function to convert set of operators to a graph. + + Args: + shapes: Dictionary mapping blob names to their shapes/dimensions. + ops: List of Caffe2 operators, representing some computation graph + ### **kwargs (model_to_graph_def, nets_to_graph_def, protos_to_graph_def) ### + colon_replacement: Symbol to replace ':' with. ':i' in TF has a special + meaning, so we need to replace it with a non-conflicting symbol. + with_ssa: Boolean + with_gradient_scope: Boolean + blob_name_tracker: Dictionary tracking names of blobs (inputs/outputs + from operators) + show_simplified: Whether to show a simplified version of the model graph + Sets all of the following values: + clear_debug_info: Boolean representing whether to silence debug + info (which can be very verbose) + show_forward_only: Boolean representing whether to only show + blobs involved in the forward pass + show_cpu_only: Boolean representing whether to only show blobs + that are not associated with a gpu + use_tensorflow_naming: Boolean representing whether to convert + some common Caffe2 naming conventions to their Tensorflow + counterparts + custom_rename: Function string -> string that defines a custom + renaming function to use. + + Returns: + current_graph: GraphDef representing the computation graph formed by the + set of operators. + blob_name_tracker: (Filtered) list of blob names corresponding to input + and output nodes of the operators in the graph. + ''' + if blob_name_tracker is not None: + blob_name_tracker.clear() + else: + blob_name_tracker = {} + + blob_name_tracker.update(_get_blob_names(ops)) + + _clear_debug_info(ops, show_simplified) # clear_debug_info + ops = _filter_ops(ops, _check_if_forward, + show_simplified) # show_forward_only + ops = _filter_ops(ops, _check_if_cpu, show_simplified) # show_cpu_only + if custom_rename: + _rename_all(shapes, blob_name_tracker, ops, custom_rename) + if colon_replacement: + _replace_colons(shapes, blob_name_tracker, ops, colon_replacement) + if with_ssa: + _convert_to_ssa(shapes, blob_name_tracker, ops) + if with_gradient_scope: + _add_gradient_scope(shapes, blob_name_tracker, ops) + _fill_missing_operator_names(ops) + if show_simplified: # use_tensorflow_naming + _rename_tensorflow_style(shapes, blob_name_tracker, ops) + producing_ops = {} + blobs = set() input_blobs, inter_blobs, _ = _compute_in_out(ops) - current_graph = GraphDef(versions=VersionDef(producer=22)) + current_graph = GraphDef() seen = set(input_blobs) - for blob in input_blobs: - current_graph.node.extend([_input_blob_to_node(blob)]) for op in ops: - current_graph.node.extend(_operator_to_node(op, inter_blobs, seen)) + nodes_from_op = _operator_to_node_simp(op, inter_blobs, seen) if \ + show_simplified else \ + [_operator_to_node(shapes, op)] # .extend() expects an iterable + current_graph.node.extend(nodes_from_op) + for input_blob in op.input: + blobs.add(input_blob) + for i, output_blob in enumerate(op.output): + blobs.add(output_blob) + producing_ops.setdefault(output_blob, []).append((op, i)) + + if show_simplified: + # Show a cleaner, easier-to-interpret version of the model graph + blobs = input_blobs + + for blob in blobs: + current_graph.node.extend([_blob_to_node(producing_ops, {}, blob)]) + + return current_graph - return current_graph, track_blob_names +def _propagate_device_option(net_def): + ''' + Propagate the device options from net to operators. + + Args: + net_def: A caffe2_pb2.NetDef representing a computation graph. The graph + consists of Caffe2 operators. -def model_to_graph(model, **kwargs): - '''Convert a caffe2 model to a tensorflow graph.''' + Returns: + None. Iterates through all ops contained within the net. For each op, + modifies the op device_option in-place to be the net device_option + if the op has no pre-existing device_option, and leaves the op as-is + if it already has a device_option. + ''' + if not net_def.HasField("device_option"): + return + for op in net_def.op: + if not op.HasField("device_option"): + op.device_option.CopyFrom(net_def.device_option) + + +def _try_get_shapes(nets): + ''' + Get missing shapes for all blobs contained in the nets. + + Args: + nets: List of core.Net to extract blob shape information from. + + Returns: + Dictionary containing blob name to shape/dimensions mapping. The net + is a computation graph that is composed of operators, and the + operators have input and output blobs, each with their own dims. + ''' + try: + # Note: this will inspect the workspace for better or worse. + # We don't care about the types, only the shapes + shapes, _ = workspace.InferShapesAndTypes(nets) + return shapes + except Exception as e: + logging.warning('Failed to compute shapes: %s', e) + return {} + + +def model_to_graph_def(model, **kwargs): + ''' + Convert a Caffe2 model to a Tensorflow graph. This function extracts + 'param_init_net' and 'net' from the model and passes it to nets_to_graph() + for further processing. + + Args: + model (cnn.CNNModelHelper, model_helper.ModelHelper): The model to + extract the nets (instances of core.Net) from. + + Returns: + Call to nets_to_graph_def() with extracted 'param_init_net', 'net' and + **kwargs. See _operators_to_graph_def for detailed **kwargs. + ''' nets = [model.param_init_net, model.net] - return nets_to_graph(nets, **kwargs) + return nets_to_graph_def(nets, **kwargs) + +def nets_to_graph_def(nets, shapes=None, **kwargs): + ''' + Convert a set of Caffe2 nets to a Tensorflow graph. + + Args: + nets: List of core.Nets. core.Net is a wrapper around a NetDef protobuf. + The corresponding protobuf can be extracted using .Proto(). + shapes: Dictionary mapping blob names to their shapes/dimensions. -def nets_to_graph(nets, **kwargs): - '''Convert a set of caffe2 nets to a tensorflow graph.''' + Returns: + Call to protos_to_graph_def() with the extracted NetDef protobufs and + **kwargs. See _operators_to_graph_def for detailed **kwargs. + ''' + # if shapes is None: + # shapes = _try_get_shapes(nets) + # _try_get_shapes(nets) depends on workspace.InferShapesAndTypes(nets), + # which is currently broken (segfault). We omit the shapes for now. + shapes = {} nets = [copy.deepcopy(net.Proto()) for net in nets] - return protos_to_graph(nets, **kwargs) + shapes = copy.deepcopy(shapes) + return protos_to_graph_def(nets, shapes, **kwargs) -def protos_to_graph(nets, **kwargs): - '''Convert a set of caffe2 net definitions to a tensorflow graph.''' - for net in nets: +def protos_to_graph_def(net_defs, shapes=None, **kwargs): + ''' + Convert a set of Caffe2 net definitions to a Tensorflow graph. + + Args: + net_defs: List of caffe2_pb2.NetDef protobufs representing computation + graphs. + shapes: Dictionary mapping blob names to their shapes/dimensions. + + Returns: + Call to _operators_to_graph_def() with the extracted operators from the + NetDefs and **kwargs. See _operators_to_graph_def for detailed + **kwargs. + ''' + for net in net_defs: _propagate_device_option(net) - ops = [op for net in nets for op in net.op] - return _operators_to_graph_def(ops, **kwargs) + shapes = copy.deepcopy(shapes or {}) + ops = [op for net_def in net_defs for op in net_def.op] + return _operators_to_graph_def(shapes, ops, **kwargs) diff --git a/tests/test_caffe2.py b/tests/test_caffe2.py new file mode 100644 index 00000000..a6739cdd --- /dev/null +++ b/tests/test_caffe2.py @@ -0,0 +1,1786 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest + +try: + import caffe2.python.brew as brew + import caffe2.python.cnn as cnn + import caffe2.python.core as core + import caffe2.python.model_helper as model_helper + from caffe2.proto import caffe2_pb2 + import tensorboardX.caffe2_graph as tb + caffe2_installed = True +except (SystemExit, ImportError): + print('Caffe2 is not installed, skipping test') + caffe2_installed = False + + +EXPECTED_CNN = """ +node { + name: "conv1/XavierFill" + op: "XavierFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 96 + } + dim { + size: 3 + } + dim { + size: 11 + } + dim { + size: 11 + } + } + } + } + } +} +node { + name: "conv1/ConstantFill" + op: "ConstantFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 96 + } + } + } + } + } +} +node { + name: "classifier/XavierFill" + op: "XavierFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1000 + } + dim { + size: 4096 + } + } + } + } + } +} +node { + name: "classifier/ConstantFill" + op: "ConstantFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1000 + } + } + } + } + } +} +node { + name: "ImageInput" + op: "ImageInput" + input: "db" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "is_test" + value { + i: 0 + } + } + attr { + key: "use_cudnn" + value { + i: 1 + } + } +} +node { + name: "NHWC2NCHW" + op: "NHWC2NCHW" + input: "data_nhwc" + device: "/gpu:0" +} +node { + name: "conv1/Conv" + op: "Conv" + input: "data" + input: "conv1/conv1_w" + input: "conv1/conv1_b" + device: "/gpu:0" + attr { + key: "exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 11 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "stride" + value { + i: 4 + } + } +} +node { + name: "conv1/Relu" + op: "Relu" + input: "conv1/conv1" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "conv1/MaxPool" + op: "MaxPool" + input: "conv1/conv1_1" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 2 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "stride" + value { + i: 2 + } + } +} +node { + name: "classifier/FC" + op: "FC" + input: "conv1/pool1" + input: "classifier/fc_w" + input: "classifier/fc_b" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "use_cudnn" + value { + i: 1 + } + } +} +node { + name: "classifier/Softmax" + op: "Softmax" + input: "classifier/fc" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "classifier/LabelCrossEntropy" + op: "LabelCrossEntropy" + input: "classifier/pred" + input: "label" + device: "/gpu:0" +} +node { + name: "classifier/AveragedLoss" + op: "AveragedLoss" + input: "classifier/xent" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/ConstantFill" + op: "ConstantFill" + input: "classifier/loss" + device: "/gpu:0" + attr { + key: "value" + value { + f: 1.0 + } + } +} +node { + name: "GRADIENTS/classifier/AveragedLossGradient" + op: "AveragedLossGradient" + input: "classifier/xent" + input: "GRADIENTS/classifier/loss_autogen_grad" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/LabelCrossEntropyGradient" + op: "LabelCrossEntropyGradient" + input: "classifier/pred" + input: "label" + input: "GRADIENTS/classifier/xent_grad" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/SoftmaxGradient" + op: "SoftmaxGradient" + input: "classifier/pred" + input: "GRADIENTS/classifier/pred_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "GRADIENTS/c/FCGradient" + op: "FCGradient" + input: "conv1/pool1" + input: "classifier/fc_w" + input: "GRADIENTS/classifier/fc_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "use_cudnn" + value { + i: 1 + } + } +} +node { + name: "GRADIENTS/conv1/MaxPoolGradient" + op: "MaxPoolGradient" + input: "conv1/conv1_1" + input: "conv1/pool1" + input: "GRADIENTS/conv1/pool1_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 2 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "stride" + value { + i: 2 + } + } +} +node { + name: "GRADIENTS/conv1/ReluGradient" + op: "ReluGradient" + input: "conv1/conv1_1" + input: "GRADIENTS/conv1/conv1_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "GRADIENTS/ConvGradient" + op: "ConvGradient" + input: "data" + input: "conv1/conv1_w" + input: "GRADIENTS/conv1/conv1_grad_1" + device: "/gpu:0" + attr { + key: "exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 11 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "stride" + value { + i: 4 + } + } +} +node { + name: "GRADIENTS/NCHW2NHWC" + op: "NCHW2NHWC" + input: "GRADIENTS/data_grad" + device: "/gpu:0" +} +node { + name: "conv1/conv1_w" + op: "Blob" + input: "conv1/XavierFill:0" + device: "/gpu:0" +} +node { + name: "classifier/fc" + op: "Blob" + input: "classifier/FC:0" + device: "/gpu:0" +} +node { + name: "data_nhwc" + op: "Blob" + input: "ImageInput:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv1_b_grad" + op: "Blob" + input: "GRADIENTS/ConvGradient:1" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/pred_grad" + op: "Blob" + input: "GRADIENTS/classifier/LabelCrossEntropyGradient:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/fc_grad" + op: "Blob" + input: "GRADIENTS/classifier/SoftmaxGradient:0" + device: "/gpu:0" +} +node { + name: "conv1/conv1_b" + op: "Blob" + input: "conv1/ConstantFill:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/fc_b_grad" + op: "Blob" + input: "GRADIENTS/c/FCGradient:1" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/fc_w_grad" + op: "Blob" + input: "GRADIENTS/c/FCGradient:0" + device: "/gpu:0" +} +node { + name: "label" + op: "Blob" + input: "ImageInput:1" + device: "/gpu:0" +} +node { + name: "GRADIENTS/data_grad" + op: "Blob" + input: "GRADIENTS/ConvGradient:2" + device: "/gpu:0" +} +node { + name: "classifier/loss" + op: "Blob" + input: "classifier/AveragedLoss:0" + device: "/gpu:0" +} +node { + name: "conv1/conv1" + op: "Blob" + input: "conv1/Conv:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv1_grad" + op: "Blob" + input: "GRADIENTS/conv1/MaxPoolGradient:0" + device: "/gpu:0" +} +node { + name: "classifier/xent" + op: "Blob" + input: "classifier/LabelCrossEntropy:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/loss_autogen_grad" + op: "Blob" + input: "GRADIENTS/classifier/ConstantFill:0" + device: "/gpu:0" +} +node { + name: "classifier/fc_w" + op: "Blob" + input: "classifier/XavierFill:0" + device: "/gpu:0" +} +node { + name: "conv1/conv1_1" + op: "Blob" + input: "conv1/Relu:0" + device: "/gpu:0" +} +node { + name: "db" + op: "Placeholder" +} +node { + name: "classifier/pred" + op: "Blob" + input: "classifier/Softmax:0" + device: "/gpu:0" +} +node { + name: "classifier/fc_b" + op: "Blob" + input: "classifier/ConstantFill:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/xent_grad" + op: "Blob" + input: "GRADIENTS/classifier/AveragedLossGradient:0" + device: "/gpu:0" +} +node { + name: "data" + op: "Blob" + input: "NHWC2NCHW:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv1_w_grad" + op: "Blob" + input: "GRADIENTS/ConvGradient:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv1_grad_1" + op: "Blob" + input: "GRADIENTS/conv1/ReluGradient:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/data_nhwc_grad" + op: "Blob" + input: "GRADIENTS/NCHW2NHWC:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/pool1_grad" + op: "Blob" + input: "GRADIENTS/c/FCGradient:2" + device: "/gpu:0" +} +node { + name: "conv1/pool1" + op: "Blob" + input: "conv1/MaxPool:0" + device: "/gpu:0" +} +""" + +EXPECTED_MNIST = """ +node { + name: "conv1/XavierFill" + op: "XavierFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 20 + } + dim { + size: 1 + } + dim { + size: 5 + } + dim { + size: 5 + } + } + } + } + } +} +node { + name: "conv1/ConstantFill" + op: "ConstantFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 20 + } + } + } + } + } +} +node { + name: "conv1/XavierFill_1" + op: "XavierFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + dim { + size: 20 + } + dim { + size: 5 + } + dim { + size: 5 + } + } + } + } + } +} +node { + name: "conv1/ConstantFill_1" + op: "ConstantFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 100 + } + } + } + } + } +} +node { + name: "classifier/XavierFill" + op: "XavierFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 500 + } + dim { + size: 1600 + } + } + } + } + } +} +node { + name: "classifier/ConstantFill" + op: "ConstantFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 500 + } + } + } + } + } +} +node { + name: "classifier/XavierFill_1" + op: "XavierFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + dim { + size: 500 + } + } + } + } + } +} +node { + name: "classifier/ConstantFill_1" + op: "ConstantFill" + device: "/gpu:0" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 10 + } + } + } + } + } +} +node { + name: "ImageInput" + op: "ImageInput" + input: "db" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "is_test" + value { + i: 0 + } + } + attr { + key: "use_cudnn" + value { + i: 1 + } + } +} +node { + name: "NHWC2NCHW" + op: "NHWC2NCHW" + input: "data_nhwc" + device: "/gpu:0" +} +node { + name: "conv1/Conv" + op: "Conv" + input: "data" + input: "conv1/conv1_w" + input: "conv1/conv1_b" + device: "/gpu:0" + attr { + key: "exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 5 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "conv1/MaxPool" + op: "MaxPool" + input: "conv1/conv1" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 2 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "stride" + value { + i: 2 + } + } +} +node { + name: "conv1/Conv_1" + op: "Conv" + input: "conv1/pool1" + input: "conv1/conv2_w" + input: "conv1/conv2_b" + device: "/gpu:0" + attr { + key: "exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 5 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "conv1/MaxPool_1" + op: "MaxPool" + input: "conv1/conv2" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 2 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "stride" + value { + i: 2 + } + } +} +node { + name: "classifier/FC" + op: "FC" + input: "conv1/pool2" + input: "classifier/fc3_w" + input: "classifier/fc3_b" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "use_cudnn" + value { + i: 1 + } + } +} +node { + name: "classifier/Relu" + op: "Relu" + input: "classifier/fc3" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "classifier/FC_1" + op: "FC" + input: "classifier/fc3_1" + input: "classifier/pred_w" + input: "classifier/pred_b" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "use_cudnn" + value { + i: 1 + } + } +} +node { + name: "classifier/Softmax" + op: "Softmax" + input: "classifier/pred" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "classifier/LabelCrossEntropy" + op: "LabelCrossEntropy" + input: "classifier/softmax" + input: "label" + device: "/gpu:0" +} +node { + name: "classifier/AveragedLoss" + op: "AveragedLoss" + input: "classifier/xent" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/ConstantFill" + op: "ConstantFill" + input: "classifier/loss" + device: "/gpu:0" + attr { + key: "value" + value { + f: 1.0 + } + } +} +node { + name: "GRADIENTS/classifier/AveragedLossGradient" + op: "AveragedLossGradient" + input: "classifier/xent" + input: "GRADIENTS/classifier/loss_autogen_grad" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/LabelCrossEntropyGradient" + op: "LabelCrossEntropyGradient" + input: "classifier/softmax" + input: "label" + input: "GRADIENTS/classifier/xent_grad" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/SoftmaxGradient" + op: "SoftmaxGradient" + input: "classifier/softmax" + input: "GRADIENTS/classifier/softmax_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "GRADIENTS/classifier/FCGradient" + op: "FCGradient" + input: "classifier/fc3_1" + input: "classifier/pred_w" + input: "GRADIENTS/classifier/pred_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "use_cudnn" + value { + i: 1 + } + } +} +node { + name: "GRADIENTS/classifier/ReluGradient" + op: "ReluGradient" + input: "classifier/fc3_1" + input: "GRADIENTS/classifier/fc3_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "GRADIENTS/c/FCGradient" + op: "FCGradient" + input: "conv1/pool2" + input: "classifier/fc3_w" + input: "GRADIENTS/classifier/fc3_grad_1" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "use_cudnn" + value { + i: 1 + } + } +} +node { + name: "GRADIENTS/conv1/MaxPoolGradient" + op: "MaxPoolGradient" + input: "conv1/conv2" + input: "conv1/pool2" + input: "GRADIENTS/conv1/pool2_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 2 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "stride" + value { + i: 2 + } + } +} +node { + name: "GRADIENTS/conv1/ConvGradient" + op: "ConvGradient" + input: "conv1/pool1" + input: "conv1/conv2_w" + input: "GRADIENTS/conv1/conv2_grad" + device: "/gpu:0" + attr { + key: "exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 5 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "GRADIENTS/conv1/MaxPoolGradient_1" + op: "MaxPoolGradient" + input: "conv1/conv1" + input: "conv1/pool1" + input: "GRADIENTS/conv1/pool1_grad" + device: "/gpu:0" + attr { + key: "cudnn_exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 2 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } + attr { + key: "stride" + value { + i: 2 + } + } +} +node { + name: "GRADIENTS/ConvGradient" + op: "ConvGradient" + input: "data" + input: "conv1/conv1_w" + input: "GRADIENTS/conv1/conv1_grad" + device: "/gpu:0" + attr { + key: "exhaustive_search" + value { + i: 0 + } + } + attr { + key: "kernel" + value { + i: 5 + } + } + attr { + key: "order" + value { + s: "NCHW" + } + } +} +node { + name: "GRADIENTS/NCHW2NHWC" + op: "NCHW2NHWC" + input: "GRADIENTS/data_grad" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/fc3_grad_1" + op: "Blob" + input: "GRADIENTS/classifier/ReluGradient:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/xent_grad" + op: "Blob" + input: "GRADIENTS/classifier/AveragedLossGradient:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/pred_w_grad" + op: "Blob" + input: "GRADIENTS/classifier/FCGradient:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/data_nhwc_grad" + op: "Blob" + input: "GRADIENTS/NCHW2NHWC:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/fc3_w_grad" + op: "Blob" + input: "GRADIENTS/c/FCGradient:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv1_grad" + op: "Blob" + input: "GRADIENTS/conv1/MaxPoolGradient_1:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv1_b_grad" + op: "Blob" + input: "GRADIENTS/ConvGradient:1" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv2_w_grad" + op: "Blob" + input: "GRADIENTS/conv1/ConvGradient:0" + device: "/gpu:0" +} +node { + name: "classifier/pred" + op: "Blob" + input: "classifier/FC_1:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/pool2_grad" + op: "Blob" + input: "GRADIENTS/c/FCGradient:2" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv1_w_grad" + op: "Blob" + input: "GRADIENTS/ConvGradient:0" + device: "/gpu:0" +} +node { + name: "data" + op: "Blob" + input: "NHWC2NCHW:0" + device: "/gpu:0" +} +node { + name: "classifier/xent" + op: "Blob" + input: "classifier/LabelCrossEntropy:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/pool1_grad" + op: "Blob" + input: "GRADIENTS/conv1/ConvGradient:2" + device: "/gpu:0" +} +node { + name: "db" + op: "Placeholder" +} +node { + name: "classifier/fc3_b" + op: "Blob" + input: "classifier/ConstantFill:0" + device: "/gpu:0" +} +node { + name: "classifier/pred_b" + op: "Blob" + input: "classifier/ConstantFill_1:0" + device: "/gpu:0" +} +node { + name: "classifier/softmax" + op: "Blob" + input: "classifier/Softmax:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/data_grad" + op: "Blob" + input: "GRADIENTS/ConvGradient:2" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/pred_b_grad" + op: "Blob" + input: "GRADIENTS/classifier/FCGradient:1" + device: "/gpu:0" +} +node { + name: "label" + op: "Blob" + input: "ImageInput:1" + device: "/gpu:0" +} +node { + name: "conv1/pool1" + op: "Blob" + input: "conv1/MaxPool:0" + device: "/gpu:0" +} +node { + name: "data_nhwc" + op: "Blob" + input: "ImageInput:0" + device: "/gpu:0" +} +node { + name: "conv1/conv2" + op: "Blob" + input: "conv1/Conv_1:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv2_grad" + op: "Blob" + input: "GRADIENTS/conv1/MaxPoolGradient:0" + device: "/gpu:0" +} +node { + name: "conv1/conv2_b" + op: "Blob" + input: "conv1/ConstantFill_1:0" + device: "/gpu:0" +} +node { + name: "conv1/conv1_b" + op: "Blob" + input: "conv1/ConstantFill:0" + device: "/gpu:0" +} +node { + name: "classifier/fc3_w" + op: "Blob" + input: "classifier/XavierFill:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/fc3_b_grad" + op: "Blob" + input: "GRADIENTS/c/FCGradient:1" + device: "/gpu:0" +} +node { + name: "classifier/pred_w" + op: "Blob" + input: "classifier/XavierFill_1:0" + device: "/gpu:0" +} +node { + name: "conv1/pool2" + op: "Blob" + input: "conv1/MaxPool_1:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/conv1/conv2_b_grad" + op: "Blob" + input: "GRADIENTS/conv1/ConvGradient:1" + device: "/gpu:0" +} +node { + name: "classifier/fc3_1" + op: "Blob" + input: "classifier/Relu:0" + device: "/gpu:0" +} +node { + name: "classifier/loss" + op: "Blob" + input: "classifier/AveragedLoss:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/fc3_grad" + op: "Blob" + input: "GRADIENTS/classifier/FCGradient:2" + device: "/gpu:0" +} +node { + name: "conv1/conv1_w" + op: "Blob" + input: "conv1/XavierFill:0" + device: "/gpu:0" +} +node { + name: "conv1/conv1" + op: "Blob" + input: "conv1/Conv:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/loss_autogen_grad" + op: "Blob" + input: "GRADIENTS/classifier/ConstantFill:0" + device: "/gpu:0" +} +node { + name: "classifier/fc3" + op: "Blob" + input: "classifier/FC:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/pred_grad" + op: "Blob" + input: "GRADIENTS/classifier/SoftmaxGradient:0" + device: "/gpu:0" +} +node { + name: "GRADIENTS/classifier/softmax_grad" + op: "Blob" + input: "GRADIENTS/classifier/LabelCrossEntropyGradient:0" + device: "/gpu:0" +} +node { + name: "conv1/conv2_w" + op: "Blob" + input: "conv1/XavierFill_1:0" + device: "/gpu:0" +} +""" + + +if caffe2_installed: + class Caffe2Test(unittest.TestCase): + def test_that_operators_gets_non_colliding_names(self): + op = caffe2_pb2.OperatorDef() + op.type = 'foo' + op.input.extend(['foo']) + tb._fill_missing_operator_names([op]) + self.assertEqual(op.input[0], 'foo') + self.assertEqual(op.name, 'foo_1') + + def test_that_replacing_colons_gives_non_colliding_names(self): + # .. and update shapes + op = caffe2_pb2.OperatorDef() + op.name = 'foo:0' + op.input.extend(['foo:0', 'foo$0']) + shapes = {'foo:0': [1]} + blob_name_tracker = tb._get_blob_names([op]) + tb._replace_colons(shapes, blob_name_tracker, [op], '$') + self.assertEqual(op.input[0], 'foo$0') + self.assertEqual(op.input[1], 'foo$0_1') + # Collision but blobs and op names are handled later by + # _fill_missing_operator_names. + self.assertEqual(op.name, 'foo$0') + self.assertEqual(len(shapes), 1) + self.assertEqual(shapes['foo$0'], [1]) + self.assertEqual(len(blob_name_tracker), 2) + self.assertEqual(blob_name_tracker['foo$0'], 'foo:0') + self.assertEqual(blob_name_tracker['foo$0_1'], 'foo$0') + + def test_that_adding_gradient_scope_does_no_fancy_renaming(self): + # because it cannot create collisions + op = caffe2_pb2.OperatorDef() + op.name = 'foo_grad' + op.input.extend(['foo_grad', 'foo_grad_1']) + shapes = {'foo_grad': [1]} + blob_name_tracker = tb._get_blob_names([op]) + tb._add_gradient_scope(shapes, blob_name_tracker, [op]) + self.assertEqual(op.input[0], 'GRADIENTS/foo_grad') + self.assertEqual(op.input[1], 'GRADIENTS/foo_grad_1') + self.assertEqual(op.name, 'GRADIENTS/foo_grad') + self.assertEqual(len(shapes), 1) + self.assertEqual(shapes['GRADIENTS/foo_grad'], [1]) + self.assertEqual(len(blob_name_tracker), 2) + self.assertEqual( + blob_name_tracker['GRADIENTS/foo_grad'], 'foo_grad') + self.assertEqual( + blob_name_tracker['GRADIENTS/foo_grad_1'], 'foo_grad_1') + + def test_that_auto_ssa_gives_non_colliding_names(self): + op1 = caffe2_pb2.OperatorDef() + op1.output.extend(['foo']) + op2 = caffe2_pb2.OperatorDef() + op2.input.extend(['foo']) + op2.output.extend(['foo']) + op2.output.extend(['foo_1']) + shapes = {'foo': [1], 'foo_1': [2]} + blob_name_tracker = tb._get_blob_names([op1, op2]) + tb._convert_to_ssa(shapes, blob_name_tracker, [op1, op2]) + self.assertEqual(op1.output[0], 'foo') + self.assertEqual(op2.input[0], 'foo') + self.assertEqual(op2.output[0], 'foo_1') + # Unfortunate name but we do not parse original `_` for now. + self.assertEqual(op2.output[1], 'foo_1_1') + self.assertEqual(len(shapes), 3) + self.assertEqual(shapes['foo'], [1]) + self.assertEqual(shapes['foo_1'], [1]) + self.assertEqual(shapes['foo_1_1'], [2]) + self.assertEqual(len(blob_name_tracker), 3) + self.assertEqual(blob_name_tracker['foo'], 'foo') + self.assertEqual(blob_name_tracker['foo_1'], 'foo') + self.assertEqual(blob_name_tracker['foo_1_1'], 'foo_1') + + def test_renaming_tensorflow_style(self): + # Construct some dummy operators here + # NOTE: '_w', '_bn', etc without the postfix '_' are only renamed when + # they are at the very end of the name. + # Test that '_w', '_w_' are renamed to '/weight', '/weight_', resp. + op1 = caffe2_pb2.OperatorDef() + op1.input.extend(['foo_w']) + op1.output.extend(['foo_w_2']) + # Test that '_bn', '_bn_' are renamed to '/batchnorm', '/batchnorm_', + # respectively. + op2 = caffe2_pb2.OperatorDef() + op2.input.extend(['foo_bn']) + op2.output.extend(['foo_bn_2']) + # Test that '_b', '_b_', are renamed to '/bias', '/bias_', resp. + op3 = caffe2_pb2.OperatorDef() + op3.input.extend(['foo_b']) + op3.output.extend(['foo_b_2']) + # Test that '_s', '_s_', are renamed to '/scale', '/scale_', resp. + op4 = caffe2_pb2.OperatorDef() + op4.input.extend(['foo_s']) + op4.output.extend(['foo_s_2']) + # Test that '_sum', '_sum_', are renamed to '/sum', '/sum_', resp. + op5 = caffe2_pb2.OperatorDef() + op5.input.extend(['foo_sum']) + op5.output.extend(['foo_sum_2']) + # Test that '_branch', '_branch_', are renamed to '/branch', '/branch_', + # respectively. Multiple inputs/outputs are also tested in this case. + op6 = caffe2_pb2.OperatorDef() + op6.input.extend(['foo_branch']) + op6.input.extend(['test_branch_2']) + op6.output.extend(['foo_branch_3']) + op6.output.extend(['test_branch4']) + shapes = { + 'foo_w': [1], 'foo_w_2': [2], 'foo_bn': [3], 'foo_bn_2': [4], + 'foo_b': [5], 'foo_b_2': [6], 'foo_s': [7], 'foo_s_2': [8], + 'foo_sum': [9], 'foo_sum_2': [10], 'foo_branch': [11], + 'test_branch_2': [12], 'foo_branch_3': [13], 'test_branch4': [14], + } + ops = [op1, op2, op3, op4, op5, op6] + blob_name_tracker = tb._get_blob_names(ops) + tb._rename_tensorflow_style(shapes, blob_name_tracker, ops) + # Testing that keys in blob name tracker were renamed correctly + self.assertEqual(blob_name_tracker['foo/weight'], 'foo_w') + self.assertEqual(blob_name_tracker['foo/weight_2'], 'foo_w_2') + self.assertEqual(blob_name_tracker['foo/batchnorm'], 'foo_bn') + self.assertEqual(blob_name_tracker['foo/batchnorm_2'], 'foo_bn_2') + self.assertEqual(blob_name_tracker['foo/bias'], 'foo_b') + self.assertEqual(blob_name_tracker['foo/bias_2'], 'foo_b_2') + self.assertEqual(blob_name_tracker['foo/scale'], 'foo_s') + self.assertEqual(blob_name_tracker['foo/scale_2'], 'foo_s_2') + self.assertEqual(blob_name_tracker['foo/sum'], 'foo_sum') + self.assertEqual(blob_name_tracker['foo/sum_2'], 'foo_sum_2') + self.assertEqual(blob_name_tracker['foo/branch'], 'foo_branch') + self.assertEqual(blob_name_tracker['test/branch_2'], 'test_branch_2') + self.assertEqual(blob_name_tracker['foo/branch_3'], 'foo_branch_3') + self.assertEqual(blob_name_tracker['test/branch4'], 'test_branch4') + # Testing that keys in shapes were renamed correctly + self.assertEqual(shapes['foo/weight'], [1]) + self.assertEqual(shapes['foo/batchnorm_2'], [4]) + self.assertEqual(shapes['foo/sum'], [9]) + self.assertEqual(shapes['test/branch_2'], [12]) + # Testing that the ops were renamed correctly + self.assertEqual(op1.input[0], 'foo/weight') + self.assertEqual(op1.output[0], 'foo/weight_2') + self.assertEqual(op2.input[0], 'foo/batchnorm') + self.assertEqual(op2.output[0], 'foo/batchnorm_2') + self.assertEqual(op3.input[0], 'foo/bias') + self.assertEqual(op3.output[0], 'foo/bias_2') + self.assertEqual(op4.input[0], 'foo/scale') + self.assertEqual(op4.output[0], 'foo/scale_2') + self.assertEqual(op5.input[0], 'foo/sum') + self.assertEqual(op5.output[0], 'foo/sum_2') + self.assertEqual(op6.input[0], 'foo/branch') + self.assertEqual(op6.input[1], 'test/branch_2') + self.assertEqual(op6.output[0], 'foo/branch_3') + self.assertEqual(op6.output[1], 'test/branch4') + + def test_filter_ops(self): + op1 = caffe2_pb2.OperatorDef() + op1.input.extend(['remove_this']) + op1.output.extend(['random_output']) + op2 = caffe2_pb2.OperatorDef() + op2.input.extend(['leave_this']) + op2.output.extend(['leave_this_also']) + op3 = caffe2_pb2.OperatorDef() + op3.input.extend(['random_input']) + op3.output.extend(['remove_this_also']) + + def filter_fn(blob): + # Filter all blobs with names containing 'remove' + return 'remove' not in str(blob) + + op_set1 = [op1, op2, op3] + op_set2 = [op1, op2, op3] + + # Test case for when perform_filter = True. + result_ops1 = tb._filter_ops(op_set1, filter_fn, True) + new_op1, new_op2 = result_ops1[0], result_ops1[1] + # input named 'remove_this' should have been filtered + self.assertEqual(len(new_op1.input), 0) + self.assertEqual(new_op1.output, ['random_output']) + self.assertEqual(new_op2.input, ['leave_this']) + self.assertEqual(new_op2.output, ['leave_this_also']) + # output named 'remove_this_also' should have been filtered as well. + # This should have also removed op3 as the filter function excludes ops + # with no outputs. + self.assertEqual(len(result_ops1), 2) + + # Test case for when perform_filter = False. op_set2 should remain + # unchanged. + result_ops2 = tb._filter_ops(op_set2, filter_fn, False) + self.assertEqual(result_ops2, op_set2) + + # Use show_simplified=False. This shows the original style of graph + # visualization from caffe2.contrib.tensorboard. + # TODO: Add test for show_simplified=True. + def test_simple_cnnmodel(self): + model = cnn.CNNModelHelper("NCHW", name="overfeat") + data, label = model.ImageInput(["db"], ["data", "label"], is_test=0) + with core.NameScope("conv1"): + conv1 = model.Conv(data, "conv1", 3, 96, 11, stride=4) + relu1 = model.Relu(conv1, conv1) + pool1 = model.MaxPool(relu1, "pool1", kernel=2, stride=2) + with core.NameScope("classifier"): + fc = model.FC(pool1, "fc", 4096, 1000) + pred = model.Softmax(fc, "pred") + xent = model.LabelCrossEntropy([pred, label], "xent") + loss = model.AveragedLoss(xent, "loss") + model.net.RunAllOnGPU() + model.param_init_net.RunAllOnGPU() + model.AddGradientOperators([loss], skip=1) + blob_name_tracker = {} + graph = tb.model_to_graph_def( + model, + blob_name_tracker=blob_name_tracker, + shapes={}, + show_simplified=False, + ) + self.assertEqual( + blob_name_tracker['GRADIENTS/conv1/conv1_b_grad'], + 'conv1/conv1_b_grad', + ) + self.maxDiff = None + # We can't guarantee the order in which they appear, so we sort + # both before we compare them + sep = "node {" + expected = "\n".join(sorted( + sep + "\n " + part.strip() + for part in EXPECTED_CNN.strip().split(sep) + if part.strip() + )) + actual = "\n".join(sorted( + sep + "\n " + part.strip() + for part in str(graph).strip().split(sep) + if part.strip() + )) + self.assertMultiLineEqual(actual, expected) + + # cnn.CNNModelHelper is deprecated, so we also test with + # model_helper.ModelHelper. The model used in this test is taken from the + # Caffe2 MNIST tutorial. Also use show_simplified=False here. + def test_simple_model(self): + model = model_helper.ModelHelper(name="mnist") + data, label = brew.image_input( + model, + ["db"], + ["data", "label"], + order="NCHW", + use_gpu_transform=False, + is_test=0 + ) + with core.NameScope("conv1"): + conv1 = brew.conv(model, data, 'conv1', dim_in=1, dim_out=20, kernel=5) + # Image size: 24 x 24 -> 12 x 12 + pool1 = brew.max_pool(model, conv1, 'pool1', kernel=2, stride=2) + # Image size: 12 x 12 -> 8 x 8 + conv2 = brew.conv(model, pool1, 'conv2', dim_in=20, dim_out=100, kernel=5) + # Image size: 8 x 8 -> 4 x 4 + pool2 = brew.max_pool(model, conv2, 'pool2', kernel=2, stride=2) + with core.NameScope("classifier"): + # 50 * 4 * 4 stands for dim_out from previous layer multiplied by the image size + fc3 = brew.fc(model, pool2, 'fc3', dim_in=100 * 4 * 4, dim_out=500) + relu = brew.relu(model, fc3, fc3) + pred = brew.fc(model, relu, 'pred', 500, 10) + softmax = brew.softmax(model, pred, 'softmax') + xent = model.LabelCrossEntropy([softmax, label], 'xent') + # compute the expected loss + loss = model.AveragedLoss(xent, "loss") + model.net.RunAllOnGPU() + model.param_init_net.RunAllOnGPU() + model.AddGradientOperators([loss], skip=1) + blob_name_tracker = {} + graph = tb.model_to_graph_def( + model, + blob_name_tracker=blob_name_tracker, + shapes={}, + show_simplified=False, + ) + self.assertEqual( + blob_name_tracker['GRADIENTS/conv1/conv1_b_grad'], + 'conv1/conv1_b_grad', + ) + self.maxDiff = None + # We can't guarantee the order in which they appear, so we sort + # both before we compare them + sep = "node {" + expected = "\n".join(sorted( + sep + "\n " + part.strip() + for part in EXPECTED_MNIST.strip().split(sep) + if part.strip() + )) + actual = "\n".join(sorted( + sep + "\n " + part.strip() + for part in str(graph).strip().split(sep) + if part.strip() + )) + self.assertMultiLineEqual(actual, expected) + + +if __name__ == "__main__": + unittest.main()