Skip to content

Commit

Permalink
fix convolve in the numpy frontend (ivy-llc#10860)
Browse files Browse the repository at this point in the history
  • Loading branch information
fnhirwa authored Feb 23, 2023
1 parent 7b61af2 commit a6e6a9d
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,30 @@

@to_ivy_arrays_and_back
def convolve(a, v, mode="full"):
if len(a) == 0:
raise ValueError("'a' cannot be empty.")
if len(v) == 0:
raise ValueError("'v' cannot be empty.")
return ivy.frontends.numpy.correlate(a, v[::-1], mode)
if a.ndim != 1 or v.ndim != 1:
raise ValueError("convolve() only support 1-dimensional inputs.")
if a.shape[0] < v.shape[0]:
a, v = v, a
v = ivy.flip(v)

out_order = slice(None)

if mode == "valid":
padding = [(0, 0)]
elif mode == "same":
padding = [(v.shape[0] // 2, v.shape[0] - v.shape[0] // 2 - 1)]
elif mode == "full":
padding = [(v.shape[0] - 1, v.shape[0] - 1)]

result = ivy.conv_general_dilated(
a[None, None, :],
v[:, None, None],
(1,),
padding,
dims=1,
data_format="channel_first",
)
return result[0, 0, out_order]


@handle_numpy_out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,10 +553,12 @@ def test_numpy_interp(
@handle_frontend_test(
fn_tree="numpy.convolve",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float", full=True),
available_dtypes=helpers.get_dtypes("float"),
min_num_dims=1,
max_num_dims=1,
num_arrays=2,
min_value=-10,
max_value=10,
shared_dtype=True,
),
mode=st.sampled_from(["valid", "same", "full"]),
Expand All @@ -571,7 +573,7 @@ def test_numpy_convolve(
on_device,
):
input_dtypes, xs = dtype_and_x
np_frontend_helpers.test_frontend_function(
helpers.test_frontend_function(
input_dtypes=input_dtypes,
frontend=frontend,
test_flags=test_flags,
Expand Down

0 comments on commit a6e6a9d

Please sign in to comment.