Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement specialized transformed logp dispatch #7188

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 9, 2024

Description

Adds a specialized dispatch for transformed logps

TODO

  • Add a direct test
  • Implement case for ZeroSumNormal

Related Issue

  • Closes #
  • Related to #

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7188.org.readthedocs.build/en/7188/

@@ -153,3 +158,52 @@ def __init__(self, scalar_op, *args, **kwargs):


MeasurableVariable.register(MeasurableElemwise)


class Transform(abc.ABC):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to abstract, without any changes

Copy link

codecov bot commented Mar 9, 2024

Codecov Report

Attention: Patch coverage is 96.72131% with 2 lines in your changes are missing coverage. Please review.

Project coverage is 92.26%. Comparing base (244fb97) to head (af027ad).

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #7188   +/-   ##
=======================================
  Coverage   92.26%   92.26%           
=======================================
  Files         100      100           
  Lines       16880    16900   +20     
=======================================
+ Hits        15574    15593   +19     
- Misses       1306     1307    +1     
Files Coverage Δ
pymc/distributions/multivariate.py 93.84% <100.00%> (+0.03%) ⬆️
pymc/initial_point.py 100.00% <100.00%> (ø)
pymc/logprob/abstract.py 96.92% <100.00%> (+1.46%) ⬆️
pymc/logprob/basic.py 94.36% <ø> (-0.04%) ⬇️
pymc/logprob/transform_value.py 94.18% <100.00%> (+0.43%) ⬆️
pymc/logprob/transforms.py 95.29% <100.00%> (-0.18%) ⬇️
pymc/logprob/utils.py 97.87% <100.00%> (+0.65%) ⬆️
pymc/model/fgraph.py 97.39% <100.00%> (ø)
pymc/model/transform/conditioning.py 95.74% <100.00%> (ø)
pymc/distributions/transforms.py 97.01% <66.66%> (-1.46%) ⬇️

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 9, 2024

Is there any expected advantage to this specialization besides avoiding constraints checks (which the user can already do via the model check_bounds flag)?

@ricardoV94 ricardoV94 marked this pull request as draft March 9, 2024 21:55
@ricardoV94 ricardoV94 requested a review from maresb March 9, 2024 22:04
@aseyboldt
Copy link
Member

aseyboldt commented Mar 10, 2024

Nice :D

In the case of the ZeroSumNormal probably not really much, once we figure out the expression in the untransformed space it doesn't change too much, and I guess it is nice to have for instance so that the logp can be computed. Just being able to do it on the transformed space would have saved us doing some math...

I think it is quite common of the logp expressions to be a bit cleaner and possibly numerically more stable on the transformed space.

  • For instance for the beta distribution we have this here in the code currently (before the bounds check):

          res = (
              pt.switch(pt.eq(alpha, 1.0), 0.0, (alpha - 1.0) * pt.log(value))
              + pt.switch(pt.eq(beta, 1.0), 0.0, (beta - 1.0) * pt.log1p(-value))
              - (pt.gammaln(alpha) + pt.gammaln(beta) - pt.gammaln(alpha + beta))
          )
          res = pt.switch(pt.bitwise_and(pt.ge(value, 0.0), pt.le(value, 1.0)), res, -np.inf)

    on the (logit) transformed space it could simply be

    normalizing_factor = pt.gammaln(alpha + beta) - pt.gammaln(alpha) - pt.gammaln(beta) 
    density = - pt.log1p(pt.exp(-logit_x)) * (alpha + beta) - logit_x * beta + normalizing_factor

    All the switches are gone, the graph is much shorter (in the original value is actually expit(x), so we actually have things like log(expit(x)) and log(1- expit(x)) in the graph).

  • Or for instance for the gamma with log transform:
    Now, we have

          beta = pt.reciprocal(scale)
          res = -pt.gammaln(alpha) + logpow(beta, alpha) - beta * value + logpow(value, alpha - 1)
          res = pt.switch(pt.ge(value, 0.0), res, -np.inf)

    Which could be

    beta = pt.reciprocal(scale)
    norm = - beta * pt.log(alpha) - pt.gammaln(beta)
    density = beta * log_x - exp(log_x) / alpha

    (I guess we could also pass in the transformed point into the logp function, so that the exp(log_x) could just be x. The original had things like logpow(exp(x), alpha - 1) in the graph, the new one doesn't. Not a big deal I think, but it is cleaner.

  • For the logit-normal what we have to do right now is even a bit silly:

          res = pt.switch(
              pt.or_(pt.le(value, 0), pt.ge(value, 1)),
              -np.inf,
              (
                  -0.5 * tau * (logit(value) - mu) ** 2
                  + 0.5 * pt.log(tau / (2.0 * np.pi))
                  - pt.log(value * (1 - value))
              ),
          )

    This contains logit(value). So we first compute value = expit(logit_x) and then just undo it. We also have the - pt.log(value * (1 - value)), which is just the jacobian det of the transform. So we are taking a point in the untransformed space, compute the transformed value and the jac det, then we (manually) untransform the value again, and (manually) subtract the logdet again...

I guess things like this might also be happining in the dirichlet dists, but looking through that is a bit more work...
Nothing here really is a game-changer, but I think it might make quite a few graphs a bit cleaner and a bit more stable.


There is also an additional thing we could do with this that right now we can't do with transformations at all. I'm not 100% sure if we should do this though:

Right now all our transformations are injective, because the whole trick with the jacobian determinant doesn't work otherwise. But if we can compute the logp on the transformed space, we could get rid of that requirement.
So for instance we could have a Horseshoe distribution with a transformed space that contains the lambda and the x values, and the tranformation just multiplies those.

Or we could have a transformation that maps a point in 2d space to an angel, and use that in for instance the VanMises distribution to avoid the topology problems we have there right now.

return _transformed_logprob(
rv_op,
transform,
unconstrained_value=unconstrained_value,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if we could also pass in the transformed value. That way we can avoid computing it twice in a graph if the logp can use that value too. I guess pytensor might get rid of that duplication anyway, but I don't know how reliable that is if the transformation is doing something more complicated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can pass it too.

You raise a point. Maybe transforms should be encapsulated in an OpFromGraph so that we can easily reverse symbolically and not worry whether they will be simplified in their raw form or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually we have the constrained value as well, its value.owner.inputs[0]

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 10, 2024

Thanks for the reply

Re: switches/bound checks, most of those could probably be removed with some domain analysis like exp(x) -> positive so any ge(x, 0) are true, that could be useful beyond PyMC. This could be coupled with user provided type hints (which we are kind of doing with transformed values here) to seed initial info that is not extractable from the graph.

Re: examples, many of those sound like interesting rewrites that PyTensor already does or could do. That need no be a blocker to allow users to work directly on transformed space if they want, but may be a good source of ideas for future rewrites.

The best argument for the functionality in my perspective is the "avoiding hard math", which is more than a nice to have, and could speedup development or allow models that are otherwise impossible to write without a hacky Potential?

For our code-base I think we still want to try and provide constrained space logps as much as possible, since not all we do (or may want to do) is logp-based sampling.

@twiecki
Copy link
Member

twiecki commented Mar 11, 2024

What does this do?

@aseyboldt
Copy link
Member

I just had a look at the generated graphs of each of those.
Let's say we want the logp of the model:

with pm.Model(check_bounds=False) as model:
    pm.ZeroSumNormal("x", shape=(10, 10), n_zerosum_axes=2)

The logp + grad with the PR:

>>> func1 = model.logp_dlogp_function(mode="NUMBA")
>>> pytensor.dprint(func1._pytensor_function)
Sum{axes=None} [id A] 1
 └─ Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)} [id B] 'x_zerosum___logprob' 0
    └─ x_zerosum__ [id C]
Neg [id D] 'x_zerosum___grad' 2
 └─ x_zerosum__ [id C]

Inner graphs:

Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)} [id B]
 ← sub [id E] 'o0'
    ├─ mul [id F]
    │  ├─ -0.5 [id G]
    │  └─ sqr [id H]
    │     └─ i0 [id I]
    └─ 0.9189385332046727 [id J]

And without:

