Skip to content

Commit

Permalink
Refactor Ivy API method testing and frontend method testing (ivy-llc#…
Browse files Browse the repository at this point in the history
…10187)

Refactoring Ivy methods and frontend methods tests to use test objects.

this includes:
- Add `frontend_method_flags` and `method_flags` test classes to store data about the flag
- Update the `handle_method` decorator and `handle_frontend_method` to generate `init_flags` and `method_flags` objects.
- Update the `test_method` decorator and `test_frontend_method` function to accept `init_flags` and `method_flags` object.
- Removed old legacy code used for type hint detection of flags.
- Refactored all functions that use `handle_method` and `handle_frontend_method` to use the updated signature.
  • Loading branch information
CatB1t authored Jan 30, 2023
1 parent baae0bf commit 4b8c8be
Show file tree
Hide file tree
Showing 14 changed files with 1,691 additions and 3,840 deletions.
198 changes: 75 additions & 123 deletions ivy_tests/test_ivy/helpers/function_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,17 +941,12 @@ def grad_fn(all_args):

def test_method(
*,
init_input_dtypes: Union[ivy.Dtype, List[ivy.Dtype]] = None,
init_as_variable_flags: Union[List[bool], pf.AsVariableFlags] = None,
init_num_positional_args: Union[int, pf.NumPositionalArg] = 0,
init_native_array_flags: Union[List[bool], pf.NativeArrayFlags] = None,
init_input_dtypes: List[ivy.Dtype] = None,
method_input_dtypes: List[ivy.Dtype] = None,
init_all_as_kwargs_np: dict = None,
method_input_dtypes: Union[ivy.Dtype, List[ivy.Dtype]],
method_as_variable_flags: Union[List[bool], pf.AsVariableFlags],
method_num_positional_args: Union[int, pf.NumPositionalArg],
method_native_array_flags: Union[List[bool], pf.NativeArrayFlags],
method_container_flags: Union[List[bool], pf.ContainerFlags],
method_all_as_kwargs_np: dict,
method_all_as_kwargs_np: dict = None,
init_flags: pf.MethodTestFlags,
method_flags: pf.MethodTestFlags,
class_name: str,
method_name: str = "__call__",
init_with_v: bool = False,
Expand Down Expand Up @@ -1041,31 +1036,14 @@ def test_method(
optional, return value from the Ground Truth function
"""
_assert_dtypes_are_valid(method_input_dtypes)
# split the arguments into their positional and keyword components

# Constructor arguments #
(init_input_dtypes, init_as_variable_flags, init_native_array_flags,) = (
ivy.default(init_input_dtypes, []),
ivy.default(init_as_variable_flags, []),
ivy.default(init_native_array_flags, []),
)
_assert_dtypes_are_valid(init_input_dtypes)
init_input_dtypes = ivy.default(init_input_dtypes, [])

# Constructor arguments #
init_all_as_kwargs_np = ivy.default(init_all_as_kwargs_np, dict())
(
method_input_dtypes,
method_as_variable_flags,
method_native_array_flags,
method_container_flags,
) = as_lists(
method_input_dtypes,
method_as_variable_flags,
method_native_array_flags,
method_container_flags,
)

# split the arguments into their positional and keyword components
args_np_constructor, kwargs_np_constructor = kwargs_to_args_n_kwargs(
num_positional_args=init_num_positional_args,
num_positional_args=init_flags.num_positional_args,
kwargs=init_all_as_kwargs_np,
)

Expand All @@ -1083,19 +1061,19 @@ def test_method(
init_input_dtypes = [
init_input_dtypes[0] for _ in range(num_arrays_constructor)
]
if len(init_as_variable_flags) < num_arrays_constructor:
init_as_variable_flags = [
init_as_variable_flags[0] for _ in range(num_arrays_constructor)
if len(init_flags.as_variable) < num_arrays_constructor:
init_flags.as_variable = [
init_flags.as_variable[0] for _ in range(num_arrays_constructor)
]
if len(init_native_array_flags) < num_arrays_constructor:
init_native_array_flags = [
init_native_array_flags[0] for _ in range(num_arrays_constructor)
if len(init_flags.native_arrays) < num_arrays_constructor:
init_flags.native_arrays = [
init_flags.native_arrays[0] for _ in range(num_arrays_constructor)
]

# update variable flags to be compatible with float dtype
init_as_variable_flags = [
init_flags.as_variable = [
v if ivy.is_float_dtype(d) else False
for v, d in zip(init_as_variable_flags, init_input_dtypes)
for v, d in zip(init_flags.as_variable, init_input_dtypes)
]

# Create Args
Expand All @@ -1107,14 +1085,16 @@ def test_method(
kwarg_np_vals=con_kwarg_np_vals,
kwargs_idxs=con_kwargs_idxs,
input_dtypes=init_input_dtypes,
as_variable_flags=init_as_variable_flags,
native_array_flags=init_native_array_flags,
as_variable_flags=init_flags.as_variable,
native_array_flags=init_flags.native_arrays,
)
# End constructor #
# end constructor #

# Method arguments #
# method arguments #
method_input_dtypes = ivy.default(method_input_dtypes, [])
args_np_method, kwargs_np_method = kwargs_to_args_n_kwargs(
num_positional_args=method_num_positional_args, kwargs=method_all_as_kwargs_np
num_positional_args=method_flags.num_positional_args,
kwargs=method_all_as_kwargs_np,
)

# extract all arrays from the arguments and keyword arguments
Expand All @@ -1129,22 +1109,22 @@ def test_method(
num_arrays_method = met_c_arg_vals + met_c_kwarg_vals
if len(method_input_dtypes) < num_arrays_method:
method_input_dtypes = [method_input_dtypes[0] for _ in range(num_arrays_method)]
if len(method_as_variable_flags) < num_arrays_method:
method_as_variable_flags = [
method_as_variable_flags[0] for _ in range(num_arrays_method)
if len(method_flags.as_variable) < num_arrays_method:
method_flags.as_variable = [
method_flags.as_variable[0] for _ in range(num_arrays_method)
]
if len(method_native_array_flags) < num_arrays_method:
method_native_array_flags = [
method_native_array_flags[0] for _ in range(num_arrays_method)
if len(method_flags.native_arrays) < num_arrays_method:
method_flags.native_arrays = [
method_flags.native_arrays[0] for _ in range(num_arrays_method)
]
if len(method_container_flags) < num_arrays_method:
method_container_flags = [
method_container_flags[0] for _ in range(num_arrays_method)
if len(method_flags.container_flags) < num_arrays_method:
method_flags.container_flags = [
method_flags.container_flags[0] for _ in range(num_arrays_method)
]

method_as_variable_flags = [
method_flags.as_variable = [
v if ivy.is_float_dtype(d) else False
for v, d in zip(method_as_variable_flags, method_input_dtypes)
for v, d in zip(method_flags.as_variable, method_input_dtypes)
]

# Create Args
Expand All @@ -1156,9 +1136,9 @@ def test_method(
kwarg_np_vals=met_kwarg_np_vals,
kwargs_idxs=met_kwargs_idxs,
input_dtypes=method_input_dtypes,
as_variable_flags=method_as_variable_flags,
native_array_flags=method_native_array_flags,
container_flags=method_container_flags,
as_variable_flags=method_flags.as_variable,
native_array_flags=method_flags.native_arrays,
container_flags=method_flags.container_flags,
)
# End Method #

Expand Down Expand Up @@ -1189,8 +1169,8 @@ def test_method(
kwarg_np_vals=con_kwarg_np_vals,
kwargs_idxs=con_kwargs_idxs,
input_dtypes=init_input_dtypes,
as_variable_flags=init_as_variable_flags,
native_array_flags=init_native_array_flags,
as_variable_flags=init_flags.as_variable,
native_array_flags=init_flags.native_arrays,
)
args_gt_method, kwargs_gt_method, _, _, _ = create_args_kwargs(
args_np=args_np_method,
Expand All @@ -1200,9 +1180,9 @@ def test_method(
kwarg_np_vals=met_kwarg_np_vals,
kwargs_idxs=met_kwargs_idxs,
input_dtypes=method_input_dtypes,
as_variable_flags=method_as_variable_flags,
native_array_flags=method_native_array_flags,
container_flags=method_container_flags,
as_variable_flags=method_flags.as_variable,
native_array_flags=method_flags.native_arrays,
container_flags=method_flags.container_flags,
)
ins_gt = ivy.__dict__[class_name](*args_gt_constructor, **kwargs_gt_constructor)
if isinstance(ins_gt, ivy.Module):
Expand Down Expand Up @@ -1245,9 +1225,9 @@ def test_method(
args_np=args_np_method,
kwargs_np=kwargs_np_method,
input_dtypes=method_input_dtypes,
as_variable_flags=method_as_variable_flags,
native_array_flags=method_native_array_flags,
container_flags=method_container_flags,
as_variable_flags=method_flags.as_variable,
native_array_flags=method_flags.native_arrays,
container_flags=method_flags.container_flags,
rtol_=rtol_,
atol_=atol_,
xs_grad_idxs=xs_grad_idxs,
Expand All @@ -1265,9 +1245,9 @@ def test_method(
args_np=args_np_method,
kwargs_np=kwargs_np_method,
input_dtypes=method_input_dtypes,
as_variable_flags=method_as_variable_flags,
native_array_flags=method_native_array_flags,
container_flags=method_container_flags,
as_variable_flags=method_flags.as_variable,
native_array_flags=method_flags.native_arrays,
container_flags=method_flags.container_flags,
rtol_=rtol_,
atol_=atol_,
xs_grad_idxs=xs_grad_idxs,
Expand Down Expand Up @@ -1298,14 +1278,10 @@ def test_method(
def test_frontend_method(
*,
init_input_dtypes: Union[ivy.Dtype, List[ivy.Dtype]] = None,
init_as_variable_flags: Union[List[bool], pf.AsVariableFlags] = None,
init_num_positional_args: Union[int, pf.NumPositionalArgFn] = 0,
init_native_array_flags: Union[List[bool], pf.NativeArrayFlags] = None,
init_all_as_kwargs_np: dict = None,
method_input_dtypes: Union[ivy.Dtype, List[ivy.Dtype]],
method_as_variable_flags: Union[List[bool], pf.AsVariableFlags],
method_num_positional_args: Union[int, pf.NumPositionalArgMethod],
method_native_array_flags: Union[List[bool], pf.NativeArrayFlags],
init_flags,
method_flags,
init_all_as_kwargs_np: dict = None,
method_all_as_kwargs_np: dict,
frontend: str,
frontend_method_data: FrontendMethodData,
Expand Down Expand Up @@ -1333,15 +1309,6 @@ def test_frontend_method(
input arguments to the constructor as keyword arguments.
method_input_dtypes
data types of the input arguments to the method in order.
method_as_variable_flags
dictates whether the corresponding input argument passed to the method should
be treated as a variable.
method_num_positional_args
number of input arguments that must be passed as positional arguments to the
method.
method_native_array_flags
dictates whether the corresponding input argument passed to the method should
be treated as a native array.
method_all_as_kwargs_np:
input arguments to the method as keyword arguments.
frontend
Expand All @@ -1367,25 +1334,9 @@ def test_frontend_method(
# split the arguments into their positional and keyword components

# Constructor arguments #
# convert single values to length 1 lists
(init_input_dtypes, init_as_variable_flags, init_native_array_flags,) = as_lists(
ivy.default(init_input_dtypes, []),
ivy.default(init_as_variable_flags, []),
ivy.default(init_native_array_flags, []),
)
init_all_as_kwargs_np = ivy.default(init_all_as_kwargs_np, dict())
(
method_input_dtypes,
method_as_variable_flags,
method_native_array_flags,
) = as_lists(
method_input_dtypes,
method_as_variable_flags,
method_native_array_flags,
)

args_np_constructor, kwargs_np_constructor = kwargs_to_args_n_kwargs(
num_positional_args=init_num_positional_args,
num_positional_args=init_flags.num_positional_args,
kwargs=init_all_as_kwargs_np,
)

Expand All @@ -1403,19 +1354,19 @@ def test_frontend_method(
init_input_dtypes = [
init_input_dtypes[0] for _ in range(num_arrays_constructor)
]
if len(init_as_variable_flags) < num_arrays_constructor:
init_as_variable_flags = [
init_as_variable_flags[0] for _ in range(num_arrays_constructor)
if len(init_flags.as_variable) < num_arrays_constructor:
init_flags.as_variable = [
init_flags.as_variable[0] for _ in range(num_arrays_constructor)
]
if len(init_native_array_flags) < num_arrays_constructor:
init_native_array_flags = [
init_native_array_flags[0] for _ in range(num_arrays_constructor)
if len(init_flags.native_arrays) < num_arrays_constructor:
init_flags.native_arrays = [
init_flags.native_arrays[0] for _ in range(num_arrays_constructor)
]

# update variable flags to be compatible with float dtype
init_as_variable_flags = [
init_flags.as_variable = [
v if ivy.is_float_dtype(d) else False
for v, d in zip(init_as_variable_flags, init_input_dtypes)
for v, d in zip(init_flags.as_variable, init_input_dtypes)
]

# Create Args
Expand All @@ -1427,14 +1378,15 @@ def test_frontend_method(
kwarg_np_vals=con_kwarg_np_vals,
kwargs_idxs=con_kwargs_idxs,
input_dtypes=init_input_dtypes,
as_variable_flags=init_as_variable_flags,
native_array_flags=init_native_array_flags,
as_variable_flags=init_flags.as_variable,
native_array_flags=init_flags.native_arrays,
)
# End constructor #

# Method arguments #
args_np_method, kwargs_np_method = kwargs_to_args_n_kwargs(
num_positional_args=method_num_positional_args, kwargs=method_all_as_kwargs_np
num_positional_args=method_flags.num_positional_args,
kwargs=method_all_as_kwargs_np,
)

# extract all arrays from the arguments and keyword arguments
Expand All @@ -1449,18 +1401,18 @@ def test_frontend_method(
num_arrays_method = met_c_arg_vals + met_c_kwarg_vals
if len(method_input_dtypes) < num_arrays_method:
method_input_dtypes = [method_input_dtypes[0] for _ in range(num_arrays_method)]
if len(method_as_variable_flags) < num_arrays_method:
method_as_variable_flags = [
method_as_variable_flags[0] for _ in range(num_arrays_method)
if len(method_flags.as_variable) < num_arrays_method:
method_flags.as_variable = [
method_flags.as_variable[0] for _ in range(num_arrays_method)
]
if len(method_native_array_flags) < num_arrays_method:
method_native_array_flags = [
method_native_array_flags[0] for _ in range(num_arrays_method)
if len(method_flags.native_arrays) < num_arrays_method:
method_flags.native_arrays = [
method_flags.native_arrays[0] for _ in range(num_arrays_method)
]

method_as_variable_flags = [
method_flags.as_variable = [
v if ivy.is_float_dtype(d) else False
for v, d in zip(method_as_variable_flags, method_input_dtypes)
for v, d in zip(method_flags.as_variable, method_input_dtypes)
]

# Create Args
Expand All @@ -1472,8 +1424,8 @@ def test_frontend_method(
kwarg_np_vals=met_kwarg_np_vals,
kwargs_idxs=met_kwargs_idxs,
input_dtypes=method_input_dtypes,
as_variable_flags=method_as_variable_flags,
native_array_flags=method_native_array_flags,
as_variable_flags=method_flags.as_variable,
native_array_flags=method_flags.native_arrays,
)
# End Method #

Expand Down
Loading

0 comments on commit 4b8c8be

Please sign in to comment.