Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numpyro keep warmup #6875

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open

Conversation

waldie11
Copy link

@waldie11 waldie11 commented Aug 23, 2023

What is this PR about?
This PR aims to improve on the experimental feature of the numpyro NUTS sampler within PyMC.
It enables access to the warmup_sample_stats xarray data struct by passing along save_warmup in idata_kwargs. It is a first step to fulfill #6723
It would be nice to eliminate the need of a jax recompilation for numpyro.infer.MCMC.run after .warmup has completed.

Checklist

Major / Breaking Changes

  • ...

New features

  • ...

Bugfixes

  • ...

Documentation

  • ...

Maintenance

  • idata = pm.sample(..., nuts_sampler="numpyro", idata_kwargs=dict( save_warmup=True, )) now provides warmup_sample_stats

📚 Documentation preview 📚: https://pymc--6875.org.readthedocs.build/en/6875/

@welcome
Copy link

welcome bot commented Aug 23, 2023

Thank You Banner
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @waldie11, thanks for opening a PR. In general we suggest you first open an issue/discussion, so developers can confirm the feature should be implemented before you do any work.

With the changes warmup samples will always be stored and can't be disabled. I don't think that should be the default, as it is not either in pm.sample

I am not sure the extra complexity is warranted either, but maybe others will agree with you (I added a request discussion label).

@waldie11
Copy link
Author

waldie11 commented Aug 24, 2023

In general we suggest you first open an issue/discussion, so developers can confirm the feature should be implemented before you do any work.

I need that feature, so the work was done already. It felt more of restoring a so far unimplemented option than a new feature.

With the changes warmup samples will always be stored and can't be disabled. I don't think that should be the default, as it is not either in pm.sample

I don't aggree with both statements. Considering 2nd: pm.sample for stock pymc with default options does additional steps to drop the warmup samples very late in _sample_return. As the current state of this numpyro NUTS sampler does not use Multitrace anymore, one cant jump that train easily. So while this might be disadvantageous considering memory usage towards the current design of including numpyro, imo it is not towards stock pymc's. Considering 1th, I think one needs to be more precise in discussion: indeed this way warmup samples are kept in memory. However they are not returned automatically within the Inferencedata object. The most comfortable way to extract the stats is by using idata_kwargs to keep everything starting with warmup_*, so by default there is no change in output. And GC should take care of the rest.

I am not sure the extra complexity is warranted either

Indeed it would be nice to have this in a more lightweight way, most apparent remove the need of the recompilation.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants