Skip to content

Commit

Permalink
Update handle_test and handle_method to work without a function t…
Browse files Browse the repository at this point in the history
…ree is set. (ivy-llc#10908)
  • Loading branch information
CatB1t authored Feb 23, 2023
1 parent ce76c3f commit 024d937
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 64 deletions.
18 changes: 9 additions & 9 deletions ivy_tests/test_ivy/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,25 +130,25 @@ def pytest_configure(config):

@pytest.fixture(autouse=True)
def run_around_tests(request, on_device, backend_fw, compile_graph, implicit):
if hasattr(request.function, "test_data"):
ivy_test = hasattr(request.function, "_ivy_test")
if ivy_test:
try:
test_globals.setup_api_test(
request.function.test_data,
backend_fw.backend,
request.function.ground_truth_backend,
on_device,
request.function.test_data
if hasattr(request.function, "test_data")
else None,
)
except Exception as e:
test_globals.teardown_api_test()
raise RuntimeError(f"Setting up test for {request.function} failed.") from e
with backend_fw.use:
with DefaultDevice(on_device):
yield
with backend_fw.use:
with DefaultDevice(on_device):
yield
if ivy_test:
test_globals.teardown_api_test()
else:
with backend_fw.use:
with DefaultDevice(on_device):
yield


def pytest_generate_tests(metafunc):
Expand Down
8 changes: 6 additions & 2 deletions ivy_tests/test_ivy/helpers/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,13 @@ def _get_ivy_torch(version=None):


def setup_api_test(
test_data: TestData, backend: str, ground_truth_backend: str, device: str
backend: str,
ground_truth_backend: str,
device: str,
test_data: TestData = None,
):
_set_test_data(test_data)
if test_data is not None:
_set_test_data(test_data)
_set_backend(backend)
_set_device(device)
_set_ground_truth_backend(ground_truth_backend)
Expand Down
109 changes: 56 additions & 53 deletions ivy_tests/test_ivy/helpers/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _partition_dtypes_into_kinds(framework, dtypes):

def handle_test(
*,
fn_tree: str,
fn_tree: str = None,
ground_truth_backend: str = ground_truth,
number_positional_args=None,
test_instance_method=BuiltInstanceStrategy,
Expand Down Expand Up @@ -300,15 +300,18 @@ def handle_test(
A search strategy that generates a list of boolean flags for array inputs to be
passed as a Container
"""
fn_tree = "ivy." + fn_tree
is_fn_tree_provided = fn_tree is not None
if is_fn_tree_provided:
fn_tree = "ivy." + fn_tree
is_hypothesis_test = len(_given_kwargs) != 0

if is_hypothesis_test:
possible_arguments = {"ground_truth_backend": st.just(ground_truth_backend)}
if is_hypothesis_test and is_fn_tree_provided:
# Use the default strategy
if number_positional_args is None:
number_positional_args = num_positional_args(fn_name=fn_tree)
# Generate the test flags strategy
test_flags = pf.function_flags(
possible_arguments["test_flags"] = pf.function_flags(
num_positional_args=number_positional_args,
instance_method=test_instance_method,
with_out=test_with_out,
Expand All @@ -319,18 +322,15 @@ def handle_test(
)

def test_wrapper(test_fn):
callable_fn, fn_name, fn_mod = _import_fn(fn_tree)
supported_device_dtypes = _get_supported_devices_dtypes(fn_name, fn_mod)
if is_fn_tree_provided:
callable_fn, fn_name, fn_mod = _import_fn(fn_tree)
supported_device_dtypes = _get_supported_devices_dtypes(fn_name, fn_mod)
possible_arguments["fn_name"] = st.just(fn_name)

# If a test is not a Hypothesis test, we only set the test global data
if is_hypothesis_test:
param_names = inspect.signature(test_fn).parameters.keys()
# Check if these arguments are being asked for
possible_arguments = {
"test_flags": test_flags,
"fn_name": st.just(fn_name),
"ground_truth_backend": st.just(ground_truth_backend),
}
filtered_args = set(param_names).intersection(possible_arguments.keys())
for key in filtered_args:
_given_kwargs[key] = possible_arguments[key]
Expand All @@ -340,13 +340,15 @@ def test_wrapper(test_fn):
wrapped_test = test_fn

# Set the test data to be used by test helpers
wrapped_test.test_data = TestData(
test_fn=wrapped_test,
fn_tree=fn_tree,
fn_name=fn_name,
supported_device_dtypes=supported_device_dtypes,
)
if is_fn_tree_provided:
wrapped_test.test_data = TestData(
test_fn=wrapped_test,
fn_tree=fn_tree,
fn_name=fn_name,
supported_device_dtypes=supported_device_dtypes,
)
wrapped_test.ground_truth_backend = ground_truth_backend
wrapped_test._ivy_test = True

return wrapped_test

Expand Down Expand Up @@ -460,7 +462,7 @@ def _import_method(method_tree: str):

def handle_method(
*,
method_tree,
method_tree: str = None,
ground_truth_backend: str = ground_truth,
test_gradients=BuiltGradientStrategy,
init_num_positional_args=None,
Expand All @@ -485,10 +487,16 @@ def handle_method(
ground_truth_backend
The framework to assert test results are equal to
"""
method_tree = "ivy." + method_tree
is_method_tree_provided = method_tree is not None
if is_method_tree_provided:
method_tree = "ivy." + method_tree
is_hypothesis_test = len(_given_kwargs) != 0
possible_arguments = {
"ground_truth_backend": st.just(ground_truth_backend),
"test_gradients": test_gradients,
}

if is_hypothesis_test:
if is_hypothesis_test and is_method_tree_provided:
callable_method, method_name, _, class_name, method_mod = _import_method(
method_tree
)
Expand All @@ -498,42 +506,35 @@ def handle_method(
fn_name=class_name + ".__init__"
)

possible_arguments["init_flags"] = pf.method_flags(
num_positional_args=init_num_positional_args,
as_variable=init_as_variable_flags,
native_arrays=init_native_arrays,
container_flags=init_container_flags,
)

if method_num_positional_args is None:
method_num_positional_args = num_positional_args_method(
method=callable_method
)

def test_wrapper(test_fn):
supported_device_dtypes = _get_method_supported_devices_dtypes(
method_name, method_mod, class_name
possible_arguments["method_flags"] = pf.method_flags(
num_positional_args=method_num_positional_args,
as_variable=method_as_variable_flags,
native_arrays=method_native_arrays,
container_flags=method_container_flags,
)

if is_hypothesis_test:
param_names = inspect.signature(test_fn).parameters.keys()

init_flags = pf.method_flags(
num_positional_args=init_num_positional_args,
as_variable=init_as_variable_flags,
native_arrays=init_native_arrays,
container_flags=init_container_flags,
)

method_flags = pf.method_flags(
num_positional_args=method_num_positional_args,
as_variable=method_as_variable_flags,
native_arrays=method_native_arrays,
container_flags=method_container_flags,
def test_wrapper(test_fn):
if is_method_tree_provided:
supported_device_dtypes = _get_method_supported_devices_dtypes(
method_name, method_mod, class_name
)
possible_arguments["class_name"] = st.just(class_name)
possible_arguments["method_name"] = st.just(method_name)

possible_arguments = {
"class_name": st.just(class_name),
"init_flags": init_flags,
"method_flags": method_flags,
"test_gradients": test_gradients,
"method_name": st.just(method_name),
"ground_truth_backend": st.just(ground_truth_backend),
}

if is_hypothesis_test:
param_names = inspect.signature(test_fn).parameters.keys()
filtered_args = set(param_names).intersection(possible_arguments.keys())

for key in filtered_args:
Expand All @@ -544,13 +545,15 @@ def test_wrapper(test_fn):
else:
wrapped_test = test_fn

wrapped_test.test_data = TestData(
test_fn=wrapped_test,
fn_tree=method_tree,
fn_name=method_name,
supported_device_dtypes=supported_device_dtypes,
)
if is_method_tree_provided:
wrapped_test.test_data = TestData(
test_fn=wrapped_test,
fn_tree=method_tree,
fn_name=method_name,
supported_device_dtypes=supported_device_dtypes,
)
wrapped_test.ground_truth_backend = ground_truth_backend
wrapped_test._ivy_test = True

return wrapped_test

Expand Down

0 comments on commit 024d937

Please sign in to comment.