Skip to content

Commit

Permalink
Add tests for wrap_output_like decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
jthielen committed Jan 14, 2020
1 parent 70f3141 commit 60fad89
Show file tree
Hide file tree
Showing 2 changed files with 235 additions and 5 deletions.
12 changes: 9 additions & 3 deletions src/metpy/calc/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,7 @@ def wrap_output_like(**wrap_kwargs):
This wrapping/conversion follows the following rules:
- If type of input and output match, do nothing
- If type of input and output match, do not covert type (but maybe convert unit)
- If type of output is ndarray, use ``np.asarray`` with dtype
- If type of output is pint.Quantity and input is...
- ndarray, return Quantity
Expand Down Expand Up @@ -1531,7 +1531,8 @@ def wrapper(*args, **kwargs):
if 'other' in wrap_kwargs:
other = wrap_kwargs['other']
elif 'argument' in wrap_kwargs:
other = signature(func).bind(*args, **kwargs).arguments[wrap_kwargs['argument']]
other = signature(func).bind(*args, **kwargs).arguments[
wrap_kwargs['argument']]
else:
raise ValueError('Must specify keyword "other" or "argument".')

Expand All @@ -1547,6 +1548,11 @@ def wrapper(*args, **kwargs):

# Proceed with wrapping rules
if isinstance(result, type(other)):
if wrap_kwargs.get('match_unit', False) and hasattr(other, 'units'):
if isinstance(result, xr.DataArray):
result.metpy.convert_units(other.units)
else:
result = result.to(other.units)
return result
elif isinstance(other, np.ndarray):
return np.asarray(result, dtype=other.dtype)
Expand All @@ -1562,7 +1568,7 @@ def wrapper(*args, **kwargs):
else:
return result.metpy.unit_array
else:
if wrap_kwargs.get('match_unit', False) and hasattr(other.attrs, 'units'):
if wrap_kwargs.get('match_unit', False) and 'units' in other.attrs:
if isinstance(result, units.Quantity):
return xr.DataArray(result.m_as(other.attrs['units']), dims=other.dims,
coords=other.coords,
Expand Down
228 changes: 226 additions & 2 deletions tests/calc/test_calc_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
reduce_point_density, resample_nn_1d, second_derivative)
from metpy.calc.tools import (_delete_masked_points, _get_bound_pressure_height,
_greater_or_close, _less_or_close, _next_non_masked_element,
_remove_nans, BASE_DEGREE_MULTIPLIER, DIR_STRS, UND)
_remove_nans, BASE_DEGREE_MULTIPLIER, DIR_STRS, UND,
wrap_output_like)
from metpy.testing import (assert_almost_equal, assert_array_almost_equal, assert_array_equal)
from metpy.units import units
from metpy.units import DimensionalityError, units


FULL_CIRCLE_DEGREES = np.arange(0, 360, BASE_DEGREE_MULTIPLIER.m) * units.degree
Expand Down Expand Up @@ -1260,3 +1261,226 @@ def test_remove_nans():
y_expected = np.array([0, 1, 3, 4])
assert_array_almost_equal(x_expected, x_test, 0)
assert_almost_equal(y_expected, y_test, 0)


