|
23 | 23 | from torch.quantization.quantize_script import prepare_dynamic_script
|
24 | 24 | from torch.quantization.quantize_script import convert_dynamic_script
|
25 | 25 | from torch.quantization.quantize_script import quantize_dynamic_script
|
| 26 | +from torch.quantization.quantize_script import fuse_conv_bn_script |
26 | 27 |
|
27 | 28 | # Testing utils
|
28 | 29 | from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn
|
@@ -86,7 +87,7 @@ def forward(self, x):
|
86 | 87 | .run(str(get_forward(scripted_or_traced._c).graph))
|
87 | 88 |
|
88 | 89 | # Run FoldConvBatchnorm2d pass.
|
89 |
| - scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c)) |
| 90 | + scripted_or_traced = fuse_conv_bn_script(scripted_or_traced) |
90 | 91 |
|
91 | 92 | # Check that after the pass one of the CallMethods is gone (supposedly,
|
92 | 93 | # the bn.forward).
|
@@ -130,7 +131,7 @@ def forward(self, x):
|
130 | 131 | .run(str(get_forward_graph(scripted_or_traced._c)))
|
131 | 132 |
|
132 | 133 | # Run FoldConvBatchnorm2d pass.
|
133 |
| - scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c)) |
| 134 | + scripted_or_traced = fuse_conv_bn_script(scripted_or_traced) |
134 | 135 |
|
135 | 136 | # Check that after the pass one of the CallMethods is gone (supposedly,
|
136 | 137 | # the bn.forward).
|
@@ -176,7 +177,7 @@ def forward(self, x):
|
176 | 177 | FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
177 | 178 | .run(str(get_forward_graph(scripted_or_traced.sub._c)))
|
178 | 179 |
|
179 |
| - scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c)) |
| 180 | + scripted_or_traced = fuse_conv_bn_script(scripted_or_traced) |
180 | 181 |
|
181 | 182 | FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
|
182 | 183 | .run(str(get_forward_graph(scripted_or_traced.sub._c)))
|
@@ -227,7 +228,7 @@ def forward(self, x):
|
227 | 228 | FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
228 | 229 | .run(str(get_forward_graph(scripted_or_traced.sub._c)))
|
229 | 230 |
|
230 |
| - scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c)) |
| 231 | + scripted_or_traced = fuse_conv_bn_script(scripted_or_traced) |
231 | 232 |
|
232 | 233 | FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
|
233 | 234 | .run(str(get_forward_graph(scripted_or_traced.sub._c)))
|
@@ -268,7 +269,7 @@ def forward(self, x):
|
268 | 269 | else:
|
269 | 270 | scripted_or_traced = torch.jit.script(eager).copy()
|
270 | 271 | torch._C._jit_pass_dedup_module_uses(scripted_or_traced ._c)
|
271 |
| - folded = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced ._c)) |
| 272 | + folded = fuse_conv_bn_script(scripted_or_traced) |
272 | 273 | x = torch.rand(1, 5, 6, 6)
|
273 | 274 | self.assertEqual(eager(x), scripted_or_traced(x))
|
274 | 275 |
|
@@ -321,7 +322,7 @@ def forward(self, x):
|
321 | 322 | FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers * 2, exactly=True) \
|
322 | 323 | .run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
|
323 | 324 |
|
324 |
| - scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c)) |
| 325 | + scripted_or_traced = fuse_conv_bn_script(scripted_or_traced) |
325 | 326 |
|
326 | 327 | FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers, exactly=True) \
|
327 | 328 | .run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
|
|
0 commit comments