forked from ZhangGe6/onnx-modifier
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmake_nodes.py
87 lines (77 loc) · 3.22 KB
/
make_nodes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import onnx
from onnx import AttributeProto
from .parse_tools import parse_str2val
def make_new_node(node_info):
name = node_info['properties']['name']
op_type = node_info['properties']['op_type']
# attributes = node_info['attributes']
# attributes = {k: v for k, v in node_info['attributes'].items() if not v == 'undefined'}
attributes = {}
for attr_name, attr_meta in node_info['attributes'].items():
attr_value, attr_type = attr_meta
if attr_value == 'undefined' or len(attr_value.replace(' ', '')) == 0:
continue
attributes[attr_name] = parse_str2val(attr_value, attr_type)
# print(attributes)
inputs = []
for key in node_info['inputs'].keys():
for inp in node_info['inputs'][key]:
# filter out the un-filled io in list
if not inp.startswith('list_custom'):
inputs.append(inp)
outputs = []
for key in node_info['outputs'].keys():
for out in node_info['outputs'][key]:
# filter out the un-filled io in list
if not out.startswith('list_custom'):
outputs.append(out)
# https://github.com/onnx/onnx/blob/main/onnx/helper.py#L82
node = onnx.helper.make_node(
op_type=op_type,
inputs=inputs,
outputs=outputs,
name=name,
**attributes
)
# print(node)
return node
def make_attr_changed_node(node, attr_change_info):
# convert the changed attribute value into the type that is consistent with the original attribute
# because AttributeProto is constructed barely based on the input value
# https://github.com/onnx/onnx/blob/4e24b635c940801555bee574b4eb3a34cab9acd5/onnx/helper.py#L472
def make_type_value(value, AttributeProto_type):
attr_type = AttributeProto.AttributeType.Name(AttributeProto_type)
if attr_type == "FLOAT":
return float(value)
elif attr_type == "INT":
return int(value)
elif attr_type == "STRING":
return str(value)
elif attr_type == "FLOATS":
return parse_str2val(value, "float[]")
elif attr_type == "INTS":
return parse_str2val(value, "int[]")
elif attr_type == "STRINGS":
return parse_str2val(value, "string[]")
else:
raise RuntimeError("type {} is not considered in current version. \
You can kindly report an issue for this problem. Thanks!".format(attr_type))
new_attr = dict()
for attr in node.attribute:
# print(onnx.helper.get_attribute_value(attr))
if attr.name in attr_change_info.keys():
# attr_change_info: {attr: [value, type]}
new_attr[attr.name] = make_type_value(attr_change_info[attr.name][0], attr.type)
else:
# https://github.com/onnx/onnx/blob/4e24b635c940801555bee574b4eb3a34cab9acd5/onnx/helper.py#L548
new_attr[attr.name] = onnx.helper.get_attribute_value(attr)
# print(new_attr)
node = onnx.helper.make_node(
op_type=node.op_type,
inputs=node.input,
outputs=node.output,
name=node.name,
**new_attr
)
# print(node)
return node