Skip to content

Commit

Permalink
move assertions, exceptions, inspection, verbosity modules to…
Browse files Browse the repository at this point in the history
… `ivy.utils` subpackage.
  • Loading branch information
CatB1t committed Feb 20, 2023
1 parent 3ecbd8c commit e33db02
Show file tree
Hide file tree
Showing 134 changed files with 718 additions and 623 deletions.
74 changes: 42 additions & 32 deletions ivy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def is_local():

class FrameworkStr(str):
def __new__(cls, fw_str):
ivy.assertions.check_elem_in_list(
ivy.utils.assertions.check_elem_in_list(
fw_str, ["jax", "tensorflow", "torch", "numpy"]
)
return str.__new__(cls, fw_str)
Expand Down Expand Up @@ -91,10 +91,10 @@ class Array:
class Device(str):
def __new__(cls, dev_str):
if dev_str != "":
ivy.assertions.check_elem_in_list(dev_str[0:3], ["gpu", "tpu", "cpu"])
ivy.utils.assertions.check_elem_in_list(dev_str[0:3], ["gpu", "tpu", "cpu"])
if dev_str != "cpu":
# ivy.assertions.check_equal(dev_str[3], ":")
ivy.assertions.check_true(
ivy.utils.assertions.check_true(
dev_str[4:].isnumeric(),
message="{} must be numeric".format(dev_str[4:]),
)
Expand All @@ -112,17 +112,19 @@ def __new__(cls, dtype_str):
if dtype_str is builtins.bool:
dtype_str = "bool"
if not isinstance(dtype_str, str):
raise ivy.exceptions.IvyException("dtype must be type str")
raise ivy.utils.exceptions.IvyException("dtype must be type str")
if dtype_str not in _all_ivy_dtypes_str:
raise ivy.exceptions.IvyException(f"{dtype_str} is not supported by ivy")
raise ivy.utils.exceptions.IvyException(
f"{dtype_str} is not supported by ivy"
)
return str.__new__(cls, dtype_str)

def __ge__(self, other):
if isinstance(other, str):
other = Dtype(other)

if not isinstance(other, Dtype):
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"Attempted to compare a dtype with something which"
"couldn't be interpreted as a dtype"
)
Expand All @@ -134,7 +136,7 @@ def __gt__(self, other):
other = Dtype(other)

if not isinstance(other, Dtype):
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"Attempted to compare a dtype with something which"
"couldn't be interpreted as a dtype"
)
Expand All @@ -146,7 +148,7 @@ def __lt__(self, other):
other = Dtype(other)

if not isinstance(other, Dtype):
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"Attempted to compare a dtype with something which"
"couldn't be interpreted as a dtype"
)
Expand All @@ -158,7 +160,7 @@ def __le__(self, other):
other = Dtype(other)

if not isinstance(other, Dtype):
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"Attempted to compare a dtype with something which"
"couldn't be interpreted as a dtype"
)
Expand Down Expand Up @@ -200,7 +202,7 @@ def info(self):
elif self.is_float_dtype:
return finfo(self)
else:
raise ivy.exceptions.IvyError(f"{self} is not supported by info")
raise ivy.utils.exceptions.IvyError(f"{self} is not supported by info")

def can_cast(self, to):
return can_cast(self, to)
Expand All @@ -221,12 +223,12 @@ def __new__(cls, shape_tup):
np.ndarray,
tf.Tensor,
)
ivy.assertions.check_isinstance(shape_tup, valid_types)
ivy.utils.assertions.check_isinstance(shape_tup, valid_types)
if isinstance(shape_tup, int):
shape_tup = (shape_tup,)
elif isinstance(shape_tup, list):
shape_tup = tuple(shape_tup)
ivy.assertions.check_all(
ivy.utils.assertions.check_all(
[isinstance(v, int) or ivy.is_int_dtype(v.dtype) for v in shape_tup],
"shape must take integers only",
)
Expand All @@ -240,13 +242,15 @@ def __new__(cls, dtype_str):
if dtype_str is builtins.int:
dtype_str = default_int_dtype()
if not isinstance(dtype_str, str):
raise ivy.exceptions.IvyException("dtype_str must be type str")
raise ivy.utils.exceptions.IvyException("dtype_str must be type str")
if "int" not in dtype_str:
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"dtype must be string and starts with int"
)
if dtype_str not in _all_ivy_dtypes_str:
raise ivy.exceptions.IvyException(f"{dtype_str} is not supported by ivy")
raise ivy.utils.exceptions.IvyException(
f"{dtype_str} is not supported by ivy"
)
return str.__new__(cls, dtype_str)

@property
Expand All @@ -259,13 +263,15 @@ def __new__(cls, dtype_str):
if dtype_str is builtins.float:
dtype_str = default_float_dtype()
if not isinstance(dtype_str, str):
raise ivy.exceptions.IvyException("dtype_str must be type str")
raise ivy.utils.exceptions.IvyException("dtype_str must be type str")
if "float" not in dtype_str:
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"dtype must be string and starts with float"
)
if dtype_str not in _all_ivy_dtypes_str:
raise ivy.exceptions.IvyException(f"{dtype_str} is not supported by ivy")
raise ivy.utils.exceptions.IvyException(
f"{dtype_str} is not supported by ivy"
)
return str.__new__(cls, dtype_str)

