Skip to content

Commit

Permalink
Fix saved model for TextVectorization with a vocab set on init
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 390720333
  • Loading branch information
mattdangerw authored and tensorflower-gardener committed Aug 13, 2021
1 parent 72c5ba7 commit 6129e95
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 94 deletions.
14 changes: 14 additions & 0 deletions keras/layers/preprocessing/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ py_library(
],
srcs_version = "PY3",
deps = [
":preprocessing_utils",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras/engine",
Expand Down Expand Up @@ -108,6 +109,7 @@ py_library(
srcs_version = "PY3",
deps = [
":category_encoding",
":preprocessing_utils",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras:backend",
Expand All @@ -122,6 +124,7 @@ py_library(
],
srcs_version = "PY3",
deps = [
":preprocessing_utils",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
"//keras:backend",
Expand Down Expand Up @@ -150,6 +153,7 @@ py_library(
srcs_version = "PY3",
deps = [
":category_encoding",
":preprocessing_utils",
":string_lookup",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
Expand Down Expand Up @@ -225,6 +229,16 @@ py_library(
],
)

py_library(
name = "preprocessing_utils",
srcs = ["preprocessing_utils.py"],
srcs_version = "PY3",
deps = [
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
],
)

cuda_py_test(
name = "category_crossing_test",
srcs = ["category_crossing_test.py"],
Expand Down
18 changes: 6 additions & 12 deletions keras/layers/preprocessing/discretization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# ==============================================================================
"""Keras discretization preprocessing layer."""

import tensorflow.compat.v2 as tf
# pylint: disable=g-classes-have-attributes
# pylint: disable=g-direct-tensorflow-import

import numpy as np
from keras.engine import base_preprocessing_layer
from keras.layers.preprocessing import preprocessing_utils as utils
from keras.utils import tf_utils
import numpy as np
import tensorflow.compat.v2 as tf
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import keras_export

Expand Down Expand Up @@ -190,7 +192,7 @@ def __init__(self,
raise ValueError("Both `num_bins` and `bin_boundaries` should not be "
"set. You passed `num_bins={}` and "
"`bin_boundaries={}`".format(num_bins, bin_boundaries))
bin_boundaries = self._convert_to_list(bin_boundaries)
bin_boundaries = utils.listify_tensors(bin_boundaries)
self.input_bin_boundaries = bin_boundaries
self.bin_boundaries = bin_boundaries if bin_boundaries is not None else []
self.num_bins = num_bins
Expand Down Expand Up @@ -232,7 +234,7 @@ def finalize_state(self):
return

# The bucketize op only support list boundaries.
self.bin_boundaries = self._convert_to_list(
self.bin_boundaries = utils.listify_tensors(
get_bin_boundaries(self.summary, self.num_bins))

def reset_state(self): # pylint: disable=method-hidden
Expand Down Expand Up @@ -282,11 +284,3 @@ def bucketize(inputs):
dense_shape=tf.identity(inputs.dense_shape))
else:
return bucketize(inputs)

def _convert_to_list(self, inputs):
if tf.is_tensor(inputs):
inputs = inputs.numpy()
if isinstance(inputs, (np.ndarray)):
inputs = inputs.tolist()
inputs = list(inputs)
return inputs
50 changes: 22 additions & 28 deletions keras/layers/preprocessing/index_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""Keras index lookup preprocessing layer."""

# pylint: disable=g-classes-have-attributes
# pylint: disable=g-direct-tensorflow-import

import collections

from keras import backend
from keras.engine import base_layer_utils
from keras.engine import base_preprocessing_layer
from keras.layers.preprocessing import category_encoding
from keras.layers.preprocessing import preprocessing_utils as utils
from keras.saving.saved_model import layer_serialization
from keras.utils import layer_utils
from keras.utils import tf_utils
Expand Down Expand Up @@ -197,14 +199,15 @@ def __init__(self,
self.output_mode = output_mode
self.sparse = sparse
self.pad_to_max_tokens = pad_to_max_tokens
self.input_vocabulary = None
# IndexLookupLayerSavedModelSaver will clear the config config vocabulary to
# restore the lookup table ops directly. We persist this hidden option to
# persist the fact that we have have a non-adaptable layer with a manually
# set vocabulary.
self._has_input_vocabulary = kwargs.pop("has_input_vocabulary", False)
self._frozen_vocab_size = None

self.input_vocabulary = vocabulary
# VocabularySavedModelSaver will clear the config vocabulary to restore the
# lookup table ops directly. We persist this hidden option to persist the
# fact that we have have a non-adaptable layer with a manually set vocab.
self._has_input_vocabulary = kwargs.pop("has_input_vocabulary",
(vocabulary is not None))

# Drop deprecated config options.
kwargs.pop("vocabulary_size", None)
kwargs.pop("has_static_table", None)
Expand Down Expand Up @@ -257,16 +260,18 @@ def __init__(self,
# find and restore a lookup_table attribute on the layer. This table needs
# to be uninitialized as a StaticHashTable cannot be initialized twice.
self.lookup_table = self._uninitialized_lookup_table()
if not self._has_input_vocabulary:
# Add a custom weight handler to return the layers vocab as it's weight.
self._add_trackable(VocabWeightHandler(self), False)
# Set adapt state.
self.token_counts = tf.lookup.experimental.MutableHashTable(

# Only set up adapt state if we did not recieve a vocab on construction.
if not self._has_input_vocabulary:
# Add a custom weight handler to return the layers vocab as it's weight.
self._add_trackable(VocabWeightHandler(self), False)
# Set adapt state.
self.token_counts = tf.lookup.experimental.MutableHashTable(
key_dtype=self.dtype, value_dtype=tf.int64, default_value=0)
if self.output_mode == TF_IDF:
self.token_document_counts = tf.lookup.experimental.MutableHashTable(
key_dtype=self.dtype, value_dtype=tf.int64, default_value=0)
if self.output_mode == TF_IDF:
self.token_document_counts = tf.lookup.experimental.MutableHashTable(
key_dtype=self.dtype, value_dtype=tf.int64, default_value=0)
self.num_documents = tf.Variable(0, dtype=tf.int64, trainable=False)
self.num_documents = tf.Variable(0, dtype=tf.int64, trainable=False)

def compute_output_shape(self, input_shape):
if self.output_mode == INT:
Expand Down Expand Up @@ -330,7 +335,7 @@ def get_config(self):
"mask_token": self.mask_token,
"output_mode": self.output_mode,
"pad_to_max_tokens": self.pad_to_max_tokens,
"vocabulary": self._make_serializable(self.input_vocabulary),
"vocabulary": utils.listify_tensors(self.input_vocabulary),
}

base_config = super().get_config()
Expand Down Expand Up @@ -363,9 +368,6 @@ def set_vocabulary(self, vocabulary, idf_weights=None):
been called.
RuntimeError: If a tensor vocabulary is passed outside of eager execution.
"""
self.input_vocabulary = vocabulary
self._has_input_vocabulary = True

if self.output_mode != TF_IDF and idf_weights is not None:
raise ValueError("`idf_weights` should only be set if output_mode is "
"TF_IDF. output_mode is {}.".format(self.output_mode))
Expand Down Expand Up @@ -753,14 +755,6 @@ def _expand_dims(self, inputs, axis):
else:
return tf.expand_dims(inputs, axis)

def _make_serializable(self, x):
if tf.is_tensor(x):
x = x.numpy()
if isinstance(x, (np.ndarray)):
x = x.tolist()
x = list(x)
return x

def _oov_start_index(self):
return 1 if self.mask_token is not None and self.output_mode == INT else 0

Expand Down Expand Up @@ -831,7 +825,7 @@ def _inverse_document_frequency(self, token_document_counts, num_documents):

@property
def _trackable_saved_model_saver(self):
return layer_serialization.IndexLookupLayerSavedModelSaver(self)
return layer_serialization.VocabularySavedModelSaver(self)

# Override points for IntegerLookup and StringLookup.
def _tensor_vocab_to_numpy(self, vocabulary):
Expand Down
19 changes: 6 additions & 13 deletions keras/layers/preprocessing/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@
# ==============================================================================
"""Normalization preprocessing layer."""

# pylint: disable=g-classes-have-attributes
# pylint: disable=g-direct-tensorflow-import

from keras import backend
from keras.engine import base_preprocessing_layer
from keras.layers.preprocessing import preprocessing_utils as utils
import numpy as np
import tensorflow.compat.v2 as tf

from tensorflow.python.util.tf_export import keras_export

# pylint: disable=g-classes-have-attributes


@keras_export('keras.layers.Normalization',
'keras.layers.experimental.preprocessing.Normalization')
Expand Down Expand Up @@ -262,8 +263,8 @@ def get_config(self):
config = super().get_config()
config.update({
'axis': self.axis,
'mean': self._convert_to_list(self.input_mean),
'variance': self._convert_to_list(self.input_variance),
'mean': utils.listify_tensors(self.input_mean),
'variance': utils.listify_tensors(self.input_variance),
})
return config

Expand All @@ -272,11 +273,3 @@ def _standardize_inputs(self, inputs):
if inputs.dtype != self.dtype:
inputs = tf.cast(inputs, self.dtype)
return inputs

def _convert_to_list(self, inputs):
if tf.is_tensor(inputs):
inputs = inputs.numpy()
if isinstance(inputs, (np.ndarray)):
inputs = inputs.tolist()
inputs = list(inputs)
return inputs
5 changes: 2 additions & 3 deletions keras/layers/preprocessing/preprocessing_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for Keras' base preprocessing layer."""

import tensorflow.compat.v2 as tf
"""Tests utils for preprocessing layers."""

import collections
import numpy as np
import tensorflow.compat.v2 as tf


class PreprocessingLayerTest(tf.test.TestCase):
Expand Down
27 changes: 27 additions & 0 deletions keras/layers/preprocessing/preprocessing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utils for preprocessing layers."""

import numpy as np
import tensorflow.compat.v2 as tf


def listify_tensors(x):
"""Convert any tensors or numpy arrays to lists for config serialization."""
if tf.is_tensor(x):
x = x.numpy()
if isinstance(x, np.ndarray):
x = x.tolist()
return x
Loading

0 comments on commit 6129e95

Please sign in to comment.