Switch [id A] 'mean(value, axis=n_zerosum_axes) = 0' 31
 ├─ All{axes=None} [id B] 27
 │  └─ MakeVector{dtype='bool'} [id C] 24
 │     ├─ All{axes=None} [id D] 20
 │     │  └─ Composite{and(le((0.1 * abs(i0)), 1e-09), invert(or(isnan((0.1 * i0)), isinf((0.1 * i0)))))} [id E] 16
 │     │     └─ Sum{axis=1} [id F] 11
 │     │        └─ Composite{...}.0 [id G] 9
 │     │           ├─ Join [id H] 8
 │     │           │  ├─ 1 [id I]
 │     │           │  ├─ Sub [id J] 4
 │     │           │  │  ├─ Join [id K] 3
 │     │           │  │  │  ├─ 0 [id L]
 │     │           │  │  │  ├─ x_zerosum__ [id M]
 │     │           │  │  │  └─ Composite{...}.1 [id N] 2
 │     │           │  │  │     └─ ExpandDims{axis=0} [id O] 1
 │     │           │  │  │        └─ Sum{axis=0} [id P] 0
 │     │           │  │  │           └─ x_zerosum__ [id M]
 │     │           │  │  └─ Composite{...}.0 [id N] 2
 │     │           │  │     └─ ···
 │     │           │  └─ Composite{...}.1 [id Q] 7
 │     │           │     └─ ExpandDims{axis=1} [id R] 6
 │     │           │        └─ Sum{axis=1} [id S] 5
 │     │           │           └─ Sub [id J] 4
 │     │           │              └─ ···
 │     │           └─ Composite{...}.0 [id Q] 7
 │     │              └─ ···
 │     └─ All{axes=None} [id T] 22
 │        └─ Composite{and(le((0.1 * abs(i0)), 1e-09), invert(or(isnan((0.1 * i0)), isinf((0.1 * i0)))))} [id U] 18
 │           └─ Sum{axis=0} [id V] 12
 │              └─ Composite{...}.0 [id G] 9
 │                 └─ ···
 ├─ Sum{axes=None} [id W] 13
 │  └─ Composite{...}.2 [id G] 9
 │     └─ ···
 └─ -inf [id X]
Add [id Y] 'x_zerosum___grad' 34
 ├─ SpecifyShape [id Z] 30
 │  ├─ Split{2}.0 [id BA] 26
 │  │  ├─ Add [id BB] 23
 │  │  │  ├─ SpecifyShape [id BC] 15
 │  │  │  │  ├─ Split{2}.0 [id BD] 10
 │  │  │  │  │  ├─ Composite{...}.1 [id G] 9
 │  │  │  │  │  │  └─ ···
 │  │  │  │  │  ├─ 1 [id I]
 │  │  │  │  │  └─ [9 1] [id BE]
 │  │  │  │  ├─ 10 [id BF]
 │  │  │  │  └─ 9 [id BG]
 │  │  │  ├─ Composite{(0.07597469266479578 * (i0 + i1))} [id BH] 21
 │  │  │  │  ├─ SpecifyShape [id BI] 14
 │  │  │  │  │  ├─ Split{2}.1 [id BD] 10
 │  │  │  │  │  │  └─ ···
 │  │  │  │  │  ├─ NoneConst{None} [id BJ]
 │  │  │  │  │  └─ 1 [id I]
 │  │  │  │  └─ ExpandDims{axis=1} [id BK] 17
 │  │  │  │     └─ Sum{axis=1} [id F] 11
 │  │  │  │        └─ ···
 │  │  │  └─ Mul [id BL] 19
 │  │  │     ├─ [[-0.31622777]] [id BM]
 │  │  │     └─ SpecifyShape [id BI] 14
 │  │  │        └─ ···
 │  │  ├─ 0 [id L]
 │  │  └─ [9 1] [id BE]
 │  ├─ 9 [id BG]
 │  └─ 9 [id BG]
 ├─ Composite{(0.07597469266479578 * (i0 - i1))} [id BN] 32
 │  ├─ SpecifyShape [id BO] 29
 │  │  ├─ Split{2}.1 [id BA] 26
 │  │  │  └─ ···
 │  │  ├─ 1 [id I]
 │  │  └─ NoneConst{None} [id BJ]
 │  └─ ExpandDims{axis=0} [id BP] 28
 │     └─ Sum{axis=0} [id BQ] 25
 │        └─ Add [id BB] 23
 │           └─ ···
 └─ Mul [id BR] 33
    ├─ [[-0.31622777]] [id BM]
    └─ SpecifyShape [id BO] 29
       └─ ···

Inner graphs:

