Skip to content

Commit

Permalink
Merge pull request BVLC#4408 from cdoersch/draw_net_phase
Browse files Browse the repository at this point in the history
Add phase support for draw net
  • Loading branch information
shelhamer authored Jul 11, 2016
2 parents 3e94c0e + f0b1a9e commit 776b301
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
32 changes: 27 additions & 5 deletions python/caffe/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def choose_color_by_layertype(layertype):
return color


def get_pydot_graph(caffe_net, rankdir, label_edges=True):
def get_pydot_graph(caffe_net, rankdir, label_edges=True, phase=None):
"""Create a data structure which represents the `caffe_net`.
Parameters
Expand All @@ -137,6 +137,9 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
Direction of graph layout.
label_edges : boolean, optional
Label the edges (default is True).
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
Returns
-------
Expand All @@ -148,6 +151,19 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
pydot_nodes = {}
pydot_edges = []
for layer in caffe_net.layer:
if phase is not None:
included = False
if len(layer.include) == 0:
included = True
if len(layer.include) > 0 and len(layer.exclude) > 0:
raise ValueError('layer ' + layer.name + ' has both include '
'and exclude specified.')
for layer_phase in layer.include:
included = included or layer_phase.phase == phase
for layer_phase in layer.exclude:
included = included and not layer_phase.phase == phase
if not included:
continue
node_label = get_layer_label(layer, rankdir)
node_name = "%s_%s" % (layer.name, layer.type)
if (len(layer.bottom) == 1 and len(layer.top) == 1 and
Expand Down Expand Up @@ -186,7 +202,7 @@ def get_pydot_graph(caffe_net, rankdir, label_edges=True):
return pydot_graph


def draw_net(caffe_net, rankdir, ext='png'):
def draw_net(caffe_net, rankdir, ext='png', phase=None):
"""Draws a caffe net and returns the image string encoded using the given
extension.
Expand All @@ -195,16 +211,19 @@ def draw_net(caffe_net, rankdir, ext='png'):
caffe_net : a caffe.proto.caffe_pb2.NetParameter protocol buffer.
ext : string, optional
The image extension (the default is 'png').
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
Returns
-------
string :
Postscript representation of the graph.
"""
return get_pydot_graph(caffe_net, rankdir).create(format=ext)
return get_pydot_graph(caffe_net, rankdir, phase=phase).create(format=ext)


def draw_net_to_file(caffe_net, filename, rankdir='LR'):
def draw_net_to_file(caffe_net, filename, rankdir='LR', phase=None):
"""Draws a caffe net, and saves it to file using the format given as the
file extension. Use '.raw' to output raw text that you can manually feed
to graphviz to draw graphs.
Expand All @@ -216,7 +235,10 @@ def draw_net_to_file(caffe_net, filename, rankdir='LR'):
The path to a file where the networks visualization will be stored.
rankdir : {'LR', 'TB', 'BT'}
Direction of graph layout.
phase : {caffe_pb2.Phase.TRAIN, caffe_pb2.Phase.TEST, None} optional
Include layers from this network phase. If None, include all layers.
(the default is None)
"""
ext = filename[filename.rfind('.')+1:]
with open(filename, 'wb') as fid:
fid.write(draw_net(caffe_net, rankdir, ext))
fid.write(draw_net(caffe_net, rankdir, ext, phase))
15 changes: 14 additions & 1 deletion python/draw_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def parse_args():
'http://www.graphviz.org/doc/info/'
'attrs.html#k:rankdir'),
default='LR')
parser.add_argument('--phase',
help=('Which network phase to draw: can be TRAIN, '
'TEST, or ALL. If ALL, then all layers are drawn '
'regardless of phase.'),
default="ALL")

args = parser.parse_args()
return args
Expand All @@ -38,7 +43,15 @@ def main():
net = caffe_pb2.NetParameter()
text_format.Merge(open(args.input_net_proto_file).read(), net)
print('Drawing net to %s' % args.output_image_file)
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir)
phase=None;
if args.phase == "TRAIN":
phase = caffe.TRAIN
elif args.phase == "TEST":
phase = caffe.TEST
elif args.phase != "ALL":
raise ValueError("Unknown phase: " + args.phase)
caffe.draw.draw_net_to_file(net, args.output_image_file, args.rankdir,
phase)


if __name__ == '__main__':
Expand Down

0 comments on commit 776b301

Please sign in to comment.