Skip to content

Commit

Permalink
Implements a modular and easily extensible evaluation framework for b…
Browse files Browse the repository at this point in the history
…oth TFLite and TFjs. The evaluation framework has the following features:

*  It is easy to add new Modules of examples since each Module is specified using a few lines of code (see `examples.py`).

* It is easy to add new converters since each converter is represented as a function (see `converters.py`). For instance, we could add the MLIR-based converter that the TFLite team is currently working on.

* The framework outputs a Markdown table (see `README.md`).

The framework has the following limitations:

* We only evaluate whether a Module converts, we do not compare any outputs between the converted model and the original model. This will require more effort, and it seems like we can do this as a follow-up if necessary (once a good fraction of ops are converted).

* If an example is missing multiple ops, then only the first missing op is reported. We could improve this by implementing mocked versions of non-working ops, which only output the right shapes. We could also consider doing this as a follow-up.

PiperOrigin-RevId: 402287865
  • Loading branch information
marcvanzee authored and jax authors committed Oct 11, 2021
1 parent a47119d commit 161363d
Show file tree
Hide file tree
Showing 7 changed files with 693 additions and 0 deletions.
74 changes: 74 additions & 0 deletions jax/experimental/jax2tf/examples_eval/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Converting JAX examples to TFLite/TFjs

## Overview

This directory implements a flexible evaluation framework for converting JAX
examples to TFjs/TFLIte using jax2tf, making it relatively easy to add new
examples and converters, and write results of all tests to Markdown.

The results of the conversion are written to `converters_results.md`.

See `examples_test.py` for instructions on how to run the evaluation.

### Features

* It is easy to add new Modules of examples since each Module is specified
using a few lines of code (see `examples.py`).

* It is easy to add new converters since each converter is represented as a
function (see `converters.py`).

* The framework outputs a Markdown table (outputted below), which provides
an overview of the missing ops for all examples and all converters.

### Limitations

* We only evaluate whether a Module converts, we do not compare any outputs
between the converted model and the original model.

* If an example is missing multiple ops, then only the first missing op is
reported.

## Code Details

### `[examples_test.py]`

This is the binary to run to execute tests. It has flags for various options.

### `[converters.py]`

This contains the functions representing different converters.

### `[all_examples.py]`

A list of all the examples to test. As one can see each example only takes a few
lines so should be quite easy to add new ones.

The file also contains several data structures:

* `Arg`: An enum used in arguments in ModuleSpec, which depend on particular
state (rng, module state), so these are instantiated dynamically when the
Modules are constructed in [examples_convert.py].

* `ModuleSpec`: An example is represented by a ModuleSpec dataclass, which
contains information for constructing and calling a module. I have designed
this interface by listing for all the Flax examples what the required
arguments are for calling `init` and `apply`, which is in the end all we need
to be able to convert a model. I expect it should be quite easy to add new
models now.

* `ExampleSuite`: Examples are collected in suites, which are outputted in a
single table in the Mardown file. This is simply a groups of examples with
some metadata.

### `[examples_converter.py]`

Takes care of all the `arg` and `kwargs` plumbing to create Modules, and tries
converting these Modules using a specified conversion function.

* This library has two interface functions; `test_convert` and
`write_markdown`, which are both called from `examples_test.py`.

* The main logic of this library is in the function `make_module`, which
converts a `ModuleSpec` into a `ModuleToConvert`, which is then the input to
the conversion function. This function is called from `test_convert`.
178 changes: 178 additions & 0 deletions jax/experimental/jax2tf/examples_eval/all_examples.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""All the examples to convert to TFLite or TFjs."""
import dataclasses
import enum
from typing import Any, Callable, Dict, Optional, Sequence, Tuple

from flax import linen as nn
import jax.numpy as jnp


class Arg(enum.Enum):
"""This enum is used to automatically generate args dependent on a Module's internal state."""
RNG = '_RNG'
ONES = '_ONES'
VARS = '_VARS'
INPUT = '_INPUTS'