Composite{and(le((0.1 * abs(i0)), 1e-09), invert(or(isnan((0.1 * i0)), isinf((0.1 * i0)))))} [id E]
 ← AND [id BS] 'o0'
    ├─ LE [id BT]
    │  ├─ mul [id BU]
    │  │  ├─ t0{0.1} [id BV]
    │  │  └─ Abs [id BW]
    │  │     └─ i0 [id BX]
    │  └─ 1e-09 [id BY]
    └─ Invert [id BZ]
       └─ OR [id CA]
          ├─ IsNan [id CB]
          │  └─ mul [id CC] 't9'
          │     ├─ t0{0.1} [id BV]
          │     └─ i0 [id BX]
          └─ IsInf [id CD]
             └─ mul [id CC] 't9'
                └─ ···

Composite{...} [id G]
 ← sub [id CE] 'o0'
    ├─ i0 [id CF]
    └─ i1 [id CG]
 ← neg [id CH] 'o1'
    └─ sub [id CE] 'o0'
       └─ ···
 ← sub [id CI] 'o2'
    ├─ mul [id CJ]
    │  ├─ -0.5 [id CK]
    │  └─ sqr [id CL]
    │     └─ sub [id CE] 'o0'
    │        └─ ···
    └─ 0.7443402118957849 [id CM]

Composite{...} [id N]
 ← mul [id CN] 'o0'
    ├─ 0.07597469266479578 [id CO]
    └─ i0 [id CP]
 ← sub [id CQ] 'o1'
    ├─ mul [id CN] 'o0'
    │  └─ ···
    └─ mul [id CR]
       ├─ 0.31622776601683794 [id CS]
       └─ i0 [id CP]

Composite{...} [id Q]
 ← mul [id CT] 'o0'
    ├─ 0.07597469266479578 [id CU]
    └─ i0 [id CV]
 ← sub [id CW] 'o1'
    ├─ mul [id CT] 'o0'
    │  └─ ···
    └─ mul [id CX]
       ├─ 0.31622776601683794 [id CY]
       └─ i0 [id CV]

Composite{and(le((0.1 * abs(i0)), 1e-09), invert(or(isnan((0.1 * i0)), isinf((0.1 * i0)))))} [id U]
 ← AND [id CZ] 'o0'
    ├─ LE [id DA]
    │  ├─ mul [id DB]
    │  │  ├─ t7{0.1} [id DC]
    │  │  └─ Abs [id DD]
    │  │     └─ i0 [id DE]
    │  └─ 1e-09 [id DF]
    └─ Invert [id DG]
       └─ OR [id DH]
          ├─ IsNan [id DI]
          │  └─ mul [id DJ] 't11'
          │     ├─ t7{0.1} [id DC]
          │     └─ i0 [id DE]
          └─ IsInf [id DK]
             └─ mul [id DJ] 't11'
                └─ ···

Composite{(0.07597469266479578 * (i0 + i1))} [id BH]
 ← mul [id DL] 'o0'
    ├─ 0.07597469266479578 [id DM]
    └─ add [id DN]
       ├─ i0 [id DO]
       └─ i1 [id DP]

Composite{(0.07597469266479578 * (i0 - i1))} [id BN]
 ← mul [id DQ] 'o0'
    ├─ 0.07597469266479578 [id DR]
    └─ sub [id DS]
       ├─ i0 [id DT]
       └─ i1 [id DU]

I think I like this PR :D

Maybe we can also add an implementation for the beta and the logit-normal dists:

@pm.logprob.abstract._transformed_logprob.register(pt.random.basic.BetaRV, pm.distributions.transforms.LogOddsTransform)
def transformed_beta_logp(op, transform, unconstrained_value, rv_inputs):
    *_, alpha, beta = rv_inputs
    logit_x = unconstrained_value

    normalizing_factor = pt.gammaln(alpha + beta) - pt.gammaln(alpha) - pt.gammaln(beta) 
    logp = normalizing_factor - pt.log1p(pt.exp(-logit_x)) * (alpha + beta) - logit_x * beta
    return pm.distributions.multivariate.check_parameters(
        logp,
        alpha > 0,
        beta > 0,
        "Alpha and beta parameters must be positive in Beta distribution",
    )

