From 8417f495e6b81a60833f86a978e5a8080a619aa0 Mon Sep 17 00:00:00 2001 From: Justus Magin Date: Tue, 9 Aug 2022 16:55:20 +0200 Subject: [PATCH] rely on `numpy`'s version of `nanprod` and `nansum` (#6873) Co-authored-by: dcherian --- xarray/core/nanops.py | 15 +++------------ xarray/tests/test_units.py | 25 +++++++++++++++++++------ 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/xarray/core/nanops.py b/xarray/core/nanops.py index d02d129aeba..920fd5a094e 100644 --- a/xarray/core/nanops.py +++ b/xarray/core/nanops.py @@ -17,15 +17,6 @@ dask_array_compat = None # type: ignore[assignment] -def _replace_nan(a, val): - """ - replace nan in a by val, and returns the replaced array and the nan - position - """ - mask = isnull(a) - return where_method(val, mask, a), mask - - def _maybe_null_out(result, axis, mask, min_count=1): """ xarray version of pandas.core.nanops._maybe_null_out @@ -105,8 +96,8 @@ def nanargmax(a, axis=None): def nansum(a, axis=None, dtype=None, out=None, min_count=None): - a, mask = _replace_nan(a, 0) - result = np.sum(a, axis=axis, dtype=dtype) + mask = isnull(a) + result = np.nansum(a, axis=axis, dtype=dtype) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) else: @@ -173,7 +164,7 @@ def nanstd(a, axis=None, dtype=None, out=None, ddof=0): def nanprod(a, axis=None, dtype=None, out=None, min_count=None): - a, mask = _replace_nan(a, 1) + mask = isnull(a) result = nputils.nanprod(a, axis=axis, dtype=dtype, out=out) if min_count is not None: return _maybe_null_out(result, axis, mask, min_count) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index f1b77296b82..52c50e28931 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -6,6 +6,7 @@ import numpy as np import pandas as pd import pytest +from packaging import version import xarray as xr from xarray.core import dtypes, duck_array_ops @@ -1530,8 +1531,12 @@ class TestVariable: ids=repr, ) def test_aggregation(self, func, dtype): - if func.name == "prod" and dtype.kind == "f": - pytest.xfail(reason="nanprod is not supported, yet") + if ( + func.name == "prod" + and dtype.kind == "f" + and version.parse(pint.__version__) < version.parse("0.19") + ): + pytest.xfail(reason="nanprod is not by older `pint` versions") array = np.linspace(0, 1, 10).astype(dtype) * ( unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless @@ -2387,8 +2392,12 @@ def test_repr(self, func, variant, dtype): ids=repr, ) def test_aggregation(self, func, dtype): - if func.name == "prod" and dtype.kind == "f": - pytest.xfail(reason="nanprod is not supported, yet") + if ( + func.name == "prod" + and dtype.kind == "f" + and version.parse(pint.__version__) < version.parse("0.19") + ): + pytest.xfail(reason="nanprod is not by older `pint` versions") array = np.arange(10).astype(dtype) * ( unit_registry.m if func.name != "cumprod" else unit_registry.dimensionless @@ -4082,8 +4091,12 @@ def test_repr(self, func, variant, dtype): ids=repr, ) def test_aggregation(self, func, dtype): - if func.name == "prod" and dtype.kind == "f": - pytest.xfail(reason="nanprod is not supported, yet") + if ( + func.name == "prod" + and dtype.kind == "f" + and version.parse(pint.__version__) < version.parse("0.19") + ): + pytest.xfail(reason="nanprod is not by older `pint` versions") unit_a, unit_b = ( (unit_registry.Pa, unit_registry.degK)