Skip to content

Commit

Permalink
Creates non-breaking changes where necessary in preparation for switc…
Browse files Browse the repository at this point in the history
…hing all of Keras to new serialization format.

PiperOrigin-RevId: 509645806
  • Loading branch information
nkovela1 authored and tensorflower-gardener committed Feb 14, 2023
1 parent 01c1f68 commit 6291ecc
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions tensorflow_probability/python/layers/distribution_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,15 +1963,16 @@ def _transposed_variational_loss(y, kl_weight=1.):


def _serialize(convert_to_tensor_fn):
return tf.keras.utils.serialize_keras_object(convert_to_tensor_fn)
return tf.keras.utils.legacy.serialize_keras_object(convert_to_tensor_fn)


def _deserialize(name, custom_objects=None):
return tf.keras.utils.deserialize_keras_object(
return tf.keras.utils.legacy.deserialize_keras_object(
name,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='convert-to-tensor function')
printable_module_name='convert-to-tensor function',
)


def _get_convert_to_tensor_fn(identifier):
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_probability/python/layers/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def deserialize_function(serial, function_type):
"""
if function_type == 'function':
# Simple lookup in custom objects
function = tf.keras.utils.deserialize_keras_object(serial)
function = tf.keras.utils.legacy.deserialize_keras_object(serial)
elif function_type == 'lambda':
# Unsafe deserialization from bytecode
function = generic_utils.func_load(serial)
Expand Down

0 comments on commit 6291ecc

Please sign in to comment.