-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use jaxified logp for initial point evaluation when sampling via Jax #7610
base: main
Are you sure you want to change the base?
Conversation
] |
|
@ricardoV94 any thoughts on this one? |
pymc/sampling/jax.py
Outdated
model_logp = model.logp() | ||
if not negative_logp: | ||
model_logp = -model_logp | ||
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) | ||
|
||
def logp_fn_wrap(x): | ||
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct, it takes jax arrays and outputs jax arrays
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's 100% true. Checking with the interactive debugger confirms that the return type is jax.Array
, but the initial point functions return a dict[str, np.ndarray]
, and we can successfully pass the .values()
of that dict into the jaxified function. So it can seemingly accept anything that's coercible to an array. Maybe it's more correct to annotate it like this:
def logp_fn_wrap(x: ArrayLike) -> jax.Array:
ArrayLike is from numpy.typing
: https://numpy.org/devdocs/reference/typing.html#numpy.typing.ArrayLike
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just pushed a commit to improve this, it's a bit tricky to annotate at the interface with _init_jitter
given that jax
is an optional dependency. I've left the type annotation as returning a np.ndarray
but included that it may return a jax.Array
in the docstring.
@nataziel good catch. The outer/user-facing functions take initvals which later get converted into |
Cleaned that up and added docstrings for the numpyro equivalent :) |
Not sure why the most recent commit didn't trigger the documentation check. edit: it wasn't using the locally installed version. With the local version properly installed it failed. Will try debug |
Seems sphinx autodoc didn't like |
Use jaxified logp for initial point evaluation when sampling via Jax
Description
sample_jax_nuts
nuts_sampler
is specified_get_batched_jittered_initial_points
logp_fn
parameter to function signature_init_jitter
will call it_init_jitter
logp_fn
parameter to function signature_init_jitter
to decide which function to use to evaluate the generated pointsRelated Issue
pt.config.floatX = "float32"
#7608Checklist
sample_blackjax_nuts
function docstring, will put in comment belowType of change
📚 Documentation preview 📚: https://pymc--7610.org.readthedocs.build/en/7610/