Skip to content

Commit

Permalink
Update groupby attrs tests (pydata#6787)
Browse files Browse the repository at this point in the history
Co-authored-by: Anderson Banihirwe <[email protected]>
  • Loading branch information
dcherian and andersy005 authored Jul 15, 2022
1 parent 5678b75 commit e086015
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions xarray/tests/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,21 +1152,33 @@ def test_groupby_count(self):
expected = DataArray([1, 1, 2], coords=[("cat", ["a", "b", "c"])])
assert_identical(actual, expected)

@pytest.mark.skip("needs to be fixed for shortcut=False, keep_attrs=False")
def test_groupby_reduce_attrs(self):
@pytest.mark.parametrize("shortcut", [True, False])
@pytest.mark.parametrize("keep_attrs", [None, True, False])
def test_groupby_reduce_keep_attrs(self, shortcut, keep_attrs):
array = self.da
array.attrs["foo"] = "bar"

actual = array.groupby("abc").reduce(
np.mean, keep_attrs=keep_attrs, shortcut=shortcut
)
with xr.set_options(use_flox=False):
expected = array.groupby("abc").mean(keep_attrs=keep_attrs)
assert_identical(expected, actual)

@pytest.mark.parametrize("keep_attrs", [None, True, False])
def test_groupby_keep_attrs(self, keep_attrs):
array = self.da
array.attrs["foo"] = "bar"

for shortcut in [True, False]:
for keep_attrs in [True, False]:
print(f"shortcut={shortcut}, keep_attrs={keep_attrs}")
actual = array.groupby("abc").reduce(
np.mean, keep_attrs=keep_attrs, shortcut=shortcut
)
expected = array.groupby("abc").mean()
if keep_attrs:
expected.attrs["foo"] = "bar"
assert_identical(expected, actual)
with xr.set_options(use_flox=False):
expected = array.groupby("abc").mean(keep_attrs=keep_attrs)
with xr.set_options(use_flox=True):
actual = array.groupby("abc").mean(keep_attrs=keep_attrs)

# values are tested elsewhere, here we jsut check data
# TODO: add check_attrs kwarg to assert_allclose
actual.data = expected.data
assert_identical(expected, actual)

def test_groupby_map_center(self):
def center(x):
Expand Down

0 comments on commit e086015

Please sign in to comment.