Skip to content

Commit

Permalink
Add 'one_hot' output mode to StringLookup and IntegerLookup
Browse files Browse the repository at this point in the history
This output mode will encode every element in a input batch individually, and append a new
output dimension (for the encoded arrays) to the input shape if necessary.

PiperOrigin-RevId: 379624593
  • Loading branch information
mattdangerw authored and tensorflower-gardener committed Jun 16, 2021
1 parent 70c085c commit 9f6a8ec
Show file tree
Hide file tree
Showing 6 changed files with 224 additions and 39 deletions.
74 changes: 54 additions & 20 deletions keras/layers/preprocessing/index_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

INT = "int"
MULTI_HOT = "multi_hot"
ONE_HOT = "one_hot"
COUNT = "count"
TF_IDF = "tf_idf"

Expand Down Expand Up @@ -106,12 +107,19 @@ class IndexLookup(base_preprocessing_layer.CombinerPreprocessingLayer):
map indices to vocabulary items instead of mapping vocabulary items to
indices. Default to False.
output_mode: Specification for the output of the layer. Defaults to `"int"`.
Values can be `"int"`, `"multi_hot"`, `"count"`, or `"tf_idf"` configuring
the layer as follows:
Values can be `"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or
`"tf_idf"` configuring the layer as follows:
- `"int"`: Return the raw integer indices of the input tokens.
- `"multi_hot"`: Outputs a single int array per sample, of either
vocab_size or max_tokens size, containing 1s in all elements where the
token mapped to that index exists at least once in the sample.
- `"one_hot"`: Encodes each individual element in the input into an
array the same size as the vocabulary, containing a 1 at the element
index. If the last dimension is size 1, will encode on that dimension.
If the last dimension is not size 1, will append a new dimension for
the encoded output.
- `"multi_hot"`: Encodes each sample in the input into a single array
the same size as the vocabulary, containing a 1 for each vocabulary
term present in the sample. Treats the last dimension as the sample
dimension, if input shape is (..., sample_length), output shape will
be (..., num_tokens).
- `"count"`: As `"multi_hot"`, but the int array contains a count of the
number of times the token at that index appeared in the sample.
- `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm is applied to
Expand Down Expand Up @@ -152,10 +160,10 @@ def __init__(self,
output_mode = MULTI_HOT
if output_mode == "tf-idf":
output_mode = TF_IDF
# 'output_mode' must be one of (INT, MULTI_HOT, COUNT, TF_IDF)
# 'output_mode' must be one of (INT, ONE_HOT, MULTI_HOT, COUNT, TF_IDF)
layer_utils.validate_string_arg(
output_mode,
allowable_strings=(INT, MULTI_HOT, COUNT, TF_IDF),
allowable_strings=(INT, ONE_HOT, MULTI_HOT, COUNT, TF_IDF),
layer_name=self.__class__.__name__,
arg_name="output_mode")

Expand Down Expand Up @@ -653,23 +661,49 @@ def call(self, inputs):
with tf.control_dependencies(lookup_checks):
if self.output_mode == INT:
return tf.identity(lookup_result)

multi_hot_output = (self.output_mode == MULTI_HOT)
if self._vocab_size and not self.pad_to_max_tokens:
out_depth = self._vocab_size
else:
out_depth = self.max_tokens
if self.sparse:
bincounts = category_encoding.sparse_bincount(lookup_result, out_depth,
multi_hot_output)
return self._encode_output(lookup_result)

def _encode_output(self, lookup_result):
def expand_dims(inputs, axis):
if tf_utils.is_sparse(inputs):
return tf.sparse.expand_dims(inputs, axis)
else:
bincounts = category_encoding.dense_bincount(lookup_result, out_depth,
multi_hot_output)
return tf.compat.v1.expand_dims(inputs, axis)

original_shape = lookup_result.shape
# In all cases, we should uprank scalar input to a single sample.
if lookup_result.shape.rank == 0:
lookup_result = expand_dims(lookup_result, -1)
# One hot will unprank only if the final output dimension is not already 1.
if self.output_mode == ONE_HOT:
if lookup_result.shape[-1] != 1:
lookup_result = expand_dims(lookup_result, -1)

# TODO(b/190445202): remove output rank restriction.
if lookup_result.shape.rank > 2:
raise ValueError(
"Received input shape {}, which would result in output rank {}. "
"Currently only outputs up to rank 2 are supported for "
"`output_mode={}`.".format(original_shape, lookup_result.shape.rank,
self.output_mode))

if self.output_mode == TF_IDF:
return tf.multiply(bincounts, self.tf_idf_weights)
binary_output = self.output_mode in (MULTI_HOT, ONE_HOT)
if self._vocab_size and not self.pad_to_max_tokens:
out_depth = self._vocab_size
else:
out_depth = self.max_tokens
if self.sparse:
bincounts = category_encoding.sparse_bincount(lookup_result, out_depth,
binary_output)
else:
bincounts = category_encoding.dense_bincount(lookup_result, out_depth,
binary_output)

if self.output_mode == TF_IDF:
return tf.multiply(bincounts, self.tf_idf_weights)

return bincounts
return bincounts

def _convert_to_ndarray(self, x):
return np.array(x) if isinstance(x, (list, tuple)) else x
Expand Down
80 changes: 74 additions & 6 deletions keras/layers/preprocessing/index_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,75 @@ def test_int_output_explicit_vocab(self):
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_binary_output_hard_maximum(self):
def test_one_hot_output_hard_maximum(self):
"""Check binary output when pad_to_max_tokens=True."""
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array(["earth", "wind", "and", "fire", "michigan", ""])
expected_output = [
[0, 1, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0],
[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[1, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
]

input_data = keras.Input(shape=(1,), dtype=tf.string)
layer = index_lookup.IndexLookup(
max_tokens=6,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
output_mode=index_lookup.ONE_HOT,
pad_to_max_tokens=True,
dtype=tf.string)
layer.set_vocabulary(vocab_data)
binary_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=binary_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_one_hot_output_soft_maximum(self):
"""Check binary output when pad_to_max_tokens=False."""
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array(["earth", "wind", "and", "fire", "michigan", ""])
expected_output = [
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
]

input_data = keras.Input(shape=(1,), dtype=tf.string)
layer = index_lookup.IndexLookup(
max_tokens=None,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
output_mode=index_lookup.ONE_HOT,
dtype=tf.string)
layer.set_vocabulary(vocab_data)
binary_data = layer(input_data)
model = keras.Model(inputs=input_data, outputs=binary_data)
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_one_hot_output_shape(self):
inputs = keras.Input(batch_size=16, shape=(1,), dtype=tf.string)
layer = index_lookup.IndexLookup(
vocabulary=["earth"],
max_tokens=2,
num_oov_indices=1,
mask_token="",
oov_token="[OOV]",
output_mode=index_lookup.ONE_HOT,
dtype=tf.string)
outputs = layer(inputs)
self.assertAllEqual(outputs.shape.as_list(), [16, 2])

def test_multi_hot_output_hard_maximum(self):
"""Check binary output when pad_to_max_tokens=True."""
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire", ""],
Expand All @@ -858,7 +926,7 @@ def test_binary_output_hard_maximum(self):
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_binary_output_no_oov(self):
def test_multi_hot_output_no_oov(self):
"""Check binary output when pad_to_max_tokens=True."""
vocab_data = ["earth", "wind", "and", "fire"]
valid_input = np.array([["earth", "wind", "and", "fire"],
Expand Down Expand Up @@ -888,7 +956,7 @@ def test_binary_output_no_oov(self):
"found OOV values.*michigan"):
_ = model.predict(invalid_input)

def test_binary_output_hard_maximum_multiple_adapts(self):
def test_multi_hot_output_hard_maximum_multiple_adapts(self):
input_array = np.array([["earth", "wind", "and", "earth"],
["ohio", "and", "earth", "michigan"]])
adapt_data = ["earth", "earth", "earth", "earth", "wind", "wind", "wind"]
Expand Down Expand Up @@ -926,8 +994,8 @@ def test_binary_output_hard_maximum_multiple_adapts(self):
self.assertAllEqual(first_expected_output, first_output)
self.assertAllEqual(second_expected_output, second_output)

def test_binary_output_soft_maximum(self):
"""Check binary output when pad_to_max_tokens=False."""
def test_multi_hot_output_soft_maximum(self):
"""Check multi_hot output when pad_to_max_tokens=False."""
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire", ""],
["fire", "and", "earth", "michigan", ""]])
Expand All @@ -950,7 +1018,7 @@ def test_binary_output_soft_maximum(self):
output_dataset = model.predict(input_array)
self.assertAllEqual(expected_output, output_dataset)

def test_binary_output_shape(self):
def test_multi_hot_output_shape(self):
input_data = keras.Input(batch_size=16, shape=(4,), dtype=tf.string)
layer = index_lookup.IndexLookup(
max_tokens=2,
Expand Down
34 changes: 28 additions & 6 deletions keras/layers/preprocessing/integer_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,19 @@ class IntegerLookup(index_lookup.IndexLookup):
map indices to vocabulary items instead of mapping vocabulary items to
indices. Default to False.
output_mode: Specification for the output of the layer. Defaults to `"int"`.
Values can be `"int"`, `"multi_hot"`, `"count"`, or `"tf_idf"` configuring
the layer as follows:
Values can be `"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or
`"tf_idf"` configuring the layer as follows:
- `"int"`: Return the vocabulary indices of the input tokens.
- `"multi_hot"`: Outputs a single int array per sample, of either
vocabulary size or `max_tokens` size, containing 1s in all elements
where the token mapped to that index exists at least once in the
sample.
- `"one_hot"`: Encodes each individual element in the input into an
array the same size as the vocabulary, containing a 1 at the element
index. If the last dimension is size 1, will encode on that dimension.
If the last dimension is not size 1, will append a new dimension for
the encoded output.
- `"multi_hot"`: Encodes each sample in the input into a single array
the same size as the vocabulary, containing a 1 for each vocabulary
term present in the sample. Treats the last dimension as the sample
dimension, if input shape is (..., sample_length), output shape will
be (..., num_tokens).
- `"count"`: As `"multi_hot"`, but the int array contains a count of the
number of times the token at that index appeared in the sample.
- `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm is applied to
Expand Down Expand Up @@ -156,6 +162,22 @@ class IntegerLookup(index_lookup.IndexLookup):
earlier examples (12 maps to 2, etc) in order to make space for the extra OOV
token.
**One-hot output**
Configure the layer with `output_mode='one_hot'`. Note that the first
`num_oov_indices` dimensions in the ont_hot encoding represent OOV values.
>>> vocab = [12, 36, 1138, 42]
>>> data = tf.constant([12, 36, 1138, 42, 7]) # Note OOV tokens
>>> layer = IntegerLookup(vocabulary=vocab, output_mode='one_hot')
>>> layer(data)
<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.]], dtype=float32)>
**Multi-hot output**
Configure the layer with `output_mode='multi_hot'`. Note that the first
Expand Down
21 changes: 20 additions & 1 deletion keras/layers/preprocessing/integer_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,26 @@ def test_no_vocab(self):
layer = integer_lookup.IntegerLookup()
layer([[1]])

