Skip to content

Commit

Permalink
MAINT: stats: avoid spurious warnings in ncx2.pdf
Browse files Browse the repository at this point in the history
  • Loading branch information
sethtroisi committed Jun 29, 2020
1 parent 14e067e commit d0976b3
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
19 changes: 15 additions & 4 deletions scipy/stats/_distn_infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from itertools import zip_longest

from scipy._lib import doccer
from scipy._lib._util import _lazywhere
from ._distr_params import distcont, distdiscrete
from scipy._lib._util import check_random_state
from scipy._lib._util import _valarray as valarray
Expand Down Expand Up @@ -560,12 +561,22 @@ def _ncx2_log_pdf(x, df, nc):
df2 = df/2.0 - 1.0
xs, ns = np.sqrt(x), np.sqrt(nc)
res = xlogy(df2/2.0, x/nc) - 0.5*(xs - ns)**2
res += np.log(ive(df2, xs*ns) / 2.0)
return res
corr = ive(df2, xs*ns) / 2.0
# Return res + np.log(corr) avoiding np.log(0)
return _lazywhere(
corr > 0,
(res, corr),
f=lambda r, c: r + np.log(c),
fillvalue=-np.inf)


def _ncx2_pdf(x, df, nc):
return np.exp(_ncx2_log_pdf(x, df, nc))
# Copy of _ncx2_log_pdf avoiding np.log(0) when corr = 0
df2 = df/2.0 - 1.0
xs, ns = np.sqrt(x), np.sqrt(nc)
res = xlogy(df2/2.0, x/nc) - 0.5*(xs - ns)**2
corr = ive(df2, xs*ns) / 2.0
return np.exp(res) * corr


def _ncx2_cdf(x, df, nc):
Expand Down Expand Up @@ -2519,7 +2530,7 @@ def expect(self, func=None, args=(), loc=0, scale=1, lb=None, ub=None,
--------
To understand the effect of the bounds of integration consider
>>> from scipy.stats import expon
>>> expon(1).expect(lambda x: 1, lb=0.0, ub=2.0)
0.6321205588285578
Expand Down
10 changes: 8 additions & 2 deletions scipy/stats/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3739,13 +3739,19 @@ def test_ncx2_tails_ticket_955():
def test_ncx2_tails_pdf():
# ncx2.pdf does not return nans in extreme tails(example from gh-1577)
# NB: this is to check that nan_to_num is not needed in ncx2.pdf
with suppress_warnings() as sup:
sup.filter(RuntimeWarning, "divide by zero encountered in log")
with warnings.catch_warnings():
warnings.simplefilter('error', RuntimeWarning)
assert_equal(stats.ncx2.pdf(1, np.arange(340, 350), 2), 0)
logval = stats.ncx2.logpdf(1, np.arange(340, 350), 2)

assert_(np.isneginf(logval).all())

# Verify logpdf has extended precision when pdf underflows to 0
with warnings.catch_warnings():
warnings.simplefilter('error', RuntimeWarning)
assert_equal(stats.ncx2.pdf(10000, 3, 12), 0)
assert_allclose(stats.ncx2.logpdf(10000, 3, 12), -4662.444377524883)


@pytest.mark.parametrize('method, expected', [
('cdf', np.array([2.497951336e-09, 3.437288941e-10])),
Expand Down

0 comments on commit d0976b3

Please sign in to comment.