Skip to content

Commit

Permalink
Use NCW/NWC for conv1d data format in TF backend.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed May 7, 2018
1 parent 895dba6 commit 9c68977
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3325,12 +3325,12 @@ def _preprocess_conv1d_input(x, data_format):
"""
if dtype(x) == 'float64':
x = tf.cast(x, 'float32')
tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
tf_data_format = 'NWC' # to pass TF Conv2dNative operations
if data_format == 'channels_first':
if not _has_nchw_support():
x = tf.transpose(x, (0, 2, 1)) # NCW -> NWC
else:
tf_data_format = 'NCHW'
tf_data_format = 'NCW'
return x, tf_data_format


Expand Down Expand Up @@ -3441,7 +3441,7 @@ def conv1d(x, kernel, strides=1, padding='valid',
padding=padding,
data_format=tf_data_format)

if data_format == 'channels_first' and tf_data_format in {'NWC', 'NHWC'}:
if data_format == 'channels_first' and tf_data_format == 'NWC':
x = tf.transpose(x, (0, 2, 1)) # NWC -> NCW
return x

Expand Down Expand Up @@ -3571,6 +3571,10 @@ def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
dilation_rate = (dilation_rate,)

x, tf_data_format = _preprocess_conv1d_input(x, data_format)
if tf_data_format == 'NWC':
tf_data_format = 'NHWC'
else:
tf_data_format = 'NCHW'
padding = _preprocess_padding(padding)
if tf_data_format == 'NHWC':
spatial_start_dim = 1
Expand Down

0 comments on commit 9c68977

Please sign in to comment.