Skip to content

Commit

Permalink
Merge pull request jax-ml#20236 from jakevdp:key-reuse-stack
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 618014760
  • Loading branch information
jax authors committed Mar 22, 2024
2 parents d57bb8c + 7e60331 commit 07e45c3
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
2 changes: 2 additions & 0 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from jax._src import dtypes
from jax._src import pretty_printer as pp
from jax._src import sharding_specs
from jax._src import source_info_util
from jax._src import tree_util as tree_util_internal
from jax._src import typing
from jax._src import op_shardings
Expand Down Expand Up @@ -154,6 +155,7 @@ class behave like an array whose base elements are keys, hiding the
_impl: PRNGImpl
_base_array: typing.Array
_consumed: bool | np.ndarray # Used in jax.experimental.key_reuse.
_source_info: None | source_info_util.SourceInfo = None

def __init__(self, impl, key_data: Any):
assert not isinstance(key_data, core.Tracer)
Expand Down
33 changes: 30 additions & 3 deletions jax/experimental/key_reuse/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from __future__ import annotations

from collections import defaultdict
import contextlib
from functools import partial, reduce, total_ordering, wraps
from typing import Any, Callable, Iterator, NamedTuple

Expand All @@ -31,6 +32,7 @@
from jax._src import prng
from jax._src import random
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.ad_checkpoint import remat_p
from jax._src.debugging import debug_callback_p
Expand All @@ -41,6 +43,27 @@
import numpy as np


traceback_util.register_exclusion(__file__)

_source_context_message = (
'PRNG key first used at the above location was subsequently reused'
' at the following location:')

def key_reuse_error_with_source_traceback(
message: str, traceback: source_info_util.Traceback | None) -> KeyReuseError:
err = KeyReuseError(message)
if traceback is not None:
filtered_tb = traceback_util.filter_traceback(traceback.as_python_traceback())
if filtered_tb:
context_err = KeyReuseError(_source_context_message).with_traceback(filtered_tb)
context_err.__context__ = err.__context__
context_err.__cause__ = err.__cause__
context_err.__suppress_context__ = err.__suppress_context__
err.__context__ = None
err.__cause__ = context_err
return err


# Create Source() and Sink() objects which validate inputs, have
# correct equality semantics, and are hashable & immutable.
@total_ordering
Expand Down Expand Up @@ -145,19 +168,23 @@ def forwards(self) -> Iterator[Forward]:

def check_signature(self, *args, funcname="function", context=None):
for sink in self.sinks:
if not isinstance(args[sink.idx], prng.PRNGKeyArray):
key = args[sink.idx]
if not isinstance(key, prng.PRNGKeyArray):
continue
if np.any(args[sink.idx]._consumed & sink.mask):
if np.any(key._consumed & sink.mask):
msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}"
if context:
msg += " {context}"
raise KeyReuseError(msg)
raise key_reuse_error_with_source_traceback(
msg, key._source_info and key._source_info.traceback)

def update_consumption(self, args_in, args_out):
for sink in self.sinks:
arg = args_in[sink.idx]
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = arg._consumed | sink.mask
if np.any(sink.mask):
arg._source_info = source_info_util.current()
for arg in args_out:
if isinstance(arg, prng.PRNGKeyArray):
arg._consumed = True
Expand Down

0 comments on commit 07e45c3

Please sign in to comment.