Skip to content

Commit

Permalink
numpy frontend scalar fix in handling casting (ivy-llc#9955)
Browse files Browse the repository at this point in the history
  • Loading branch information
juliagsy authored Feb 13, 2023
1 parent d3102ce commit fa7ded3
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 51 deletions.
13 changes: 9 additions & 4 deletions ivy/assertions.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,15 @@ def check_exists(x, inverse=False, message=""):
)


def check_elem_in_list(elem, list, message=""):
message = message if message != "" else "{} must be one of {}".format(elem, list)
if elem not in list:
raise ivy.exceptions.IvyException(message)
def check_elem_in_list(elem, list, inverse=False, message=""):
if inverse and elem in list:
raise ivy.exceptions.IvyException(
message if message != "" else "{} must not be one of {}".format(elem, list)
)
elif not inverse and elem not in list:
raise ivy.exceptions.IvyException(
message if message != "" else "{} must be one of {}".format(elem, list)
)


def check_true(expression, message="expression must be True"):
Expand Down
4 changes: 2 additions & 2 deletions ivy/functional/frontends/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,13 @@ def promote_types_of_numpy_inputs(
and x1.shape == ()
and not (hasattr(x2, "shape") and x2.shape == ())
):
x1 = ivy.to_scalar(x1[()])
x1 = ivy.to_scalar(x1)
if (
hasattr(x2, "shape")
and x2.shape == ()
and not (hasattr(x1, "shape") and x1.shape == ())
):
x2 = ivy.to_scalar(x2[()])
x2 = ivy.to_scalar(x2)
type1 = ivy.default_dtype(item=x1).strip("u123456789")
type2 = ivy.default_dtype(item=x2).strip("u123456789")
if hasattr(x1, "dtype") and not hasattr(x2, "dtype") and type1 == type2:
Expand Down
186 changes: 141 additions & 45 deletions ivy/functional/frontends/numpy/func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,125 @@
import ivy.functional.frontends.numpy as np_frontend


# Helpers #
# ------- #

# general casting
def _assert_array(args, dtype, scalar_check=False, casting="safe"):
if args and dtype:
if not scalar_check:
ivy.assertions.check_all_or_any_fn(
*args,
fn=lambda x: np_frontend.can_cast(
x, ivy.as_ivy_dtype(dtype), casting=casting
),
type="all",
message="type of input is incompatible with dtype: {}".format(dtype),
)
else:
assert_fn = None if casting == "safe" else ivy.exists
if ivy.is_bool_dtype(dtype):
assert_fn = ivy.is_bool_dtype
if ivy.is_int_dtype(dtype):
assert_fn = lambda x: not ivy.is_float_dtype(x)

if assert_fn:
ivy.assertions.check_all_or_any_fn(
*args,
fn=lambda x: assert_fn(x)
if ivy.shape(x) == ()
else np_frontend.can_cast(
x, ivy.as_ivy_dtype(dtype), casting=casting
),
type="all",
message="type of input is incompatible with dtype: {}".format(
dtype
),
)


def _assert_scalar(args, dtype):
if args and dtype:
assert_fn = None
if ivy.is_int_dtype(dtype):
assert_fn = lambda x: type(x) != float
elif ivy.is_bool_dtype(dtype):
assert_fn = lambda x: type(x) == bool

if assert_fn:
ivy.assertions.check_all_or_any_fn(
*args,
fn=assert_fn,
type="all",
message="type of input is incompatible with dtype: {}".format(dtype),
)


# no casting
def _assert_no_array(args, dtype, scalar_check=False, none=False):
if args:
first_arg = args[0]
fn_func = ivy.as_ivy_dtype(dtype) if ivy.exists(dtype) else ivy.dtype(first_arg)
assert_fn = lambda x: ivy.dtype(x) == fn_func
if scalar_check:
assert_fn = (
lambda x: ivy.dtype(x) == fn_func
if ivy.shape(x) != ()
else _casting_no_special_case(ivy.dtype(x), fn_func, none)
)
ivy.assertions.check_all_or_any_fn(
*args,
fn=assert_fn,
type="all",
message="type of input is incompatible with dtype: {}".format(dtype),
)


def _casting_no_special_case(dtype1, dtype, none=False):
if dtype == "float16":
allowed_dtypes = ["float32", "float64"]
if not none:
allowed_dtypes += ["float16"]
return dtype1 in allowed_dtypes
if dtype in ["int8", "uint8"]:
if none:
return ivy.is_int_dtype(dtype1) and dtype1 not in ["int8", "uint8"]
return ivy.is_int_dtype(dtype1)
return dtype1 == dtype