def test_binary_output(self):
def test_one_hot_output(self):
vocab_data = [2, 3, 4, 5]
input_array = np.array([2, 3, 4, 5, 6])
expected_output = [
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
]

input_data = keras.Input(shape=(1,), dtype=tf.int64)
layer = integer_lookup.IntegerLookup(
vocabulary=vocab_data, output_mode="one_hot")
res = layer(input_data)
model = keras.Model(inputs=input_data, outputs=res)
output_data = model.predict(input_array)
self.assertAllEqual(expected_output, output_data)

def test_multi_hot_output(self):
vocab_data = [2, 3, 4, 5]
input_array = np.array([[2, 2, 3, 4], [0, 1, 5, 2]])
expected_output = [[0, 1, 1, 1, 0], [1, 1, 0, 0, 1]]
Expand Down
33 changes: 28 additions & 5 deletions keras/layers/preprocessing/string_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,19 @@ class StringLookup(index_lookup.IndexLookup):
map indices to vocabulary items instead of mapping vocabulary items to
indices. Default to False.
output_mode: Specification for the output of the layer. Defaults to `"int"`.
Values can be `"int"`, `"multi_hot"`, `"count"`, or `"tf_idf"` configuring
the layer as follows:
Values can be `"int"`, `"one_hot"`, `"multi_hot"`, `"count"`, or
`"tf_idf"` configuring the layer as follows:
- `"int"`: Return the raw integer indices of the input tokens.
- `"multi_hot"`: Outputs a single int array per sample, of either
vocab_size or max_tokens size, containing 1s in all elements where the
token mapped to that index exists at least once in the sample.
- `"one_hot"`: Encodes each individual element in the input into an
array the same size as the vocabulary, containing a 1 at the element
index. If the last dimension is size 1, will encode on that dimension.
If the last dimension is not size 1, will append a new dimension for
the encoded output.
- `"multi_hot"`: Encodes each sample in the input into a single array
the same size as the vocabulary, containing a 1 for each vocabulary
term present in the sample. Treats the last dimension as the sample
dimension, if input shape is (..., sample_length), output shape will
be (..., num_tokens).
- `"count"`: As `"multi_hot"`, but the int array contains a count of the
number of times the token at that index appeared in the sample.
- `"tf_idf"`: As `"multi_hot"`, but the TF-IDF algorithm is applied to
Expand Down Expand Up @@ -153,6 +160,22 @@ class StringLookup(index_lookup.IndexLookup):
earlier examples (a maps to 2, etc) in order to make space for the extra OOV
value.
**One-hot output**
Configure the layer with `output_mode='one_hot'`. Note that the first
`num_oov_indices` dimensions in the ont_hot encoding represent OOV values.
>>> vocab = ["a", "b", "c", "d"]
>>> data = tf.constant(["a", "b", "c", "d", "z"])
>>> layer = StringLookup(vocabulary=vocab, output_mode='one_hot')
>>> layer(data)
<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.],
[0., 0., 0., 1., 0.],
[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.]], dtype=float32)>
**Multi-hot output**
Configure the layer with `output_mode='multi_hot'`. Note that the first
Expand Down
21 changes: 20 additions & 1 deletion keras/layers/preprocessing/string_lookup_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,26 @@ def test_no_vocab(self):
layer = string_lookup.StringLookup()
layer([["a"]])

def test_binary_output(self):
def test_one_hot_output(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array(["earth", "wind", "and", "fire", "michigan"])
expected_output = [
[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0],
[0, 0, 0, 0, 1],
[1, 0, 0, 0, 0],
]

input_data = keras.Input(shape=(1,), dtype=tf.string)
layer = string_lookup.StringLookup(
vocabulary=vocab_data, output_mode="one_hot")
res = layer(input_data)
model = keras.Model(inputs=input_data, outputs=res)
output_data = model.predict(input_array)
self.assertAllEqual(expected_output, output_data)

def test_multi_hot_output(self):
vocab_data = ["earth", "wind", "and", "fire"]
input_array = np.array([["earth", "wind", "and", "fire"],
["fire", "and", "earth", "michigan"]])
Expand Down

0 comments on commit 9f6a8ec

Please sign in to comment.