Skip to content

Commit

Permalink
set mixed_function attribute to True for mixed functions, removed red…
Browse files Browse the repository at this point in the history
…undant @inputs_to_ivy_arrays and @outputs_to_ivy_arrays, and updated the docs to explain when these are redundant
  • Loading branch information
AnnaTz committed Apr 20, 2023
1 parent 7a8fc1e commit 462ceaa
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 19 deletions.
3 changes: 2 additions & 1 deletion docs/overview/deep_dive/function_wrapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ Function Wrapping
.. _`corresponding flags`: https://github.com/unifyai/ivy/blob/f1cf9cee62d162fbbd2a4afccd3a90e0cedd5d1f/ivy_tests/test_ivy/conftest.py#L174
.. _`handle_mixed_function`: https://github.com/unifyai/ivy/blob/6a57477daa87e3b3c6d157f10b935ba4fa21c39f/ivy/func_wrapper.py#L923
.. _`stored as an attribute`: https://github.com/unifyai/ivy/blob/6a57477daa87e3b3c6d157f10b935ba4fa21c39f/ivy/func_wrapper.py#L701
.. _`ivy.linear`: https://github.com/unifyai/ivy/blob/7a8fc1ea4eca6d061ae7a3efd1814518d4a6016f/ivy/functional/ivy/layers.py#L172

When a backend framework is set by calling :code:`ivy.set_backend(backend_name)`, then all Ivy functions are `wrapped`_.
This is achieved by calling `_wrap_function`_, which will apply the appropriate wrapping to the given function, based on what decorators it has.
Expand Down Expand Up @@ -61,7 +62,7 @@ Following are some of the wrapping functions currently used:
Another place wherein this decorator is helpful is when we perform configurable argument testing for the parameters :code:`(as_variable, with_out, native_array, container, instance_method, test_gradients)` through the command line.
The `corresponding flags`_ are used to set these values.

#. `handle_mixed_function`_: This wrapping function enables switching between compositional and primary implementations of :ref:`Mixed Functions` based on some condition on the arguments of the function. The condition is specified through a lambda function which when evaluates to `True` the primary implementation is run and otherwise the compositional implementation is executed. For backends that have a primary implementation of a mixed function, the reference to the compositional implementation is `stored as an attribute`_ inside the backend function during backend setting.
#. `handle_mixed_function`_: This wrapping function enables switching between compositional and primary implementations of :ref:`Mixed Functions` based on some condition on the arguments of the function. The condition is specified through a lambda function which when evaluates to `True` the primary implementation is run and otherwise the compositional implementation is executed. For backends that have a primary implementation of a mixed function, the reference to the compositional implementation is `stored as an attribute`_ inside the backend function during backend setting. To make use of this wrapper, it is necessary to set the `mixed_function` attribute of the function to True, as is done for example in `ivy.linear`_. This attribute also automatically activates the :code:`inputs_to_ivy_arrays` wrapper when calling the compositional implementation of the mixed function, and the :code:`outputs_to_ivy_arrays` and :code:`inputs_to_native_arrays` wrappers when calling the backend implementation. It is therefore preferable to avoid adding those decorators to mixed functions manually.

When calling `_wrap_function`_ during :ref:`Backend Setting`, firstly the attributes of the functions are checked to get all the wrapping functions for a particular functions.
Then all the wrapping functions applicable to a function are used to wrap the function.
Expand Down
2 changes: 0 additions & 2 deletions ivy/functional/ivy/elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
handle_nestable,
integer_arrays_to_float,
handle_array_like_without_promotion,
inputs_to_ivy_arrays,
)
from ivy.utils.exceptions import handle_exceptions

