Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jaxified logp for initial point evaluation when sampling via Jax #7610

Open
wants to merge 9 commits into
base: main
Choose a base branch
from

Conversation

nataziel
Copy link

@nataziel nataziel commented Dec 11, 2024

Use jaxified logp for initial point evaluation when sampling via Jax

Description

  • get jaxified logp function in sample_jax_nuts
    • uses different parameters to get jaxified function depending on which nuts_sampler is specified
  • pass jaxified logp function to _get_batched_jittered_initial_points
    • added logp_fn parameter to function signature
    • wrap passed function to conform to how _init_jitter will call it
  • pass wrapped function to _init_jitter
    • added logp_fn parameter to function signature
  • added logic in _init_jitter to decide which function to use to evaluate the generated points
  • added a bunch of type annotations

Related Issue

Checklist

  • Checked that the pre-commit linting/style checks pass
  • Included tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings and/or example notebooks) - have a question about the sample_blackjax_nuts function docstring, will put in comment below
  • If you are a pro: each commit corresponds to a [relevant logical change]

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7610.org.readthedocs.build/en/7610/

Copy link

welcome bot commented Dec 11, 2024

Thank You Banner]
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

@github-actions github-actions bot added the bug label Dec 11, 2024
@nataziel
Copy link
Author

_sample_blackjax_nuts mentions initvals in the parameters section of the docstrings, is that a kwarg or should it be changed to initial_points?
I think it's the former and I could add an initial_points parameter to the docstring? Happy to do if that's correct.

pymc/initial_point.py Outdated Show resolved Hide resolved
pymc/sampling/mcmc.py Outdated Show resolved Hide resolved
@nataziel
Copy link
Author

_sample_blackjax_nuts mentions initvals in the parameters section of the docstrings, is that a kwarg or should it be changed to initial_points? I think it's the former and I could add an initial_points parameter to the docstring? Happy to do if that's correct.

@ricardoV94 any thoughts on this one?

model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])

def logp_fn_wrap(x):
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct, it takes jax arrays and outputs jax arrays

Copy link
Author

@nataziel nataziel Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's 100% true. Checking with the interactive debugger confirms that the return type is jax.Array, but the initial point functions return a dict[str, np.ndarray], and we can successfully pass the .values() of that dict into the jaxified function. So it can seemingly accept anything that's coercible to an array. Maybe it's more correct to annotate it like this:

def logp_fn_wrap(x: ArrayLike) -> jax.Array:

ArrayLike is from numpy.typing: https://numpy.org/devdocs/reference/typing.html#numpy.typing.ArrayLike

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just pushed a commit to improve this, it's a bit tricky to annotate at the interface with _init_jitter given that jax is an optional dependency. I've left the type annotation as returning a np.ndarray but included that it may return a jax.Array in the docstring.

@ricardoV94
Copy link
Member

@nataziel good catch. The outer/user-facing functions take initvals which later get converted into initial_points in the inner functions. Feel free to update the docstrings of the inner functions if they still refer to initvals

@nataziel
Copy link
Author

@nataziel good catch. The outer/user-facing functions take initvals which later get converted into initial_points in the inner functions. Feel free to update the docstrings of the inner functions if they still refer to initvals

Cleaned that up and added docstrings for the numpyro equivalent :)

@nataziel
Copy link
Author

nataziel commented Dec 12, 2024

Not sure why the most recent commit didn't trigger the documentation check. I was able to run make rtd locally and the build succeeded.

edit: it wasn't using the locally installed version. With the local version properly installed it failed. Will try debug

@nataziel
Copy link
Author

Seems sphinx autodoc didn't like if TYPE_CHECKING. Successfully builds with make rtd on my machine now. Not sure how to manually trigger the check

@nataziel nataziel requested a review from ricardoV94 December 12, 2024 13:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: model initial_point fails when pt.config.floatX = "float32"
2 participants