Skip to content

Commit

Permalink
Make separable conv backend tests efficient (keras-team#9570)
Browse files Browse the repository at this point in the history
  • Loading branch information
taehoonlee authored and fchollet committed Mar 6, 2018
1 parent 614a8b4 commit 62c395e
Showing 1 changed file with 29 additions and 21 deletions.
50 changes: 29 additions & 21 deletions tests/keras/backend/backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,11 @@ def ref_depthwise_conv(x, w, padding, data_format):
return y


def ref_separable_conv(x, w1, w2, padding, data_format):
x2 = ref_depthwise_conv(x, w1, padding, data_format)
return ref_conv(x2, w2, padding, data_format)


def ref_rnn(x, w, init, go_backwards=False, mask=None, unroll=False, input_length=None):
w_i, w_h, w_o = w
h = []
Expand Down Expand Up @@ -1086,27 +1091,30 @@ def legacy_test_conv3d(self):
BACKENDS, cntk_dynamicity=True,
data_format=data_format)

def test_separable_conv2d(self):
for (input_shape, data_format) in [
((2, 3, 4, 5), 'channels_first'),
((2, 3, 5, 6), 'channels_first'),
((1, 6, 5, 3), 'channels_last')]:
input_depth = input_shape[1] if data_format == 'channels_first' else input_shape[-1]
_, x_val = parse_shape_or_val(input_shape)
x_tf = KTF.variable(x_val)
for kernel_shape in [(2, 2), (4, 3)]:
for depth_multiplier in [1, 2]:
_, depthwise_val = parse_shape_or_val(kernel_shape + (input_depth, depth_multiplier))
_, pointwise_val = parse_shape_or_val((1, 1) + (input_depth * depth_multiplier, 7))

z_tf = KTF.eval(KTF.separable_conv2d(x_tf, KTF.variable(depthwise_val),
KTF.variable(pointwise_val),
data_format=data_format))
z_c = cntk_func_three_tensor('separable_conv2d', input_shape,
depthwise_val,
pointwise_val,
data_format=data_format)([x_val])[0]
assert_allclose(z_tf, z_c, 1e-3)
@pytest.mark.skipif(K.backend() == 'theano', reason='Not supported.')
@pytest.mark.parametrize('op,input_shape,kernel_shape,depth_multiplier,padding,data_format', [
('separable_conv2d', (2, 3, 4, 5), (3, 3), 1, 'same', 'channels_first'),
('separable_conv2d', (2, 3, 5, 6), (4, 3), 2, 'valid', 'channels_first'),
('separable_conv2d', (1, 6, 5, 3), (3, 4), 1, 'valid', 'channels_last'),
('separable_conv2d', (1, 7, 6, 3), (3, 3), 2, 'same', 'channels_last'),
])
def test_separable_conv2d(self, op, input_shape, kernel_shape, depth_multiplier, padding, data_format):
input_depth = input_shape[1] if data_format == 'channels_first' else input_shape[-1]
_, x = parse_shape_or_val(input_shape)
_, depthwise = parse_shape_or_val(kernel_shape + (input_depth, depth_multiplier))
_, pointwise = parse_shape_or_val((1, 1) + (input_depth * depth_multiplier, 7))
y1 = ref_separable_conv(x, depthwise, pointwise, padding, data_format)
if K.backend() == 'cntk':
y2 = cntk_func_three_tensor(
op, input_shape,
depthwise, pointwise,
padding=padding, data_format=data_format)([x])[0]
else:
y2 = K.eval(getattr(K, op)(
K.variable(x),
K.variable(depthwise), K.variable(pointwise),
padding=padding, data_format=data_format))
assert_allclose(y1, y2, atol=1e-05)

def test_pool2d(self):
check_single_tensor_operation('pool2d', (5, 10, 12, 3),
Expand Down

0 comments on commit 62c395e

Please sign in to comment.