Skip to content

Commit

Permalink
adding divisor_override to avg_pool2d and avg_pool3d (ivy-llc#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
sherry30 authored Apr 20, 2023
1 parent 215176f commit 7a8fc1e
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 34 deletions.
10 changes: 10 additions & 0 deletions ivy/data_classes/array/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def avg_pool2d(
data_format: str = "NHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Expand All @@ -292,6 +293,9 @@ def avg_pool2d(
Whether to include padding in the averaging calculation.
ceil_mode
Whether to use ceil or floor for creating the output shape.
divisor_override
If given, it will be used as the divisor,
otherwise kernel_size will be used.
out
optional output array, for writing the result to. It must have a shape that
the inputs broadcast to.
Expand Down Expand Up @@ -327,6 +331,7 @@ def avg_pool2d(
data_format=data_format,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
divisor_override=divisor_override,
out=out,
)

Expand All @@ -340,6 +345,7 @@ def avg_pool3d(
data_format: str = "NDHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""
Expand All @@ -362,6 +368,9 @@ def avg_pool3d(
Whether to include padding in the averaging calculation.
ceil_mode
Whether to use ceil or floor for creating the output shape.
divisor_override
If specified, it will be used as divisor,
otherwise kernel_size will be used.
out
optional output array, for writing the result to. It must have
a shape that the inputs broadcast to.
Expand Down Expand Up @@ -391,6 +400,7 @@ def avg_pool3d(
data_format=data_format,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
divisor_override=divisor_override,
out=out,
)

Expand Down
18 changes: 18 additions & 0 deletions ivy/data_classes/container/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,6 +578,7 @@ def static_avg_pool2d(
data_format: str = "NHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
Expand Down Expand Up @@ -605,6 +606,9 @@ def static_avg_pool2d(
Whether to include padding in the averaging calculation.
ceil_mode
Whether to use ceil or floor for creating the output shape.
divisor_override
If specified, it will be used as divisor,
otherwise kernel_size will be used.
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -634,6 +638,7 @@ def static_avg_pool2d(
data_format=data_format,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
divisor_override=divisor_override,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
Expand All @@ -651,6 +656,7 @@ def avg_pool2d(
data_format: str = "NHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
Expand Down Expand Up @@ -678,6 +684,9 @@ def avg_pool2d(
Whether to include padding in the averaging calculation.
ceil_mode
Whether to use ceil or floor for creating the output shape.
divisor_override
If specified, it will be used as divisor,
otherwise kernel_size will be used.
out
optional output array, for writing the result to. It must have a shape
that the inputs broadcast to.
Expand Down Expand Up @@ -706,6 +715,7 @@ def avg_pool2d(
data_format=data_format,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
divisor_override=divisor_override,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
Expand All @@ -724,6 +734,7 @@ def static_avg_pool3d(
data_format: str = "NDHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
Expand Down Expand Up @@ -751,6 +762,8 @@ def static_avg_pool3d(
Whether to include padding in the averaging calculation.
ceil_mode
Whether to use ceil or floor for creating the output shape.
divisor_override
If specified, it will be used as the divisor, otherwise
out
optional output array, for writing the result to. It must
have a shape that the inputs broadcast to.
Expand Down Expand Up @@ -783,6 +796,7 @@ def static_avg_pool3d(
data_format=data_format,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
divisor_override=divisor_override,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
Expand All @@ -800,6 +814,7 @@ def avg_pool3d(
data_format: str = "NDHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
Expand Down Expand Up @@ -827,6 +842,8 @@ def avg_pool3d(
Whether to include padding in the averaging calculation.
ceil_mode
Whether to use ceil or floor for creating the output shape.
divisor_override
If specified, it will be used as the divisor, otherwise
out
optional output array, for writing the result to. It must
have a shape that the inputs broadcast to.
Expand Down Expand Up @@ -858,6 +875,7 @@ def avg_pool3d(
data_format=data_format,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
divisor_override=divisor_override,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
Expand Down
54 changes: 32 additions & 22 deletions ivy/functional/backends/jax/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def avg_pool2d(
data_format: str = "NHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:

Expand All @@ -287,17 +288,21 @@ def avg_pool2d(
div_shape = x.shape[:-1] + (1,)
if len(div_shape) - 2 == len(kernel):
div_shape = (1,) + div_shape[1:]
res = res / general_pool(
jnp.ones(div_shape, dtype=res.dtype),
0.0,
jlax.add,
kernel,
strides,
padding,
2,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
)
if divisor_override is not None:
divisor = divisor_override
else:
divisor = general_pool(
jnp.ones(div_shape, dtype=res.dtype),
0.0,
jlax.add,
kernel,
strides,
padding,
2,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
)
res = res / divisor
if data_format == "NCHW":
return jnp.transpose(res, (0, 3, 1, 2))
return res
Expand All @@ -313,6 +318,7 @@ def avg_pool3d(
data_format: str = "NDHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:

Expand All @@ -333,17 +339,21 @@ def avg_pool3d(
x, 0.0, jlax.add, kernel, strides, padding, 3, ceil_mode=ceil_mode
)

res = res / general_pool(
jnp.ones_like(x, dtype=res.dtype),
0.0,
jlax.add,
kernel,
strides,
padding,
3,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
)
if divisor_override is not None:
divisor = divisor_override
else:
divisor = general_pool(
jnp.ones_like(x, dtype=res.dtype),
0.0,
jlax.add,
kernel,
strides,
padding,
3,
count_include_pad=count_include_pad,
ceil_mode=ceil_mode,
)
res = res / divisor

if data_format == "NCDHW":
res = jnp.transpose(x, (0, 2, 3, 4, 1))
Expand Down
25 changes: 21 additions & 4 deletions ivy/functional/backends/numpy/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ def avg_pool2d(
data_format: str = "NHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if isinstance(kernel, int):
Expand Down Expand Up @@ -398,8 +399,15 @@ def avg_pool2d(
)

# B x OH x OW x O
res = np.mean(sub_matrices, axis=(3, 4))
if (not count_include_pad or ceil_mode) and any(pad_specific):
if divisor_override is not None:
res = np.sum(sub_matrices, axis=(3, 4)) / divisor_override
else:
res = np.mean(sub_matrices, axis=(3, 4))
if (
(not count_include_pad or ceil_mode)
and any(pad_specific)
and not divisor_override
):
if not count_include_pad:
num_padded_values = [
np.array(
Expand Down Expand Up @@ -450,6 +458,7 @@ def avg_pool3d(
data_format: str = "NDHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:

Expand Down Expand Up @@ -502,8 +511,16 @@ def avg_pool3d(
)

# B x OH x OW x O
res = np.mean(sub_matrices, axis=(4, 5, 6))
if (not count_include_pad or ceil_mode) and any(pad_specific):
if divisor_override is not None:
res = np.sum(sub_matrices, axis=(4, 5, 6)) / divisor_override
else:
res = np.mean(sub_matrices, axis=(4, 5, 6))

if (
(not count_include_pad or ceil_mode)
and any(pad_specific)
and not divisor_override
):
if not count_include_pad:
num_padded_values = [
np.array(
Expand Down
2 changes: 2 additions & 0 deletions ivy/functional/backends/paddle/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def avg_pool2d(
data_format: str = "NHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
raise IvyNotImplementedException()
Expand All @@ -87,6 +88,7 @@ def avg_pool3d(
data_format: str = "NDHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[paddle.Tensor] = None,
) -> paddle.Tensor:
raise IvyNotImplementedException()
Expand Down
29 changes: 25 additions & 4 deletions ivy/functional/backends/tensorflow/experimental/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def avg_pool2d(
data_format: str = "NHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if isinstance(kernel, int):
Expand All @@ -241,10 +242,17 @@ def avg_pool2d(
manual_padding = True
padding = "VALID"

res = tf.nn.avg_pool2d(x, kernel, strides, padding)
if divisor_override is not None:
# sum pooling then dividing by divisor_override if it is provided
res = tf.nn.depthwise_conv2d(
x, tf.ones(kernel + [x.shape[-1], 1]), [1] + strides + [1], padding
)
res = res / divisor_override
else:
res = tf.nn.avg_pool2d(x, kernel, strides, padding)

# removing any manual padding added because of ceil_mode or count_include_pad
if (manual_padding and not count_include_pad) or ceil_mode:
if (manual_padding and not count_include_pad) or ceil_mode and not divisor_override:
if not count_include_pad:
num_padded_values = [
tf.convert_to_tensor(
Expand Down Expand Up @@ -301,6 +309,7 @@ def avg_pool3d(
data_format: str = "NDHWC",
count_include_pad: bool = False,
ceil_mode: bool = False,
divisor_override: Optional[int] = None,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
if isinstance(kernel, int):
Expand All @@ -326,10 +335,22 @@ def avg_pool3d(
manual_padding = True
padding = "VALID"

res = tf.nn.avg_pool3d(x, kernel, strides, padding)
if divisor_override is not None:
# sum pooling then dividing by divisor_override if it is provided
res = ivy.conv_general_dilated(
x,
tf.ones(kernel + [1, x.shape[-1]]),
strides,
padding,
dims=3,
feature_group_count=x.shape[-1],
)
res = res / divisor_override
else:
res = tf.nn.avg_pool3d(x, kernel, strides, padding)

# removing any manual padding added because of ceil_mode or count_include_pad
if (manual_padding and not count_include_pad) or ceil_mode:
if (manual_padding and not count_include_pad) or ceil_mode and not divisor_override:
if not count_include_pad:
num_padded_values = [
tf.convert_to_tensor(
Expand Down
Loading

0 comments on commit 7a8fc1e

Please sign in to comment.