Skip to content

Commit b0324a9

Browse files
edmundw314facebook-github-bot
authored andcommitted
_jit_pass_fold_convbn wrapped with fuse_conv_bn_script (pytorch#40224)
Summary: Pull Request resolved: pytorch#40224 Test Plan: Imported from OSS Differential Revision: D22117111 Pulled By: edmundw314 fbshipit-source-id: 9252674bd770ba6669d50090849d9f9bc13edaa3
1 parent b7bfdcb commit b0324a9

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

test/quantization/test_quantize_script.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torch.quantization.quantize_script import prepare_dynamic_script
2424
from torch.quantization.quantize_script import convert_dynamic_script
2525
from torch.quantization.quantize_script import quantize_dynamic_script
26+
from torch.quantization.quantize_script import fuse_conv_bn_script
2627

2728
# Testing utils
2829
from torch.testing._internal.common_quantization import test_only_eval_fn as _test_only_eval_fn
@@ -86,7 +87,7 @@ def forward(self, x):
8687
.run(str(get_forward(scripted_or_traced._c).graph))
8788

8889
# Run FoldConvBatchnorm2d pass.
89-
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
90+
scripted_or_traced = fuse_conv_bn_script(scripted_or_traced)
9091

9192
# Check that after the pass one of the CallMethods is gone (supposedly,
9293
# the bn.forward).
@@ -130,7 +131,7 @@ def forward(self, x):
130131
.run(str(get_forward_graph(scripted_or_traced._c)))
131132

132133
# Run FoldConvBatchnorm2d pass.
133-
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
134+
scripted_or_traced = fuse_conv_bn_script(scripted_or_traced)
134135

135136
# Check that after the pass one of the CallMethods is gone (supposedly,
136137
# the bn.forward).
@@ -176,7 +177,7 @@ def forward(self, x):
176177
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
177178
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
178179

179-
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
180+
scripted_or_traced = fuse_conv_bn_script(scripted_or_traced)
180181

181182
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 1, exactly=True) \
182183
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
@@ -227,7 +228,7 @@ def forward(self, x):
227228
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
228229
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
229230

230-
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
231+
scripted_or_traced = fuse_conv_bn_script(scripted_or_traced)
231232

232233
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", 2, exactly=True) \
233234
.run(str(get_forward_graph(scripted_or_traced.sub._c)))
@@ -268,7 +269,7 @@ def forward(self, x):
268269
else:
269270
scripted_or_traced = torch.jit.script(eager).copy()
270271
torch._C._jit_pass_dedup_module_uses(scripted_or_traced ._c)
271-
folded = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced ._c))
272+
folded = fuse_conv_bn_script(scripted_or_traced)
272273
x = torch.rand(1, 5, 6, 6)
273274
self.assertEqual(eager(x), scripted_or_traced(x))
274275

@@ -321,7 +322,7 @@ def forward(self, x):
321322
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers * 2, exactly=True) \
322323
.run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
323324

324-
scripted_or_traced = wrap_cpp_module(torch._C._jit_pass_fold_convbn(scripted_or_traced._c))
325+
scripted_or_traced = fuse_conv_bn_script(scripted_or_traced)
325326

326327
FileCheck().check_count("prim::CallMethod[name=\"forward\"]", num_layers, exactly=True) \
327328
.run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))

torch/quantization/quantize_script.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,17 @@ def script_qconfig(qconfig):
2727
def script_qconfig_dict(qconfig_dict):
2828
return {k: script_qconfig(v) if v else None for k, v in qconfig_dict.items()}
2929

30+
def fuse_conv_bn_script(model):
31+
return torch.jit._recursive.wrap_cpp_module(torch._C._jit_pass_fold_convbn(model._c))
32+
3033
def _prepare_script(model, qconfig_dict, inplace=False, quant_type=QuantType.STATIC):
3134
assert not inplace, "The inplace support is still in development"
3235
_check_is_script_module(model)
3336
_check_forward_method(model)
3437
if not all(isinstance(x, str) for x in qconfig_dict.keys()):
3538
raise ValueError('qconfig_dict should only contain names(str) as keys.')
3639
scripted_qconfig_dict = script_qconfig_dict(qconfig_dict)
37-
model = wrap_cpp_module(torch._C._jit_pass_fold_convbn(model._c))
40+
model = fuse_conv_bn_script(model)
3841
return wrap_cpp_module(torch._C._jit_pass_insert_observers(model._c,
3942
'forward',
4043
scripted_qconfig_dict,

0 commit comments

Comments
 (0)