Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Truncated normal dispatches #7506

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ jobs:
tests/distributions/test_shape_utils.py
tests/distributions/test_mixture.py
tests/test_testing.py
tests/dispatch/test_jax.py

- |
tests/distributions/test_continuous.py
Expand Down
32 changes: 32 additions & 0 deletions pymc/dispatch/dispatch_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax

from pytensor.link.jax.dispatch import jax_funcify

from pymc.distributions.continuous import TruncatedNormalRV


@jax_funcify.register(TruncatedNormalRV)
def jax_funcify_TruncatedNormalRV(op, **kwargs):
def trunc_normal_fn(key, size, mu, sigma, lower, upper):
rng_key = key["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
key["jax_state"] = rng_key

truncnorm = jax.nn.initializers.truncated_normal(sigma, lower=lower, upper=upper)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't pass sigma or mu as parameters in jax.random.truncated_normal

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was able to use jax.random.truncated_normal, had to transform lower and upper


return key, truncnorm(key["jax_state"], size) + mu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding mu like this is potentially wrong, because when size is None, mu could be larger and we end up with repeated values

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also you should dispatch on the more specific jax_sample_fn. For the issue with broadcasting, check how we do it here for Normal for example: https://github.com/pymc-devs/pytensor/blob/5d4b0c4b9a1e478dda48e912ee708a9e557e9343/pytensor/link/jax/dispatch/random.py#L147-L173


return trunc_normal_fn
38 changes: 38 additions & 0 deletions tests/dispatch/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import pytest

from pytensor import function

import pymc as pm

from pymc.dispatch import dispatch_jax # noqa: F401

jax = pytest.importorskip("jax")


def test_jax_TruncatedNormal():
with pm.Model() as m:
f_jax = function(
[],
[pm.TruncatedNormal("a", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))],
mode="JAX",
)
f_py = function(
[],
[pm.TruncatedNormal("b", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))],
)

assert jax.numpy.array_equal(a1=f_py(), a2=f_jax())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails with a NotImplementedError: No JAX implementation for the given distribution: truncated_normal

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dispatch file needs to be imported when pymc is imported in order to be registered

Copy link
Member

@ricardoV94 ricardoV94 Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The two are not expected to match in values, because JAX uses a different implementation than numpy. You can make a TruncatedNormal with a large sigma, and confirm it does not go beyond the bounds as a check

Loading