Skip to content

Commit

Permalink
numpy frontend casting (ivy-llc#5177)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy authored Oct 8, 2022
1 parent e9a48ac commit a0a4ec6
Show file tree
Hide file tree
Showing 8 changed files with 256 additions and 59 deletions.
137 changes: 137 additions & 0 deletions ivy/functional/frontends/numpy/func_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import ivy
import functools
from typing import Callable


def _is_same_kind_or_safe(t1, t2):
if ivy.is_float_dtype(t1):
return ivy.is_float_dtype(t2) or ivy.can_cast(t1, t2)
elif ivy.is_uint_dtype(t1):
return ivy.is_uint_dtype(t2) or ivy.can_cast(t1, t2)
elif ivy.is_int_dtype(t1):
return ivy.is_int_dtype(t2) or ivy.can_cast(t1, t2)
elif ivy.is_bool_dtype(t1):
return ivy.is_bool_dtype(t2) or ivy.can_cast(t1, t2)
raise ivy.exceptions.IvyException(
"dtypes of input must be float, int, uint, or bool"
)


def _assert_args_and_fn(args, kwargs, dtype, fn):
ivy.assertions.check_all_or_any_fn(
*args,
fn=fn,
type="all",
message="type of input is incompatible with dtype: {}".format(dtype),
)
ivy.assertions.check_all_or_any_fn(
*kwargs,
fn=fn,
type="all",
message="type of input is incompatible with dtype: {}".format(dtype),
)


def handle_numpy_casting(fn: Callable) -> Callable:
@functools.wraps(fn)
def new_fn(*args, casting="same_kind", dtype=None, **kwargs):
"""
Check numpy casting type.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, or raise IvyException if error is thrown.
"""
ivy.assertions.check_elem_in_list(
casting,
["no", "equiv", "safe", "same_kind", "unsafe"],
message="casting must be one of [no, equiv, safe, same_kind, unsafe]",
)
args = list(args)
args_idxs = ivy.nested_argwhere(args, ivy.is_array)
args_to_check = ivy.multi_index_nest(args, args_idxs)
kwargs_idxs = ivy.nested_argwhere(kwargs, ivy.is_array)
kwargs_to_check = ivy.multi_index_nest(kwargs, kwargs_idxs)
if (args_to_check or kwargs_to_check) and (
casting == "no" or casting == "equiv"
):
first_arg = args_to_check[0] if args_to_check else kwargs_to_check[0]
fn_func = (
ivy.as_ivy_dtype(dtype) if ivy.exists(dtype) else ivy.dtype(first_arg)
)
_assert_args_and_fn(
args_to_check,
kwargs_to_check,
dtype,
fn=lambda x: ivy.dtype(x) == fn_func,
)
elif ivy.exists(dtype):
assert_fn = None
if casting == "safe":
assert_fn = lambda x: ivy.can_cast(x, ivy.as_ivy_dtype(dtype))
elif casting == "same_kind":
assert_fn = lambda x: _is_same_kind_or_safe(
ivy.dtype(x), ivy.as_ivy_dtype(dtype)
)
if ivy.exists(assert_fn):
_assert_args_and_fn(
args_to_check,
kwargs_to_check,
dtype,
fn=assert_fn,
)
ivy.map_nest_at_indices(
args, args_idxs, lambda x: ivy.astype(x, ivy.as_ivy_dtype(dtype))
)
ivy.map_nest_at_indices(
kwargs, kwargs_idxs, lambda x: ivy.astype(x, ivy.as_ivy_dtype(dtype))
)

return fn(*args, **kwargs)

new_fn.handle_numpy_casting = True
return new_fn


def handle_numpy_casting_special(fn: Callable) -> Callable:
@functools.wraps(fn)
def new_fn(*args, casting="same_kind", dtype=None, **kwargs):
"""
Check numpy casting type for special cases where output must be type bool.
Parameters
----------
args
The arguments to be passed to the function.
kwargs
The keyword arguments to be passed to the function.
Returns
-------
The return of the function, or raise IvyException if error is thrown.
"""
ivy.assertions.check_elem_in_list(
casting,
["no", "equiv", "safe", "same_kind", "unsafe"],
message="casting must be one of [no, equiv, safe, same_kind, unsafe]",
)
if ivy.exists(dtype):
ivy.assertions.check_equal(
ivy.as_ivy_dtype(dtype),
"bool",
message="output is compatible with bool only",
)

return fn(*args, **kwargs)

new_fn.handle_numpy_casting_special = True
return new_fn
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# global
import ivy
from ivy.functional.frontends.numpy.func_wrapper import handle_numpy_casting


def outer(a, b, out=None):
Expand All @@ -10,6 +11,7 @@ def inner(a, b, /):
return ivy.inner(a, b)


@handle_numpy_casting
def matmul(
x1, x2, /, out=None, *, casting="same_kind", order="K", dtype=None, subok=True
):
Expand Down
10 changes: 4 additions & 6 deletions ivy/functional/frontends/numpy/logic/array_type_testing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# global
import ivy
from ivy.functional.frontends.numpy.func_wrapper import handle_numpy_casting_special


@handle_numpy_casting_special
def isfinite(
x,
/,
Expand All @@ -13,14 +15,13 @@ def isfinite(
dtype=None,
subok=True,
):
if dtype:
x = ivy.astype(ivy.array(x), ivy.as_ivy_dtype(dtype))
ret = ivy.isfinite(x, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret


@handle_numpy_casting_special
def isinf(
x,
/,
Expand All @@ -32,14 +33,13 @@ def isinf(
dtype=None,
subok=True,
):
if dtype:
x = ivy.astype(ivy.array(x), ivy.as_ivy_dtype(dtype))
ret = ivy.isinf(x, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret


@handle_numpy_casting_special
def isnan(
x,
/,
Expand All @@ -51,8 +51,6 @@ def isnan(
dtype=None,
subok=True,
):
if dtype:
x = ivy.astype(ivy.array(x), ivy.as_ivy_dtype(dtype))
ret = ivy.isnan(x, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
Expand Down
25 changes: 7 additions & 18 deletions ivy/functional/frontends/numpy/logic/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

# local
from ivy.func_wrapper import from_zero_dim_arrays_to_float
from ivy.functional.frontends.numpy.func_wrapper import handle_numpy_casting


@from_zero_dim_arrays_to_float
@handle_numpy_casting
def equal(
x1,
x2,
Expand All @@ -18,9 +20,6 @@ def equal(
dtype=None,
subok=True,
):
if dtype:
x1 = ivy.astype(ivy.array(x1), ivy.as_ivy_dtype(dtype))
x2 = ivy.astype(ivy.array(x2), ivy.as_ivy_dtype(dtype))
ret = ivy.equal(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
Expand All @@ -36,6 +35,7 @@ def array_equal(a1, a2, equal_nan=False):
return ivy.array(ivy.array_equal(a1[~a1nan], a2[~a2nan]))


@handle_numpy_casting
def greater(
x1,
x2,
Expand All @@ -48,15 +48,13 @@ def greater(
dtype=None,
subok=True,
):
if dtype:
x1 = ivy.astype(ivy.array(x1), ivy.as_ivy_dtype(dtype))
x2 = ivy.astype(ivy.array(x2), ivy.as_ivy_dtype(dtype))
ret = ivy.greater(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret


@handle_numpy_casting
def greater_equal(
x1,
x2,
Expand All @@ -69,15 +67,13 @@ def greater_equal(
dtype=None,
subok=True,
):
if dtype:
x1 = ivy.astype(ivy.array(x1), ivy.as_ivy_dtype(dtype))
x2 = ivy.astype(ivy.array(x2), ivy.as_ivy_dtype(dtype))
ret = ivy.greater_equal(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret


@handle_numpy_casting
def less(
x1,
x2,
Expand All @@ -90,15 +86,13 @@ def less(
dtype=None,
subok=True,
):
if dtype:
x1 = ivy.astype(ivy.array(x1), ivy.as_ivy_dtype(dtype))
x2 = ivy.astype(ivy.array(x2), ivy.as_ivy_dtype(dtype))
ret = ivy.less(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret


@handle_numpy_casting
def less_equal(
x1,
x2,
Expand All @@ -111,15 +105,13 @@ def less_equal(
dtype=None,
subok=True,
):
if dtype:
x1 = ivy.astype(ivy.array(x1), ivy.as_ivy_dtype(dtype))
x2 = ivy.astype(ivy.array(x2), ivy.as_ivy_dtype(dtype))
ret = ivy.less_equal(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
return ret


@handle_numpy_casting
def not_equal(
x1,
x2,
Expand All @@ -132,9 +124,6 @@ def not_equal(
dtype=None,
subok=True,
):
if dtype:
x1 = ivy.astype(ivy.array(x1), ivy.as_ivy_dtype(dtype))
x2 = ivy.astype(ivy.array(x2), ivy.as_ivy_dtype(dtype))
ret = ivy.not_equal(x1, x2, out=out)
if ivy.is_array(where):
ret = ivy.where(where, ret, ivy.default(out, ivy.zeros_like(ret)), out=out)
Expand Down
37 changes: 37 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_numpy/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def where(draw):
return draw(st.just(values) | st.just(True))


@st.composite
def get_casting(draw):
return draw(st.sampled_from(["no", "equiv", "safe", "same_kind", "unsafe"]))


@st.composite
def dtype_x_bounded_axis(draw, **kwargs):
dtype, x, shape = draw(helpers.dtype_and_values(**kwargs, ret_shape=True))
Expand Down Expand Up @@ -148,3 +153,35 @@ def handle_where_and_array_bools(where, input_dtype, as_variable, native_array):
input_dtype += ["bool"]
return where, as_variable + [False], native_array + [False]
return where, as_variable, native_array


def handle_dtype_and_casting(
*,
dtypes,
get_dtypes_kind="valid",
get_dtypes_index=0,
get_dtypes_none=True,
get_dtypes_key=None,
):
casting = get_casting()
if casting in ["no", "equiv"]:
dtype = dtypes[0]
dtypes = [dtype for x in dtypes]
return dtype, dtypes, casting
dtype = helpers.get_dtypes(
get_dtypes_kind,
index=get_dtypes_index,
full=False,
none=get_dtypes_none,
key=get_dtypes_key,
)
if casting in ["safe", "same_kind"]:
while not ivy.all([ivy.can_cast(x, dtype) for x in dtypes]):
dtype = helpers.get_dtypes(
get_dtypes_kind,
index=get_dtypes_index,
full=False,
none=get_dtypes_none,
key=get_dtypes_key,
)
return dtype, dtypes, casting
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# local
import ivy_tests.test_ivy.helpers as helpers
import ivy_tests.test_ivy.test_frontends.test_numpy.helpers as np_frontend_helpers
from ivy_tests.test_ivy.helpers import handle_cmd_line_args
from ivy_tests.test_ivy.test_functional.test_core.test_linalg import (
_get_first_matrix_and_dtype,
Expand Down Expand Up @@ -94,8 +95,12 @@ def test_numpy_matmul(
):
dtype1, x1 = x
dtype2, x2 = y
dtype, dtypes, casting = np_frontend_helpers.handle_dtype_and_casting(
dtypes=dtype1 + dtype2,
get_dtypes_kind="numeric",
)
helpers.test_frontend_function(
input_dtypes=dtype1 + dtype2,
input_dtypes=dtypes,
as_variable_flags=as_variable,
with_out=with_out,
num_positional_args=num_positional_args,
Expand All @@ -104,6 +109,11 @@ def test_numpy_matmul(
fn_tree="matmul",
x1=x1,
x2=x2,
out=None,
casting=casting,
order="K",
dtype=dtype,
subok=True,
)


Expand Down
Loading

0 comments on commit a0a4ec6

Please sign in to comment.