@pytest.mark.parametrize('test, other, match_unit, expected', [
(np.arange(4), np.ones(3), False, np.arange(4)),
(np.arange(4), np.ones(3), True, np.arange(4)),
(np.arange(4), [0] * units.m, False, np.arange(4) * units('dimensionless')),
(np.arange(4), [0] * units.m, True, np.arange(4) * units.m),
(
np.arange(4),
xr.DataArray(
np.zeros(4),
dims=('x',),
coords={'x': np.linspace(0, 1, 4)},
attrs={'units': 'meter', 'description': 'Just some zeros'}
),
False,
xr.DataArray(
np.arange(4),
dims=('x',),
coords={'x': np.linspace(0, 1, 4)}
)
),
(
np.arange(4),
xr.DataArray(
np.zeros(4),
dims=('x',),
coords={'x': np.linspace(0, 1, 4)},
attrs={'units': 'meter', 'description': 'Just some zeros'}
),
True,
xr.DataArray(
np.arange(4),
dims=('x',),
coords={'x': np.linspace(0, 1, 4)},
attrs={'units': 'meter'}
)
),
([2, 4, 8] * units.kg, np.ones(3), False, np.array([2, 4, 8])),
([2, 4, 8] * units.kg, np.ones(3), True, np.array([2, 4, 8])),
([2, 4, 8] * units.kg, [0] * units.m, False, [2, 4, 8] * units.kg),
([2, 4, 8] * units.kg, [0] * units.g, True, [2000, 4000, 8000] * units.g),
(
[2, 4, 8] * units.kg,
xr.DataArray(
np.zeros(3),
dims=('x',),
coords={'x': np.linspace(0, 1, 3)},
attrs={'units': 'meter'}
),
False,
xr.DataArray(
[2, 4, 8],
dims=('x',),
coords={'x': np.linspace(0, 1, 3)},
attrs={'units': 'kilogram'}
)
),
(
[2, 4, 8] * units.kg,
xr.DataArray(
np.zeros(3),
dims=('x',),
coords={'x': np.linspace(0, 1, 3)},
attrs={'units': 'gram'}
),
True,
xr.DataArray(
[2000, 4000, 8000],
dims=('x',),
coords={'x': np.linspace(0, 1, 3)},
attrs={'units': 'gram'}
)
),
(
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter', 'description': 'A range of values'}
),
np.arange(4, dtype=np.float64),
False,
np.linspace(0, 1, 5)
),
(
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter', 'description': 'A range of values'}
),
np.arange(4, dtype=np.float64),
True,
np.linspace(0, 1, 5)
),
(
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter', 'description': 'A range of values'}
),
[0] * units.kg,
False,
np.linspace(0, 1, 5) * units.m
),
(
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter', 'description': 'A range of values'}
),
[0] * units.cm,
True,
np.linspace(0, 100, 5) * units.cm
),
(
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter', 'description': 'A range of values'}
),
xr.DataArray(
np.zeros(3),
dims=('x',),
coords={'x': np.linspace(0, 1, 3)},
attrs={'units': 'kilogram', 'description': 'Alternative data'}
),
False,
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter', 'description': 'A range of values'}
)
),
(
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter', 'description': 'A range of values'}
),
xr.DataArray(
np.zeros(3),
dims=('x',),
coords={'x': np.linspace(0, 1, 3)},
attrs={'units': 'centimeter', 'description': 'Alternative data'}
),
True,
xr.DataArray(
np.linspace(0, 100, 5),
attrs={'units': 'centimeter', 'description': 'A range of values'}
)
),
])
def test_wrap_output_like_with_other_kwarg(test, other, match_unit, expected):
"""Test the wrap output like decorator when using the output kwarg."""
@wrap_output_like(other=other, match_unit=match_unit)
def almost_identity(arg):
return arg

result = almost_identity(test)

if hasattr(expected, 'units'):
assert expected.units == result.units
if isinstance(expected, xr.DataArray):
xr.testing.assert_identical(result, expected)
else:
assert_array_equal(result, expected)


@pytest.mark.parametrize('test, other', [
([2, 4, 8] * units.kg, [0] * units.m),
(
[2, 4, 8] * units.kg,
xr.DataArray(
np.zeros(3),
dims=('x',),
coords={'x': np.linspace(0, 1, 3)},
attrs={'units': 'meter'}
)
),
(
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter'}
),
[0] * units.kg
),
(
xr.DataArray(
np.linspace(0, 1, 5),
attrs={'units': 'meter'}
),
xr.DataArray(
np.zeros(3),
dims=('x',),
coords={'x': np.linspace(0, 1, 3)},
attrs={'units': 'kilogram'}
)
)
])
def test_wrap_output_like_with_other_kwarg_raising_dimensionality_error(test, other):
"""Test the wrap output like decorator when when a dimensionality error is raised."""
@wrap_output_like(other=other, match_unit=True)
def almost_identity(arg):
return arg

with pytest.raises(DimensionalityError):
almost_identity(test)


def test_wrap_output_like_with_argument_kwarg():
"""Test the wrap output like decorator with signature recognition."""
@wrap_output_like(argument='a')
def double(a):
return units.Quantity(2) * a.metpy.unit_array

test = xr.DataArray([1, 3, 5, 7], attrs={'units': 'm'})
expected = xr.DataArray([2, 6, 10, 14], attrs={'units': 'meter'})

xr.testing.assert_identical(double(test), expected)


def test_wrap_output_like_without_control_kwarg():
"""Test that the wrap output like decorator fails when not provided a control param."""
@wrap_output_like()
def func(arg):
return np.array(arg)

with pytest.raises(ValueError) as exc:
func(0)
assert 'Must specify keyword' in str(exc)

0 comments on commit 60fad89

Please sign in to comment.