Skip to content

Commit

Permalink
fix formatting for backend_handler.
Browse files Browse the repository at this point in the history
  • Loading branch information
CatB1t committed Feb 9, 2023
1 parent 0630f1e commit 7b29742
Showing 1 changed file with 49 additions and 61 deletions.
110 changes: 49 additions & 61 deletions ivy/backend_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ def __exit__(self, exc_type, exc_val, exc_tb):
_array_types["numpy"] = "ivy.functional.backends.numpy"
_array_types["jax.interpreters.xla"] = "ivy.functional.backends.jax"
_array_types["jaxlib.xla_extension"] = "ivy.functional.backends.jax"
_array_types["tensorflow.python.framework.ops"] = \
"ivy.functional.backends.tensorflow"
_array_types["tensorflow.python.ops.resource_variable_ops"] = \
"ivy.functional.backends.tensorflow"
_array_types["tensorflow.python.framework.ops"] = "ivy.functional.backends.tensorflow"
_array_types[
"tensorflow.python.ops.resource_variable_ops"
] = "ivy.functional.backends.tensorflow"
_array_types["torch"] = "ivy.functional.backends.torch"
_array_types["torch.nn.parameter"] = "ivy.functional.backends.torch"

Expand Down Expand Up @@ -126,23 +126,23 @@ def fn_name_from_version_specific_fn_name(name, version):
if "_to_" in name:
i = name.index("_v_")
e = name.index("_to_")
version_start = name[i + 3: e]
version_start = name[i + 3 : e]
version_start = tuple(map(int, version_start.split("p")))
version_end = name[e + 4:]
version_end = name[e + 4 :]
version_end = tuple(map(int, version_end.split("p")))
if version_start <= version <= version_end:
return name[0:i]
elif "_and_above" in name:
i = name.index("_v_")
e = name.index("_and_")
version_start = name[i + 3: e]
version_start = name[i + 3 : e]
version_start = tuple(map(int, version_start.split("p")))
if version >= version_start:
return name[0:i]
else:
i = name.index("_v_")
e = name.index("_and_")
version_start = name[i + 3: e]
version_start = name[i + 3 : e]
version_start = tuple(map(int, version_start.split("p")))
if version <= version_start:
return name[0:i]
Expand All @@ -160,7 +160,7 @@ def set_backend_to_specific_version(backend):
"""
# TODO: add functionality and tests
f = str(backend.__name__)
f = f[f.index("backends") + 9:]
f = f[f.index("backends") + 9 :]

f = importlib.import_module(f)
f_version = f.__version__
Expand Down Expand Up @@ -227,21 +227,19 @@ def current_backend(*args, **kwargs):


def convert_from_source_backend_to_numpy(variable_ids, numpy_objs):

# Dynamic Backend
from ivy.functional.ivy.gradients import _is_variable, _variable_data

def _is_var(obj):

if isinstance(obj, ivy.Container):
def _map_fn(x):

def _map_fn(x):
x = x.data if isinstance(x, ivy.Array) else x
if x.__class__.__module__ in \
("numpy",
"jax.interpreters.xla",
"jaxlib.xla_extension"
):
if x.__class__.__module__ in (
"numpy",
"jax.interpreters.xla",
"jaxlib.xla_extension",
):
return False

return _is_variable(x)
Expand All @@ -250,30 +248,36 @@ def _map_fn(x):

else:
obj = obj.data if isinstance(obj, ivy.Array) else obj
if obj.__class__.__module__ in \
("numpy", "jax.interpreters.xla", "jaxlib.xla_extension"):
if obj.__class__.__module__ in (
"numpy",
"jax.interpreters.xla",
"jaxlib.xla_extension",
):
return False
return _is_variable(obj)

def _remove_intermediate_arrays(arr_list, cont_list):
cont_list = [cont.cont_to_flat_list() for cont in cont_list]

cont_ids = [id(item.data) if isinstance(item, ivy.Array) else id(item)
for cont in cont_list for item in cont]
arr_ids = [id(item.data) if isinstance(item, ivy.Array) else id(item)
for item in arr_list]
cont_ids = [
id(item.data) if isinstance(item, ivy.Array) else id(item)
for cont in cont_list
for item in cont
]
arr_ids = [
id(item.data) if isinstance(item, ivy.Array) else id(item)
for item in arr_list
]

new_objs = {k: v for k, v in zip(arr_ids, arr_list)
if k not in cont_ids
}
new_objs = {k: v for k, v in zip(arr_ids, arr_list) if k not in cont_ids}

return list(new_objs.values())

# get all ivy array and container instances in the project scope
array_list, container_list = [[obj for obj in gc.get_objects()
if isinstance(obj, obj_type)]
for obj_type in (ivy.Array, ivy.Container)
]
array_list, container_list = [
[obj for obj in gc.get_objects() if isinstance(obj, obj_type)]
for obj_type in (ivy.Array, ivy.Container)
]

# filter uninitialized arrays
array_list = [arr for arr in array_list if arr.__dict__]
Expand All @@ -285,9 +289,7 @@ def _remove_intermediate_arrays(arr_list, cont_list):
# now convert all ivy.Array and ivy.Container instances
# to numpy using the current backend
for obj in new_objs:

if obj.dynamic_backend:

numpy_objs.append(obj)
if _is_var(obj):
# add variable object id to set
Expand All @@ -307,31 +309,23 @@ def _remove_intermediate_arrays(arr_list, cont_list):


def convert_from_numpy_to_target_backend(variable_ids, numpy_objs):

# Dynamic Backend
from ivy.functional.ivy.gradients import _variable

# convert all ivy.Array and ivy.Container instances from numpy
# to native arrays using the newly set backend
for obj in numpy_objs:

np_arr = obj.data if isinstance(obj, ivy.Array) else obj
# check if object was originally a variable
if id(obj) in variable_ids:
native_arr = ivy.nested_map(
np_arr,
current_backend().asarray,
include_derived=True,
shallow=False
np_arr, current_backend().asarray, include_derived=True, shallow=False
)
new_data = _variable(native_arr)

else:
new_data = ivy.nested_map(
np_arr,
current_backend().asarray,
include_derived=True,
shallow=False
np_arr, current_backend().asarray, include_derived=True, shallow=False
)

if isinstance(obj, ivy.Container):
Expand Down Expand Up @@ -361,7 +355,7 @@ def set_backend(backend: str, dynamic: bool = False):
>>> native = ivy.native_array([1])
>>> print(type(native))
<class 'jaxlib.xla_extension.DeviceArray'>
""" # noqa
""" # noqa
ivy.assertions.check_false(
isinstance(backend, str) and backend not in _backend_dict,
"backend must be one from {}".format(list(_backend_dict.keys())),
Expand All @@ -372,8 +366,9 @@ def set_backend(backend: str, dynamic: bool = False):
# created during 1st conversion step

if dynamic:
variable_ids, numpy_objs = \
convert_from_source_backend_to_numpy(variable_ids, numpy_objs)
variable_ids, numpy_objs = convert_from_source_backend_to_numpy(
variable_ids, numpy_objs
)

# update the global dict with the new backend
ivy.locks["backend_setter"].acquire()
Expand Down Expand Up @@ -401,10 +396,7 @@ def set_backend(backend: str, dynamic: bool = False):
continue
backend.__dict__[k] = v
ivy.__dict__[k] = _wrap_function(
key=k,
to_wrap=backend.__dict__[k],
original=v,
compositional=compositional
key=k, to_wrap=backend.__dict__[k], original=v, compositional=compositional
)

if dynamic:
Expand All @@ -416,12 +408,12 @@ def set_backend(backend: str, dynamic: bool = False):


def set_numpy_backend():
"""Sets NumPy to be the global backend. equivalent to `ivy.set_backend("numpy")`.""" # noqa
"""Sets NumPy to be the global backend. equivalent to `ivy.set_backend("numpy")`.""" # noqa
set_backend("numpy")


def set_jax_backend():
"""Sets JAX to be the global backend. equivalent to `ivy.set_backend("jax")`.""" # noqa
"""Sets JAX to be the global backend. equivalent to `ivy.set_backend("jax")`.""" # noqa
set_backend("jax")


Expand All @@ -434,7 +426,7 @@ def set_tensorflow_backend():


def set_torch_backend():
"""Sets torch to be the global backend. equivalent to `ivy.set_backend("torch")`.""" # noqa
"""Sets torch to be the global backend. equivalent to `ivy.set_backend("torch")`.""" # noqa
set_backend("torch")


Expand Down Expand Up @@ -467,8 +459,8 @@ def get_backend(backend: Optional[str] = None):
>>> ivy.set_backend("jax")
>>> ivy_jax = ivy.get_backend()
>>> print(ivy_jax)
<module 'ivy.functional.backends.jax' from '/ivy/ivy/functional/backends/jax/__init__.py'>
""" # noqa
<module 'ivy.functional.backends.jax' from '/ivy/ivy/functional/backends/jax/__init__.py'>
""" # noqa
# ToDo: change this so that it doesn't depend at all on the global ivy.
# Currently all backend-agnostic implementations returned in this
# module will still use the global ivy backend.
Expand Down Expand Up @@ -519,7 +511,7 @@ def unset_backend():
>>> x = ivy.native_array([1])
>>> print(type(x))
<class'tensorflow.python.framework.ops.EagerTensor'>
""" # noqa
""" # noqa
backend = None
# if the backend stack is empty, nothing is done then we just return `None`
if backend_stack:
Expand All @@ -536,10 +528,7 @@ def unset_backend():
if new_backend.current_backend_str() == "numpy":
ivy.set_default_device("cpu")
elif new_backend.current_backend_str() == "jax":
ivy.set_global_attr(
"RNG",
ivy.functional.backends.jax.random.RNG
)
ivy.set_global_attr("RNG", ivy.functional.backends.jax.random.RNG)
new_backend_dict = (
backend_stack[-1].__dict__ if backend_stack else ivy_original_dict
)
Expand Down Expand Up @@ -571,8 +560,7 @@ def choose_random_backend(excluded=None):
or not installed.""",
)
f = np.random.choice(
[f_srt for f_srt in list(_backend_dict.keys())
if f_srt not in excluded]
[f_srt for f_srt in list(_backend_dict.keys()) if f_srt not in excluded]
)
if f is None:
excluded.append(f)
Expand Down

0 comments on commit 7b29742

Please sign in to comment.