-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Truncated normal dispatches #7506
base: main
Are you sure you want to change the base?
Changes from 5 commits
e797cd9
8a9f17e
fd3b6d2
a7cb67b
1de633c
b31dc5d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Copyright 2024 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import jax | ||
|
||
from pytensor.link.jax.dispatch import jax_funcify | ||
|
||
from pymc.distributions.continuous import TruncatedNormalRV | ||
|
||
|
||
@jax_funcify.register(TruncatedNormalRV) | ||
def jax_funcify_TruncatedNormalRV(op, **kwargs): | ||
def trunc_normal_fn(key, size, mu, sigma, lower, upper): | ||
rng_key = key["jax_state"] | ||
rng_key, sampling_key = jax.random.split(rng_key, 2) | ||
key["jax_state"] = rng_key | ||
|
||
truncnorm = jax.nn.initializers.truncated_normal(sigma, lower=lower, upper=upper) | ||
|
||
return key, truncnorm(key["jax_state"], size) + mu | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Adding mu like this is potentially wrong, because when size is None, mu could be larger and we end up with repeated values There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also you should dispatch on the more specific |
||
|
||
return trunc_normal_fn |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# Copyright 2024 The PyMC Developers | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import numpy as np | ||
import pytest | ||
|
||
from pytensor import function | ||
|
||
import pymc as pm | ||
|
||
from pymc.dispatch import dispatch_jax # noqa: F401 | ||
|
||
jax = pytest.importorskip("jax") | ||
|
||
|
||
def test_jax_TruncatedNormal(): | ||
with pm.Model() as m: | ||
f_jax = function( | ||
[], | ||
[pm.TruncatedNormal("a", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))], | ||
mode="JAX", | ||
) | ||
f_py = function( | ||
[], | ||
[pm.TruncatedNormal("b", 0, 1, lower=-1, upper=2, rng=np.random.default_rng(seed=123))], | ||
) | ||
|
||
assert jax.numpy.array_equal(a1=f_py(), a2=f_jax()) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails with a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dispatch file needs to be imported when pymc is imported in order to be registered There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The two are not expected to match in values, because JAX uses a different implementation than numpy. You can make a TruncatedNormal with a large sigma, and confirm it does not go beyond the bounds as a check |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not this? https://jax.readthedocs.io/en/latest/_autosummary/jax.random.truncated_normal.html
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't pass sigma or mu as parameters in
jax.random.truncated_normal
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was able to use
jax.random.truncated_normal
, had to transformlower
andupper