Skip to content

Commit

Permalink
Adds NotImplemented errors to obm_module and adds related tests to ja…
Browse files Browse the repository at this point in the history
…x_module_test.

PiperOrigin-RevId: 678800824
  • Loading branch information
Orbax Authors committed Sep 25, 2024
1 parent cd4b736 commit c62f130
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
70 changes: 70 additions & 0 deletions export/orbax/export/jax_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,76 @@


class JaxModuleTest(tf.test.TestCase, parameterized.TestCase):
def test_jax_module_orbax_model_unimplemented_methods(self):
def linear(params, x):
return params['w'] @ x + params['b']

key_w, key_b = jax.random.split(jax.random.PRNGKey(1234), 2)
params = {
'w': jax.random.normal(key_w, shape=(8, 8)),
'b': jax.random.normal(key_b, shape=(8, 1)),
}
j_module = JaxModule(
params=params,
apply_fn={'linear': linear},
export_version=constants.ExportModelType.ORBAX_MODEL,
)

self.assertEqual(
j_module._export_version,
constants.ExportModelType.ORBAX_MODEL,
)

# None of the obm_module methods are implemnted. When the export version is
# ORBAX_MODEL, the obm_module methods should raise a NotImplementedError.
with self.assertRaises(NotImplementedError):
j_module.apply_fn_map # pylint: disable=pointless-statement

with self.assertRaises(NotImplementedError):
j_module.model_params # pylint: disable=pointless-statement

with self.assertRaises(NotImplementedError):
j_module.methods # pylint: disable=pointless-statement

with self.assertRaises(TypeError):
j_module.with_gradient # pylint: disable=pointless-statement

with self.assertRaises(TypeError):
j_module.obm_module_to_jax_exported_map({'x': [1, 2, 3]})

# Several functions are not supported for ORBAX_MODEL export.
# This test ensures that the JaxModule raises a TypeError when these
# functions are called.
def test_jax_module_orbax_model_unsupported_methods(self):
def linear(params, x):
return params['w'] @ x + params['b']

key_w, key_b = jax.random.split(jax.random.PRNGKey(1234), 2)
params = {
'w': jax.random.normal(key_w, shape=(8, 8)),
'b': jax.random.normal(key_b, shape=(8, 1)),
}
j_module = JaxModule(
params=params,
apply_fn={'linear': linear},
export_version=constants.ExportModelType.ORBAX_MODEL,
)

self.assertEqual(
j_module._export_version,
constants.ExportModelType.ORBAX_MODEL,
)

with self.assertRaises(TypeError):
j_module.update_variables(
{'w': jax.random.normal(key_w, shape=(8, 8), dtype=jnp.float32)}
)

with self.assertRaises(TypeError):
j_module.jax2tf_kwargs_map # pylint: disable=pointless-statement

with self.assertRaises(TypeError):
j_module.input_polymorphic_shape_map # pylint: disable=pointless-statement

def test_jax_module_default_export_version(self):
j_module = JaxModule(
Expand Down
9 changes: 4 additions & 5 deletions export/orbax/export/modules/obm_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,20 +35,19 @@ def __init__(
@property
def apply_fn_map(self) -> Mapping[str, ApplyFn]:
"""Returns the apply_fn_map."""
return {}
raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')

@property
def model_params(self) -> PyTree:
"""Returns the model parameters."""
return {}
raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')

@property
def methods(self) -> Mapping[str, Callable[..., Any]]:
"""Named methods in the context of the chosen export pathway."""
return {}
raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')

@property
def jax_methods(self) -> Mapping[str, Callable[..., Any]]:
"""Named methods in JAX context for validation."""
return {}

raise NotImplementedError('apply_fn_map is not implemented for ObmModule.')

0 comments on commit c62f130

Please sign in to comment.