Skip to content

Commit

Permalink
adding mode_list argument in _interp_args so custom list of modes…
Browse files Browse the repository at this point in the history
… can be drawn for resampling frontend functions
  • Loading branch information
sherry30 committed Mar 8, 2023
1 parent 3aed083 commit c8dd1c5
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,10 @@ def test_torch_pad(

@handle_frontend_test(
fn_tree="torch.nn.functional.interpolate",
dtype_and_input_and_other=_interp_args(scale_factor=True),
dtype_and_input_and_other=_interp_args(
mode_list=["linear", "bilinear", "trilinear", "nearest", "area"],
scale_factor=True,
),
number_positional_args=st.just(2),
)
def test_torch_interpolate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -257,13 +257,15 @@ def test_dct(


@st.composite
def _interp_args(draw, mode=None, scale_factor=False):
if not mode:
def _interp_args(draw, mode=None, mode_list=None, scale_factor=False):
if not mode and not mode_list:
mode = draw(
st.sampled_from(
["linear", "bilinear", "trilinear", "nearest", "area", "tf_area"]
)
)
elif mode_list:
mode = draw(st.sampled_from(mode_list))
align_corners = draw(st.one_of(st.booleans(), st.none()))
if mode == "linear":
num_dims = 3
Expand Down

0 comments on commit c8dd1c5

Please sign in to comment.