Skip to content

Commit

Permalink
fix failing tests in frontends related to changes in commit f84bea1
Browse files Browse the repository at this point in the history
  • Loading branch information
CatB1t committed Feb 22, 2023
1 parent 23c8066 commit b304c85
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 16 deletions.
24 changes: 18 additions & 6 deletions ivy_tests/test_ivy/test_frontends/test_numpy/test_func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def _fn(x=None, check_default=False, dtype=None):

@given(
dtype_x_shape=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
available_dtypes=helpers.get_dtypes("valid", prune_function=False),
ret_shape=True,
),
)
Expand Down Expand Up @@ -61,7 +61,9 @@ def test_inputs_to_ivy_arrays(dtype_x_shape):


@given(
dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")),
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid", prune_function=False)
),
)
def test_outputs_to_numpy_arrays(dtype_and_x):
x_dtype, x = dtype_and_x
Expand All @@ -78,7 +80,7 @@ def test_outputs_to_numpy_arrays(dtype_and_x):

@given(
dtype_x_shape=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
available_dtypes=helpers.get_dtypes("valid", prune_function=False),
ret_shape=True,
),
)
Expand Down Expand Up @@ -113,7 +115,9 @@ def test_to_ivy_arrays_and_back(dtype_x_shape):
@st.composite
def _zero_dim_to_scalar_helper(draw):
dtype = draw(
helpers.get_dtypes("valid", full=False).filter(lambda x: "bfloat16" not in x)
helpers.get_dtypes("valid", prune_function=False, full=False).filter(
lambda x: "bfloat16" not in x
)
)[0]
shape = draw(helpers.get_shape())
return draw(
Expand Down Expand Up @@ -151,8 +155,16 @@ def _dtype_helper(draw):
st.sampled_from(
[
draw(st.sampled_from([int, float, bool])),
ivy.as_native_dtype(draw(helpers.get_dtypes("valid", full=False))[0]),
np_frontend.dtype(draw(helpers.get_dtypes("valid", full=False))[0]),
ivy.as_native_dtype(
draw(helpers.get_dtypes("valid", full=False, prune_function=False))[
0
]
),
np_frontend.dtype(
draw(helpers.get_dtypes("valid", full=False, prune_function=False))[
0
]
),
draw(st.sampled_from(list(np_frontend.numpy_scalar_to_dtype.keys()))),
draw(st.sampled_from(list(np_frontend.numpy_str_to_type_table.keys()))),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

@st.composite
def _array_mask(draw):
dtype = draw(helpers.get_dtypes("valid", full=False))
dtype = draw(helpers.get_dtypes("valid", prune_function=False, full=False))
dtypes, x_mask = draw(
helpers.dtype_and_values(
num_arrays=2,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def _fn(x=None, dtype=None):


@given(
dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")),
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid", prune_function=False)
),
)
def test_inputs_to_ivy_arrays(dtype_and_x):
x_dtype, x = dtype_and_x
Expand Down Expand Up @@ -50,7 +52,9 @@ def test_inputs_to_ivy_arrays(dtype_and_x):


@given(
dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")),
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid", prune_function=False)
),
)
def test_outputs_to_frontend_arrays(dtype_and_x):
x_dtype, x = dtype_and_x
Expand All @@ -64,7 +68,9 @@ def test_outputs_to_frontend_arrays(dtype_and_x):


@given(
dtype_and_x=helpers.dtype_and_values(available_dtypes=helpers.get_dtypes("valid")),
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid", prune_function=False)
),
)
def test_to_ivy_arrays_and_back(dtype_and_x):
x_dtype, x = dtype_and_x
Expand Down Expand Up @@ -96,13 +102,21 @@ def _dtype_helper(draw):
return draw(
st.sampled_from(
[
draw(helpers.get_dtypes("valid", full=False))[0],
ivy.as_native_dtype(draw(helpers.get_dtypes("valid", full=False))[0]),
draw(helpers.get_dtypes("valid", prune_function=False, full=False))[0],
ivy.as_native_dtype(
draw(helpers.get_dtypes("valid", prune_function=False, full=False))[
0
]
),
draw(
st.sampled_from(list(tf_frontend.tensorflow_enum_to_type.values()))
),
draw(st.sampled_from(list(tf_frontend.tensorflow_enum_to_type.keys()))),
np_frontend.dtype(draw(helpers.get_dtypes("valid", full=False))[0]),
np_frontend.dtype(
draw(helpers.get_dtypes("valid", prune_function=False, full=False))[
0
]
),
draw(st.sampled_from(list(np_frontend.numpy_scalar_to_dtype.keys()))),
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def _fn(x, check_default=False):

@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid")
available_dtypes=helpers.get_dtypes("valid", prune_function=False)
).filter(lambda x: "bfloat16" not in x[0]),
)
def test_inputs_to_ivy_arrays(dtype_and_x):
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_inputs_to_ivy_arrays(dtype_and_x):

@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid")
available_dtypes=helpers.get_dtypes("valid", prune_function=False)
).filter(lambda x: "bfloat16" not in x[0]),
)
def test_outputs_to_frontend_arrays(dtype_and_x):
Expand All @@ -73,7 +73,7 @@ def test_outputs_to_frontend_arrays(dtype_and_x):

@given(
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid")
available_dtypes=helpers.get_dtypes("valid", prune_function=False)
).filter(lambda x: "bfloat16" not in x[0]),
)
def test_to_ivy_arrays_and_back(dtype_and_x):
Expand Down

0 comments on commit b304c85

Please sign in to comment.