Skip to content

Commit

Permalink
Merge pull request lanpa#271 from orionr/fix-multiple-outputs
Browse files Browse the repository at this point in the history
Fix graph.py to handle multiple output nodes
  • Loading branch information
lanpa authored Nov 9, 2018
2 parents d56cc06 + 6fe9fb3 commit 4202508
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 20 deletions.
12 changes: 12 additions & 0 deletions examples/demo_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ def forward(self, x):



class MultipleOutput(nn.Module):
def __init__(self):
super(MultipleOutput, self).__init__()
self.Linear_1 = nn.Linear(3, 5)
self.Linear_2 = nn.Linear(3, 7)

def forward(self, x):
return self.Linear_1(x), self.Linear_2(x)

with SummaryWriter(comment='MultipleOutput') as w:
w.add_graph(MultipleOutput(), dummy_input, True)



class SimpleModel(nn.Module):
Expand Down
42 changes: 22 additions & 20 deletions tensorboardX/pytorch_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,23 @@ def parse(graph):
scope = {}
for n in graph.nodes():
if n.kind() == 'prim::Undefined':
scope[next(iter(n.outputs())).uniqueName()] = 'Undefined'
for outputnode in iter(n.outputs()):
scope[outputnode.uniqueName()] = 'Undefined'
continue
inputs = [i.uniqueName() for i in n.inputs()]
for i in range(0, len(inputs)):
if inputs[i] not in scope.keys():
scope[inputs[i]] = n.scopeName()

uname = next(iter(n.outputs())).uniqueName()
if n.scopeName() == '':
scopename = n.scopeName()
if not scopename:
print('{} has empty scope name. FIXME!'.format(n))
scope[uname] = 'unknownScope'
else:
scope[uname] = n.scopeName()
scopename = 'unknownScope'

for outputnode in iter(n.outputs()):
uname = outputnode.uniqueName()
scope[uname] = scopename

if LooseVersion(torch.__version__) >= LooseVersion("0.4"):
scope['0'] = 'input'
else:
Expand All @@ -51,20 +55,18 @@ def parse(graph):
"Error getting attributes of node {}, error is {}".format(attrs, e))
# singlequote will be escaped by tensorboard
attrs = attrs.replace("'", ' ')
inputs = [i.uniqueName() for i in n.inputs()]
# FIXME: only first output is considered (only Dropout)
outputnode = next(iter(n.outputs()))
uname = outputnode.uniqueName()
if outputnode.type().kind() == 'TensorType':
outputsize = outputnode.type().sizes()
nodes.append({'name': uname,
'op': n.kind(),
'inputs': inputs,
'attr': attrs,
'outputsize': outputsize})
else:
nodes.append({'name': uname, 'op': n.kind(),
'inputs': inputs, 'attr': attrs})
for outputnode in iter(n.outputs()):
inputs = [i.uniqueName() for i in n.inputs()]
uname = outputnode.uniqueName()
if outputnode.type().kind() == 'TensorType':
outputsize = outputnode.type().sizes()
nodes.append({'name': uname,
'op': n.kind(),
'inputs': inputs,
'attr': attrs,
'outputsize': outputsize})
else:
nodes.append({'name': uname, 'op': n.kind(), 'inputs': inputs, 'attr': attrs})

for n in graph.inputs():
uname = n.uniqueName()
Expand Down

0 comments on commit 4202508

Please sign in to comment.