Skip to content

Commit

Permalink
add docstrings for test decorators, add comments to explain the logic…
Browse files Browse the repository at this point in the history
… for `handle_test` decorator.
  • Loading branch information
CatB1t committed Jan 8, 2023
1 parent 9e0058a commit 2ea6b93
Showing 1 changed file with 135 additions and 4 deletions.
139 changes: 135 additions & 4 deletions ivy_tests/test_ivy/helpers/testing_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,22 @@

@st.composite
def num_positional_args_method(draw, *, method):
"""
Draws an integers randomly from the minimum and maximum number of positional
arguments a given method can take.
Parameters
----------
draw
special function that draws data randomly (but is reproducible) from a given
data-set (ex. list).
method
callable method
Returns
-------
A strategy that can be used in the @given hypothesis decorator.
"""
total, num_positional_only, num_keyword_only, = (
0,
0,
Expand All @@ -60,14 +76,15 @@ def num_positional_args_method(draw, *, method):

@st.composite
def num_positional_args(draw, *, fn_name: str = None):
"""Draws an integers randomly from the minimum and maximum number of positional
"""
Draws an integers randomly from the minimum and maximum number of positional
arguments a given function can take.
Parameters
----------
draw
special function that draws data randomly (but is reproducible) from a given
data-set (ex. list).
special function that draws data randomly (but is reproducible) from a
given data-set (ex. list).
fn_name
name of the function.
Expand Down Expand Up @@ -158,6 +175,24 @@ def _generate_shared_test_flags(
def _get_method_supported_devices_dtypes(
method_name: str, class_module: str, class_name: str
):
"""
Get supported devices and data types for a method in Ivy API
Parameters
----------
method_name
Name of the method in the class
class_module
Name of the class module
class_name
Name of the class
Returns
-------
Returns a dictonary containing supported device types and its supported data types
for the method
"""
supported_device_dtypes = {}
backends = available_frameworks
for b in backends: # ToDo can optimize this ?
Expand All @@ -169,6 +204,21 @@ def _get_method_supported_devices_dtypes(


def _get_supported_devices_dtypes(fn_name: str, fn_module: str):
"""
Get supported devices and data types for a function in Ivy API
Parameters
----------
fn_name
Name of the function
fn_module
Full import path of the function module
Returns
-------
Returns a dictonary containing supported device types and its supported data types
for the function
"""
supported_device_dtypes = {}
backends = available_frameworks
for b in backends: # ToDo can optimize this ?
Expand Down Expand Up @@ -196,11 +246,53 @@ def handle_test(
container_flags=BuiltContainerStrategy,
**_given_kwargs,
):
"""
A test wrapper for Ivy functions.
Sets the required test globals and creates test flags strategies.
Parameters
----------
fn_tree
Full function import path
ground_truth_backend
The framework to assert test results are equal to
number_positional_args
A search strategy for determining the number of positional arguments to be
passed to the function
test_instance_method
A search strategy that generates a boolean to test instance methods
test_with_out
A search strategy that generates a boolean to test the function with an `out`
parameter
test_gradients
A search strategy that generates a boolean to test the function with arrays as
gradients
as_variable_flags
A search strategy that generates a list of boolean flags for array inputs to be
passed as a Variable array
native_array_flags
A search strategy that generates a list of boolean flags for array inputs to be
passed as a native array
container_flags
A search strategy that generates a list of boolean flags for array inputs to be
passed as a Container
"""
fn_tree = "ivy." + fn_tree
is_hypothesis_test = len(_given_kwargs) != 0

if is_hypothesis_test:
# Use the default strategy
if number_positional_args is None:
number_positional_args = num_positional_args(fn_name=fn_tree)
# Generate the test flags strategy
test_flags = pf.function_flags(
num_positional_args=number_positional_args,
instance_method=test_instance_method,
Expand All @@ -216,8 +308,9 @@ def test_wrapper(test_fn):
param_names = inspect.signature(test_fn).parameters.keys()
supported_device_dtypes = _get_supported_devices_dtypes(fn_name, fn_mod)

# Additional arguments are being passed
# If a test is not a Hypothesis test, we only set the test global data
if is_hypothesis_test:
# Check if these arguments are being asked for
possible_arguments = {
"test_flags": test_flags,
"fn_name": st.just(fn_name),
Expand All @@ -226,10 +319,12 @@ def test_wrapper(test_fn):
filtered_args = set(param_names).intersection(possible_arguments.keys())
for key in filtered_args:
_given_kwargs[key] = possible_arguments[key]
# Wrap the test with the @given decorator
wrapped_test = given(**_given_kwargs)(test_fn)
else:
wrapped_test = test_fn

# Set the test data to be used by test helpers
wrapped_test.test_data = TestData(
test_fn=wrapped_test,
fn_tree=fn_tree,
Expand All @@ -244,6 +339,15 @@ def test_wrapper(test_fn):


def handle_frontend_test(*, fn_tree: str, **_given_kwargs):
"""
A test wrapper for Ivy frontend functions.
Sets the required test globals and creates test flags strategies.
Parameters
----------
fn_tree
Full function import path
"""
fn_tree = "ivy.functional.frontends." + fn_tree
is_hypothesis_test = len(_given_kwargs) != 0

Expand Down Expand Up @@ -292,6 +396,18 @@ def _import_method(method_tree: str):
def handle_method(
*, method_tree, ground_truth_backend: str = ground_truth, **_given_kwargs
):
"""
A test wrapper for Ivy methods.
Sets the required test globals and creates test flags strategies.
Parameters
----------
method_tree
Full method import path
ground_truth_backend
The framework to assert test results are equal to
"""
method_tree = "ivy." + method_tree
is_hypothesis_test = len(_given_kwargs) != 0

Expand Down Expand Up @@ -353,6 +469,21 @@ def test_wrapper(test_fn):
def handle_frontend_method(
*, class_tree: str, init_tree: str, method_name: str, **_given_kwargs
):
"""
A test wrapper for Ivy frontends methods.
Sets the required test globals and creates test flags strategies.
Parameters
----------
class_tree
Full class import path
init_tree
Full import path for the function used to create the class
method_name
Name of the method
"""
split_index = init_tree.rfind(".")
framework_init_module = init_tree[:split_index]
ivy_init_module = f"ivy.functional.frontends.{init_tree[:split_index]}"
Expand Down

0 comments on commit 2ea6b93

Please sign in to comment.