From 7e60331cd124366dee0c962fe5b69d1eea233381 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Thu, 21 Mar 2024 13:34:26 -0700 Subject: [PATCH] [key reuse] print information about key reuse location --- jax/_src/prng.py | 2 ++ jax/experimental/key_reuse/_core.py | 33 ++++++++++++++++++++++++++--- 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index f495e1c99d26..2da8196a3bf4 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -34,6 +34,7 @@ from jax._src import dtypes from jax._src import pretty_printer as pp from jax._src import sharding_specs +from jax._src import source_info_util from jax._src import tree_util as tree_util_internal from jax._src import typing from jax._src import op_shardings @@ -154,6 +155,7 @@ class behave like an array whose base elements are keys, hiding the _impl: PRNGImpl _base_array: typing.Array _consumed: bool | np.ndarray # Used in jax.experimental.key_reuse. + _source_info: None | source_info_util.SourceInfo = None def __init__(self, impl, key_data: Any): assert not isinstance(key_data, core.Tracer) diff --git a/jax/experimental/key_reuse/_core.py b/jax/experimental/key_reuse/_core.py index 0ce387b008f6..5bcffe333944 100644 --- a/jax/experimental/key_reuse/_core.py +++ b/jax/experimental/key_reuse/_core.py @@ -15,6 +15,7 @@ from __future__ import annotations from collections import defaultdict +import contextlib from functools import partial, reduce, total_ordering, wraps from typing import Any, Callable, Iterator, NamedTuple @@ -31,6 +32,7 @@ from jax._src import prng from jax._src import random from jax._src import source_info_util +from jax._src import traceback_util from jax._src import util from jax._src.ad_checkpoint import remat_p from jax._src.debugging import debug_callback_p @@ -41,6 +43,27 @@ import numpy as np +traceback_util.register_exclusion(__file__) + +_source_context_message = ( + 'PRNG key first used at the above location was subsequently reused' + ' at the following location:') + +def key_reuse_error_with_source_traceback( + message: str, traceback: source_info_util.Traceback | None) -> KeyReuseError: + err = KeyReuseError(message) + if traceback is not None: + filtered_tb = traceback_util.filter_traceback(traceback.as_python_traceback()) + if filtered_tb: + context_err = KeyReuseError(_source_context_message).with_traceback(filtered_tb) + context_err.__context__ = err.__context__ + context_err.__cause__ = err.__cause__ + context_err.__suppress_context__ = err.__suppress_context__ + err.__context__ = None + err.__cause__ = context_err + return err + + # Create Source() and Sink() objects which validate inputs, have # correct equality semantics, and are hashable & immutable. @total_ordering @@ -145,19 +168,23 @@ def forwards(self) -> Iterator[Forward]: def check_signature(self, *args, funcname="function", context=None): for sink in self.sinks: - if not isinstance(args[sink.idx], prng.PRNGKeyArray): + key = args[sink.idx] + if not isinstance(key, prng.PRNGKeyArray): continue - if np.any(args[sink.idx]._consumed & sink.mask): + if np.any(key._consumed & sink.mask): msg = f"Previously-consumed key passed to {funcname} at index {sink.idx}" if context: msg += " {context}" - raise KeyReuseError(msg) + raise key_reuse_error_with_source_traceback( + msg, key._source_info and key._source_info.traceback) def update_consumption(self, args_in, args_out): for sink in self.sinks: arg = args_in[sink.idx] if isinstance(arg, prng.PRNGKeyArray): arg._consumed = arg._consumed | sink.mask + if np.any(sink.mask): + arg._source_info = source_info_util.current() for arg in args_out: if isinstance(arg, prng.PRNGKeyArray): arg._consumed = True