Skip to content

Commit

Permalink
fix dtype and shape issue in avg_pool1d numpy backend
Browse files Browse the repository at this point in the history
  • Loading branch information
sherry30 committed Apr 5, 2023
1 parent 5d743bf commit dafba54
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions ivy/functional/backends/numpy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def avg_pool1d(
(pad_w // 2, pad_w - pad_w // 2),
(0, 0),
],
"edge",
constant_values=0.0,
)

x_shape = x.shape
Expand All @@ -289,15 +289,17 @@ def avg_pool1d(
_get_num_padded_values,
constant={
"p": pad_w,
"n": x_shape[1],
"n": x.shape[1] - pad_w,
"k": kernel[0],
"s": strides[0],
},
unique={
"i": np.arange(res.shape[1]),
},
)
res = kernel[0] * res / (kernel[0] - num_padded_values)
res = (kernel[0] * res) / (
kernel[0] - np.array(num_padded_values, dtype=res.dtype)
)[:, None]

if data_format == "NCW":
return res.swapaxes(1, 2)
Expand Down

0 comments on commit dafba54

Please sign in to comment.