@pm.logprob.abstract._transformed_logprob.register(pm.distributions.continuous.LogitNormalRV, pm.distributions.transforms.LogOddsTransform)
def transformed_logit_normal_logp(op, transform, unconstrained_value, rv_inputs):
    *_, mu, sigma = rv_inputs

    return pm.logp(pm.Normal.dist(mu, sigma), unconstrained_value)

def transformed_zerosumnormal_logp(op, transform, unconstrained_value, rv_inputs):
_, sigma, _ = rv_inputs
zerosum_axes = transform.zerosum_axes
if len(zerosum_axes) != op.ndim_supp:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this check is necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was a sanity check. A user could have defined an invalid transform manually. I don't really care either way

@ricardoV94
Copy link
Member Author

@aseyboldt perhaps a more fair graph comparison would be with Model(check_bounds=False). The mean switch for instance is just a bound check thing that could equally be removed in a model with the transform and the standard logp

@aseyboldt
Copy link
Member

aseyboldt commented Mar 11, 2024

That was with check_bounds=False. For most dists that only disables the checks on the parameters, not the domain of the value. But I don't think that's the main source for the difference either.
But this comparison wasn't entirely fair, as the transformation wasn't needed at all anymore with the PR, while in most real applications that would still be needed when the variable is used downstream in the model. I guess this one is a fairer comparision:

with pm.Model(check_bounds=False) as model:
    x = pm.ZeroSumNormal("x", shape=(100, 100), n_zerosum_axes=2)
    pm.Normal("y", mu=x, sigma=1, shape=(100, 100))

With PR:

Add [id A] 18
 ├─ Sum{axes=None} [id B] 3
 │  └─ Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)} [id C] 'x_zerosum___logprob' 1
 │     └─ x_zerosum__ [id D]
 └─ Sum{axes=None} [id E] 14
    └─ Composite{...}.2 [id F] 'y_logprob' 11
       ├─ Join [id G] 10
       │  ├─ 1 [id H]
       │  ├─ Sub [id I] 6
       │  │  ├─ Join [id J] 5
       │  │  │  ├─ 0 [id K]
       │  │  │  ├─ x_zerosum__ [id D]
       │  │  │  └─ Composite{...}.1 [id L] 4
       │  │  │     └─ ExpandDims{axis=0} [id M] 2
       │  │  │        └─ Sum{axis=0} [id N] 0
       │  │  │           └─ x_zerosum__ [id D]
       │  │  └─ Composite{...}.0 [id L] 4
       │  │     └─ ···
       │  └─ Composite{...}.1 [id O] 9
       │     └─ ExpandDims{axis=1} [id P] 8
       │        └─ Sum{axis=1} [id Q] 7
       │           └─ Sub [id I] 6
       │              └─ ···
       ├─ Composite{...}.0 [id O] 9
       │  └─ ···
       └─ y [id R]
Composite{((-i0) + i1 + i2 + i3)} [id S] 'x_zerosum___grad' 28
 ├─ x_zerosum__ [id D]
 ├─ Split{2}.0 [id T] 22
 │  ├─ Add [id U] 21
 │  │  ├─ SpecifyShape [id V] 17
 │  │  │  ├─ Split{2}.0 [id W] 13
 │  │  │  │  ├─ Composite{...}.0 [id F] 11
 │  │  │  │  │  └─ ···
 │  │  │  │  ├─ 1 [id H]
 │  │  │  │  └─ [99  1] [id X]
 │  │  │  ├─ 100 [id Y]
 │  │  │  └─ 99 [id Z]
 │  │  ├─ Composite{(0.00909090909090909 * (i0 - i1))} [id BA] 19
 │  │  │  ├─ SpecifyShape [id BB] 16
 │  │  │  │  ├─ Split{2}.1 [id W] 13
 │  │  │  │  │  └─ ···
 │  │  │  │  ├─ NoneConst{None} [id BC]
 │  │  │  │  └─ 1 [id H]
 │  │  │  └─ ExpandDims{axis=1} [id BD] 15
 │  │  │     └─ Sum{axis=1} [id BE] 12
 │  │  │        └─ Composite{...}.0 [id F] 11
 │  │  │           └─ ···
 │  │  └─ Mul [id BF] 20
 │  │     ├─ [[-0.1]] [id BG]
 │  │     └─ SpecifyShape [id BB] 16
 │  │        └─ ···
 │  ├─ 0 [id K]
 │  └─ [99  1] [id X]
 ├─ Composite{(0.00909090909090909 * (i0 - i1))} [id BH] 27
 │  ├─ SpecifyShape [id BI] 24
 │  │  ├─ Split{2}.1 [id T] 22
 │  │  │  └─ ···
 │  │  ├─ 1 [id H]
 │  │  └─ NoneConst{None} [id BC]
 │  └─ ExpandDims{axis=0} [id BJ] 25
 │     └─ Sum{axis=0} [id BK] 23
 │        └─ Add [id U] 21
 │           └─ ···
 └─ Mul [id BL] 26
    ├─ [[-0.1]] [id BG]
    └─ SpecifyShape [id BI] 24
       └─ ···
