forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implements a modular and easily extensible evaluation framework for b…
…oth TFLite and TFjs. The evaluation framework has the following features: * It is easy to add new Modules of examples since each Module is specified using a few lines of code (see `examples.py`). * It is easy to add new converters since each converter is represented as a function (see `converters.py`). For instance, we could add the MLIR-based converter that the TFLite team is currently working on. * The framework outputs a Markdown table (see `README.md`). The framework has the following limitations: * We only evaluate whether a Module converts, we do not compare any outputs between the converted model and the original model. This will require more effort, and it seems like we can do this as a follow-up if necessary (once a good fraction of ops are converted). * If an example is missing multiple ops, then only the first missing op is reported. We could improve this by implementing mocked versions of non-working ops, which only output the right shapes. We could also consider doing this as a follow-up. PiperOrigin-RevId: 402287865
- Loading branch information
1 parent
a47119d
commit 161363d
Showing
7 changed files
with
693 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# Converting JAX examples to TFLite/TFjs | ||
|
||
## Overview | ||
|
||
This directory implements a flexible evaluation framework for converting JAX | ||
examples to TFjs/TFLIte using jax2tf, making it relatively easy to add new | ||
examples and converters, and write results of all tests to Markdown. | ||
|
||
The results of the conversion are written to `converters_results.md`. | ||
|
||
See `examples_test.py` for instructions on how to run the evaluation. | ||
|
||
### Features | ||
|
||
* It is easy to add new Modules of examples since each Module is specified | ||
using a few lines of code (see `examples.py`). | ||
|
||
* It is easy to add new converters since each converter is represented as a | ||
function (see `converters.py`). | ||
|
||
* The framework outputs a Markdown table (outputted below), which provides | ||
an overview of the missing ops for all examples and all converters. | ||
|
||
### Limitations | ||
|
||
* We only evaluate whether a Module converts, we do not compare any outputs | ||
between the converted model and the original model. | ||
|
||
* If an example is missing multiple ops, then only the first missing op is | ||
reported. | ||
|
||
## Code Details | ||
|
||
### `[examples_test.py]` | ||
|
||
This is the binary to run to execute tests. It has flags for various options. | ||
|
||
### `[converters.py]` | ||
|
||
This contains the functions representing different converters. | ||
|
||
### `[all_examples.py]` | ||
|
||
A list of all the examples to test. As one can see each example only takes a few | ||
lines so should be quite easy to add new ones. | ||
|
||
The file also contains several data structures: | ||
|
||
* `Arg`: An enum used in arguments in ModuleSpec, which depend on particular | ||
state (rng, module state), so these are instantiated dynamically when the | ||
Modules are constructed in [examples_convert.py]. | ||
|
||
* `ModuleSpec`: An example is represented by a ModuleSpec dataclass, which | ||
contains information for constructing and calling a module. I have designed | ||
this interface by listing for all the Flax examples what the required | ||
arguments are for calling `init` and `apply`, which is in the end all we need | ||
to be able to convert a model. I expect it should be quite easy to add new | ||
models now. | ||
|
||
* `ExampleSuite`: Examples are collected in suites, which are outputted in a | ||
single table in the Mardown file. This is simply a groups of examples with | ||
some metadata. | ||
|
||
### `[examples_converter.py]` | ||
|
||
Takes care of all the `arg` and `kwargs` plumbing to create Modules, and tries | ||
converting these Modules using a specified conversion function. | ||
|
||
* This library has two interface functions; `test_convert` and | ||
`write_markdown`, which are both called from `examples_test.py`. | ||
|
||
* The main logic of this library is in the function `make_module`, which | ||
converts a `ModuleSpec` into a `ModuleToConvert`, which is then the input to | ||
the conversion function. This function is called from `test_convert`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""All the examples to convert to TFLite or TFjs.""" | ||
import dataclasses | ||
import enum | ||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple | ||
|
||
from flax import linen as nn | ||
import jax.numpy as jnp | ||
|
||
|
||
class Arg(enum.Enum): | ||
"""This enum is used to automatically generate args dependent on a Module's internal state.""" | ||
RNG = '_RNG' | ||
ONES = '_ONES' | ||
VARS = '_VARS' | ||
INPUT = '_INPUTS' | ||
|
||
|
||
@dataclasses.dataclass | ||
class ModuleSpec: | ||
"""Specification of a Module.""" | ||
module_path: str | ||
input_shape: Tuple[int, ...] = () | ||
module_args: Sequence[Any] = () | ||
module_kwargs: Optional[Dict[str, Any]] = None | ||
init_args: Sequence[Any] = (Arg.RNG, Arg.ONES) | ||
init_kwargs: Optional[Dict[str, Any]] = None | ||
apply_args: Sequence[Any] = (Arg.VARS, Arg.INPUT) | ||
apply_kwargs: Optional[Dict[str, Any]] = None | ||
dtype: jnp.dtype = jnp.float32 | ||
rng_key: int = 0 | ||
apply_method_fn: str = '__call__' | ||
|
||
def __post_init__(self): | ||
self.module_kwargs = self.module_kwargs or {} | ||
self.init_kwargs = self.init_kwargs or {} | ||
self.apply_kwargs = self.apply_kwargs or {} | ||
|
||
|
||
@dataclasses.dataclass | ||
class ExampleSuite: | ||
"""A suite of examples.""" | ||
name: str | ||
description: str | ||
url: str | ||
examples: Dict[str, ModuleSpec] | ||
|
||
|
||
@dataclasses.dataclass | ||
class TransformerConfig: | ||
"""Transformer config.""" | ||
vocab_size: int = 8 | ||
output_vocab_size: int = 8 | ||
share_embeddings: bool = False | ||
logits_via_embedding: bool = False | ||
dtype: jnp.dtype = jnp.float32 | ||
emb_dim: int = 4 | ||
num_heads: int = 1 | ||
num_layers: int = 1 | ||
qkv_dim: int = 2 | ||
mlp_dim: int = 2 | ||
max_len: int = 2 | ||
dropout_rate: float = 0. | ||
attention_dropout_rate: float = 0. | ||
kernel_init: Callable[..., Any] = nn.initializers.xavier_uniform() | ||
bias_init: Callable[..., Any] = nn.initializers.normal(stddev=1e-6) | ||
posemb_init: Optional[Callable[..., Any]] = None | ||
deterministic: bool = True | ||
decode: bool = True | ||
|
||
|
||
def _flax_examples(): | ||
return { | ||
'imagenet': | ||
ModuleSpec( | ||
module_path='imagenet.models.ResNet50', | ||
input_shape=(1, 2, 2, 3), | ||
module_kwargs=dict(num_classes=2, dtype=jnp.float32), | ||
apply_kwargs=dict(train=False, mutable=False)), | ||
'lm1b': | ||
ModuleSpec( | ||
module_path='lm1b.models.TransformerLM', | ||
input_shape=(2, 1), | ||
module_kwargs=dict(config=TransformerConfig()), | ||
apply_kwargs=dict(rngs={'cache': Arg.RNG}, mutable=['cache'])), | ||
'mnist': | ||
ModuleSpec(module_path='mnist.train.CNN', input_shape=(1, 28, 28, 1)), | ||
'nlp_seq': | ||
ModuleSpec( | ||
module_path='nlp_seq.models.Transformer', | ||
input_shape=(2, 1), | ||
init_args=(Arg.RNG,), | ||
init_kwargs=dict(inputs=Arg.ONES, train=False), | ||
module_kwargs=dict(config=TransformerConfig()), | ||
apply_args=(Arg.VARS,), | ||
apply_kwargs=dict(inputs=Arg.INPUT, train=False)), | ||
'pixelcnn++': | ||
ModuleSpec( | ||
module_path='pixelcnn.pixelcnn.PixelCNNPP', | ||
input_shape=(1, 32, 32, 3), | ||
init_kwargs=dict(train=False), | ||
module_kwargs=dict( | ||
depth=1, features=2, logistic_components=2, dropout_p=0.), | ||
apply_kwargs=dict(train=False)), | ||
'ppo': | ||
ModuleSpec( | ||
module_path='ppo.models.ActorCritic', | ||
input_shape=(1, 8, 8, 4), | ||
module_kwargs=dict(num_outputs=2)), | ||
'seq2seq': | ||
ModuleSpec( | ||
module_path='seq2seq.train.Seq2seq', | ||
input_shape=(1, 2, 15), | ||
module_kwargs=dict(teacher_force=True, hidden_size=2), | ||
init_args=({ | ||
'params': Arg.RNG, | ||
'lstm': Arg.RNG | ||
}, Arg.ONES, Arg.ONES), | ||
apply_args=(Arg.VARS, Arg.INPUT, Arg.INPUT), | ||
apply_kwargs=dict(rngs={'lstm': Arg.RNG})), | ||
'sst2': | ||
ModuleSpec( | ||
module_path='sst2.models.TextClassifier', | ||
input_shape=(2, 3), | ||
module_kwargs=dict( | ||
embedding_size=3, | ||
hidden_size=1, | ||
vocab_size=5, | ||
output_size=3, | ||
dropout_rate=0., | ||
word_dropout_rate=0.), | ||
init_args=(Arg.RNG, Arg.ONES, jnp.array([2, 3], dtype=jnp.int32)), | ||
init_kwargs=dict(deterministic=True), | ||
apply_args=(Arg.VARS, Arg.INPUT, jnp.array([2, 3], | ||
dtype=jnp.int32)), | ||
apply_kwargs=dict(deterministic=True), | ||
dtype=jnp.int32), | ||
'vae': | ||
ModuleSpec( | ||
module_path='vae.train.VAE', | ||
input_shape=(1, 8, 8, 3), | ||
module_args=(3,), | ||
init_args=(Arg.RNG, Arg.ONES, Arg.RNG), | ||
apply_method_fn='generate'), | ||
'wmt': | ||
ModuleSpec( | ||
module_path='wmt.models.Transformer', | ||
input_shape=(2, 1), | ||
module_kwargs=dict(config=TransformerConfig()), | ||
init_args=(Arg.RNG, Arg.ONES, Arg.ONES), | ||
apply_args=(Arg.VARS, Arg.INPUT, Arg.INPUT), | ||
apply_kwargs=dict(mutable=['cache'])), | ||
} | ||
|
||
def get_suite(suite_name: str) -> Optional[ExampleSuite]: | ||
"""Returns all examples in `suite_name`.""" | ||
if suite_name == 'flax': | ||
return ExampleSuite( | ||
name='The Flax Examples', | ||
description="""List of examples maintained by the Flax team. | ||
These exampls are representative for what the average ML researcher is interested in.""", | ||
url='https://github.com/google/flax/tree/main/examples', | ||
examples=_flax_examples() | ||
) | ||
else: | ||
return None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
# Copyright 2020 Google LLC | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Converters for jax2tf.""" | ||
import functools | ||
import tempfile | ||
|
||
from jax.experimental import jax2tf | ||
from jax.experimental.jax2tf.examples import saved_model_lib | ||
from jax.experimental.jax2tf.examples_eval import examples_converter | ||
import tensorflow as tf | ||
from tensorflowjs.converters import converter as tfjs_converter | ||
|
||
TempDir = tempfile.TemporaryDirectory | ||
|
||
|
||
def jax2tf_to_tfjs(module: examples_converter.ModuleToConvert): | ||
"""Converts the given `module` using the TFjs converter.""" | ||
with TempDir() as saved_model_path, TempDir() as converted_model_path: | ||
# the model must be converted with with_gradient set to True to be able to | ||
# convert the saved model to TF.js, as "PreventGradient" is not supported | ||
saved_model_lib.convert_and_save_model( | ||
module.apply, | ||
module.variables, | ||
saved_model_path, | ||
input_signatures=[ | ||
tf.TensorSpec( | ||
shape=module.input_shape, | ||
dtype=module.dtype, | ||
name='input') | ||
], | ||
with_gradient=True, | ||
compile_model=False, | ||
enable_xla=False | ||
) | ||
tfjs_converter.convert([saved_model_path, converted_model_path]) | ||
|
||
|
||
def jax2tf_to_tflite(module: examples_converter.ModuleToConvert): | ||
"""Converts the given `module` using the TFLite converter.""" | ||
apply = functools.partial(module.apply, module.variables) | ||
tf_predict = tf.function( | ||
jax2tf.convert(apply, enable_xla=False), | ||
input_signature=[ | ||
tf.TensorSpec( | ||
shape=module.input_shape, | ||
dtype=module.dtype, | ||
name='input') | ||
], | ||
autograph=False) | ||
|
||
# Convert TF function to TF Lite format. | ||
converter = tf.lite.TFLiteConverter.from_concrete_functions( | ||
[tf_predict.get_concrete_function()]) | ||
converter.target_spec.supported_ops = [ | ||
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops. | ||
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops. | ||
] | ||
converter.convert() |
49 changes: 49 additions & 0 deletions
49
jax/experimental/jax2tf/examples_eval/converters_results.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Evaluation Results | ||
|
||
*Last generated on: 2021-10-04* (YYYY-MM-DD) | ||
|
||
## jax2tf --> TFLite | ||
|
||
### The Flax Examples | ||
[URL to examples](https://github.com/google/flax/tree/main/examples) | ||
|
||
Description: List of examples maintained by the Flax team. | ||
These exampls are representative for what the average ML researcher is interested in. | ||
|
||
| Example | Result | Error Message | | ||
| --- | --- | --- | | ||
| imagenet | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.') | ||
| lm1b | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64") | ||
| mnist | SUCCESS | | ||
| nlp_seq | FAIL | ConverterError('/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: error: \'tf.Expm1\' op is neither a custom op nor a flex op\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:3798:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:819:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:836:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:277:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/lax/lax.py:192:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/numpy/lax_numpy.py:661:0: note: called from\n/Users/marcvanzee/github/jax/jax/linear_util.py:166:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:879:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:1645:0: note: called from\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: note: Error code: ERROR_NEEDS_CUSTOM_OPS\n<unknown>:0: error: failed while converting: \'main\': \nSome ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom \nCustom ops: Expm1\nDetails:\n\ttf.Expm1(tensor<2x1x2xf32>) -> (tensor<2x1x2xf32>) : {device = ""}\n\n') | ||
| pixelcnn++ | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Input padding not supported in TensorFlow. - See source code for the precise conditions under which convolutions can be converted without XLA.') | ||
| ppo | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.') | ||
| seq2seq | SUCCESS | | ||
| sst2 | FAIL | NotImplementedError("Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0, 1, 2))'; op_shape=(2, 6, 3).") | ||
| vae | FAIL | ModuleNotFoundError("No module named 'utils'") | ||
| wmt | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64") | ||
|
||
## jax2tf --> TFjs | ||
|
||
### The Flax Examples | ||
[URL to examples](https://github.com/google/flax/tree/main/examples) | ||
|
||
Description: List of examples maintained by the Flax team. | ||
These exampls are representative for what the average ML researcher is interested in. | ||
|
||
| Example | Result | Error Message | | ||
| --- | --- | --- | | ||
| imagenet | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.') | ||
| lm1b | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64") | ||
| mnist | SUCCESS | | ||
| nlp_seq | FAIL | ValueError("Error when tracing gradients for SavedModel.\n\nSee the stack trace above to see the error that was raised when converting a gradient function to a concrete function. You may need to update the custom gradient, or disable saving gradients with the option tf.saved_model.SaveOptions(custom_gradients=False).\n\tProblematic op name: IdentityN\n\tGradient inputs: (<tf.Tensor 'AddV2_12:0' shape=(2, 1, 8) dtype=float32>, <tf.Tensor 'jax2tf_arg_0:0' shape=(8,) dtype=float32>, <tf.Tensor 'jax2tf_arg_1:0' shape=(4, 8) dtype=float32>, <tf.Tensor 'jax2tf_arg_2:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_3:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_4:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_5:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_6:0' shape=(2,) dtype=float32>, <tf.Tensor 'jax2tf_arg_7:0' shape=(4, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_8:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_9:0' shape=(2, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_10:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_11:0' shape=(1, 2, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_12:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_13:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_14:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_15:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_16:0' shape=(8, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_17:0' shape=(2, 1) dtype=float32>)") | ||
| pixelcnn++ | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Input padding not supported in TensorFlow. - See source code for the precise conditions under which convolutions can be converted without XLA.') | ||
| ppo | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.') | ||
| seq2seq | FAIL | ValueError('Unsupported Ops in the model before optimization\nBitcast, BitwiseAnd, BitwiseOr, RightShift, LeftShift, BitwiseXor') | ||
| sst2 | FAIL | NotImplementedError("Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0, 1, 2))'; op_shape=(2, 6, 3).") | ||
| vae | SUCCESS | | ||
| wmt | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64") | ||
|
||
## Table generation | ||
|
||
See `examples_test.py` for instructions on how to regenerate this table. |
15 changes: 15 additions & 0 deletions
15
jax/experimental/jax2tf/examples_eval/converters_results.md.template
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Evaluation Results | ||
|
||
*Last generated on: {{generation_date}}* (YYYY-MM-DD) | ||
|
||
## jax2tf --> TFLite | ||
|
||
{{jax2tf_to_tflite}} | ||
|
||
## jax2tf --> TFjs | ||
|
||
{{jax2tf_to_tfjs}} | ||
|
||
## Table generation | ||
|
||
See `examples_test.py` for instructions on how to regenerate this table. |
Oops, something went wrong.