@dataclasses.dataclass
class ModuleSpec:
"""Specification of a Module."""
module_path: str
input_shape: Tuple[int, ...] = ()
module_args: Sequence[Any] = ()
module_kwargs: Optional[Dict[str, Any]] = None
init_args: Sequence[Any] = (Arg.RNG, Arg.ONES)
init_kwargs: Optional[Dict[str, Any]] = None
apply_args: Sequence[Any] = (Arg.VARS, Arg.INPUT)
apply_kwargs: Optional[Dict[str, Any]] = None
dtype: jnp.dtype = jnp.float32
rng_key: int = 0
apply_method_fn: str = '__call__'

def __post_init__(self):
self.module_kwargs = self.module_kwargs or {}
self.init_kwargs = self.init_kwargs or {}
self.apply_kwargs = self.apply_kwargs or {}


@dataclasses.dataclass
class ExampleSuite:
"""A suite of examples."""
name: str
description: str
url: str
examples: Dict[str, ModuleSpec]


@dataclasses.dataclass
class TransformerConfig:
"""Transformer config."""
vocab_size: int = 8
output_vocab_size: int = 8
share_embeddings: bool = False
logits_via_embedding: bool = False
dtype: jnp.dtype = jnp.float32
emb_dim: int = 4
num_heads: int = 1
num_layers: int = 1
qkv_dim: int = 2
mlp_dim: int = 2
max_len: int = 2
dropout_rate: float = 0.
attention_dropout_rate: float = 0.
kernel_init: Callable[..., Any] = nn.initializers.xavier_uniform()
bias_init: Callable[..., Any] = nn.initializers.normal(stddev=1e-6)
posemb_init: Optional[Callable[..., Any]] = None
deterministic: bool = True
decode: bool = True


def _flax_examples():
return {
'imagenet':
ModuleSpec(
module_path='imagenet.models.ResNet50',
input_shape=(1, 2, 2, 3),
module_kwargs=dict(num_classes=2, dtype=jnp.float32),
apply_kwargs=dict(train=False, mutable=False)),
'lm1b':
ModuleSpec(
module_path='lm1b.models.TransformerLM',
input_shape=(2, 1),
module_kwargs=dict(config=TransformerConfig()),
apply_kwargs=dict(rngs={'cache': Arg.RNG}, mutable=['cache'])),
'mnist':
ModuleSpec(module_path='mnist.train.CNN', input_shape=(1, 28, 28, 1)),
'nlp_seq':
ModuleSpec(
module_path='nlp_seq.models.Transformer',
input_shape=(2, 1),
init_args=(Arg.RNG,),
init_kwargs=dict(inputs=Arg.ONES, train=False),
module_kwargs=dict(config=TransformerConfig()),
apply_args=(Arg.VARS,),
apply_kwargs=dict(inputs=Arg.INPUT, train=False)),
'pixelcnn++':
ModuleSpec(
module_path='pixelcnn.pixelcnn.PixelCNNPP',
input_shape=(1, 32, 32, 3),
init_kwargs=dict(train=False),
module_kwargs=dict(
depth=1, features=2, logistic_components=2, dropout_p=0.),
apply_kwargs=dict(train=False)),
'ppo':
ModuleSpec(
module_path='ppo.models.ActorCritic',
input_shape=(1, 8, 8, 4),
module_kwargs=dict(num_outputs=2)),
'seq2seq':
ModuleSpec(
module_path='seq2seq.train.Seq2seq',
input_shape=(1, 2, 15),
module_kwargs=dict(teacher_force=True, hidden_size=2),
init_args=({
'params': Arg.RNG,
'lstm': Arg.RNG
}, Arg.ONES, Arg.ONES),
apply_args=(Arg.VARS, Arg.INPUT, Arg.INPUT),
apply_kwargs=dict(rngs={'lstm': Arg.RNG})),
'sst2':
ModuleSpec(
module_path='sst2.models.TextClassifier',
input_shape=(2, 3),
module_kwargs=dict(
embedding_size=3,
hidden_size=1,
vocab_size=5,
output_size=3,
dropout_rate=0.,
word_dropout_rate=0.),
init_args=(Arg.RNG, Arg.ONES, jnp.array([2, 3], dtype=jnp.int32)),
init_kwargs=dict(deterministic=True),
apply_args=(Arg.VARS, Arg.INPUT, jnp.array([2, 3],
dtype=jnp.int32)),
apply_kwargs=dict(deterministic=True),
dtype=jnp.int32),
'vae':
ModuleSpec(
module_path='vae.train.VAE',
input_shape=(1, 8, 8, 3),
module_args=(3,),
init_args=(Arg.RNG, Arg.ONES, Arg.RNG),
apply_method_fn='generate'),
'wmt':
ModuleSpec(
module_path='wmt.models.Transformer',
input_shape=(2, 1),
module_kwargs=dict(config=TransformerConfig()),
init_args=(Arg.RNG, Arg.ONES, Arg.ONES),
apply_args=(Arg.VARS, Arg.INPUT, Arg.INPUT),
apply_kwargs=dict(mutable=['cache'])),
}

def get_suite(suite_name: str) -> Optional[ExampleSuite]:
"""Returns all examples in `suite_name`."""
if suite_name == 'flax':
return ExampleSuite(
name='The Flax Examples',
description="""List of examples maintained by the Flax team.
These exampls are representative for what the average ML researcher is interested in.""",
url='https://github.com/google/flax/tree/main/examples',
examples=_flax_examples()
)
else:
return None
69 changes: 69 additions & 0 deletions jax/experimental/jax2tf/examples_eval/converters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Converters for jax2tf."""
import functools
import tempfile

from jax.experimental import jax2tf
from jax.experimental.jax2tf.examples import saved_model_lib
from jax.experimental.jax2tf.examples_eval import examples_converter
import tensorflow as tf
from tensorflowjs.converters import converter as tfjs_converter

TempDir = tempfile.TemporaryDirectory


def jax2tf_to_tfjs(module: examples_converter.ModuleToConvert):
"""Converts the given `module` using the TFjs converter."""
with TempDir() as saved_model_path, TempDir() as converted_model_path:
# the model must be converted with with_gradient set to True to be able to
# convert the saved model to TF.js, as "PreventGradient" is not supported
saved_model_lib.convert_and_save_model(
module.apply,
module.variables,
saved_model_path,
input_signatures=[
tf.TensorSpec(
shape=module.input_shape,
dtype=module.dtype,
name='input')
],
with_gradient=True,
compile_model=False,
enable_xla=False
)
tfjs_converter.convert([saved_model_path, converted_model_path])


def jax2tf_to_tflite(module: examples_converter.ModuleToConvert):
"""Converts the given `module` using the TFLite converter."""
apply = functools.partial(module.apply, module.variables)
tf_predict = tf.function(
jax2tf.convert(apply, enable_xla=False),
input_signature=[
tf.TensorSpec(
shape=module.input_shape,
dtype=module.dtype,
name='input')
],
autograph=False)

# Convert TF function to TF Lite format.
converter = tf.lite.TFLiteConverter.from_concrete_functions(
[tf_predict.get_concrete_function()])
converter.target_spec.supported_ops = [
tf.lite.OpsSet.TFLITE_BUILTINS, # enable TensorFlow Lite ops.
tf.lite.OpsSet.SELECT_TF_OPS # enable TensorFlow ops.
]
converter.convert()
49 changes: 49 additions & 0 deletions jax/experimental/jax2tf/examples_eval/converters_results.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Evaluation Results

*Last generated on: 2021-10-04* (YYYY-MM-DD)

## jax2tf --> TFLite

### The Flax Examples
[URL to examples](https://github.com/google/flax/tree/main/examples)

Description: List of examples maintained by the Flax team.
These exampls are representative for what the average ML researcher is interested in.

| Example | Result | Error Message |
| --- | --- | --- |
| imagenet | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
| lm1b | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
| mnist | SUCCESS |
| nlp_seq | FAIL | ConverterError('/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: error: \'tf.Expm1\' op is neither a custom op nor a flex op\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:3798:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:819:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:836:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:277:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/lax/lax.py:192:0: note: called from\n/Users/marcvanzee/github/jax/jax/_src/numpy/lax_numpy.py:661:0: note: called from\n/Users/marcvanzee/github/jax/jax/linear_util.py:166:0: note: called from\n/Users/marcvanzee/github/jax/jax/experimental/jax2tf/jax2tf.py:879:0: note: called from\n/Users/marcvanzee/github/jax/jax/core.py:1645:0: note: called from\n/Users/marcvanzee/.pyenv/versions/3.7.10/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750:0: note: Error code: ERROR_NEEDS_CUSTOM_OPS\n<unknown>:0: error: failed while converting: \'main\': \nSome ops in the model are custom ops, See instructions to implement custom ops: https://www.tensorflow.org/lite/guide/ops_custom \nCustom ops: Expm1\nDetails:\n\ttf.Expm1(tensor<2x1x2xf32>) -> (tensor<2x1x2xf32>) : {device = ""}\n\n')
| pixelcnn++ | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Input padding not supported in TensorFlow. - See source code for the precise conditions under which convolutions can be converted without XLA.')
| ppo | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
| seq2seq | SUCCESS |
| sst2 | FAIL | NotImplementedError("Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0, 1, 2))'; op_shape=(2, 6, 3).")
| vae | FAIL | ModuleNotFoundError("No module named 'utils'")
| wmt | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")

## jax2tf --> TFjs

### The Flax Examples
[URL to examples](https://github.com/google/flax/tree/main/examples)

Description: List of examples maintained by the Flax team.
These exampls are representative for what the average ML researcher is interested in.

| Example | Result | Error Message |
| --- | --- | --- |
| imagenet | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
| lm1b | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")
| mnist | SUCCESS |
| nlp_seq | FAIL | ValueError("Error when tracing gradients for SavedModel.\n\nSee the stack trace above to see the error that was raised when converting a gradient function to a concrete function. You may need to update the custom gradient, or disable saving gradients with the option tf.saved_model.SaveOptions(custom_gradients=False).\n\tProblematic op name: IdentityN\n\tGradient inputs: (<tf.Tensor 'AddV2_12:0' shape=(2, 1, 8) dtype=float32>, <tf.Tensor 'jax2tf_arg_0:0' shape=(8,) dtype=float32>, <tf.Tensor 'jax2tf_arg_1:0' shape=(4, 8) dtype=float32>, <tf.Tensor 'jax2tf_arg_2:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_3:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_4:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_5:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_6:0' shape=(2,) dtype=float32>, <tf.Tensor 'jax2tf_arg_7:0' shape=(4, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_8:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_9:0' shape=(2, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_10:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_11:0' shape=(1, 2, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_12:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_13:0' shape=(4, 1, 2) dtype=float32>, <tf.Tensor 'jax2tf_arg_14:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_15:0' shape=(4,) dtype=float32>, <tf.Tensor 'jax2tf_arg_16:0' shape=(8, 4) dtype=float32>, <tf.Tensor 'jax2tf_arg_17:0' shape=(2, 1) dtype=float32>)")
| pixelcnn++ | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Input padding not supported in TensorFlow. - See source code for the precise conditions under which convolutions can be converted without XLA.')
| ppo | FAIL | NotImplementedError('Call to conv_general_dilated cannot be converted with enable_xla=False. Unimplemented support for window_strides != (1, 1) - See source code for the precise conditions under which convolutions can be converted without XLA.')
| seq2seq | FAIL | ValueError('Unsupported Ops in the model before optimization\nBitcast, BitwiseAnd, BitwiseOr, RightShift, LeftShift, BitwiseXor')
| sst2 | FAIL | NotImplementedError("Call to gather cannot be converted with enable_xla=False. unsupported dimension_numbers 'GatherDimensionNumbers(offset_dims=(1, 2), collapsed_slice_dims=(0,), start_index_map=(0, 1, 2))'; op_shape=(2, 6, 3).")
| vae | SUCCESS |
| wmt | FAIL | TypeError("Value passed to parameter 'begin' has DataType uint32 not in list of allowed values: int32, int64")

## Table generation

See `examples_test.py` for instructions on how to regenerate this table.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Evaluation Results

*Last generated on: {{generation_date}}* (YYYY-MM-DD)

## jax2tf --> TFLite

{{jax2tf_to_tflite}}

## jax2tf --> TFjs

{{jax2tf_to_tfjs}}

## Table generation

See `examples_test.py` for instructions on how to regenerate this table.
Loading

0 comments on commit 161363d

Please sign in to comment.