Composite{...}.1 [id F] 'y_grad' 11
 └─ ···

Inner graphs:

Composite{((-0.5 * sqr(i0)) - 0.9189385332046727)} [id C]
 ← sub [id BM] 'o0'
    ├─ mul [id BN]
    │  ├─ -0.5 [id BO]
    │  └─ sqr [id BP]
    │     └─ i0 [id BQ]
    └─ 0.9189385332046727 [id BR]

Composite{...} [id F]
 ← sub [id BS] 'o0'
    ├─ i2 [id BT]
    └─ sub [id BU]
       ├─ i0 [id BV]
       └─ i1 [id BW]
 ← neg [id BX] 'o1'
    └─ sub [id BS] 'o0'
       └─ ···
 ← sub [id BY] 'o2'
    ├─ mul [id BZ]
    │  ├─ -0.5 [id CA]
    │  └─ sqr [id CB]
    │     └─ sub [id BS] 'o0'
    │        └─ ···
    └─ 0.9189385332046727 [id CC]

Composite{...} [id L]
 ← mul [id CD] 'o0'
    ├─ 0.00909090909090909 [id CE]
    └─ i0 [id CF]
 ← sub [id CG] 'o1'
    ├─ mul [id CD] 'o0'
    │  └─ ···
    └─ mul [id CH]
       ├─ 0.1 [id CI]
       └─ i0 [id CF]

Composite{...} [id O]
 ← mul [id CJ] 'o0'
    ├─ 0.00909090909090909 [id CK]
    └─ i0 [id CL]
 ← sub [id CM] 'o1'
    ├─ mul [id CJ] 'o0'
    │  └─ ···
    └─ mul [id CN]
       ├─ 0.1 [id CO]
       └─ i0 [id CL]

Composite{((-i0) + i1 + i2 + i3)} [id S]
 ← add [id CP] 'o0'
    ├─ neg [id CQ]
    │  └─ i0 [id CR]
    ├─ i1 [id CS]
    ├─ i2 [id CT]
    └─ i3 [id CU]

Composite{(0.00909090909090909 * (i0 - i1))} [id BA]
 ← mul [id CV] 'o0'
    ├─ 0.00909090909090909 [id CW]
    └─ sub [id CX]
       ├─ i0 [id CY]
       └─ i1 [id CZ]

Composite{(0.00909090909090909 * (i0 - i1))} [id BH]
 ← mul [id DA] 'o0'
    ├─ 0.00909090909090909 [id DB]
    └─ sub [id DC]
       ├─ i0 [id DD]
       └─ i1 [id DE]

Without PR

