Skip to content

Commit

Permalink
slow.nansum and slow.ss now longer coerce dtype
Browse files Browse the repository at this point in the history
closes pydata#180
  • Loading branch information
kwgoodman committed Dec 12, 2017
1 parent 449d5c3 commit 66b6f78
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 17 deletions.
2 changes: 2 additions & 0 deletions RELEASE.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ Bottleneck 1.3.0

- Move documentation to https://kwgoodman.github.io/bottleneck-doc
- Remove numpydoc package from Bottleneck source distribution
- bn.slow.nansum and bn.slow.ss now longer coerce output to have the same
dtype as input

**Bug Fixes**

Expand Down
13 changes: 1 addition & 12 deletions bottleneck/slow/reduce.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,12 @@
import warnings
import numpy as np
from numpy import nanmean
from numpy import nanmean, nansum

__all__ = ['median', 'nanmedian', 'nansum', 'nanmean', 'nanvar', 'nanstd',
'nanmin', 'nanmax', 'nanargmin', 'nanargmax', 'ss', 'anynan',
'allnan']


def nansum(a, axis=None):
"Slow nansum function used for unaccelerated dtype."
a = np.asarray(a)
y = np.nansum(a, axis=axis)
if y.dtype != a.dtype:
y = y.astype(a.dtype)
return y


def nanargmin(a, axis=None):
"Slow nanargmin function used for unaccelerated dtypes."
with warnings.catch_warnings():
Expand Down Expand Up @@ -76,8 +67,6 @@ def ss(a, axis=None):
"Slow sum of squares used for unaccelerated dtypes."
a = np.asarray(a)
y = np.multiply(a, a).sum(axis)
if y.dtype != a.dtype:
y = y.astype(a.dtype)
return y


Expand Down
11 changes: 6 additions & 5 deletions bottleneck/tests/reduce_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_reduce():
yield unit_maker, func


def unit_maker(func, decimal=5):
def unit_maker(func, decimal=5, skip_dtype=['nansum', 'ss']):
"Test that bn.xxx gives the same output as bn.slow.xxx."
fmt = '\nfunc %s | input %s (%s) | shape %s | axis %s | order %s\n'
fmt += '\nInput array:\n%s\n'
Expand Down Expand Up @@ -62,10 +62,11 @@ def unit_maker(func, decimal=5):
ok_(False, err_msg)
assert_array_almost_equal(actual, desired, decimal, err_msg)
err_msg += '\n dtype mismatch %s %s'
if hasattr(actual, 'dtype') and hasattr(desired, 'dtype'):
da = actual.dtype
dd = desired.dtype
assert_equal(da, dd, err_msg % (da, dd))
if name not in skip_dtype:
if hasattr(actual, 'dtype') and hasattr(desired, 'dtype'):
da = actual.dtype
dd = desired.dtype
assert_equal(da, dd, err_msg % (da, dd))


# ---------------------------------------------------------------------------
Expand Down

0 comments on commit 66b6f78

Please sign in to comment.