-
Notifications
You must be signed in to change notification settings - Fork 94
/
Copy pathgenerated_op.py
94 lines (87 loc) · 3.16 KB
/
generated_op.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import logging
import torch
from torch import nn
import numpy as np
class Register:
def __init__(self, registry_name):
self._dict = {}
self._name = registry_name
def __setitem__(self, key, value):
if not callable(value):
raise Exception("Value of a Registry must be a callable")
if key is None:
key = value.__name__
if key in self._dict:
logging.warning("Key %s already in registry %s." % (key, self._name))
self._dict[key] = value
def register(self, key_name):
"""Decorator to register a function or class."""
def add(key, value):
self[key] = value
return value
# @reg.register('alias')
return lambda func: add(key_name, func)
def __getitem__(self, key):
return self._dict[key]
def __contains__(self, key):
return key in self._dict
def keys(self):
"""key"""
return self._dict.keys()
op_reg = Register("op_register")
@op_reg.register("Conv_Add_fused")
def run_Conv_Add_fused(node):
inp_0 = node.input[0]
inp_1 = node.input[1]
inp_2 = node.input[2]
inp_0_tensor = torch.tensor(np.array(inp_0.value,dtype=np.float32).reshape(inp_0.dims))
inp_1_tensor = torch.tensor(np.array(inp_1.value,dtype=np.float32).reshape(inp_1.dims))
inp_2_tensor = torch.tensor(np.array(inp_2.value,dtype=np.float32).reshape(inp_2.dims))
param_0 = node.attr
conv_0 = nn.Conv2d(param_0.c_in, param_0.c_out, param_0.ksize, param_0.stride, param_0.pad,1,1,True)
conv_0.weight.data = inp_1_tensor
conv_0.bias.data = inp_2_tensor
tmp_0 = conv_0(inp_0_tensor)
out = tmp_0.detach().numpy()
out_0 = node.output[0]
out_0.value=out
if out_0.reshaped==0:
out_0.dims=out.shape
@op_reg.register("Relu")
def run_Relu(node):
inp_0 = node.input[0]
inp_0_tensor = torch.tensor(np.array(inp_0.value,dtype=np.float32).reshape(inp_0.dims))
relu_0 = nn.ReLU()
tmp_0 = relu_0(inp_0_tensor)
out = np.array(tmp_0)
out_0 = node.output[0]
out_0.value=out
if out_0.reshaped==0:
out_0.dims=out.shape
@op_reg.register("MaxPool")
def run_MaxPool(node):
inp_0 = node.input[0]
inp_0_tensor = torch.tensor(np.array(inp_0.value,dtype=np.float32).reshape(inp_0.dims))
param_0 = node.attr
maxpool_0 = nn.MaxPool2d(param_0.ksize, param_0.stride, param_0.pad)
tmp_0 = maxpool_0(inp_0_tensor)
out = np.array(tmp_0)
out_0 = node.output[0]
out_0.value=out
if out_0.reshaped==0:
out_0.dims=out.shape
@op_reg.register("MatMul_Add_fused")
def run_MatMul_Add_fused(node):
inp_0 = node.input[0]
inp_1 = node.input[1]
inp_2 = node.input[2]
inp_0_tensor = torch.tensor(np.array(inp_0.value,dtype=np.float32).reshape(inp_0.dims))
inp_1_tensor = torch.tensor(np.array(inp_1.value,dtype=np.float32).reshape(inp_1.dims))
inp_2_tensor = torch.tensor(np.array(inp_2.value,dtype=np.float32).reshape(inp_2.dims))
tmp_0 = torch.matmul(inp_0_tensor,inp_1_tensor)
tmp_1 = torch.add(tmp_0,inp_2_tensor)
out = np.array(tmp_1)
out_0 = node.output[0]
out_0.value=out
if out_0.reshaped==0:
out_0.dims=out.shape