Composite{(switch(i0, i1, i2) + i3)} [id A] 33
 ├─ All{axes=None} [id B] 29
 │  └─ MakeVector{dtype='bool'} [id C] 26
 │     ├─ All{axes=None} [id D] 24
 │     │  └─ Composite{and(le((0.01 * abs(i0)), 1e-09), invert(or(isnan((0.01 * i0)), isinf((0.01 * i0)))))} [id E] 20
 │     │     └─ Sum{axis=1} [id F] 15
 │     │        └─ Composite{...}.0 [id G] 9
 │     │           ├─ Join [id H] 8
 │     │           │  ├─ 1 [id I]
 │     │           │  ├─ Sub [id J] 4
 │     │           │  │  ├─ Join [id K] 3
 │     │           │  │  │  ├─ 0 [id L]
 │     │           │  │  │  ├─ x_zerosum__ [id M]
 │     │           │  │  │  └─ Composite{...}.1 [id N] 2
 │     │           │  │  │     └─ ExpandDims{axis=0} [id O] 1
 │     │           │  │  │        └─ Sum{axis=0} [id P] 0
 │     │           │  │  │           └─ x_zerosum__ [id M]
 │     │           │  │  └─ Composite{...}.0 [id N] 2
 │     │           │  │     └─ ···
 │     │           │  └─ Composite{...}.1 [id Q] 7
 │     │           │     └─ ExpandDims{axis=1} [id R] 6
 │     │           │        └─ Sum{axis=1} [id S] 5
 │     │           │           └─ Sub [id J] 4
 │     │           │              └─ ···
 │     │           ├─ Composite{...}.0 [id Q] 7
 │     │           │  └─ ···
 │     │           └─ y [id T]
 │     └─ All{axes=None} [id U] 23
 │        └─ Composite{and(le((0.01 * abs(i0)), 1e-09), invert(or(isnan((0.01 * i0)), isinf((0.01 * i0)))))} [id V] 19
 │           └─ Sum{axis=0} [id W] 14
 │              └─ Composite{...}.0 [id G] 9
 │                 └─ ···
 ├─ Sum{axes=None} [id X] 13
 │  └─ Composite{...}.3 [id G] 9
 │     └─ ···
 ├─ -inf [id Y]
 └─ Sum{axes=None} [id Z] 12
    └─ Composite{...}.4 [id G] 'y_logprob' 9
       └─ ···
Add [id BA] 'x_zerosum___grad' 36
 ├─ SpecifyShape [id BB] 32
 │  ├─ Split{2}.0 [id BC] 28
 │  │  ├─ Add [id BD] 25
 │  │  │  ├─ SpecifyShape [id BE] 18
 │  │  │  │  ├─ Split{2}.0 [id BF] 11
 │  │  │  │  │  ├─ Composite{...}.1 [id G] 9
 │  │  │  │  │  │  └─ ···
 │  │  │  │  │  ├─ 1 [id I]
 │  │  │  │  │  └─ [99  1] [id BG]
 │  │  │  │  ├─ 100 [id BH]
 │  │  │  │  └─ 99 [id BI]
 │  │  │  ├─ Composite{(0.00909090909090909 * (i0 - i1))} [id BJ] 21
 │  │  │  │  ├─ SpecifyShape [id BK] 17
 │  │  │  │  │  ├─ Split{2}.1 [id BF] 11
 │  │  │  │  │  │  └─ ···
 │  │  │  │  │  ├─ NoneConst{None} [id BL]
 │  │  │  │  │  └─ 1 [id I]
 │  │  │  │  └─ ExpandDims{axis=1} [id BM] 16
 │  │  │  │     └─ Sum{axis=1} [id BN] 10
 │  │  │  │        └─ Composite{...}.1 [id G] 9
 │  │  │  │           └─ ···
 │  │  │  └─ Mul [id BO] 22
 │  │  │     ├─ [[-0.1]] [id BP]
 │  │  │     └─ SpecifyShape [id BK] 17
 │  │  │        └─ ···
 │  │  ├─ 0 [id L]
 │  │  └─ [99  1] [id BG]
 │  ├─ 99 [id BI]
 │  └─ 99 [id BI]
 ├─ Composite{(0.00909090909090909 * (i0 - i1))} [id BQ] 34
 │  ├─ SpecifyShape [id BR] 31
 │  │  ├─ Split{2}.1 [id BC] 28
 │  │  │  └─ ···
 │  │  ├─ 1 [id I]
 │  │  └─ NoneConst{None} [id BL]
 │  └─ ExpandDims{axis=0} [id BS] 30
 │     └─ Sum{axis=0} [id BT] 27
 │        └─ Add [id BD] 25
 │           └─ ···
 └─ Mul [id BU] 35
    ├─ [[-0.1]] [id BP]
    └─ SpecifyShape [id BR] 31
       └─ ···
Composite{...}.2 [id G] 'y_grad' 9
 └─ ···

Inner graphs:

Composite{(switch(i0, i1, i2) + i3)} [id A]
 ← add [id BV] 'o0'
    ├─ Switch [id BW]
    │  ├─ i0 [id BX]
    │  ├─ i1 [id BY]
    │  └─ i2 [id BZ]
    └─ i3 [id CA]

