forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_fx_param_shape_control_flow.py
155 lines (109 loc) · 4.88 KB
/
test_fx_param_shape_control_flow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# Owner(s): ["module: fx"]
import unittest
import torch
import torch.fx
from torch.testing._internal.common_utils import TestCase
class MyModuleBase(torch.nn.Module):
def forward(self, x):
matrx = self.get_mul_matrix()
if self.no_relu():
return torch.mm(x, matrx)
else:
return torch.relu(torch.mm(x, matrx))
def get_mul_matrix(self):
return self.param
def no_relu(self):
raise Exception("not implemented")
class MyModuleParamShape(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.shape[0] < 10
class MyModuleParamSize(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.size()[0] < 10
class MyModuleParamDim(MyModuleBase):
def __init__(self, param):
super().__init__()
self.param = param
def get_mul_matrix(self):
return self.param[0] if (self.param.dim() == 3) else self.param
def no_relu(self):
return self.param.dim() == 3
class MyModuleParamNDim(MyModuleBase):
def __init__(self, param):
super().__init__()
self.param = param
def get_mul_matrix(self):
return self.param[0] if (self.param.ndim == 3) else self.param
def no_relu(self):
return self.param.ndim == 3
class MyModuleParamNumEl(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.numel() < 10 * 3
class MyModuleParamNElement(MyModuleBase):
def __init__(self, in_channels):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(in_channels, 3))
def no_relu(self):
return self.param.nelement() < 10 * 3
class TestConstParamShapeInControlFlow(TestCase):
def verify_mm_relu_mods(self, mm_only_mod, relu_mod):
"""
Verify one module only does a mm op while the other
performs both mm and relu ops in cascade
"""
x = torch.randn(10, 5)
torch.testing.assert_allclose(mm_only_mod(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
tracer = torch.fx.Tracer(param_shapes_constant=True)
traced_graph = tracer.trace(mm_only_mod)
# verify the graph module calculates the same result
graph_mod_mm = torch.fx.GraphModule(mm_only_mod, traced_graph)
torch.testing.assert_allclose(graph_mod_mm(x), torch.mm(x, mm_only_mod.get_mul_matrix()))
# Make a new module with different parameter shape to go down the different
# code path
x = torch.randn(10, 15)
torch.testing.assert_allclose(relu_mod(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
tracer2 = torch.fx.Tracer(param_shapes_constant=True)
traced_graph2 = tracer2.trace(relu_mod)
# verify the graph module calculates the same result
graph_mod_relu = torch.fx.GraphModule(relu_mod, traced_graph2)
torch.testing.assert_allclose(graph_mod_relu(x), torch.relu(torch.mm(x, relu_mod.get_mul_matrix())))
graph1_node_targets = [n.target for n in traced_graph.nodes]
graph2_node_targets = [n.target for n in traced_graph2.nodes]
# the second graph has an exta relu function call node
assert torch.mm in graph1_node_targets and torch.mm in graph2_node_targets
assert torch.relu not in graph1_node_targets and torch.relu in graph2_node_targets
def test_param_shape_const(self):
mymod = MyModuleParamShape(in_channels=5)
mymod2 = MyModuleParamShape(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_size_const(self):
mymod = MyModuleParamSize(in_channels=5)
mymod2 = MyModuleParamSize(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_dim_const(self):
mymod = MyModuleParamDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
mymod2 = MyModuleParamDim(torch.nn.Parameter(torch.randn(15, 3)))
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_ndim_const(self):
mymod = MyModuleParamNDim(torch.nn.Parameter(torch.randn(2, 5, 3)))
mymod2 = MyModuleParamNDim(torch.nn.Parameter(torch.randn(15, 3)))
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_numel_const(self):
mymod = MyModuleParamNumEl(in_channels=5)
mymod2 = MyModuleParamNumEl(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
def test_param_nelement_const(self):
mymod = MyModuleParamNElement(in_channels=5)
mymod2 = MyModuleParamNElement(in_channels=15)
self.verify_mm_relu_mods(mymod, mymod2)
if __name__ == '__main__':
unittest.main()