@property
Expand All @@ -276,13 +282,15 @@ def info(self):
class UintDtype(IntDtype):
def __new__(cls, dtype_str):
if not isinstance(dtype_str, str):
raise ivy.exceptions.IvyException("dtype_str must be type str")
raise ivy.utils.exceptions.IvyException("dtype_str must be type str")
if "uint" not in dtype_str:
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"dtype must be string and starts with uint"
)
if dtype_str not in _all_ivy_dtypes_str:
raise ivy.exceptions.IvyException(f"{dtype_str} is not supported by ivy")
raise ivy.utils.exceptions.IvyException(
f"{dtype_str} is not supported by ivy"
)
return str.__new__(cls, dtype_str)

@property
Expand All @@ -293,13 +301,15 @@ def info(self):
class ComplexDtype(Dtype):
def __new__(cls, dtype_str):
if not isinstance(dtype_str, str):
raise ivy.exceptions.IvyException("dtype_str must be type str")
raise ivy.utils.exceptions.IvyException("dtype_str must be type str")
if "complex" not in dtype_str:
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"dtype must be string and starts with complex"
)
if dtype_str not in _all_ivy_dtypes_str:
raise ivy.exceptions.IvyException(f"{dtype_str} is not supported by ivy")
raise ivy.utils.exceptions.IvyException(
f"{dtype_str} is not supported by ivy"
)
return str.__new__(cls, dtype_str)

@property
Expand Down Expand Up @@ -730,14 +740,14 @@ class Node(str):
choose_random_backend,
clear_backend_stack,
)
from . import assertions, func_wrapper, exceptions
from . import func_wrapper
from .utils import assertions, exceptions, verbosity
from .utils.backend import handler
from . import functional
from .functional import *
from . import stateful
from .stateful import *
from . import verbosity
from .inspection import fn_array_spec, add_array_specs
from ivy.utils.inspection import fn_array_spec, add_array_specs

add_array_specs()

Expand Down Expand Up @@ -919,8 +929,8 @@ def del_global_attr(attr_name):


def _assert_array_significant_figures_formatting(sig_figs):
ivy.assertions.check_isinstance(sig_figs, int)
ivy.assertions.check_greater(sig_figs, 0)
ivy.utils.assertions.check_isinstance(sig_figs, int)
ivy.utils.assertions.check_greater(sig_figs, 0)


# ToDo: SF formating for complex number
Expand Down Expand Up @@ -991,8 +1001,8 @@ def unset_array_significant_figures():


def _assert_array_decimal_values_formatting(dec_vals):
ivy.assertions.check_isinstance(dec_vals, int)
ivy.assertions.check_greater(dec_vals, 0, allow_equal=True)
ivy.utils.assertions.check_isinstance(dec_vals, int)
ivy.utils.assertions.check_greater(dec_vals, 0, allow_equal=True)


def array_decimal_values(dec_vals=None):
Expand Down Expand Up @@ -1114,7 +1124,7 @@ def set_nan_policy(warn_level):
"""
global nan_policy_stack
if warn_level not in ["nothing", "warns", "raise_exception"]:
raise ivy.exceptions.IvyException(
raise ivy.utils.exceptions.IvyException(
"nan_policy must be one of 'nothing', 'warns', 'raise_exception'"
)
nan_policy_stack.append(warn_level)
Expand Down
8 changes: 4 additions & 4 deletions ivy/array/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def _init(self, data, dynamic_backend=None):
if ivy.is_ivy_array(data):
self._data = data.data
else:
ivy.assertions.check_true(
ivy.utils.assertions.check_true(
ivy.is_native_array(data), "data must be native array"
)
self._data = data
Expand Down Expand Up @@ -215,7 +215,7 @@ def mT(self) -> ivy.Array:
``(..., M, N)``, the returned array must have shape ``(..., N, M)``).
The returned array must have the same data type as the original array.
"""
ivy.assertions.check_greater(len(self._data.shape), 2, allow_equal=True)
ivy.utils.assertions.check_greater(len(self._data.shape), 2, allow_equal=True)
return ivy.matrix_transpose(self._data)

@property
Expand Down Expand Up @@ -244,15 +244,15 @@ def T(self) -> ivy.Array:
two-dimensional array whose first and last dimensions (axes) are
permuted in reverse order relative to original array.
"""
ivy.assertions.check_equal(len(self._data.shape), 2)
ivy.utils.assertions.check_equal(len(self._data.shape), 2)
return ivy.matrix_transpose(self._data)

# Setters #
# --------#

@data.setter
def data(self, data):
ivy.assertions.check_true(
ivy.utils.assertions.check_true(
ivy.is_native_array(data), "data must be native array"
)
self._init(data)
Expand Down
Loading

0 comments on commit e33db02

Please sign in to comment.