Skip to content

Commit

Permalink
Make TextVectorization work with list input.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 379565815
  • Loading branch information
yashk2810 authored and tensorflower-gardener committed Jun 15, 2021
1 parent e1a17d3 commit f604d04
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
3 changes: 2 additions & 1 deletion keras/layers/preprocessing/text_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ def build(self, input_shape):
# expression to evaluate to False instead of True if the shape is undefined;
# the expression needs to evaluate to True in that case.
if self._split is not None:
if input_shape.ndims > 1 and not input_shape[-1] == 1: # pylint: disable=g-comparison-negation
if (input_shape is not None and input_shape.ndims > 1 and
not input_shape[-1] == 1): # pylint: disable=g-comparison-negation
raise RuntimeError(
"When using TextVectorization to tokenize strings, the innermost "
"dimension of the input array must be 1, got shape "
Expand Down
6 changes: 6 additions & 0 deletions keras/layers/preprocessing/text_vectorization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,12 @@ def test_layer_dimensionality_handling_with_split(self, data, expected):
output = vectorization(tf.ragged.constant(data, inner_shape=(1,)))
self.assertAllEqual(expected, output)

def test_layer_list_input(self):
layer = text_vectorization.TextVectorization(vocabulary=["a", "b", "c"])
output = layer(["a", "b", "c"])
expected_output = [[2], [3], [4]]
self.assertEqual(output.numpy().tolist(), expected_output)


@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
class TextVectorizationPreprocessingTest(
Expand Down

0 comments on commit f604d04

Please sign in to comment.