Status | Accepted |
---|---|
Author(s) | Allen Lavoie ([email protected]), André Susano Pinto ([email protected]), Arno Eigenwillig ([email protected]), Rohan Jain ([email protected]) |
Sponsor | Karmel Allison ([email protected]) |
Updated | 2019-02-28 |
Provide an API for serialization/deserialization in TF-2.0 that supports both serving and reuse use-cases.
TensorFlow 2 will include neither tf.Session
nor tf.train.Saver
as public symbols. Current SavedModel export and import workflows rely heavily on these symbols. This document proposes adding tf.saved_model.save
and tf.saved_model.load
as 2.x-compatible replacements. These symbols are mentioned at a high level in the Functions, not Sessions RFC.
There are several reasons to serialize state in programs which use TensorFlow, each with different requirements. This proposal addresses serving and sharing.
When writing a new model, the first need users run into is a way to continue training after restarting a program. The original code still exists, although it may be modified slightly.
tf.train.Checkpoint handles this case. TensorFlow state referenced by Python objects is recorded and can be restored later in time. In order to be used the user needs to have access to the code that constructs the objects.
Once trained, users may want to serve requests using their model. Ideally for this use-case the representation should be a hermetic and stateless black box, usable through a stable interface, and with minimal dependencies.
SavedModel as a format satisfies this use-case. Various APIs for creating SavedModels exist in TensorFlow 1.x (SavedModelBuilder, Estimator, others). Generally the only interface to the model needed for this use-case is a signature specifying what goes into the model and what comes out. This use-case is well supported in TF 1.x and will continue to be supported under this proposal.
Beyond serving, users may want to reuse parts of a trained model in building new models. SavedModel allows saving the computation together with its pre-trained weights, without depending on the model definition in Python and its particular framework. This helps reproducibility (say, for results reported in a paper) and reuse (say, for importing a pre-trained embedding into a model that uses it on a new task). While serving and reproducibility call for a complete model, reuse typically concerns part of a trained model and its composition. That means loading the SavedModel must restore enough Python state to allow building on top of it.
These use-cases are not well addressed in core TensorFlow 1.x APIs. Export APIs have been complicated by concerns relevant to the serving use-case, and not much time has been spent on usability for re-importing models into Python. TensorFlow Hub has solved this (esp. the import workflow) for sharing graphs, but this needs significant redesign for TF 2.x in light of the Functions, not Sessions RFC.
- Support use-cases covered by existing SavedModel export/import APIs and existing Hub module APIs: serving and sharing
- Export a smaller part of a full model.
- Import a part into a larger model.
- Fine-tune an imported part (requires defining "trainable" set of variables, running update ops for batch normalization, exposing regularization losses).
- Re-export an imported part (including its new fine-tuned state).
- Programmatically inspect signatures/functions
- Import state once but data flow using that state multiple times
- Export to SavedModel in a way that is compatible with existing serving infrastructure. Extensions to the format may be required for 2.x compatibility, but existing loading procedures will continue to work.
- Reimport TensorFlow functionality of exported objects back into Python
- With minimal special casing for Keras types, support a SavedModel implementation of
tf.keras.Model.save
andtf.keras.model.load_model
. Details of this special casing will be left to another document. - Exporting and importing should be idempotent (reimported representations are saveable)
- With minimal special casing for Keras types, support a SavedModel implementation of
- Importing existing SavedModels into Python in TensorFlow 2.x. This will share the
tf.saved_model.load
symbol, but will lack any object structure.
- Export arbitrary SavedModels. Use-cases will be covered. For example SavedModel supports multiple MetaGraphs, but the APIs proposed here may only ever export SavedModels with a single MetaGraph.
- Usable Python interfaces for all symbols in the TensorFlow API on import. Everything will of course be usable when graph building, but objects may not have many features when imported back into Python. The set of types with "nice" import representations is expected to increase over time.
-
A signature identifies inputs and outputs for a computation, roughly the feeds and fetches of a single
session.run
call in TensorFlow 1.x.- In a SavedModel or MetaGraph, SignatureDefs identify input and output tensors in the GraphDef, possibly overlapping.
-
Concrete function graph: A subgraph with a single signature which TensorFlow can execute natively to compute Tensor outputs from Tensor inputs.
-
Polymorphic function: A Python callable that encapsulates several concrete function graphs behind one, more natural API. For example it may use non-Tensor arguments to dispatch between the concrete graphs that compute outputs from the Tensor arguments.
-
FunctionDef: A protocol buffer representing a concrete TensorFlow function.
-
FunctionDefLibrary: A protocol buffer containing multiple FunctionDefs and their gradient functions.
-
GraphDef: A serialized "v1-style" TensorFlow graph. Includes a FunctionDefLibrary.
-
MetaGraph: A GraphDef + training checkpoint. Contains fields for signatures and other metadata, although these are often blank.
-
SavedModel: A collection of MetaGraphs with additional assets. Signatures and other MetaGraph fields are always filled.
-
Checkpointable: A mixin which manages dependencies between objects. Checkpointable objects have named dependencies on other checkpointable objects. Most TensorFlow objects already participate in this scheme.
-
tf.train.Checkpoint: A checkpointable object with a save() method, used for writing training checkpoints. Used in several examples in this document as an easy way to make a checkpointable object, since there are no plans for a separate "Checkpointable" symbol.
-
Resource: A data type (dtype) for tensors which point to state. Used to implement variables and tables. In TensorFlow 2.x, resources persist as eager tensors attached to a single Python object (e.g. a
tf.Variable
object), sharing a lifetime with that object. Operations take resource-dtype tensors in order to read or mutate the pointed-to state. The RFC for variables in 2.x has more detail. -
Capture: An implicit tensor input to a function, typically a resource. Variables and other resources which are used or created inside a function are not owned by that function, and are instead lifted out as eager tensors. When a function is called, these tensors are automatically collected and passed in.
The tf.saved_model
additions proposed here handle serialization of Python objects, attributes and functions as a graph of objects, variables, resources and functions backed by polymorphic TensorFlow functions. These objects can be used without access to the original code. It can also store higher level information that allows the objects to be reconstructed assuming a factory of "revivable names"->"revivable classes".
New user-facing symbols will be tf.saved_model.save
and tf.saved_model.load
. These are mentioned at a high level in the Functions, not Sessions RFC.
# Serialize objects reachable from root into a SavedModel.
tf.saved_model.save(
obj : Checkpointable,
export_dir : str
signatures=None : map[str->Function])
# Load the root object from a SavedModel.
tf.saved_model.load(
export_dir: str,
tags=None : list[str]) : Checkpointable
The remainder of this section defines their behavior.
This section describes with examples the basic primitives that are needed to load a SavedModel for reusing it in python without depending on its original code. Note that in many cases, reviving the original class would provide much more functionality than what can be serialized.
Those primitive types are: Variable, CheckpointableBase, PolymorphicFunction, TrackableResource, and any nest of those, including nesting of Checkpointable attributes in an object.
Individual examples:
Save | Load |
---|---|
obj.x = tf.Variable(...) |
type(obj.x) => tf.Variable |
obj.x = tf.Checkpoint(...) |
type(obj.x) => CheckpointableBase |
obj.x = tf.function(...) |
type(obj.x) => PolymorphicFunction |
obj.x = tf.lookup.Table(...) |
type(obj.x) => TrackableResource |
obj.x = [tf.Variable(), tf.Variable(), ...] |
type(obj.x) => [tf.Variable] |
As a preview to the rest of this design, consider the following rough outline of how this would handle a basic text embedding model. First user1 has code that defines a model object that is a CheckpointableBase
that has its resources (variables, tables and assets files) declared as members. Additionally the user took care to annotate the "embed" method with tf.function
decorator and provide a signature (not providing a signature is also possible, but leads to many technicalities).
class Model(tf.Module):
def __init__(self, vocab_file, dim):
# The table object tracks its asset file automatically
self.table = tf.contrib.lookup.index_table_from_file(
vocabulary_file=vocab_file,
num_oov_buckets=1,
)
vocab_size = self.table.size()
self.embeddings = tf.Variable(tf.random.uniform(shape=[vocab_size, dim]))
def tokenize(self, sentences):
sparse_words = tf.string_split(sentences)
token_ids = self.table.lookup(sparse_words)
return token_ids
@tf.function(signature=[tf.TensorSpec(shape=[None], dtype=tf.string)])
def embed(self, sentences):
token_ids = self.tokenize(sentences)
combined = tf.nn.embedding_lookup_sparse(self.embeddings, token_ids, None)
return combined
root = Model("/tmp/vocab", 64)
tf.saved_model.save(root, "/user1/tmp/model")
Another user2 could then inspect and call into parts of this SavedModel without having access to the original code. Note how initialization of resources has happened during load
without user action.
obj = tf.saved_model.load("/user2/download/model")
obj.embed(["hello world"]) => <tf.Tensor: shape=(1, 64), dtype=, numpy=...>
obj.embeddings => <tf.Variable...>
obj.table => <TrackableResource> (the method .lookup would not be present).
tf.saved_model.save
will save the same set of stateful objects as tf.train.Checkpoint
would given the same root object. In addition, save
will iterate over each checkpointable object's attributes and find functions (self.f = tf.function(...)
) and methods (@tf.function
-decorated).
Collected functions and methods are polymorphic, having one or more "concrete" functions, each corresponding to a FunctionDef with tensor inputs and outputs. For methods, we look up the function definitions corresponding to the object we're saving. After that they're similar to attribute-assigned functions.
Polymorphic functions with no signature specified, and which have not been called, have zero concrete functions associated with them. Saving an object with such a polymorphic function will raise an exception.
Each concrete function may reference variables or other stateful objects. Any variables referenced this way imply a dependency of the function's object on the variable. If no transitive dependency exists at export time, an exception will be raised. An automatic dependency scheme may be considered if there is a strong use-case. Such a scheme would be a backwards-compatible addition.
Concrete functions corresponding to signatures which can not be serialized (see Serialization formats) will raise an exception on export.
has_fns = tf.Module()
has_fns.v = tf.Variable(1.)
has_fns.a = tf.function(lambda x: x + has_fns.v + 1.)
has_fns.b = tf.function(lambda x: x + has_fns.v + 2.)
has_fns.c_dep = tf.function(lambda x: x + 3.)
has_fns.c = tf.function(
lambda x: has_fns.v + has_fns.c_dep(x),
input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),))
has_fns.a(tf.constant(2.)) # 4.
has_fns.python_attribute = 12 # Not exported
# Error: exporting a function without a trace (for b)
# "c" can be traced for export since it has an input signature specified
# "c_dep" is traced for export as a result of "c"'s trace
tf.saved_model.save(has_fns, "/tmp/fns")
has_fns.b(tf.constant(3.))
tf.saved_model.save(has_fns, "/tmp/fns")
imported_fns = tf.saved_model.load("/tmp/fns")
print(imported_fns.v) # 1.
imported_fns.a(tf.constant(1.)) # 3.
imported_fns.b(tf.constant(1.)) # 4.
imported_fns.c(tf.constant([1., 2.])) # [5., 6.]
imported_fns.c_dep(tf.constant([1.])) # [4.]
imported_fns.c_dep(tf.constant(1.)) # Error: no trace for scalar inputs
imported_fns.python_attribute # Attribute error; Python attributes are not saved
# Exported SavedModels are also usable as training checkpoints
training_checkpoint = tf.train.Checkpoint(v=tf.Variable(2.))
training_checkpoint.load("/tmp/fns")
print(training_checkpoint.v) # 1.
training_checkpoint.a # Attribute error; functions are not restored
tf.function
, when decorating a method, allows variable creation for each new self
argument. When saving and restoring objects, methods behave like functions assigned to attributes (with the self
argument bound).
class Net(tf.Module):
def __init__(self):
self.y = None
@tf.function
def add(self, x):
if self.y is None:
self.y = tf.Variable(2.)
return x + self.y
net = Net()
net.add(3.) # Variable created
net.add([3.]) # A second concrete function
tf.saved_model.save(net, "/tmp/net")
imported_net = tf.saved_model.load("/tmp/net")
print(imported_net) # <Checkpointable object at 0x7faddc343278>,
# type not preserved by default.
print(imported_net.y) # 2.
imported_net.add(3.) # 5.
imported_net.add([3.]) # [5.]
imported_net.y.assign(3.)
imported_net.add(3.) # 6.
imported_net.add([3.]) # [6.]
Limited serialization of Python objects will support common non-Tensor types in function signatures.
@tf.function
def f(x, training):
return x if training else 2.
f(-1., training=True)
f(-1., training=False)
obj = tf.train.Checkpoint(f=f) # save() exports objects, so we wrap f
tf.saved_model.save(obj, "/tmp/f")
imported = tf.saved_model.load("/tmp/f")
imported.f(10., training=True) # 10.
imported.f(10., training=False) # 2.
Lists, tuples, namedtuples, and dictionaries may be nested in function input signatures and return values. For example:
@tf.function
def g(x):
return [x[0] + 0.1, x[1]["a"] + 0.2]
print(g((tf.constant(1.), {"a": tf.constant(2.)}))) # [1.1, 2.2]
obj = tf.train.Checkpoint(g=g)
tf.saved_model.save(obj, "/tmp/g")
imported = tf.saved_model.load("/tmp/g")
print(imported.g((tf.constant(-1.), {"a": tf.constant(-2.)}))) # [-0.9, -1.8]
See input signature serialization for details. Additional restrictions are placed on functions which will be used as SavedModel signatures, save(obj, signatures=...)
, since these must be callable from the C++ loader API.
tf.saved_model.save
will take an optional argument specifying which function/methods will be recorded as "signatures" in the SavedModel (allowing them to be used when serving for example). These functions must have input signatures specified, either when the tf.function
is created or by calling get_concrete_function
. SignatureDefs and corresponding function call ops will be generated for each signature function in the SavedModel.
class Net(tf.Module):
@tf.function(input_signature=tf.TensorSpec([None, 5], tf.float32))
def infer(self, x):
return x
net = Net()
tf.saved_model.save(
net, "/tmp/serve1",
# One SignatureDef added with the default serving signature key.
signatures=net.infer)
# Or if multiple signatures should be exported:
signature_functions = {
tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
net.infer,
"other_signature_key": ...)
}
tf.saved_model.save(net, "/tmp/serve2", signature_functions)
Input signatures need not be specified when using tf.function
decorators, which would generally require them to be static for the whole Python program. Instead, a signature may be specified at export time. If no corresponding trace exists in the function cache, a new trace will be created.
class Net2(tf.Module):
@tf.function
def infer(self, labels, training, x1, x2):
if training:
return labels, x1, x2
else:
return labels, x1 + 1., x2 + 1.
net = Net2()
tf.saved_model.save(
obj=net, export_dir="/tmp/serve1",
signatures=net.infer.get_concrete_function(
tf.TensorSpec(None, tf.int64),
training=False,
x1=tf.TensorSpec([None, None, 3], tf.float32),
x2=tf.TensorSpec(None, tf.float64)))
# Then a serving request could specify:
# {"labels": [0, 1], "x1": [[[1., 2., 3.]]], "x2": 0.}
# and get a response like
# {"output_1": [0, 1], "output_2": [[[2., 3., 4.]]], "output_3": 1.}
# The "training" argument is important for picking which trace to serve with,
# but does not affect the signature.
Non-Tensor
arguments to functions used to generate signatures are fine (e.g. training=False
). The Tensor
arguments are the only ones serving APIs will know about, and these may not be nested inside lists, tuples, or dictionaries (since there is no way to specify nested Tensor
arguments when serving). Arguments are identified by name when importing a SavedModel
for serving, so there is no ambiguity even if non-Tensor
Python arguments are interspersed with Tensor
arguments.
SavedModel SignatureDef outputs require each output to have a name. If a single dictionary is returned from the traced function, the string keys will be used to name outputs. Otherwise signature functions must return a flat list of Tensor
s, and the outputs will be numbered ("output_1", "output_2", ...). Flattening outputs would be trivial before numbering them, but serving APIs would have no way to reconstruct the structure.
Any return structure other than a dictionary with string keys or a flat sequence of Tensors from a function used to generate a signature will raise an exception on export. Note that this limitation applies only to serving functions; functions attached to objects but not specified as signatures may have other output patterns.
The above examples have only numeric Tensor
s in their exported signatures. If we now decide that tf.Example
protos are a better way to pass data to our model, we can add parsing functionality by re-exporting.
net = tf.saved_model.load("/tmp/serve1") # From the previous example
@tf.function(input_signature=tf.TensorSpec([None], tf.string))
def _parse(serialized):
parsed = tf.parse_example(
serialized,
features={"labels": tf.FixedLenFeature(shape=[10], dtype=tf.int64),
"x1": tf.FixedLenFeature(shape=[10, 3], dtype=tf.float32),
"x2": tf.FixedLenFeature(shape=[10], dtype=tf.float64)})
return net.infer(**parsed)
net.infer_from_proto = _parse
tf.saved_model.save(obj=net, export_dir="/tmp/serve_proto",
signatures=net.infer_from_proto)
Of course re-exporting is not a necessary step; adding a proto parsing head before the first export would be the common case.
When unambiguous and not specified with an explicit signatures=
argument, a default serving signature may be inferred automatically. For example if the object passed to tf.saved_model.save
has only one @tf.function
attached and that function has a signature. Symbolic tf.keras.Model
objects, but not subclassed Model
s in general, similarly have signatures and so may provide an automatic default serving signature.
The .signatures
attribute will be reserved on objects which are saved and restored. On tf.saved_model.load
it will contain an immutable mapping from signature keys (e.g. "serving_default"
) to the functions used to implement each signature. This supports introspection, allowing users to import a SavedModel and run exactly the computation a serving API would run.
On tf.saved_model.save
, signatures in the exported SavedModel come from the first available of the following sources:
- The
tf.saved_model.save(..., signatures=...)
argument - The
.signatures
attribute of the exported object - A set of heuristics to determine a default signature, for example exporting a functional Keras model or searching for
@tf.function
-decorated methods with a signature. This may fail, in which case no signatures will be exported.
For example:
class Net(tf.Module):
@tf.function
def infer(self, x):
return x
net = Net()
net(tf.constant(1.))
tf.saved_model.save(
net, "/tmp/serve",
signatures=net.infer.get_concrete_function(tf.TensorSpec(None, tf.float32)))
loaded = tf.saved_model.load("/tmp/serve")
loaded.signatures["serving_default"](x=tf.constant([[1.]]))
# -> {"output_0": <tf.Tensor [[1.]], dtype=float32>}
tf.saved_model.save(loaded, "/tmp/serve1") # Contains the same signature
Attempting to save an object with a .signatures
attribute containing something other than an immutable signature mapping (for example created by tf.saved_model.load
) will raise an exception. This prevents accidentally ignored signatures when the attribute is modified and the argument passed. However, internal APIs may make use of the attribute to provide a user-overridable default.
The Python APIs proposed in this document will be targeted at a TensorFlow 2.x environment. They will not be tested with Graph
and Session
, and so will not be usable from TensorFlow 1.x without eager execution enabled.
SavedModels exported from TensorFlow 1.x APIs will be importable using the proposed APIs. SavedModel signatures will be available as callable Python functions. This includes the functionally-equivalent Estimator.export_saved_model
, which will still be available in TensorFlow 2.x.
SavedModels exported from the proposed APIs will be importable using TensorFlow 1.x APIs, including TensorFlow Serving and the C++ loader API. The available computation will be the exported signatures, tf.saved_model.save(..., signatures=...)
.
Using tf.saved_model.load
on a SavedModel exported from a TensorFlow 1.x API will import each SignatureDef as an individual tf.compat.v1.wrap_function
object. This will follow the same style as for signatures exported using tf.saved_model.save, with a .signatures
attribute of the root object containing a mapping from signature keys to wrap_function
objects. Another attribute will contain variables.
If multiple MetaGraphs exist in the SavedModel, the tf.saved_model.load(..., tags=...)
argument must be specified and must match exactly one MetaGraph. Only one MetaGraph will be loaded per call to tf.saved_model.load
.
Loading for a MetaGraph will follow the existing procedure for the C++ and Python loader APIs, a checkpoint restore followed by the main op running with asset paths "fed". This procedure will be wrapped in its own wrap_function
object and executed when tf.saved_model.load
runs.
Not all existing SavedModels will be loadable to start. Some known tricky issues:
- Reference variables (as opposed to resource variables) do not exist in 2.x, so SavedModels using these will require rewriting
- Control flow using collections (while_loop/cond) will need some graph rewriting to import correctly
State (a variable, table, etc.) is represented in TensorFlow 2.x using a resource-dtype eager tensor. Such state is uniquely associated with a Python object (e.g. a tf.Variable
), and deletion/collection of the Python object triggers deletion of the resource (DestroyResourceOp). Functions reference state through special "capture" inputs, with a resource-dtype placeholder in the function body which is fed the eager resource tensor on each function execution, giving the function a temporary reference to the resource. This is true even if a resource is created while tracing the function, in which case the resource handle is lifted out into the eager context before being captured.
Managing resources involves two operations: the creation of the resource-dtype tensor (e.g. VarHandleOp
) and its initialization (e.g. AssignVariableOp
). Both of these operations may be executed eagerly, but an exported SavedModel needs to include the operations themselves. A TrackableResource
type will associate resource-dtype eager tensors with functions to create and initialize them. On export, resource tensors will be collected through object dependencies and matched to the captured inputs of exported functions.
Objects which reference external files that should be included in the SavedModel will indicate these asset paths by subclassing TrackableAsset
, serving the same purpose as the v1 assets collection. Paths referenced this way will be copied into the SavedModel's assets directory on export.
TrackableResource
and TrackableAsset
may be used together, creating a resource which is initialized from an asset.
class TrackableAsset(CheckpointableBase):
def __init__(self, asset_path):
# This variable will be initialized using the absolute path to a resource
# on SavedModel restore. It will not be checkpointed.
self._asset_path = tf.Variable(asset_path, dtype=tf.string)
@property
def asset_path(self):
return self._asset_path
class TextFileInitializer(TrackableAsset):
def __init__(self, asset_path):
# Lets object-based saving track asset paths. This value is recorded in an
# AssetFileDef in the SavedModel.
TrackableAsset.__init__(self, asset_path)
def initialize(self, table):
gen_lookup_ops.initialize_table_from_text_file_v2(
table.resource_handle,
self.asset_path,
...)
class Table(TrackableResource):
def __init__(self, initializer):
self._initializer = initializer
self._track_checkpointable(initializer, name="_initializer")
@tf.function(input_signature=())
def create_resource(self):
Return gen_lookup_ops.hash_table_v2(...)
@tf.function(input_signature=())
def initialize(self):
self._initializer.initialize(self) # May capture an asset variable
In the SavedModel protocol buffer, AssetFileDefs will have a restore function taking the full asset path which assigns to the asset path variable. A serving API expecting a 1.x-style SavedModel will feed values for the AssetFileDef Tensors and run the referenced function call op, initializing the variable. The 2.x SavedModel import API will run the function directly. Asset variables will then be captured inputs to TrackableResource
initializer functions.
TrackableAsset
and TrackableResource
objects will be recreated by the Python SavedModel import routine to make reexport possible while preserving asset paths.
Functions will be traced outside of any device scope, and we will rely on the placement of the PartitionedCallOp
for a "default" device. So no special treatment is needed to switch devices between export and import: just call the imported function in a device scope.
Device placements specified within the function body will be hard-coded in the SavedModel, and aside from library code needing to place things on the CPU, we should discourage tf.device
within graph functions so devices aren't hard-coded for export.
This means that the Python implementation of polymorphic functions (tf.function
) should not specialize a function's trace based on the device stack where it is called. Instead, it should look up the graph function to call without regard to device placement, tracing outside a device scope if a new graph function must be created. Then the function call itself will be within the enclosing device scope.
This does not protect users who use device-specific operations (cuDNN) or layouts which are only supported on one type of device. Such SavedModels may only be usable on one type of device.
A user should eventually be able to export a single-machine computation and import the SavedModel under a DistributionStrategy
scope. An initial implementation will simply hard-code device placements when a distribution strategy is active, meaning that the DistributionStrategy
used on export will be the only usable configuration on import.
Options for allowing single-device models to be imported with a DistributionStrategy
include recording and saving attribute accesses for variable objects (assign_add
, assign
, read
, etc.) and rewiring the graph on import, or supporting templated functions which can be specialized to access variables in a certain way. Solving this will be crucial to support sharing SavedModels for reuse (see Sharing above).
By default, imported objects will have unique types inheriting from Checkpointable
. Objects of these types will have tf.function
callables in their attributes, along with attributes for variables and other checkpointable dependencies.
In some cases, Python values are important parts of an object's API. For example the tf.keras.backend.learning_phase()
global is a Python integer which affects the behavior of tf.keras.Layer
methods. Such Python values must already be part of a polymorphic function's cache key for correct tracing regardless of export/import. This will be implemented by a "tf.function
-compatible" method which explicitly takes all of its inputs as arguments and returns its outputs (e.g. taking a learning phase and returning regularization losses). There will be a way to register a custom base class for a revived type which has arbitrary Python attributes and convenience wrappers for the tf.function
-compatible TensorFlow methods. The registration will be keyed to a unique string which must be the same at import time as it was at export time.
Registrations for revived types will initially be considered implementation details used to support saving and restoring TensorFlow types, but may eventually be exposed as public APIs.
class HasPython(tf.Module):
@tf.function
def do(self, x, learning_phase):
if learning_phase == 0:
return x
else:
return x + 1.
def __call__(self, x):
# Python methods which call TF methods are fine, but need a custom revived
# type.
return self.do(x, learning_phase=tf.keras.backend.learning_phase())
has_python = HasPython()
tf.saved_model.save(has_python, "/tmp/haspython")
imported = tf.saved_model.load("/tmp/haspython")
tf.keras.backend.set_learning_phase(1)
imported(1.) # 2.
For the same reasons slot variables are special cased in tf.train.Checkpoint
, optimizers will require some special-casing when restored from a SavedModel.
Restored optimizers will be generic Optimizer
instances with their behavior defined by the SavedModel, and have their slot variables restored and mapped from the right recreated variable objects (this will be the main Optimizer
-specific special casing). Non-slot variables will be handled as for any other objects. The functionality in _prepare
, _resource_apply_dense
, _resource_apply_sparse_duplicate_indices
, ..., _finish
will all be traced, with the restored Optimizer
using the implementations from the SavedModel.
Restored Optimizer
s will not be limited to optimizing variables in the imported model. The exported signatures will allow any gradient shape, which should be no problem for the ops used to implement core optimizers. Tracing a slot variable lookup doesn't make much sense from the perspective of a TensorFlow graph (or would limit the Optimizer
to working with variables which existed at export time), so some refactoring may be required to create pure functions which take (primary, slots, gradient) rather than taking (primary, gradient) and looking up slot variables. Then the RevivedOptimizer
Python type would be responsible for looking up the correct slot variables.
On import, objects are restored and variables set to their checkpointed values. However, the imported types will be usable, exposing initialization graphs.
class Net(tf.Module):
def __init__(self, units):
self.units = units
self.var = None
self.built = False
def build(self, x):
self.var = tf.Variable(x * tf.ones(self.units))
@tf.function
def do(self, x):
if not self.built:
self.build(x)
return x + self.var
net = Net(5)
net.do(1.)
net.var.assign([1., 2., 3., 4., 5.])
tf.saved_model.save(net, "/tmp/net")
imported_net = tf.saved_model.load("/tmp/net")
assert list(imported_net.var.numpy()) == [1., 2., 3., 4., 5.]
net_from_class = type(imported_net)() # No Tensor constructor arguments
# net_from_class.var is uninitialized
net_from_class.do(2.)
assert list(net_from_class.var.numpy()) == [2., 2., 2., 2., 2.]
Constructing a new object from a revived type will also construct new objects for any dependencies. To be usable, the pre-export object associated with this type must not have had a transitive dependency on any function or method unless it also had transitive dependencies on all that function's referenced variables. So for example an object referencing a checkpointable list of functions which reference its variables may be constructed, but the list itself may not be constructed on its own.
Unless __init__
is decorated, revived objects will not take constructor arguments. Constructing a new object from a revived type creates uninitialized variables of the same shape and dtype as the revived object with that type, and calling the method which created a variable (before export) initializes it. Variable initialization will be automatic and idempotent, as implemented in tf.function.
A @tf.function
-decorated __init__
before export requires corresponding tensors be passed to the constructor of the revived type. Dependencies are constructed in an uninitialized state even if they have tensor arguments to their __init__
methods, with the understanding that initialization will be included in the trace of the method of the parent object which created the depended-on object before export (i.e. using a traced constructor to create a depended-on object outside of a traced method will result in uninitialized variables).
- (Done) Basic export to SavedModel for serving. Requires function(s) to be specified for signatures. Marked as experimental to avoid creating a format that won't import correctly into Python once that's implemented.
- Import of existing SavedModels (v1 compatibility)
- Python object re-import
- Serialized representation for polymorphic functions/methods
- Collect and save functions and methods attached to all objects (not just those specified explicitly as signatures)
- A function-based SaverDef to allow a sessionless restore
- Import generic objects back into Python with functions/methods and variables.
- Mark as not experimental. Export no longer requires signatures.
- Decorate call methods (and loss lambdas, etc.) of Layers so that the TensorFlow parts of their functionality can be referenced individually when deserialized.
- Python wrapper support for imported objects, used to make Keras Layers and Models import nicely.
- Distribution strategies integration: allow imported models to be run on multiple devices (without requiring them to have been saved with a distribution strategy set).
- Assets, any necessary changes to immutable tables to work with the tweaked asset handling.
- Import wrappers for other Python types (e.g. Optimizers)
- Make imported types work for creating new objects using saved initialization graphs
At a high level, every time an Operation
or Tensor
is referenced directly in the SavedModel format, it will be helpful to add a function name for use when loading the SavedModel into a sessionless TF 2.x environment. The existing op/tensor references will then reference function call ops in the MetaGraph, telling existing SavedModel loading infrastructure how to call the functions.
The loading procedure for session-based APIs will be the same as it is today, relying on ops in the SavedModel's MetaGraph. Loading into a 2.x-style context without any sessions uses the following procedure:
- Python objects are created from the CheckpointableObjectGraph proto with the dependency structure they had before export. Every created object inherits from CheckpointableBase, but many will have more specific types.
- Objects with custom types registered will be re-created using those types
- Variables are recreated as
tf.Variable
objects. This will use the VariableDef as a serialization format for variable attributes, but unlike the existingtf.Variable.from_proto
it will always create a new resource handle rather than looking up an operation in a Graph. They will be uninitialized to start. - Resources with no more specific type will be revived as
TrackableResource
objects, which ensures that the information required to re-create and initialize their resources is preserved for re-export. - Objects which inherited from
TrackableAsset
before serialization will be revived as that type. - All other objects will be generic Checkpointable objects.
TrackableAsset
objects have string variables created for them, initialized with the absolute path of the corresponding asset.- Functions for resource handle creation associated with
TrackableResource
s are imported from the SavedModel into the eager context and represented asFunction
objects. - Resource handle creation runs for
TrackableResource
objects. Variables will have already created resource tensors. We create a map from resource tensors in the SavedModel to the newly created eager resource tensors. - Remaining "concrete" functions from the SavedModel are imported into the eager context and wrapped as
Function
objects, with captured resources mapped to their corresponding eager resource tensors. Function
objects are gathered intoPolymorphicFunction
objects and assigned to object attributes- Variables are restored from the checkpoint by running the restore function referenced in the SaverDef (through the imported
Function
object) TrackableResource
objects have their initializer functions run, which includes for example initializing tables from assets
Several protocol buffers will require backwards-compatible additions to support loading without a Graph
/Session
.
The existing CheckpointableObjectGraph will be augmented with pointers to the SavedModel components necessary to recreate objects. For example objects representing variables will identify their VariableDef, TrackableAsset
s will identify their AssetFileDef, and TrackableResource
s will identify handle creation and initialization functions.
PolymorphicFunction
s will be nodes in this graph (but without any children), which allows multiple objects to reference the same function.
Each PolymorphicFunction
is a list of signatures (indicating input and output formatting) each with a corresponding "concrete" FunctionDef. This format will build on existing Hub work which serializes Python function signatures.
There will be no general pickling, so only a limited set of types will be supported in the signatures of serialized functions. This support will include at least the basic types Boolean, string, integer, float, None
, TensorShape, and dtype. Nests of container types may involve lists, tuples, dicts, and namedtuples.
For each concrete function in the PolymorphicFunction
, any Python arguments will be serialized along with indicators specifying tensor inputs. This is necessary to allow selection between concrete functions when the restored PolymorphicFunction
is called, e.g. f(..., training=True)
vs. f(..., training=False)
. Each argument to the function and its output may be an arbitrary nest of the supported Python types and tensors.
SaverDef currently names a feed tensor which takes a checkpoint path and an operation to run to save and restore. These fields will continue to exist in the SaverDef and will be filled in by tf.saved_model.save
so that 1.x-style loader APIs can restore variables from checkpoints.
Two fields will be added identifying save and restore functions (FunctionDefs in the GraphDef's FunctionDefLibrary), each taking scalar string tensors. Call operations for these functions will be referenced by the existing restore_op_name
and save_tensor_name
fields.
Each object in the CheckpointableObjectGraph will have save and restore functions for each SaveableObject they export, and these functions will be composed into the SaverDef's save and restore functions. The restored Python objects will re-export these save and restore functions so that loading and saving again is idempotent (and subsets of objects are re-saveable).
VariableDefs will be used as-is to store variable attributes. The existing tf.Variable.from_proto
restore logic will not be used to re-create variables when loading into a 2.x context. They will instead be created with new eager resource handles in an uninitialized state, then restored from the checkpoint.
TrackableAsset
objects will have corresponding AssetFileDef protos in the SavedModel. When loading using a 1.x-style API, the fed filename tensor will be used to assign to the TrackableAsset
's variable. The AssetFileDef proto will not change.
Are there concerns on the concepts, types used on this format being python specific?
What about trying to load these representations on other languages?
How, once a user has specified a signatures=
argument on export, should that argument be represented on re-import (if at all). One idea is to put them in a special signatures
attribute of the imported object, another is to use the signature keys (e.g. serving_default
) as attributes on the imported object directly. There seems to be some consensus that being able to access these signatures on import is important.
Probably not. The compatibility section now makes this explicit.
TensorFlow 1.x SavedModel load APIs have a tag-based selection system for choosing which MetaGraphs to load. Should load()
take arguments to replicate this behavior, even if they’re not relevant to save()
? Or is returning a list and lazily loading MetaGraphs sufficient?
Is returning an unwrapped object in the single-MetaGraph case but a list in the multi-MetaGraph case too surprising? If so, do we need a separate API?
Things that changed since doc was sent out
- Import representation of signatures: attach to root object? Decided to attach to .signatures attribute (see Imported representation of signatures)
Two big issues to discuss
- Functions in MetaGraph
- e.g. should train/eval/serving be in the same metagraph? What if train graph has a custom op that can't be loaded into the servo?
- Is it possible to import a subset of the graph?
- If all attributes are imported when the SavedModel is loaded, then this is impossible
- Proposal:
- already in current design=single recursive export
- allow two options during import: entire object vs set of signatures
- Is it possible to filter at load time? yeah, will need to be implemented
- Proposal: Export entire object graph -and- v1 SavedModel as separate fields
- Two MetaGraphs, shared FunctionDef
- Pros:
- Allows compatibility with V1 tools (ease transition)
- Conceptually consistent for users/library devs: one metagraph=restore, other=serving
- Cons:
- Duplicate information
- Bigger file (there's a 2gb limit?)
- Additional complexity for hypothetical situation
- Decision: Extend metagraph, v1 will only look at some fields
- Compatibility story
- What happens when a V2 loader runs into a V1 SavedModel?
- Proposal: Always return an object that represents the single metagraph (default case)
- object has .signatures, .variables attributes
- User must specify tags if there are multiple metagraphs (Error out if not specified)
- Proposal: Always return an object that represents the single metagraph (default case)
- Load V2 into Estimator/Graph mode
- Make sure collections are loaded correctly
- Implementable, low priority
- Load V2 into Keras V2 layer
- Should be able to be done, low priority
- What happens when a V2 loader runs into a V1 SavedModel?
Distribution Strategy + SavedModels
- Day 1: Functions are not device hard-coded, not DistributionStrategy friendly
- Options:
- Export Distributed Graph with devices hardcoded
- Load inside DistributionStrategy scope
- If the variables are created and the forward pass is run in the scope, then the model will be distributed
- Look out for bugs in BatchNormalization
- Requires variable annotations in SavedModel
- Aggregation mode
- Sync mode