Skip to content

Commit

Permalink
rename onnx functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
lanpa committed Dec 14, 2018
1 parent 4f3e138 commit 4e7bb73
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 78 deletions.
1 change: 1 addition & 0 deletions examples/demo_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@

with SummaryWriter() as w:
w.add_onnx_graph('examples/mnist/model.onnx')
# w.add_onnx_graph('/Users/dexter/Downloads/resnet50/model.onnx')
88 changes: 10 additions & 78 deletions tensorboardX/onnx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,25 @@
from .proto.tensor_shape_pb2 import TensorShapeProto
# from .proto.onnx_pb2 import ModelProto


def gg(fname):
import onnx # 0.2.1
m = onnx.load(fname)
g = m.graph
return parse(g)

def parse(graph):
nodes_proto = []
nodes = []
g = m.graph
import itertools
for node in itertools.chain(g.input, g.output):
for node in itertools.chain(graph.input, graph.output):
nodes_proto.append(node)

for node in nodes_proto:
print(node.name)
shapeproto = TensorShapeProto(
dim=[TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim])
nodes.append(NodeDef(
name=node.name,
name=node.name.encode(encoding='utf_8'),
op='Variable',
input=[],
attr={
Expand All @@ -29,14 +32,14 @@ def gg(fname):
})
)

for node in g.node:
for node in graph.node:
attr = []
for s in node.attribute:
attr.append(' = '.join([str(f[1]) for f in s.ListFields()]))
attr = ', '.join(attr).encode(encoding='utf_8')

print(node.output[0])
nodes.append(NodeDef(
name=node.output[0],
name=node.output[0].encode(encoding='utf_8'),
op=node.op_type,
input=node.input,
attr={'parameters': AttrValue(s=attr)},
Expand All @@ -46,10 +49,6 @@ def gg(fname):
for node in nodes:
mapping[node.name] = node.op + '_' + node.name

nodes, mapping = updatenodes(nodes, mapping)
mapping = smartGrouping(nodes, mapping)
nodes, mapping = updatenodes(nodes, mapping)

return GraphDef(node=nodes, versions=VersionDef(producer=22))


Expand Down Expand Up @@ -89,70 +88,3 @@ def parser(s, nodes, node):
parser(s[1], nodes, findnode(nodes, n))
else:
return False


# TODO: use recursive parse

def smartGrouping(nodes, mapping):
# a Fully Conv is: (TODO: check var1.size(0)==var2.size(0))
# GEMM <-- Variable (c1)
# ^-- Transpose (c2) <-- Variable (c3)

# a Conv with bias is: (TODO: check var1.size(0)==var2.size(0))
# Add <-- Conv (c2) <-- Variable (c3)
# ^-- Variable (c1)
#
# gemm = ('Gemm', ('Variable', ('Transpose', ('Variable'))))

FCcounter = 1
Convcounter = 1
for node in nodes:
if node.op == 'Gemm':
c1 = c2 = c3 = False
for name_in in node.input:
n = findnode(nodes, name_in)
if n.op == 'Variable':
c1 = True
c1name = n.name
if n.op == 'Transpose':
c2 = True
c2name = n.name
if len(n.input) == 1:
nn = findnode(nodes, n.input[0])
if nn.op == 'Variable':
c3 = True
c3name = nn.name
# print(n.op, n.name, c1, c2, c3)
if c1 and c2 and c3:
# print(c1name, c2name, c3name)
mapping[c1name] = 'FC{}/{}'.format(FCcounter, c1name)
mapping[c2name] = 'FC{}/{}'.format(FCcounter, c2name)
mapping[c3name] = 'FC{}/{}'.format(FCcounter, c3name)
mapping[node.name] = 'FC{}/{}'.format(FCcounter, node.name)
FCcounter += 1
continue
if node.op == 'Add':
c1 = c2 = c3 = False
for name_in in node.input:
n = findnode(nodes, name_in)
if n.op == 'Variable':
c1 = True
c1name = n.name
if n.op == 'Conv':
c2 = True
c2name = n.name
if len(n.input) >= 1:
for nn_name in n.input:
nn = findnode(nodes, nn_name)
if nn.op == 'Variable':
c3 = True
c3name = nn.name

if c1 and c2 and c3:
# print(c1name, c2name, c3name)
mapping[c1name] = 'Conv{}/{}'.format(Convcounter, c1name)
mapping[c2name] = 'Conv{}/{}'.format(Convcounter, c2name)
mapping[c3name] = 'Conv{}/{}'.format(Convcounter, c3name)
mapping[node.name] = 'Conv{}/{}'.format(Convcounter, node.name)
Convcounter += 1
return mapping

0 comments on commit 4e7bb73

Please sign in to comment.