Composite{and(le((0.01 * abs(i0)), 1e-09), invert(or(isnan((0.01 * i0)), isinf((0.01 * i0)))))} [id E]
 ← AND [id CB] 'o0'
    ├─ LE [id CC]
    │  ├─ mul [id CD]
    │  │  ├─ t7{0.01} [id CE]
    │  │  └─ Abs [id CF]
    │  │     └─ i0 [id CG]
    │  └─ 1e-09 [id CH]
    └─ Invert [id CI]
       └─ OR [id CJ]
          ├─ IsNan [id CK]
          │  └─ mul [id CL] 't4'
          │     ├─ t7{0.01} [id CE]
          │     └─ i0 [id CG]
          └─ IsInf [id CM]
             └─ mul [id CL] 't4'
                └─ ···

Composite{...} [id G]
 ← sub [id CN] 'o0'
    ├─ i0 [id CO]
    └─ i1 [id CP]
 ← sub [id CQ] 'o1'
    ├─ sub [id CR] 't13'
    │  ├─ i2 [id CS]
    │  └─ sub [id CN] 'o0'
    │     └─ ···
    └─ sub [id CN] 'o0'
       └─ ···
 ← neg [id CT] 'o2'
    └─ sub [id CR] 't13'
       └─ ···
 ← sub [id CU] 'o3'
    ├─ mul [id CV]
    │  ├─ t6{-0.5} [id CW]
    │  └─ sqr [id CX]
    │     └─ sub [id CN] 'o0'
    │        └─ ···
    └─ 0.9006516563938997 [id CY]
 ← sub [id CZ] 'o4'
    ├─ mul [id DA]
    │  ├─ t6{-0.5} [id CW]
    │  └─ sqr [id DB]
    │     └─ sub [id CR] 't13'
    │        └─ ···
    └─ 0.9189385332046727 [id DC]

Composite{...} [id N]
 ← mul [id DD] 'o0'
    ├─ 0.00909090909090909 [id DE]
    └─ i0 [id DF]
 ← sub [id DG] 'o1'
    ├─ mul [id DD] 'o0'
    │  └─ ···
    └─ mul [id DH]
       ├─ 0.1 [id DI]
       └─ i0 [id DF]

Composite{...} [id Q]
 ← mul [id DJ] 'o0'
    ├─ 0.00909090909090909 [id DK]
    └─ i0 [id DL]
 ← sub [id DM] 'o1'
    ├─ mul [id DJ] 'o0'
    │  └─ ···
    └─ mul [id DN]
       ├─ 0.1 [id DO]
       └─ i0 [id DL]

Composite{and(le((0.01 * abs(i0)), 1e-09), invert(or(isnan((0.01 * i0)), isinf((0.01 * i0)))))} [id V]
 ← AND [id DP] 'o0'
    ├─ LE [id DQ]
    │  ├─ mul [id DR]
    │  │  ├─ t3{0.01} [id DS]
    │  │  └─ Abs [id DT]
    │  │     └─ i0 [id DU]
    │  └─ 1e-09 [id DV]
    └─ Invert [id DW]
       └─ OR [id DX]
          ├─ IsNan [id DY]
          │  └─ mul [id DZ] 't10'
          │     ├─ t3{0.01} [id DS]
          │     └─ i0 [id DU]
          └─ IsInf [id EA]
             └─ mul [id DZ] 't10'
                └─ ···

Composite{(0.00909090909090909 * (i0 - i1))} [id BJ]
 ← mul [id EB] 'o0'
    ├─ 0.00909090909090909 [id EC]
    └─ sub [id ED]
       ├─ i0 [id EE]
       └─ i1 [id EF]

Composite{(0.00909090909090909 * (i0 - i1))} [id BQ]
 ← mul [id EG] 'o0'
    ├─ 0.00909090909090909 [id EH]
    └─ sub [id EI]
       ├─ i0 [id EJ]
       └─ i1 [id EK]

That's 140 vs 200 lines. Got to admit that I'm not 100% sure what the difference really is though...

It's also a bit faster, with 70μs vs 90μs.

@ricardoV94
Copy link
Member Author

ricardoV94 commented Mar 11, 2024

That was with check_bounds=False

Then I'm surprised the isclose(mean, 0) switch was not removed

@aseyboldt
Copy link
Member

You are right, it seems check_bounds=False didn't do anything, because I wasn't compiling the pytensor functions in a model context. I had no idea I had to do that...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants