Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable sampling in chunks with external jax samplers #7465

Open
wants to merge 19 commits into
base: main
Choose a base branch
from

Conversation

andrewdipper
Copy link
Contributor

@andrewdipper andrewdipper commented Aug 16, 2024

Initial take on extending blackjax and numpyro samplers to be able to sequentially sample multiple chunks. This eliminates the requirement of the gpu having sufficient memory to store all samples at once - they just need to fit in cpu memory.

Changes / features:

  • Sampling with one chunk (current behavior) and sampling with multiple chunks will return the exact same samples. This is how the chunking is tested
  • As long as the first chunk compiles / samples the remainder should not cause an out of memory error. I have pretty high confidence this holds for blackjax. Numpyro is harder due to the higher level api, but there won't be a big memory jump. Hopefully this prevents annoying errors deep into sampling
  • Postprocessing is done on a per chunk basis (and is compiled with sampling for blackjax)
  • When num_chunks==1 samples are stored on the sampling device consistent with current behavior. With multiple chunks they are transferred to cpu memory

Some question marks:

  • Currently progress through the chunks is just written to the log - I'm not sure if this is the most reasonable solution.
  • The postprocessing_backend option is removed. I think this is reasonable as any postprocessing memory requirements should be dominated by the already necessary transpose of the chains and samples dimensions (this is due to vmap(scan) materializing the scan dimension first and subsequently transposing). Unless I'm missing another reason to force the postprocessing backend?

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7465.org.readthedocs.build/en/7465/

@andrewdipper
Copy link
Contributor Author

The test failure seems a bit random - I haven't been able to trigger a failure locally. I get some acceptance_rates pretty far from 0.5 so I'm not sure how stable it's expected to be.

