forked from keras-team/keras
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
249 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
import numpy as np | ||
|
||
from keras_core.api_export import keras_core_export | ||
|
||
|
||
@keras_core_export( | ||
[ | ||
"keras_core.utils.pad_sequences", | ||
"keras_core.preprocessing.sequence.pad_sequences", | ||
] | ||
) | ||
def pad_sequences( | ||
sequences, | ||
maxlen=None, | ||
dtype="int32", | ||
padding="pre", | ||
truncating="pre", | ||
value=0.0, | ||
): | ||
"""Pads sequences to the same length. | ||
This function transforms a list (of length `num_samples`) | ||
of sequences (lists of integers) | ||
into a 2D NumPy array of shape `(num_samples, num_timesteps)`. | ||
`num_timesteps` is either the `maxlen` argument if provided, | ||
or the length of the longest sequence in the list. | ||
Sequences that are shorter than `num_timesteps` | ||
are padded with `value` until they are `num_timesteps` long. | ||
Sequences longer than `num_timesteps` are truncated | ||
so that they fit the desired length. | ||
The position where padding or truncation happens is determined by | ||
the arguments `padding` and `truncating`, respectively. | ||
Pre-padding or removing values from the beginning of the sequence is the | ||
default. | ||
>>> sequence = [[1], [2, 3], [4, 5, 6]] | ||
>>> keras_core.utils.pad_sequences(sequence) | ||
array([[0, 0, 1], | ||
[0, 2, 3], | ||
[4, 5, 6]], dtype=int32) | ||
>>> keras_core.utils.pad_sequences(sequence, value=-1) | ||
array([[-1, -1, 1], | ||
[-1, 2, 3], | ||
[ 4, 5, 6]], dtype=int32) | ||
>>> keras_core.utils.pad_sequences(sequence, padding='post') | ||
array([[1, 0, 0], | ||
[2, 3, 0], | ||
[4, 5, 6]], dtype=int32) | ||
>>> keras_core.utils.pad_sequences(sequence, maxlen=2) | ||
array([[0, 1], | ||
[2, 3], | ||
[5, 6]], dtype=int32) | ||
Args: | ||
sequences: List of sequences (each sequence is a list of integers). | ||
maxlen: Optional Int, maximum length of all sequences. If not provided, | ||
sequences will be padded to the length of the longest individual | ||
sequence. | ||
dtype: (Optional, defaults to `"int32"`). Type of the output sequences. | ||
To pad sequences with variable length strings, you can use `object`. | ||
padding: String, "pre" or "post" (optional, defaults to `"pre"`): | ||
pad either before or after each sequence. | ||
truncating: String, "pre" or "post" (optional, defaults to `"pre"`): | ||
remove values from sequences larger than | ||
`maxlen`, either at the beginning or at the end of the sequences. | ||
value: Float or String, padding value. (Optional, defaults to 0.) | ||
Returns: | ||
NumPy array with shape `(len(sequences), maxlen)` | ||
""" | ||
if not hasattr(sequences, "__len__"): | ||
raise ValueError("`sequences` must be iterable.") | ||
num_samples = len(sequences) | ||
|
||
lengths = [] | ||
sample_shape = () | ||
flag = True | ||
|
||
# take the sample shape from the first non empty sequence | ||
# checking for consistency in the main loop below. | ||
|
||
for x in sequences: | ||
try: | ||
lengths.append(len(x)) | ||
if flag and len(x): | ||
sample_shape = np.asarray(x).shape[1:] | ||
flag = False | ||
except TypeError as e: | ||
raise ValueError( | ||
"`sequences` must be a list of iterables. " | ||
f"Found non-iterable: {str(x)}" | ||
) from e | ||
|
||
if maxlen is None: | ||
maxlen = np.max(lengths) | ||
|
||
is_dtype_str = np.issubdtype(dtype, np.str_) or np.issubdtype( | ||
dtype, np.unicode_ | ||
) | ||
if isinstance(value, str) and dtype != object and not is_dtype_str: | ||
raise ValueError( | ||
f"`dtype` {dtype} is not compatible with `value`'s type: " | ||
f"{type(value)}\nYou should set `dtype=object` for variable length " | ||
"strings." | ||
) | ||
|
||
x = np.full((num_samples, maxlen) + sample_shape, value, dtype=dtype) | ||
for idx, s in enumerate(sequences): | ||
if not len(s): | ||
continue # empty list/array was found | ||
if truncating == "pre": | ||
trunc = s[-maxlen:] | ||
elif truncating == "post": | ||
trunc = s[:maxlen] | ||
else: | ||
raise ValueError(f'Truncating type "{truncating}" not understood') | ||
|
||
# check `trunc` has expected shape | ||
trunc = np.asarray(trunc, dtype=dtype) | ||
if trunc.shape[1:] != sample_shape: | ||
raise ValueError( | ||
f"Shape of sample {trunc.shape[1:]} of sequence at " | ||
f"position {idx} is different from expected shape " | ||
f"{sample_shape}" | ||
) | ||
|
||
if padding == "post": | ||
x[idx, : len(trunc)] = trunc | ||
elif padding == "pre": | ||
x[idx, -len(trunc) :] = trunc | ||
else: | ||
raise ValueError(f'Padding type "{padding}" not understood') | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from keras_core import testing | ||
from keras_core.utils import sequence_utils | ||
|
||
|
||
class PadSequencesTest(testing.TestCase): | ||
def test_pad_sequences(self): | ||
a = [[1], [1, 2], [1, 2, 3]] | ||
|
||
# test padding | ||
b = sequence_utils.pad_sequences(a, maxlen=3, padding="pre") | ||
self.assertAllClose(b, [[0, 0, 1], [0, 1, 2], [1, 2, 3]]) | ||
b = sequence_utils.pad_sequences(a, maxlen=3, padding="post") | ||
self.assertAllClose(b, [[1, 0, 0], [1, 2, 0], [1, 2, 3]]) | ||
|
||
# test truncating | ||
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="pre") | ||
self.assertAllClose(b, [[0, 1], [1, 2], [2, 3]]) | ||
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="post") | ||
self.assertAllClose(b, [[0, 1], [1, 2], [1, 2]]) | ||
|
||
# test value | ||
b = sequence_utils.pad_sequences(a, maxlen=3, value=1) | ||
self.assertAllClose(b, [[1, 1, 1], [1, 1, 2], [1, 2, 3]]) | ||
|
||
def test_pad_sequences_str(self): | ||
a = [["1"], ["1", "2"], ["1", "2", "3"]] | ||
|
||
# test padding | ||
b = sequence_utils.pad_sequences( | ||
a, maxlen=3, padding="pre", value="pad", dtype=object | ||
) | ||
self.assertAllEqual( | ||
b, [["pad", "pad", "1"], ["pad", "1", "2"], ["1", "2", "3"]] | ||
) | ||
b = sequence_utils.pad_sequences( | ||
a, maxlen=3, padding="post", value="pad", dtype="<U3" | ||
) | ||
self.assertAllEqual( | ||
b, [["1", "pad", "pad"], ["1", "2", "pad"], ["1", "2", "3"]] | ||
) | ||
|
||
# test truncating | ||
b = sequence_utils.pad_sequences( | ||
a, maxlen=2, truncating="pre", value="pad", dtype=object | ||
) | ||
self.assertAllEqual(b, [["pad", "1"], ["1", "2"], ["2", "3"]]) | ||
b = sequence_utils.pad_sequences( | ||
a, maxlen=2, truncating="post", value="pad", dtype="<U3" | ||
) | ||
self.assertAllEqual(b, [["pad", "1"], ["1", "2"], ["1", "2"]]) | ||
|
||
with self.assertRaisesRegex( | ||
ValueError, "`dtype` int32 is not compatible with " | ||
): | ||
sequence_utils.pad_sequences( | ||
a, maxlen=2, truncating="post", value="pad" | ||
) | ||
|
||
def test_pad_sequences_vector(self): | ||
a = [[[1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2], [3, 3]]] | ||
|
||
# test padding | ||
b = sequence_utils.pad_sequences(a, maxlen=3, padding="pre") | ||
self.assertAllClose( | ||
b, | ||
[ | ||
[[0, 0], [0, 0], [1, 1]], | ||
[[0, 0], [2, 1], [2, 2]], | ||
[[3, 1], [3, 2], [3, 3]], | ||
], | ||
) | ||
b = sequence_utils.pad_sequences(a, maxlen=3, padding="post") | ||
self.assertAllClose( | ||
b, | ||
[ | ||
[[1, 1], [0, 0], [0, 0]], | ||
[[2, 1], [2, 2], [0, 0]], | ||
[[3, 1], [3, 2], [3, 3]], | ||
], | ||
) | ||
|
||
# test truncating | ||
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="pre") | ||
self.assertAllClose( | ||
b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 2], [3, 3]]] | ||
) | ||
|
||
b = sequence_utils.pad_sequences(a, maxlen=2, truncating="post") | ||
self.assertAllClose( | ||
b, [[[0, 0], [1, 1]], [[2, 1], [2, 2]], [[3, 1], [3, 2]]] | ||
) | ||
|
||
# test value | ||
b = sequence_utils.pad_sequences(a, maxlen=3, value=1) | ||
self.assertAllClose( | ||
b, | ||
[ | ||
[[1, 1], [1, 1], [1, 1]], | ||
[[1, 1], [2, 1], [2, 2]], | ||
[[3, 1], [3, 2], [3, 3]], | ||
], | ||
) |