-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Numpyro keep warmup #6875
base: main
Are you sure you want to change the base?
Numpyro keep warmup #6875
Conversation
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @waldie11, thanks for opening a PR. In general we suggest you first open an issue/discussion, so developers can confirm the feature should be implemented before you do any work.
With the changes warmup samples will always be stored and can't be disabled. I don't think that should be the default, as it is not either in pm.sample
I am not sure the extra complexity is warranted either, but maybe others will agree with you (I added a request discussion label).
I need that feature, so the work was done already. It felt more of restoring a so far unimplemented option than a new feature.
I don't aggree with both statements. Considering 2nd: pm.sample for stock pymc with default options does additional steps to drop the warmup samples very late in
Indeed it would be nice to have this in a more lightweight way, most apparent remove the need of the recompilation. |
What is this PR about?
This PR aims to improve on the experimental feature of the numpyro NUTS sampler within PyMC.
It enables access to the warmup_sample_stats xarray data struct by passing along save_warmup in
idata_kwargs
. It is a first step to fulfill #6723It would be nice to eliminate the need of a jax recompilation for
numpyro.infer.MCMC.run
after.warmup
has completed.Checklist
Major / Breaking Changes
New features
Bugfixes
Documentation
Maintenance
idata = pm.sample(..., nuts_sampler="numpyro", idata_kwargs=dict( save_warmup=True, ))
now provides warmup_sample_stats📚 Documentation preview 📚: https://pymc--6875.org.readthedocs.build/en/6875/