Skip to content

Commit

Permalink
Added data_format to flatten layer. (keras-team#9696)
Browse files Browse the repository at this point in the history
* Added data_format to flatten

* Added flatten tests

* Fixed Tests

* Added more dimension tests

* Reverted TF backend change

* Reverted

* Fixed CI Problems

* Altered to K.ndim for compatability

* Updated CNTK backend

* Updated to match comments

* Updated Docs
  • Loading branch information
joeyearsley authored and fchollet committed Apr 1, 2018
1 parent 1ee31ee commit aedad39
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 5 deletions.
5 changes: 4 additions & 1 deletion keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,10 @@ def reshape(x, shape):
def permute_dimensions(x, pattern):
dims = len(int_shape(x))
num_dynamic_axis = _get_dynamic_axis_num(x)
current_layout = tuple([i for i in range(dims)])
if isinstance(pattern, list):
current_layout = [i for i in range(dims)]
else:
current_layout = tuple([i for i in range(dims)])

if num_dynamic_axis > 0 and pattern[:num_dynamic_axis] != current_layout[:num_dynamic_axis]:
raise ValueError('CNTK backend: the permute pattern %s '
Expand Down
24 changes: 23 additions & 1 deletion keras/layers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from ..utils.generic_utils import func_load
from ..utils.generic_utils import deserialize_keras_object
from ..utils.generic_utils import has_arg
from ..utils import conv_utils
from ..legacy import interfaces


Expand Down Expand Up @@ -465,6 +466,13 @@ def get_config(self):
class Flatten(Layer):
"""Flattens the input. Does not affect the batch size.
# Arguments
data_format: A string, one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
`(batch, ..., channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, ...)`.
# Example
```python
Expand All @@ -479,9 +487,10 @@ class Flatten(Layer):
```
"""

def __init__(self, **kwargs):
def __init__(self, data_format='channels_last', **kwargs):
super(Flatten, self).__init__(**kwargs)
self.input_spec = InputSpec(min_ndim=3)
self.data_format = conv_utils.normalize_data_format(data_format)

def compute_output_shape(self, input_shape):
if not all(input_shape[1:]):
Expand All @@ -494,8 +503,21 @@ def compute_output_shape(self, input_shape):
return (input_shape[0], np.prod(input_shape[1:]))

def call(self, inputs):
if self.data_format == 'channels_first':
# Ensure works for any dim
permutation = [0]
permutation.extend([i for i in
range(2, K.ndim(inputs))])
permutation.append(1)
inputs = K.permute_dimensions(inputs, permutation)

return K.batch_flatten(inputs)

def get_config(self):
config = {'data_format': self.data_format}
base_config = super(Flatten, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class RepeatVector(Layer):
"""Repeats the input n times.
Expand Down
60 changes: 57 additions & 3 deletions tests/keras/layers/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,63 @@ def test_permute():

@keras_test
def test_flatten():
layer_test(layers.Flatten,
kwargs={},
input_shape=(3, 2, 4))

def test_4d():
np_inp_channels_last = np.arange(24, dtype='float32').reshape(
(1, 4, 3, 2))

np_output_cl = layer_test(layers.Flatten,
kwargs={'data_format':
'channels_last'},
input_data=np_inp_channels_last)

np_inp_channels_first = np.transpose(np_inp_channels_last,
[0, 3, 1, 2])

np_output_cf = layer_test(layers.Flatten,
kwargs={'data_format':
'channels_first'},
input_data=np_inp_channels_first,
expected_output=np_output_cl)

def test_3d():
np_inp_channels_last = np.arange(12, dtype='float32').reshape(
(1, 4, 3))

np_output_cl = layer_test(layers.Flatten,
kwargs={'data_format':
'channels_last'},
input_data=np_inp_channels_last)

np_inp_channels_first = np.transpose(np_inp_channels_last,
[0, 2, 1])

np_output_cf = layer_test(layers.Flatten,
kwargs={'data_format':
'channels_first'},
input_data=np_inp_channels_first,
expected_output=np_output_cl)

def test_5d():
np_inp_channels_last = np.arange(120, dtype='float32').reshape(
(1, 5, 4, 3, 2))

np_output_cl = layer_test(layers.Flatten,
kwargs={'data_format':
'channels_last'},
input_data=np_inp_channels_last)

np_inp_channels_first = np.transpose(np_inp_channels_last,
[0, 4, 1, 2, 3])

np_output_cf = layer_test(layers.Flatten,
kwargs={'data_format':
'channels_first'},
input_data=np_inp_channels_first,
expected_output=np_output_cl)
test_3d()
test_4d()
test_5d()


@keras_test
Expand Down

0 comments on commit aedad39

Please sign in to comment.