Skip to content

Commit

Permalink
rely on numpy's version of nanprod and nansum (pydata#6873)
Browse files Browse the repository at this point in the history
Co-authored-by: dcherian <[email protected]>
  • Loading branch information
keewis and dcherian authored Aug 9, 2022
1 parent 3c8ce0f commit 8417f49
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 18 deletions.
15 changes: 3 additions & 12 deletions xarray/core/nanops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,6 @@
dask_array_compat = None # type: ignore[assignment]


def _replace_nan(a, val):
"""
replace nan in a by val, and returns the replaced array and the nan
position
"""
mask = isnull(a)
return where_method(val, mask, a), mask


def _maybe_null_out(result, axis, mask, min_count=1):
"""
xarray version of pandas.core.nanops._maybe_null_out
Expand Down Expand Up @@ -105,8 +96,8 @@ def nanargmax(a, axis=None):


def nansum(a, axis=None, dtype=None, out=None, min_count=None):
a, mask = _replace_nan(a, 0)
result = np.sum(a, axis=axis, dtype=dtype)
mask = isnull(a)
result = np.nansum(a, axis=axis, dtype=dtype)
if min_count is not None:
return _maybe_null_out(result, axis, mask, min_count)
else:
Expand Down Expand Up @@ -173,7 +164,7 @@ def nanstd(a, axis=None, dtype=None, out=None, ddof=0):


def nanprod(a, axis=None, dtype=None, out=None, min_count=None):
a, mask = _replace_nan(a, 1)
mask = isnull(a)
result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out)
if min_count is not None:
return _maybe_null_out(result, axis, mask, min_count)
Expand Down
25 changes: 19 additions & 6 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
import pandas as pd
import pytest
from packaging import version

import xarray as xr
from xarray.core import dtypes, duck_array_ops
Expand Down Expand Up @@ -1530,8 +1531,12 @@ class TestVariable:
ids=repr,
)
def test_aggregation(self, func, dtype):
if func.name == "prod" and dtype.kind == "f":
pytest.xfail(reason="nanprod is not supported, yet")
if (
func.name == "prod"
and dtype.kind == "f"
and version.parse(pint.__version__) < version.parse("0.19")
):
pytest.xfail(reason="nanprod is not by older `pint` versions")

array = np.linspace(0, 1, 10).astype(dtype) * (
unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless
Expand Down Expand Up @@ -2387,8 +2392,12 @@ def test_repr(self, func, variant, dtype):
ids=repr,
)
def test_aggregation(self, func, dtype):
if func.name == "prod" and dtype.kind == "f":
pytest.xfail(reason="nanprod is not supported, yet")
if (
func.name == "prod"
and dtype.kind == "f"
and version.parse(pint.__version__) < version.parse("0.19")
):
pytest.xfail(reason="nanprod is not by older `pint` versions")

array = np.arange(10).astype(dtype) * (
unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless
Expand Down Expand Up @@ -4082,8 +4091,12 @@ def test_repr(self, func, variant, dtype):
ids=repr,
)
def test_aggregation(self, func, dtype):
if func.name == "prod" and dtype.kind == "f":
pytest.xfail(reason="nanprod is not supported, yet")
if (
func.name == "prod"
and dtype.kind == "f"
and version.parse(pint.__version__) < version.parse("0.19")
):
pytest.xfail(reason="nanprod is not by older `pint` versions")

unit_a, unit_b = (
(unit_registry.Pa, unit_registry.degK)
Expand Down

0 comments on commit 8417f49

Please sign in to comment.