Skip to content

Commit

Permalink
[JAX] Remove references to omnistaging.
Browse files Browse the repository at this point in the history
Omnistaging has been the default and only option for a long time.

PiperOrigin-RevId: 455703225
  • Loading branch information
hawkinsp authored and tensorflower-gardener committed Jun 17, 2022
1 parent a7b35af commit cb2a94c
Showing 1 changed file with 0 additions and 3 deletions.
3 changes: 0 additions & 3 deletions spinoffs/oryx/oryx/core/trace_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import threading
from typing import Any, Dict, Generator, List

import jax
from jax import abstract_arrays
from jax import api_util
from jax import core as jax_core
Expand Down Expand Up @@ -63,8 +62,6 @@ def wrapped(*args, **kwargs):
flat_args, in_tree = tree_util.tree_flatten(args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
flat_avals = safe_map(get_shaped_aval, flat_args)
if not jax.config.omnistaging_enabled:
raise ValueError('Oryx must be used with JAX omnistaging enabled.')
if dynamic:
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(
flat_fun,
Expand Down

0 comments on commit cb2a94c

Please sign in to comment.