Skip to content

Commit

Permalink
update test decorators (ivy-llc#9066)
Browse files Browse the repository at this point in the history
* remove redundant code, remove incorrect branching.

* update test decorators to return `given` Hypothesis object.
  • Loading branch information
CatB1t authored Dec 24, 2022
1 parent be3544d commit d5cb10a
Showing 1 changed file with 24 additions and 45 deletions.
69 changes: 24 additions & 45 deletions ivy_tests/test_ivy/helpers/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import importlib
import inspect
import typing
from functools import partial

from hypothesis import given, strategies as st

Expand Down Expand Up @@ -183,8 +182,6 @@ def _get_supported_devices_dtypes(fn_name: str, fn_module: str):

# Decorators

possible_fixtures = ["backend_fw", "on_device"]


def handle_test(
*,
Expand Down Expand Up @@ -219,20 +216,17 @@ def test_wrapper(test_fn):
param_names = inspect.signature(test_fn).parameters.keys()
supported_device_dtypes = _get_supported_devices_dtypes(fn_name, fn_mod)

# No Hypothesis @given is used
# Additional arguments are being passed
if is_hypothesis_test:
if "test_flags" in param_names:
_given_kwargs["test_flags"] = test_flags
wrapped_test = given(**_given_kwargs)(test_fn)
possible_arguments = {
"fn_name": fn_name,
"ground_truth_backend": ground_truth_backend,
"test_flags": test_flags,
"fn_name": st.just(fn_name),
"ground_truth_backend": st.just(ground_truth_backend),
}
filtered_args = set(param_names).intersection(possible_arguments.keys())
partial_kwargs = {k: possible_arguments[k] for k in filtered_args}
_name = wrapped_test.__name__
wrapped_test = partial(wrapped_test, **partial_kwargs)
wrapped_test.__name__ = _name
for key in filtered_args:
_given_kwargs[key] = possible_arguments[key]
wrapped_test = given(**_given_kwargs)(test_fn)
else:
wrapped_test = test_fn

Expand All @@ -249,33 +243,26 @@ def test_wrapper(test_fn):
return test_wrapper


possible_fixtures_frontends = ["on_device", "frontend"]


def handle_frontend_test(*, fn_tree: str, **_given_kwargs):
fn_tree = "ivy.functional.frontends." + fn_tree
is_hypothesis_test = len(_given_kwargs) != 0
given_kwargs = _given_kwargs

def test_wrapper(test_fn):
callable_fn, fn_name, fn_mod = _import_fn(fn_tree)
supported_device_dtypes = _get_supported_devices_dtypes(fn_name, fn_mod)

if is_hypothesis_test:
param_names = inspect.signature(test_fn).parameters.keys()
_given_kwargs = _generate_shared_test_flags(
given_kwargs = _generate_shared_test_flags(
param_names,
given_kwargs,
_given_kwargs,
fn_tree,
)
wrapped_test = given(**_given_kwargs)(test_fn)
if "fn_tree" in param_names:
_name = wrapped_test.__name__
possible_arguments = {"fn_tree": fn_tree}
filtered_args = set(param_names).intersection(possible_arguments.keys())
partial_kwargs = {k: possible_arguments[k] for k in filtered_args}
wrapped_test = partial(wrapped_test, **partial_kwargs)
wrapped_test.__name__ = _name
possible_arguments = {"fn_tree": st.just(fn_tree)}
filtered_args = set(param_names).intersection(possible_arguments.keys())
for key in filtered_args:
given_kwargs[key] = possible_arguments[key]
wrapped_test = given(**given_kwargs)(test_fn)
else:
wrapped_test = test_fn

Expand Down Expand Up @@ -338,21 +325,15 @@ def test_wrapper(test_fn):
)
elif v is pf.BuiltGradientStrategy:
_given_kwargs[k] = v

wrapped_test = given(**_given_kwargs)(test_fn)
possible_arguments = {
"class_name": class_name,
"method_name": method_name,
"ground_truth_backend": ground_truth_backend,
"class_name": st.just(class_name),
"method_name": st.just(method_name),
"ground_truth_backend": st.just(ground_truth_backend),
}
filtered_args = set(param_names).intersection(possible_arguments.keys())
partial_kwargs = {k: possible_arguments[k] for k in filtered_args}
_name = wrapped_test.__name__
wrapped_test = partial(
wrapped_test,
**partial_kwargs,
)
wrapped_test.__name__ = _name
for key in filtered_args:
_given_kwargs[key] = possible_arguments[key]
wrapped_test = given(**_given_kwargs)(test_fn)
else:
wrapped_test = test_fn

Expand Down Expand Up @@ -412,19 +393,17 @@ def test_wrapper(test_fn):
elif v is pf.NumPositionalArgFn:
_given_kwargs[k] = num_positional_args(fn_name=init_tree[4:])

wrapped_test = given(**_given_kwargs)(test_fn)
_name = wrapped_test.__name__
frontend_helper_data = FrontendMethodData(
ivy_init_module=importlib.import_module(ivy_init_module),
framework_init_module=importlib.import_module(framework_init_module),
init_name=init_name,
method_name=method_name,
)
possible_arguments = {"frontend_method_data": frontend_helper_data}
possible_arguments = {"frontend_method_data": st.just(frontend_helper_data)}
filtered_args = set(param_names).intersection(possible_arguments.keys())
partial_kwargs = {k: possible_arguments[k] for k in filtered_args}
wrapped_test = partial(wrapped_test, **partial_kwargs)
wrapped_test.__name__ = _name
for key in filtered_args:
_given_kwargs[key] = possible_arguments[key]
wrapped_test = given(**_given_kwargs)(test_fn)
else:
wrapped_test = test_fn

Expand Down

0 comments on commit d5cb10a

Please sign in to comment.