-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement specialized transformed logp dispatch #7188
base: main
Are you sure you want to change the base?
Conversation
@@ -153,3 +158,52 @@ def __init__(self, scalar_op, *args, **kwargs): | |||
|
|||
|
|||
MeasurableVariable.register(MeasurableElemwise) | |||
|
|||
|
|||
class Transform(abc.ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved to abstract, without any changes
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7188 +/- ##
=======================================
Coverage 92.26% 92.26%
=======================================
Files 100 100
Lines 16880 16900 +20
=======================================
+ Hits 15574 15593 +19
- Misses 1306 1307 +1
|
f8e83f3
to
af027ad
Compare
Is there any expected advantage to this specialization besides avoiding constraints checks (which the user can already do via the model |
Nice :D In the case of the ZeroSumNormal probably not really much, once we figure out the expression in the untransformed space it doesn't change too much, and I guess it is nice to have for instance so that the logp can be computed. Just being able to do it on the transformed space would have saved us doing some math... I think it is quite common of the logp expressions to be a bit cleaner and possibly numerically more stable on the transformed space.
I guess things like this might also be happining in the dirichlet dists, but looking through that is a bit more work... There is also an additional thing we could do with this that right now we can't do with transformations at all. I'm not 100% sure if we should do this though: Right now all our transformations are injective, because the whole trick with the jacobian determinant doesn't work otherwise. But if we can compute the logp on the transformed space, we could get rid of that requirement. Or we could have a transformation that maps a point in 2d space to an angel, and use that in for instance the VanMises distribution to avoid the topology problems we have there right now. |
return _transformed_logprob( | ||
rv_op, | ||
transform, | ||
unconstrained_value=unconstrained_value, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be nice if we could also pass in the transformed value. That way we can avoid computing it twice in a graph if the logp can use that value too. I guess pytensor might get rid of that duplication anyway, but I don't know how reliable that is if the transformation is doing something more complicated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can pass it too.
You raise a point. Maybe transforms should be encapsulated in an OpFromGraph so that we can easily reverse symbolically and not worry whether they will be simplified in their raw form or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually we have the constrained value as well, its value.owner.inputs[0]
Thanks for the reply Re: switches/bound checks, most of those could probably be removed with some domain analysis like exp(x) -> positive so any ge(x, 0) are true, that could be useful beyond PyMC. This could be coupled with user provided type hints (which we are kind of doing with transformed values here) to seed initial info that is not extractable from the graph. Re: examples, many of those sound like interesting rewrites that PyTensor already does or could do. That need no be a blocker to allow users to work directly on transformed space if they want, but may be a good source of ideas for future rewrites. The best argument for the functionality in my perspective is the "avoiding hard math", which is more than a nice to have, and could speedup development or allow models that are otherwise impossible to write without a hacky Potential? For our code-base I think we still want to try and provide constrained space logps as much as possible, since not all we do (or may want to do) is logp-based sampling. |
What does this do? |
I just had a look at the generated graphs of each of those. with pm.Model(check_bounds=False) as model:
pm.ZeroSumNormal("x", shape=(10, 10), n_zerosum_axes=2) The logp + grad with the PR:
And without:
I think I like this PR :D Maybe we can also add an implementation for the beta and the logit-normal dists: @pm.logprob.abstract._transformed_logprob.register(pt.random.basic.BetaRV, pm.distributions.transforms.LogOddsTransform)
def transformed_beta_logp(op, transform, unconstrained_value, rv_inputs):
*_, alpha, beta = rv_inputs
logit_x = unconstrained_value
normalizing_factor = pt.gammaln(alpha + beta) - pt.gammaln(alpha) - pt.gammaln(beta)
logp = normalizing_factor - pt.log1p(pt.exp(-logit_x)) * (alpha + beta) - logit_x * beta
return pm.distributions.multivariate.check_parameters(
logp,
alpha > 0,
beta > 0,
"Alpha and beta parameters must be positive in Beta distribution",
)
@pm.logprob.abstract._transformed_logprob.register(pm.distributions.continuous.LogitNormalRV, pm.distributions.transforms.LogOddsTransform)
def transformed_logit_normal_logp(op, transform, unconstrained_value, rv_inputs):
*_, mu, sigma = rv_inputs
return pm.logp(pm.Normal.dist(mu, sigma), unconstrained_value) |
def transformed_zerosumnormal_logp(op, transform, unconstrained_value, rv_inputs): | ||
_, sigma, _ = rv_inputs | ||
zerosum_axes = transform.zerosum_axes | ||
if len(zerosum_axes) != op.ndim_supp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this check is necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was a sanity check. A user could have defined an invalid transform manually. I don't really care either way
@aseyboldt perhaps a more fair graph comparison would be with |
That was with with pm.Model(check_bounds=False) as model:
x = pm.ZeroSumNormal("x", shape=(100, 100), n_zerosum_axes=2)
pm.Normal("y", mu=x, sigma=1, shape=(100, 100)) With PR:
Without PR
That's 140 vs 200 lines. Got to admit that I'm not 100% sure what the difference really is though... It's also a bit faster, with 70μs vs 90μs. |
Then I'm surprised the isclose(mean, 0) switch was not removed |
You are right, it seems check_bounds=False didn't do anything, because I wasn't compiling the pytensor functions in a model context. I had no idea I had to do that... |
Description
Adds a specialized dispatch for transformed logps
TODO
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7188.org.readthedocs.build/en/7188/