Skip to content

Commit

Permalink
provide a error summary for assert_allclose (pydata#3847)
Browse files Browse the repository at this point in the history
* allow passing a callable as compat to diff_{dataset,array}_repr

* rewrite assert_allclose to provide a failure summary

* make sure we're comparing variables

* remove spurious comments

* override test_aggregate_complex with a test compatible with pint

* expect the asserts to raise

* xfail the tests failing due to isclose not accepting non-quantity tolerances

* mark top-level function tests as xfailing if they use assert_allclose

* mark test_1d_math as runnable but xfail it

* bump dask and distributed

* entry to whats-new.rst

* attempt to fix the failing py36-min-all-deps and py36-min-nep18 CI

* conditionally xfail tests using assert_allclose with pint < 0.12

* xfail more tests depending on which pint version is used

* try using numpy.testing.assert_allclose instead

* try computing if the dask version is too old and dask.array[bool]

* fix the dask version checking

* convert all dask arrays to numpy when using a insufficient dask version
  • Loading branch information
keewis authored Jun 13, 2020
1 parent e26b80f commit 2ba5300
Show file tree
Hide file tree
Showing 9 changed files with 150 additions and 28 deletions.
4 changes: 2 additions & 2 deletions ci/requirements/py36-min-all-deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ dependencies:
- cfgrib=0.9
- cftime=1.0
- coveralls
- dask=2.2
- distributed=2.2
- dask=2.5
- distributed=2.5
- flake8
- h5netcdf=0.7
- h5py=2.9 # Policy allows for 2.10, but it's a conflict-fest
Expand Down
4 changes: 2 additions & 2 deletions ci/requirements/py36-min-nep18.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ dependencies:
# require drastically newer packages than everything else
- python=3.6
- coveralls
- dask=2.4
- distributed=2.4
- dask=2.5
- distributed=2.5
- msgpack-python=0.6 # remove once distributed is bumped. distributed GH3491
- numpy=1.17
- pandas=0.25
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ New Features
:py:meth:`core.groupby.DatasetGroupBy.quantile`, :py:meth:`core.groupby.DataArrayGroupBy.quantile`
(:issue:`3843`, :pull:`3844`)
By `Aaron Spring <https://github.com/aaronspring>`_.
- Add a diff summary for `testing.assert_allclose`. (:issue:`3617`, :pull:`3847`)
By `Justus Magin <https://github.com/keewis>`_.

Bug fixes
~~~~~~~~~
Expand Down
20 changes: 20 additions & 0 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import contextlib
import inspect
import warnings
from distutils.version import LooseVersion
from functools import partial

import numpy as np
Expand All @@ -20,6 +21,14 @@
except ImportError:
dask_array = None # type: ignore

# TODO: remove after we stop supporting dask < 2.9.1
try:
import dask

dask_version = dask.__version__
except ImportError:
dask_version = None


def _dask_or_eager_func(
name,
Expand Down Expand Up @@ -199,8 +208,19 @@ def allclose_or_equiv(arr1, arr2, rtol=1e-5, atol=1e-8):
"""
arr1 = asarray(arr1)
arr2 = asarray(arr2)

lazy_equiv = lazy_array_equiv(arr1, arr2)
if lazy_equiv is None:
# TODO: remove after we require dask >= 2.9.1
sufficient_dask_version = (
dask_version is not None and LooseVersion(dask_version) >= "2.9.1"
)
if not sufficient_dask_version and any(
isinstance(arr, dask_array_type) for arr in [arr1, arr2]
):
arr1 = np.array(arr1)
arr2 = np.array(arr2)

return bool(isclose(arr1, arr2, rtol=rtol, atol=atol, equal_nan=True).all())
else:
return lazy_equiv
Expand Down
16 changes: 14 additions & 2 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,10 @@ def extra_items_repr(extra_keys, mapping, ab_side):
for k in a_keys & b_keys:
try:
# compare xarray variable
compatible = getattr(a_mapping[k], compat)(b_mapping[k])
if not callable(compat):
compatible = getattr(a_mapping[k], compat)(b_mapping[k])
else:
compatible = compat(a_mapping[k], b_mapping[k])
is_variable = True
except AttributeError:
# compare attribute value
Expand Down Expand Up @@ -596,8 +599,13 @@ def extra_items_repr(extra_keys, mapping, ab_side):


def _compat_to_str(compat):
if callable(compat):
compat = compat.__name__

if compat == "equals":
return "equal"
elif compat == "allclose":
return "close"
else:
return compat

Expand All @@ -611,8 +619,12 @@ def diff_array_repr(a, b, compat):
]

summary.append(diff_dim_summary(a, b))
if callable(compat):
equiv = compat
else:
equiv = array_equiv

if not array_equiv(a.data, b.data):
if not equiv(a.data, b.data):
temp = [wrap_indent(short_numpy_repr(obj), start=" ") for obj in (a, b)]
diff_data_repr = [
ab_side + "\n" + ab_data_repr
Expand Down
43 changes: 24 additions & 19 deletions xarray/testing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
"""Testing functions exposed to the user API"""
import functools
from typing import Hashable, Set, Union

import numpy as np
import pandas as pd

from xarray.core import duck_array_ops, formatting
from xarray.core import duck_array_ops, formatting, utils
from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.indexes import default_indexes
Expand Down Expand Up @@ -118,27 +119,31 @@ def assert_allclose(a, b, rtol=1e-05, atol=1e-08, decode_bytes=True):
"""
__tracebackhide__ = True
assert type(a) == type(b)
kwargs = dict(rtol=rtol, atol=atol, decode_bytes=decode_bytes)

equiv = functools.partial(
_data_allclose_or_equiv, rtol=rtol, atol=atol, decode_bytes=decode_bytes
)
equiv.__name__ = "allclose"

def compat_variable(a, b):
a = getattr(a, "variable", a)
b = getattr(b, "variable", b)

return a.dims == b.dims and (a._data is b._data or equiv(a.data, b.data))

if isinstance(a, Variable):
assert a.dims == b.dims
allclose = _data_allclose_or_equiv(a.values, b.values, **kwargs)
assert allclose, f"{a.values}\n{b.values}"
allclose = compat_variable(a, b)
assert allclose, formatting.diff_array_repr(a, b, compat=equiv)
elif isinstance(a, DataArray):
assert_allclose(a.variable, b.variable, **kwargs)
assert set(a.coords) == set(b.coords)
for v in a.coords.variables:
# can't recurse with this function as coord is sometimes a
# DataArray, so call into _data_allclose_or_equiv directly
allclose = _data_allclose_or_equiv(
a.coords[v].values, b.coords[v].values, **kwargs
)
assert allclose, "{}\n{}".format(a.coords[v].values, b.coords[v].values)
allclose = utils.dict_equiv(
a.coords, b.coords, compat=compat_variable
) and compat_variable(a.variable, b.variable)
assert allclose, formatting.diff_array_repr(a, b, compat=equiv)
elif isinstance(a, Dataset):
assert set(a.data_vars) == set(b.data_vars)
assert set(a.coords) == set(b.coords)
for k in list(a.variables) + list(a.coords):
assert_allclose(a[k], b[k], **kwargs)

allclose = a._coord_names == b._coord_names and utils.dict_equiv(
a.variables, b.variables, compat=compat_variable
)
assert allclose, formatting.diff_dataset_repr(a, b, compat=equiv)
else:
raise TypeError("{} not supported by assertion comparison".format(type(a)))

Expand Down
2 changes: 1 addition & 1 deletion xarray/tests/test_duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def test_reduce(dim_num, dtype, dask, func, skipna, aggdim):

actual = getattr(da, func)(skipna=skipna, dim=aggdim)
assert_dask_array(actual, dask)
assert np.allclose(
np.testing.assert_allclose(
actual.values, np.array(expected), rtol=1.0e-4, equal_nan=True
)
except (TypeError, AttributeError, ZeroDivisionError):
Expand Down
25 changes: 25 additions & 0 deletions xarray/tests/test_testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,32 @@
import pytest

import xarray as xr


def test_allclose_regression():
x = xr.DataArray(1.01)
y = xr.DataArray(1.02)
xr.testing.assert_allclose(x, y, atol=0.01)


@pytest.mark.parametrize(
"obj1,obj2",
(
pytest.param(
xr.Variable("x", [1e-17, 2]), xr.Variable("x", [0, 3]), id="Variable",
),
pytest.param(
xr.DataArray([1e-17, 2], dims="x"),
xr.DataArray([0, 3], dims="x"),
id="DataArray",
),
pytest.param(
xr.Dataset({"a": ("x", [1e-17, 2]), "b": ("y", [-2e-18, 2])}),
xr.Dataset({"a": ("x", [0, 2]), "b": ("y", [0, 1])}),
id="Dataset",
),
),
)
def test_assert_allclose(obj1, obj2):
with pytest.raises(AssertionError):
xr.testing.assert_allclose(obj1, obj2)
62 changes: 60 additions & 2 deletions xarray/tests/test_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,10 @@ def test_apply_ufunc_dataset(dtype):
assert_identical(expected, actual)


# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize(
"unit,error",
(
Expand Down Expand Up @@ -512,6 +516,10 @@ def test_align_dataarray(fill_value, variant, unit, error, dtype):
assert_allclose(expected_b, actual_b)


# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize(
"unit,error",
(
Expand Down Expand Up @@ -929,6 +937,10 @@ def test_concat_dataset(variant, unit, error, dtype):
assert_identical(expected, actual)


# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize(
"unit,error",
(
Expand Down Expand Up @@ -1036,6 +1048,10 @@ def test_merge_dataarray(variant, unit, error, dtype):
assert_allclose(expected, actual)


# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize(
"unit,error",
(
Expand Down Expand Up @@ -1385,7 +1401,6 @@ def wrapper(cls):
"test_datetime64_conversion",
"test_timedelta64_conversion",
"test_pandas_period_index",
"test_1d_math",
"test_1d_reduce",
"test_array_interface",
"test___array__",
Expand Down Expand Up @@ -1413,6 +1428,13 @@ def example_1d_objects(self):
]:
yield (self.cls("x", data), data)

# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
def test_real_and_imag(self):
super().test_real_and_imag()

@pytest.mark.parametrize(
"func",
(
Expand Down Expand Up @@ -1450,6 +1472,22 @@ def test_aggregation(self, func, dtype):
assert_units_equal(expected, actual)
xr.testing.assert_identical(expected, actual)

# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
def test_aggregate_complex(self):
variable = xr.Variable("x", [1, 2j, np.nan] * unit_registry.m)
expected = xr.Variable((), (0.5 + 1j) * unit_registry.m)
actual = variable.mean()

assert_units_equal(expected, actual)
xr.testing.assert_allclose(expected, actual)

# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize(
"func",
(
Expand Down Expand Up @@ -1748,6 +1786,10 @@ def test_isel(self, indices, dtype):
assert_units_equal(expected, actual)
xr.testing.assert_identical(expected, actual)

# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize(
"unit,error",
(
Expand Down Expand Up @@ -2224,6 +2266,10 @@ def test_repr(self, func, variant, dtype):
# warnings or errors, but does not check the result
func(data_array)

# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose",
)
@pytest.mark.parametrize(
"func",
(
Expand All @@ -2235,7 +2281,7 @@ def test_repr(self, func, variant, dtype):
function("mean"),
pytest.param(
function("median"),
marks=pytest.mark.xfail(
marks=pytest.mark.skip(
reason="median does not work with dataarrays yet"
),
),
Expand Down Expand Up @@ -3283,6 +3329,10 @@ def test_head_tail_thin(self, func, dtype):
assert_units_equal(expected, actual)
xr.testing.assert_identical(expected, actual)

# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize("variant", ("data", "coords"))
@pytest.mark.parametrize(
"func",
Expand Down Expand Up @@ -3356,6 +3406,10 @@ def test_interp_reindex_indexing(self, func, unit, error, dtype):
assert_units_equal(expected, actual)
xr.testing.assert_identical(expected, actual)

# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize("variant", ("data", "coords"))
@pytest.mark.parametrize(
"func",
Expand Down Expand Up @@ -3558,6 +3612,10 @@ def test_computation(self, func, dtype):
assert_units_equal(expected, actual)
xr.testing.assert_identical(expected, actual)

# TODO: remove once pint==0.12 has been released
@pytest.mark.xfail(
LooseVersion(pint.__version__) <= "0.11", reason="pint bug in isclose"
)
@pytest.mark.parametrize(
"func",
(
Expand Down

0 comments on commit 2ba5300

Please sign in to comment.