Skip to content

Commit

Permalink
Refactorings to the jit implementation.
Browse files Browse the repository at this point in the history
Notably:
* We can share more code between jit/pjit. There's no significant difference between the two, other than the handling of the resource environment, so we can share more of the code.
* Rather than having an infer_params callback, we can just teach common_infer_params (now named _infer_params) to handle the resource environment, which is the only meaningful difference. common_infer_params already had to understand the two cases, so there's no reason we need to hoist part of that logic into a callback.
* If we slightly alter the role of PjitInfo so it contains only the things we know about a jit() or can deduce from its arguments, we can construct it ahead of time. This does require that we split out a couple of things that we cannot deduce at that time, namely the resource environment and the two layout parameters into separate arguments, but the result reads more cleanly to me.

No functional changes intended, this is just to improve readability.

PiperOrigin-RevId: 617812557
  • Loading branch information
hawkinsp authored and jax authors committed Mar 21, 2024
1 parent 2bd579b commit d3e03ff
Show file tree
Hide file tree
Showing 2 changed files with 124 additions and 126 deletions.
30 changes: 3 additions & 27 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,34 +300,10 @@ def jit(
>>> g(jnp.arange(4), 3)
Array([ 0, 1, 256, 6561], dtype=int32)
"""
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
static_argnames) = pjit.pre_infer_params(
return pjit.make_jit(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes)

fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)

def infer_params(*args, **kwargs):
# TODO(yashkatariya): Remove this when it's added on jit.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
pjit_info_args = pjit.PjitInfo(
fun=fun, fun_sourceinfo=fun_sourceinfo, fun_signature=fun_signature,
in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline, resource_env=None,
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
out_layouts=out_layouts)
return pjit.common_infer_params(pjit_info_args, *args, **kwargs)

has_explicit_sharding = pjit._pjit_explicit_sharding(
in_shardings, out_shardings, device, backend)
return pjit.post_infer_params(fun, infer_params, static_argnums,
static_argnames, donate_argnums,
abstracted_axes, has_explicit_sharding)
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env=False)


@contextmanager
Expand Down
220 changes: 121 additions & 99 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,33 @@ def _device_assignment_mismatch_error(fun_name, fails, args_flat, api_name,
return msg


def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
class PjitInfo(NamedTuple):
"""Things that we know about a jit instance before it is called.
In other words, this structure contains arguments to jit()/pjit(),
preprocessed and validated.
"""
fun: Callable
fun_sourceinfo: str | None
fun_signature: inspect.Signature | None
in_shardings: Any
out_shardings: Any
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
donate_argnames: tuple[str, ...]
device: xc.Device | None
backend: str | None
keep_unused: bool
inline: bool
abstracted_axes: Any | None
has_explicit_sharding: bool
use_resource_env: bool # False for jit, True for pjit


def _python_pjit_helper(jit_info, *args, **kwargs):
args_flat, _, params, _, out_tree, _, _, _, arg_names, attrs_tracked = \
infer_params_fn(*args, **kwargs)
_infer_params(jit_info, args, kwargs)
for arg in args_flat:
dispatch.check_arg(arg)
if attrs_tracked:
Expand All @@ -145,6 +169,7 @@ def _python_pjit_helper(fun, infer_params_fn, *args, **kwargs):
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
fun = jit_info.fun
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
Expand All @@ -165,14 +190,16 @@ def _get_states(attrs_tracked):
return [jax_getattr(obj, attr) for (obj, attr) in attrs_tracked]


def _python_pjit(fun: Callable, infer_params_fn):
def _python_pjit(jit_info: PjitInfo):

fun = jit_info.fun

@wraps(fun)
@api_boundary
def wrapped(*args, **kwargs):
if config.disable_jit.value:
return fun(*args, **kwargs)
return _python_pjit_helper(fun, infer_params_fn, *args, **kwargs)[0]
return _python_pjit_helper(jit_info, *args, **kwargs)[0]

def _python_pjit_evict_fn():
_create_pjit_jaxpr.evict_function(fun) # type: ignore
Expand Down Expand Up @@ -254,42 +281,61 @@ def _get_cpp_global_cache(pjit_has_explicit_sharding):
return _cpp_pjit_cache


def _cpp_pjit(fun: Callable, infer_params_fn, static_argnums, static_argnames,
donate_argnums, pjit_has_explicit_sharding):
def _cpp_pjit(jit_info: PjitInfo):

@api_boundary
def cache_miss(*args, **kwargs):
outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
fun, infer_params_fn, *args, **kwargs)
jit_info, *args, **kwargs)
executable = _read_most_recent_pjit_call_executable(jaxpr)
maybe_fastpath_data = _get_fastpath_data(
executable, out_tree, args_flat, out_flat, attrs_tracked, jaxpr.effects)
return outs, maybe_fastpath_data

fun = jit_info.fun
if xla_extension_version >= 226:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry,
pxla.shard_arg if xla_extension_version >= 229 else pxla.temp_shard_arg, # type: ignore
_get_cpp_global_cache(pjit_has_explicit_sharding)) # type: ignore
_get_cpp_global_cache(jit_info.has_explicit_sharding)) # type: ignore
else:
cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"),
fun, cache_miss, static_argnums, static_argnames,
donate_argnums, tree_util.dispatch_registry,
_get_cpp_global_cache(pjit_has_explicit_sharding))
fun, cache_miss, jit_info.static_argnums, jit_info.static_argnames,
jit_info.donate_argnums, tree_util.dispatch_registry,
_get_cpp_global_cache(jit_info.has_explicit_sharding))

cpp_pjitted_f = wraps(fun)(cpp_pjit_f)
cpp_pjitted_f._fun = fun
type(cpp_pjitted_f).clear_cache = _cpp_pjit_evict_fn
return cpp_pjitted_f


def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, donate_argnames,
static_argnums, static_argnames, device,
backend, abstracted_axes):
def _pjit_explicit_sharding(in_shardings, out_shardings, device,
backend) -> bool:
in_shardings_flat, _ = tree_flatten(in_shardings)
out_shardings_flat, _ = tree_flatten(out_shardings)
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(i) for i in out_shardings_flat))


def _parse_jit_arguments(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
static_argnums: int | Sequence[int] | None,
static_argnames: str | Iterable[str] | None,
device: xc.Device | None, backend: str | None,
abstracted_axes: Any | None, keep_unused: bool,
inline: bool, use_resource_env: bool) -> PjitInfo:
"""Parses the arguments to jit/pjit.
Performs any preprocessing and validation of the arguments that we can do
ahead of time before the jit()-ed function is invoked.
"""
if abstracted_axes and not config.dynamic_shapes.value:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")

Expand Down Expand Up @@ -326,18 +372,31 @@ def pre_infer_params(fun, in_shardings, out_shardings,
donate_argnums, donate_argnames, static_argnums, static_argnames = resolve_argnums(
fun, donate_argnums, donate_argnames, static_argnums, static_argnames)

return (in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames)
fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)

has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, device, backend)

def post_infer_params(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums, abstracted_axes,
pjit_has_explicit_sharding):
if abstracted_axes is None:
wrapped = _cpp_pjit(fun, infer_params_fn, static_argnums, static_argnames,
donate_argnums, pjit_has_explicit_sharding)
return PjitInfo(
fun=fun,
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline,
abstracted_axes=abstracted_axes,
has_explicit_sharding=has_explicit_sharding,
use_resource_env=use_resource_env)


def _make_jit_wrapper(jit_info: PjitInfo):
if jit_info.abstracted_axes is None:
wrapped = _cpp_pjit(jit_info)
else:
wrapped = _python_pjit(fun, infer_params_fn)
wrapped = _python_pjit(jit_info)

@api_boundary
def lower(*args, **kwargs):
Expand All @@ -348,8 +407,8 @@ def lower(*args, **kwargs):
out_layouts = kwargs.pop('_out_layouts', None)
(args_flat, flat_global_in_avals, params, in_tree, out_tree,
donated_invars, in_layouts_flat, out_layouts_flat,
arg_names, ()) = infer_params_fn(
*args, **kwargs, _in_layouts=in_layouts, _out_layouts=out_layouts)
arg_names, ()) = _infer_params(
jit_info, args, kwargs, in_layouts=in_layouts, out_layouts=out_layouts)
resource_env = params['resource_env']
mesh = None if resource_env is None else resource_env.physical_mesh
try:
Expand All @@ -363,7 +422,9 @@ def lower(*args, **kwargs):
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
api_name = 'jit' if params['resource_env'] is None else 'pjit'
fun_name = getattr(fun, '__qualname__', getattr(fun, '__name__', str(fun)))
fun = jit_info.fun
fun_name = getattr(fun, '__qualname__',
getattr(fun, '__name__', str(fun)))
msg = _device_assignment_mismatch_error(
fun_name, fails, args_flat, api_name, arg_names)
raise ValueError(msg) from None
Expand All @@ -375,8 +436,9 @@ def lower(*args, **kwargs):

@api_boundary
def eval_shape(*args, **kwargs):
_, _, params, _, out_tree, _, _, _, _, _ = infer_params_fn(
*args, **kwargs, _in_layouts=None, _out_layouts=None)
_, _, params, _, out_tree, _, _, _, _, _ = _infer_params(
jit_info, args, kwargs, in_layouts=None, out_layouts=None
)
out_s = [None if is_unspecified(s) else getattr(s, '_original_sharding', s)
for s in params['out_shardings']]
out = [api.ShapeDtypeStruct(x.shape, x.dtype, x.named_shape, sharding=s)
Expand All @@ -387,52 +449,43 @@ def eval_shape(*args, **kwargs):
wrapped.eval_shape = eval_shape
return wrapped


def _pjit_explicit_sharding(in_shardings, out_shardings, device,
backend) -> bool:
in_shardings_flat, _ = tree_flatten(in_shardings)
out_shardings_flat, _ = tree_flatten(out_shardings)
return (device is not None or
backend is not None or
any(not is_unspecified(i) for i in in_shardings_flat) or
any(not is_unspecified(i) for i in out_shardings_flat))


class PjitInfo(NamedTuple):
fun: Callable
fun_sourceinfo: str | None
fun_signature: inspect.Signature
in_shardings: Any
out_shardings: Any
static_argnums: tuple[int, ...]
static_argnames: tuple[str, ...]
donate_argnums: tuple[int, ...]
donate_argnames: tuple[str, ...]
device: xc.Device | None
backend: str | None
keep_unused: bool
inline: bool
resource_env: Any
abstracted_axes: Any | None
in_layouts: Any # pytree[XlaCompatibleLayout] | None
out_layouts: Any # pytree[XlaCompatibleLayout] | None


def common_infer_params(pjit_info_args, *args, **kwargs):
def make_jit(fun: Callable, in_shardings: Any, out_shardings: Any,
donate_argnums: int | Sequence[int] | None,
donate_argnames: str | Iterable[str] | None,
static_argnums: int | Sequence[int] | None,
static_argnames: str | Iterable[str] | None,
device: xc.Device | None, backend: str | None,
abstracted_axes: Any | None, keep_unused: bool,
inline: bool, use_resource_env: bool) -> Any:
"""jit() and pjit() are thin wrappers around this function."""
jit_info = _parse_jit_arguments(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env)
return _make_jit_wrapper(jit_info)


def _infer_params(jit_info, args, kwargs, *, in_layouts=None, out_layouts=None):
(fun, fun_sourceinfo, fun_signature, user_in_shardings, user_out_shardings,
static_argnums, static_argnames,
donate_argnums, donate_argnames, device, backend, keep_unused, inline,
resource_env, abstracted_axes, in_layouts, out_layouts) = pjit_info_args
abstracted_axes, _, use_resource_env) = jit_info

if (kwargs and user_in_shardings is not None and
not is_unspecified(user_in_shardings)):
raise ValueError(
"pjit does not support kwargs when in_shardings is specified.")

if resource_env is not None:
if use_resource_env:
# We need to fetch the mesh from inside the wrapped function, because
# meshes are dynamically scoped (i.e., with a context manager).
resource_env = mesh_lib.thread_resources.env
pjit_mesh = resource_env.physical_mesh
jit_name = 'pjit'
else:
resource_env = None
pjit_mesh = None
jit_name = 'jit'

if (backend or device) and pjit_mesh is not None and not pjit_mesh.empty:
raise ValueError(
Expand All @@ -441,8 +494,6 @@ def common_infer_params(pjit_info_args, *args, **kwargs):

axes_specs = _flat_axes_specs(abstracted_axes, *args, **kwargs)

jit_name = 'jit' if resource_env is None else 'pjit'

dbg = debug_info(jit_name, fun_sourceinfo, fun_signature, args, kwargs,
static_argnums, static_argnames)
f = lu.wrap_init(fun)
Expand Down Expand Up @@ -782,39 +833,10 @@ def pjit(
... print(f(x)) # doctest: +SKIP
[ 0.5 2. 4. 6. 8. 10. 12. 10. ]
"""
(in_shardings, out_shardings, donate_argnums, donate_argnames, static_argnums,
static_argnames) = pre_infer_params(
return make_jit(
fun, in_shardings, out_shardings, donate_argnums, donate_argnames,
static_argnums, static_argnames, device, backend, abstracted_axes)

fun_sourceinfo = api_util.fun_sourceinfo(fun)
fun_signature = api_util.fun_signature(fun)

def infer_params(*args, **kwargs):
# Putting this outside of wrapped would make resources lexically scoped
resource_env = mesh_lib.thread_resources.env
# TODO(yashkatariya): Remove this when it's added on jit. Also default to
# layout.DefaultLayout() when out of experimental.
in_layouts = kwargs.pop('_in_layouts', None)
out_layouts = kwargs.pop('_out_layouts', None)
pjit_info_args = PjitInfo(
fun=fun,
fun_sourceinfo=fun_sourceinfo,
fun_signature=fun_signature,
in_shardings=in_shardings,
out_shardings=out_shardings, static_argnums=static_argnums,
static_argnames=static_argnames, donate_argnums=donate_argnums,
donate_argnames=donate_argnames, device=device, backend=backend,
keep_unused=keep_unused, inline=inline, resource_env=resource_env,
abstracted_axes=abstracted_axes, in_layouts=in_layouts,
out_layouts=out_layouts)
return common_infer_params(pjit_info_args, *args, **kwargs)

has_explicit_sharding = _pjit_explicit_sharding(
in_shardings, out_shardings, device, backend)
return post_infer_params(fun, infer_params, static_argnums, static_argnames,
donate_argnums, abstracted_axes,
has_explicit_sharding)
static_argnums, static_argnames, device, backend, abstracted_axes,
keep_unused, inline, use_resource_env=True)


def hashable_pytree(pytree):
Expand Down

0 comments on commit d3e03ff

Please sign in to comment.