Expand Down Expand Up @@ -5451,7 +5450,6 @@ def rad2deg(


@handle_array_function
@inputs_to_ivy_arrays
@handle_out_argument
@handle_nestable
@handle_exceptions
Expand Down
12 changes: 8 additions & 4 deletions ivy/functional/ivy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def kaiser_window(


@infer_dtype
@inputs_to_ivy_arrays
@handle_out_argument
@handle_nestable
@handle_exceptions
Expand Down Expand Up @@ -289,10 +288,10 @@ def kaiser_bessel_derived_window(
>>> ivy.kaiser_bessel_derived_window(5)
ivy.array([0.00713103, 0.70710677, 0.99997455, 0.99997455, 0.70710677])
>>> ivy.kaiser_derived_window(5, False)
>>> ivy.kaiser_bessel_derived_window(5, False)
ivy.array([0.00726415, 0.9999736 , 0.9999736 , 0.00726415])
>>> ivy.kaiser_derived_window(5, False, 5)
>>> ivy.kaiser_bessel_derived_window(5, False, 5)
ivy.array([0.18493208, 0.9827513 , 0.9827513 , 0.18493208])
"""
window_length = window_length // 2
Expand All @@ -315,8 +314,10 @@ def sum_2N_1_n(n):
return ivy.array(dn_low + dn_mid, dtype=dtype, out=out)


kaiser_bessel_derived_window.mixed_function = True


@infer_dtype
@inputs_to_ivy_arrays
@handle_out_argument
@handle_nestable
@handle_exceptions
Expand Down Expand Up @@ -388,6 +389,9 @@ def hamming_window(
)


hamming_window.mixed_function = True


@infer_device
@outputs_to_ivy_arrays
@handle_nestable
Expand Down
18 changes: 14 additions & 4 deletions ivy/functional/ivy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,7 +831,6 @@ def ifft(
return ivy.current_backend(x).ifft(x, dim, norm=norm, n=n, out=out)


@inputs_to_ivy_arrays
@handle_out_argument
@handle_nestable
@handle_exceptions
Expand Down Expand Up @@ -889,6 +888,9 @@ def embedding(
return ret


embedding.mixed_function = True


@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
Expand Down Expand Up @@ -1309,7 +1311,6 @@ def get_x_interp(y):
return result


@inputs_to_ivy_arrays
@handle_out_argument
@handle_nestable
def interpolate(
Expand Down Expand Up @@ -1590,6 +1591,9 @@ def interpolate(
return ivy.astype(ret, ivy.dtype(x), out=out)


interpolate.mixed_function = True


def _get_size(scale_factor, size, dims, x_shape):
if scale_factor is not None:
if isinstance(scale_factor, (float, int)):
Expand Down Expand Up @@ -1685,7 +1689,7 @@ def _mask(vals, length, range_max, dim):
return vals, length


@inputs_to_ivy_arrays
@handle_nestable
def adaptive_avg_pool1d(
input: Union[ivy.Array, ivy.NativeArray],
output_size: int,
Expand Down Expand Up @@ -1753,7 +1757,10 @@ def adaptive_avg_pool1d(
return pooled_output


@inputs_to_ivy_arrays
adaptive_avg_pool1d.mixed_function = True


@handle_nestable
def adaptive_avg_pool2d(
input: Union[ivy.Array, ivy.NativeArray],
output_size: Union[Sequence[int], int],
Expand Down Expand Up @@ -1829,3 +1836,6 @@ def adaptive_avg_pool2d(
if squeeze:
return ivy.squeeze(pooled_output, axis=0)
return pooled_output


adaptive_avg_pool2d.mixed_function = True
2 changes: 0 additions & 2 deletions ivy/functional/ivy/experimental/linear_algebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
handle_nestable,
handle_array_like_without_promotion,
handle_array_function,
inputs_to_ivy_arrays,
)
from ivy.utils.exceptions import handle_exceptions

Expand All @@ -23,7 +22,6 @@ def _check_valid_dimension_size(std):


@handle_array_function
@inputs_to_ivy_arrays
@handle_array_like_without_promotion
@handle_nestable
@handle_exceptions
Expand Down
12 changes: 7 additions & 5 deletions ivy/functional/ivy/experimental/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@
handle_array_like_without_promotion,
handle_view,
inputs_to_ivy_arrays,
handle_array_function, outputs_to_ivy_arrays,
handle_array_function,
)
from ivy.utils.backend import current_backend
from ivy.utils.exceptions import handle_exceptions


@inputs_to_ivy_arrays
@handle_out_argument
@handle_view
@handle_array_like_without_promotion
Expand Down Expand Up @@ -953,7 +952,6 @@ def _check_arguments(
)


@inputs_to_ivy_arrays
@handle_out_argument
@handle_array_like_without_promotion
@handle_nestable
Expand Down Expand Up @@ -1201,6 +1199,9 @@ def pad(
return padded


pad.mixed_function = True


@to_native_arrays_and_back
@handle_out_argument
@handle_view
Expand Down Expand Up @@ -1615,8 +1616,6 @@ def expand(
return ivy.current_backend(x).expand(x, shape, out=out, copy=copy)


@inputs_to_ivy_arrays
@outputs_to_ivy_arrays
@handle_array_like_without_promotion
@handle_nestable
@handle_exceptions
Expand Down Expand Up @@ -1673,6 +1672,9 @@ def as_strided(
)


as_strided.mixed_function = True


@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
Expand Down
4 changes: 3 additions & 1 deletion ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,6 @@ def dropout(
@handle_array_function
@handle_array_like_without_promotion
@handle_exceptions
@inputs_to_ivy_arrays
def scaled_dot_product_attention(
q: Union[ivy.Array, ivy.NativeArray],
k: Union[ivy.Array, ivy.NativeArray],
Expand Down Expand Up @@ -554,6 +553,9 @@ def scaled_dot_product_attention(
return ivy.einsum("... q k, ... k f -> ... q f", attn, v, out=out)


scaled_dot_product_attention.mixed_function = True


@handle_array_function
@handle_array_like_without_promotion
@handle_exceptions
Expand Down

0 comments on commit 462ceaa

Please sign in to comment.