Skip to content

Commit

Permalink
update the test of tril_indices to correctly generate the correct dty…
Browse files Browse the repository at this point in the history
…pe argument
  • Loading branch information
Daniel4078 authored Apr 14, 2023
1 parent a25c3e3 commit a00a007
Showing 1 changed file with 18 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,13 @@ def test_triu_indices(
min_num_dims=1,
max_num_dims=1,
),
dtype=helpers.get_dtypes("float", full=False),
test_gradients=st.just(False),
)
def test_vorbis_window(
*,
dtype_and_x,
dtype,
test_flags,
backend_fw,
fn_name,
Expand All @@ -69,7 +71,7 @@ def test_vorbis_window(
fn_name=fn_name,
on_device=on_device,
window_length=x[0],
dtype=input_dtype[0],
dtype=dtype[0],
)


Expand Down Expand Up @@ -163,7 +165,7 @@ def test_kaiser_window(
),
periodic=st.booleans(),
beta=st.floats(min_value=1, max_value=5),
dtype=helpers.get_dtypes("float"),
dtype=helpers.get_dtypes("float", full=False),
test_gradients=st.just(False),
)
def test_kaiser_bessel_derived_window(
Expand All @@ -189,7 +191,7 @@ def test_kaiser_bessel_derived_window(
window_length=x[0],
periodic=periodic,
beta=beta,
dtype=dtype,
dtype=dtype[0],
)


Expand Down Expand Up @@ -238,41 +240,44 @@ def test_hamming_window(
periodic=periodic,
alpha=f[0],
beta=f[1],
dtype=dtype,
dtype=dtype[0],
)


@handle_test(
fn_tree="functional.ivy.experimental.tril_indices",
n_rows=helpers.ints(min_value=0, max_value=10),
n_cols=st.none() | helpers.ints(min_value=0, max_value=10),
dtype_and_n=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("integer"),
num_arrays=0,
shape=(1),
min_value=0,
max_value=10,
),
k=helpers.ints(min_value=-11, max_value=11),
test_with_out=st.just(False),
test_instance_method=st.just(False),
test_gradients=st.just(False),
)
def test_tril_indices(
*,
n_rows,
n_cols,
dtype_and_n,
k,
test_flags,
backend_fw,
fn_name,
on_device,
ground_truth_backend,
):
input_dtype, x = dtype_and_n
helpers.test_function(
input_dtypes=["int64"], # TODO remove
input_dtypes=input_dtype,
ground_truth_backend=ground_truth_backend,
test_flags=test_flags,
fw=backend_fw,
on_device=on_device,
fn_name=fn_name,
n_rows=n_rows,
n_cols=n_cols,
n_rows=x[0],
n_cols=x[1],
k=k,
device=on_device,
)


Expand Down

0 comments on commit a00a007

Please sign in to comment.