@@ -229,7 +229,7 @@ def test_get_log_likelihood():
b_true = trace.log_likelihood.b.values
a = np.array(trace.posterior.a)
sigma_log_ = np.log(np.array(trace.posterior.sigma))
b_jax = _get_log_likelihood(model, [a, sigma_log_])["b"]
b_jax = jax.vmap(_get_log_likelihood_fn(model))([a, sigma_log_])["b"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why did the behavior (had to) change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For postprocessing I needed to be able to calculate the log_likelihood without the final wrapping vmap. It's possible to have it just calculate the likelihood instead of returning a function. However the extra vmap will still be necessary

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to calculate the likelihood instead of using a returned likelihood calculator function.

@ricardoV94
Copy link
Member

The test failure seems a bit random - I haven't been able to trigger a failure locally. I get some acceptance_rates pretty far from 0.5 so I'm not sure how stable it's expected to be.

I don't think this was failing before so might be related to the changes

@ricardoV94
Copy link
Member

@ferrine any opinion on the removal of postprocessing_backend?

import warnings

warnings.warn(
"postprocessing_backend={'cpu', 'gpu'} will be removed in a future release, "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should deprecate before rendering the argument useless or raise already. Also can the message mention the alternative is num_chunks now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, makes sense - I can add back that functionality - it's just a few extra branches to keep track of

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added back logic for postprocessing_backend. It doesn't have any integration with chunked sampling but restores postprocessing all at once on a different backend.

@andrewdipper
Copy link
Contributor Author

The test failure seems a bit random - I haven't been able to trigger a failure locally. I get some acceptance_rates pretty far from 0.5 so I'm not sure how stable it's expected to be.

I don't think this was failing before so might be related to the changes

It was splitting the random key 1 extra time as compared to the current behavior. Removing the extra split fixes the failure and I believe the numpyro samples generated will now be identical to those currently generated. However, this means that choosing the wrong key can still trigger the test failure. Based off of pyro-ppl/numpyro#1786, it seems like acceptance_rates won't be super stable.

@ricardoV94
Copy link
Member

The test failure seems a bit random - I haven't been able to trigger a failure locally. I get some acceptance_rates pretty far from 0.5 so I'm not sure how stable it's expected to be.

I don't think this was failing before so might be related to the changes

It was splitting the random key 1 extra time as compared to the current behavior. Removing the extra split fixes the failure and I believe the numpyro samples generated will now be identical to those currently generated. However, this means that choosing the wrong key can still trigger the test failure. Based off of pyro-ppl/numpyro#1786, it seems like acceptance_rates won't be super stable.

Okay if it's not stable feel free to choose the best code and pick a seed that happens to works

@ferrine
Copy link
Member

ferrine commented Aug 18, 2024

@ferrine any opinion on the removal of postprocessing_backend?

What will be different there what what will be memory consumption? What overhead is put on the gpu/ram?

@ferrine
Copy link
Member

ferrine commented Aug 18, 2024

What if a single sample does not compile on the gpu? Is it realistic? What about num_samples_in_chunk parameter?

Copy link

codecov bot commented Aug 18, 2024

Codecov Report

Attention: Patch coverage is 86.48649% with 15 lines in your changes missing coverage. Please review.

Project coverage is 92.40%. Comparing base (cdcdb58) to head (4e6b1fa).
Report is 60 commits behind head on main.

Files with missing lines Patch % Lines
pymc/sampling/jax.py 86.48% 15 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7465      +/-   ##
==========================================
- Coverage   92.44%   92.40%   -0.04%     
==========================================
  Files         103      103              
  Lines       17119    17153      +34     
==========================================
+ Hits        15825    15850      +25     
- Misses       1294     1303       +9     
Files with missing lines Coverage Δ
pymc/sampling/jax.py 91.83% <86.48%> (-2.96%) ⬇️

@andrewdipper
Copy link
Contributor Author

@ferrine any opinion on the removal of postprocessing_backend?

What will be different there what what will be memory consumption? What overhead is put on the gpu/ram?

I don't know in general what the memory complexity of the postprocessing transformations can be. However, when sampling with chain_method=vectorized the vmap(scan) seems to always be turned into scan(vmap) with a subsequent transpose on the (chain, samples) dimension. That requires a copy on the leaves of the pytree. My (perhaps invalid) assumption is most transformations fit within that space but I'm not sure on the chain_method="parallel" case.

Practically speaking the postprocessing is jit compiled with the sampling step so if sampling starts then the memory is sufficient (for numpyro if tuning starts then memory is sufficient). I'm not sure I can see a case where postprocessing memory requirements are very high and cpu memory is so dominant of gpu memory that num_chunks cannot get the memory down. And remember that currently all samples must fit on the gpu together.

What if a single sample does not compile on the gpu? Is it realistic? What about num_samples_in_chunk parameter?

I'm not sure on if that happens / what the current resolution would be.

The parameterization is with draws and num_chunks with num_samples_in_chunk = draws / num_chunks. Is there as reason to prefer num_samples_in_chunk?

@andrewdipper
Copy link
Contributor Author

Is there a proper way to run tests with a gpu backend enabled? My test for postprocessing_backend gets skipped since the backend is not available.

@ricardoV94
Copy link
Member

Is there a proper way to run tests with a gpu backend enabled? My test for postprocessing_backend gets skipped since the backend is not available.

No, GitHub actions doesn't include gpu in the free plan

@twiecki
Copy link
Member

twiecki commented Sep 14, 2024

What's the stauts here, can we merge?

@andrewdipper
Copy link
Contributor Author

Let me know if anything else needs to be done on my end

@ricardoV94
Copy link
Member

I am leaning a bit on "this is too much complexity on our side".

@andrewdipper
Copy link
Contributor Author

I believe you're talking about higher level complexity. But iirc for blackjax the multi_step function can be replaced by a new (potentially unreleased) run_inference_algorithm from them - that should simplify part of the code a fair bit. The samples would just be generated with different random seeds.

Either way let me know what you decide

@ricardoV94
Copy link
Member

@andrewdipper wanna give a try at that simpler approach?

@andrewdipper
Copy link
Contributor Author

Sure, I'll give it a go

@andrewdipper
Copy link
Contributor Author

Apologies for the delay - I got caught up.

Switched to using blackjax.util.run_inference_algorithm and tried to clarify things a bit. Let me know if you think it's viable.

I removed the postprocessing test as it doesn't get run and blackjax chunked sampling will no longer be identical to when it's just a single chunk. I plan to swap in some other sampling tests so the code has test coverage.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants