Skip to content

Commit

Permalink
[Bugfix][Relay][Keras] Fix SeparableConv2D conversion in dilation_rat…
Browse files Browse the repository at this point in the history
…e attribute (apache#15122)

* Update keras.py

fix the _convert_separable_convolution. The dilation_rate always be the default value (e.g., [1,1]).

* add new test cases to capture bug in separableConv2d

* Update test_forward.py
  • Loading branch information
jikechao authored Jun 21, 2023
1 parent 31be726 commit 54b9741
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 5 additions & 1 deletion python/tvm/relay/frontend/keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,13 +569,17 @@ def _convert_separable_convolution(inexpr, keras_layer, etab, data_layout, input
weight0 = weightList[0].transpose([2, 3, 0, 1])
else:
weight0 = weightList[0]
if isinstance(keras_layer.dilation_rate, (list, tuple)):
dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
else:
dilation = [keras_layer.dilation_rate, keras_layer.dilation_rate]
params0 = {
"weight": etab.new_const(weight0),
"channels": in_channels * depth_mult,
"groups": in_channels,
"kernel_size": [kernel_h, kernel_w],
"strides": [stride_h, stride_w],
"dilation": [1, 1],
"dilation": dilation,
"padding": [0, 0],
"data_layout": data_layout,
"kernel_layout": kernel_layout,
Expand Down
2 changes: 2 additions & 0 deletions tests/python/frontend/keras/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ def test_forward_conv(self, keras_mod):
keras_mod.layers.DepthwiseConv2D(kernel_size=(3, 3), padding="same"),
keras_mod.layers.Conv2DTranspose(filters=10, kernel_size=(3, 3), padding="valid"),
keras_mod.layers.SeparableConv2D(filters=10, kernel_size=(3, 3), padding="same"),
keras_mod.layers.SeparableConv2D(filters=10, kernel_size=(3, 3), dilation_rate=(2, 2)),
keras_mod.layers.SeparableConv2D(filters=2, kernel_size=(3, 3), dilation_rate=2),
]
for conv_func in conv_funcs:
x = conv_func(data)
Expand Down

0 comments on commit 54b9741

Please sign in to comment.