def _assert_no_scalar(args, dtype, none=False):
if args:
first_arg = args[0]
ivy.assertions.check_all_or_any_fn(
*args,
fn=lambda x: type(x) == type(first_arg),
type="all",
message="type of input is incompatible with dtype {}".format(dtype),
)
if dtype:
if ivy.is_int_dtype(dtype):
check_dtype = int
elif ivy.is_float_dtype(dtype):
check_dtype = float
else:
check_dtype = bool
ivy.assertions.check_equal(
type(args[0]),
check_dtype,
message="type of input is incompatible with dtype {}".format(dtype),
)
if ivy.as_ivy_dtype(dtype) not in ["float64", "int8", "int64", "uint8"]:
if type(args[0]) == int:
ivy.assertions.check_elem_in_list(
dtype,
["int16", "int32", "uint16", "uint32", "uint64"],
inverse=True,
)
elif type(args[0]) == float:
ivy.assertions.check_equal(dtype, "float32", inverse=True)


def handle_numpy_dtype(fn: Callable) -> Callable:
@functools.wraps(fn)
def new_fn(*args, dtype=None, **kwargs):
Expand Down Expand Up @@ -37,21 +156,6 @@ def new_fn(*args, dtype=None, **kwargs):
return new_fn


def _assert_args_and_fn(args, kwargs, dtype, fn):
ivy.assertions.check_all_or_any_fn(
*args,
fn=fn,
type="all",
message="type of input is incompatible with dtype: {}".format(dtype),
)
ivy.assertions.check_all_or_any_fn(
*kwargs,
fn=fn,
type="all",
message="type of input is incompatible with dtype: {}".format(dtype),
)


def handle_numpy_casting(fn: Callable) -> Callable:
@functools.wraps(fn)
def new_fn(*args, casting="same_kind", dtype=None, **kwargs):
Expand All @@ -76,45 +180,37 @@ def new_fn(*args, casting="same_kind", dtype=None, **kwargs):
message="casting must be one of [no, equiv, safe, same_kind, unsafe]",
)
args = list(args)
args_scalar_idxs = ivy.nested_argwhere(
args, lambda x: isinstance(x, (int, float, bool))
)
args_scalar_to_check = ivy.multi_index_nest(args, args_scalar_idxs)
args_idxs = ivy.nested_argwhere(args, ivy.is_array)
args_to_check = ivy.multi_index_nest(args, args_idxs)
kwargs_idxs = ivy.nested_argwhere(kwargs, ivy.is_array)
kwargs_idxs.remove(["out"]) if ["out"] in kwargs_idxs else kwargs_idxs
kwargs_to_check = ivy.multi_index_nest(kwargs, kwargs_idxs)
if (args_to_check or kwargs_to_check) and (
casting == "no" or casting == "equiv"
):
first_arg = args_to_check[0] if args_to_check else kwargs_to_check[0]
fn_func = (
ivy.as_ivy_dtype(dtype) if ivy.exists(dtype) else ivy.dtype(first_arg)

if casting in ["no", "equiv"]:
none = not dtype
if none:
dtype = args_to_check[0].dtype if args_to_check else None
_assert_no_array(
args_to_check,
dtype,
scalar_check=(args_to_check and args_scalar_to_check),
none=none,
)
_assert_args_and_fn(
_assert_no_scalar(args_scalar_to_check, dtype, none=none)
elif casting in ["same_kind", "safe"]:
_assert_array(
args_to_check,
kwargs_to_check,
dtype,
fn=lambda x: ivy.dtype(x) == fn_func,
scalar_check=(args_to_check and args_scalar_to_check),
casting=casting,
)
elif ivy.exists(dtype):
assert_fn = None
if casting == "safe":
assert_fn = lambda x: np_frontend.can_cast(x, ivy.as_ivy_dtype(dtype))
elif casting == "same_kind":
assert_fn = lambda x: np_frontend.can_cast(
x, ivy.as_ivy_dtype(dtype), casting="same_kind"
)
if ivy.exists(assert_fn):
_assert_args_and_fn(
args_to_check,
kwargs_to_check,
dtype,
fn=assert_fn,
)
_assert_scalar(args_scalar_to_check, dtype)

if ivy.exists(dtype):
ivy.map_nest_at_indices(
args, args_idxs, lambda x: ivy.astype(x, ivy.as_ivy_dtype(dtype))
)
ivy.map_nest_at_indices(
kwargs, kwargs_idxs, lambda x: ivy.astype(x, ivy.as_ivy_dtype(dtype))
)

return fn(*args, **kwargs)

Expand Down

0 comments on commit fa7ded3

Please sign in to comment.