Skip to content

Commit

Permalink
Removed sequence padding from ivy transpose convolutions as its resul…
Browse files Browse the repository at this point in the history
…ts were conceptually incorrect. Performing transpose convolution given a sequence padding is a counter-approach to performing it based on given "same"/"valid" padding and output_shape. For now the former will only be available in the torch frontend.
  • Loading branch information
AnnaTz committed Feb 23, 2023
1 parent a67f563 commit 37efbf0
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 58 deletions.
11 changes: 5 additions & 6 deletions ivy/functional/backends/jax/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def conv1d_transpose(
x: JaxArray,
filters: JaxArray,
strides: Union[int, Tuple[int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand Down Expand Up @@ -131,7 +131,7 @@ def conv2d_transpose(
x: JaxArray,
filters: JaxArray,
strides: Union[int, Tuple[int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand Down Expand Up @@ -217,7 +217,7 @@ def conv3d_transpose(
x: JaxArray,
filters: JaxArray,
strides: Union[int, Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand Down Expand Up @@ -322,7 +322,7 @@ def conv_general_transpose(
x: JaxArray,
filters: JaxArray,
strides: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
dims: Optional[int] = 2,
Expand All @@ -342,9 +342,8 @@ def conv_general_transpose(
filter_df = _get_filter_dataformat(dims)
if data_format == "channel_first":
x = jnp.transpose(x, (0, *range(2, dims + 2), 1))
x_shape = list(x.shape[1 : dims + 1])
padding = _get_tranpose_padding(
x_shape, filters.shape, strides, padding, dims, dilations, output_shape
x.shape[1:], filters.shape, strides, padding, dims, dilations, output_shape
)
res = jnp.concatenate(
[
Expand Down
8 changes: 4 additions & 4 deletions ivy/functional/backends/numpy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def conv1d_transpose(
x: np.ndarray,
filters: np.ndarray,
strides: Union[int, Tuple[int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand Down Expand Up @@ -234,7 +234,7 @@ def conv2d_transpose(
x: np.ndarray,
filters: np.ndarray,
strides: Union[int, Tuple[int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand Down Expand Up @@ -371,7 +371,7 @@ def conv3d_transpose(
x: np.ndarray,
filters: np.ndarray,
strides: Union[int, Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand Down Expand Up @@ -475,7 +475,7 @@ def conv_general_transpose(
x: np.ndarray,
filters: np.ndarray,
strides: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
dims: Optional[int] = 2,
Expand Down
11 changes: 4 additions & 7 deletions ivy/functional/backends/tensorflow/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def conv1d_transpose(
x: Union[tf.Tensor, tf.Variable],
filters: Union[tf.Tensor, tf.Variable],
strides: Union[int, Tuple[int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand All @@ -110,7 +110,6 @@ def conv1d_transpose(
output_shape = _output_shape(
x.shape, filters.shape, output_shape, strides, padding, 1, dilations
)
padding = padding if isinstance(padding, str) else "VALID"
res = tf.nn.conv1d_transpose(
x, filters, output_shape, strides, padding, "NWC", dilations
)
Expand Down Expand Up @@ -145,7 +144,7 @@ def conv2d_transpose(
x: Union[tf.Tensor, tf.Variable],
filters: Union[tf.Tensor, tf.Variable],
strides: Union[int, Tuple[int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand All @@ -165,7 +164,6 @@ def conv2d_transpose(
output_shape = _output_shape(
x.shape, filters.shape, output_shape, strides, padding, 2, dilations
)
padding = padding if isinstance(padding, str) else "VALID"
res = tf.nn.conv2d_transpose(
x, filters, output_shape, strides, padding, "NHWC", dilations
)
Expand Down Expand Up @@ -230,7 +228,7 @@ def conv3d_transpose(
x: Tensor,
filters: Tensor,
strides: Union[int, Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand All @@ -254,7 +252,6 @@ def conv3d_transpose(
output_shape = _output_shape(
x.shape, filters.shape, output_shape, strides[1:], padding, 3, dilations
)
padding = padding if isinstance(padding, str) else "VALID"
res = tf.nn.conv3d_transpose(
x, filters, output_shape, strides, padding, "NDHWC", dilations
)
Expand Down Expand Up @@ -377,7 +374,7 @@ def conv_general_transpose(
x: Union[tf.Tensor, tf.Variable],
filters: Union[tf.Tensor, tf.Variable],
strides: Union[int, Tuple[int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
dims: Optional[int] = 2,
Expand Down
20 changes: 8 additions & 12 deletions ivy/functional/backends/torch/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def conv1d_transpose(
x: torch.Tensor,
filters: torch.Tensor,
strides: Union[int, Tuple[int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand All @@ -128,10 +128,9 @@ def conv1d_transpose(
x = x.permute(0, 2, 1)
strides = [strides] if isinstance(strides, int) else strides
dilations = [dilations] if isinstance(dilations, int) else dilations
filter_shape = list(filters.shape[0:1])
filters = filters.permute(1, 2, 0)
not_valid_pad, padding_list, output_padding = _pad_before_conv_tranpose(
x, filters, strides, padding, 1, dilations, output_shape, filter_shape
x, filters, strides, padding, 1, dilations, output_shape, filters.shape[2:]
)
res = torch.nn.functional.conv_transpose1d(
x,
Expand Down Expand Up @@ -190,7 +189,7 @@ def conv2d_transpose(
x: torch.Tensor,
filters: torch.Tensor,
strides: Union[int, Tuple[int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand All @@ -202,10 +201,9 @@ def conv2d_transpose(
x = x.permute(0, 3, 1, 2)
strides = [strides] * 2 if isinstance(strides, int) else strides
dilations = [dilations] * 2 if isinstance(dilations, int) else dilations
filter_shape = list(filters.shape[0:2])
filters = filters.permute(2, 3, 0, 1)
not_valid_pad, padding_list, output_padding = _pad_before_conv_tranpose(
x, filters, strides, padding, 2, dilations, output_shape, filter_shape
x, filters, strides, padding, 2, dilations, output_shape, filters.shape[2:]
)
res = torch.nn.functional.conv_transpose2d(
x,
Expand Down Expand Up @@ -299,7 +297,7 @@ def conv3d_transpose(
x: torch.Tensor,
filters: torch.Tensor,
strides: Union[int, Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.NativeShape, Sequence[int]]] = None,
Expand All @@ -311,10 +309,9 @@ def conv3d_transpose(
x = x.permute(0, 4, 1, 2, 3)
strides = [strides] * 3 if isinstance(strides, int) else strides
dilations = [dilations] * 3 if isinstance(dilations, int) else dilations
filter_shape = list(filters.shape[0:3])
filters = filters.permute(3, 4, 0, 1, 2)
not_valid_pad, padding_list, output_padding = _pad_before_conv_tranpose(
x, filters, strides, padding, 3, dilations, output_shape, filter_shape
x, filters, strides, padding, 3, dilations, output_shape, filters.shape[2:]
)
res = torch.nn.functional.conv_transpose3d(
x,
Expand Down Expand Up @@ -401,7 +398,7 @@ def conv_general_transpose(
x: torch.Tensor,
filters: torch.Tensor,
strides: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
dims: Optional[int] = 2,
Expand All @@ -418,10 +415,9 @@ def conv_general_transpose(
x = x.permute(0, dims + 1, *range(1, dims + 1))
strides = [strides] * dims if isinstance(strides, int) else strides
dilations = [dilations] * dims if isinstance(dilations, int) else dilations
filter_shape = list(filters.shape[0:dims])
filters = filters.permute(dims, dims + 1, *range(dims))
not_valid_pad, padding_list, output_padding = _pad_before_conv_tranpose(
x, filters, strides, padding, dims, dilations, output_shape, filter_shape
x, filters, strides, padding, dims, dilations, output_shape, filters.shape[2:]
)
if dims == 1:
res = torch.nn.functional.conv_transpose1d(
Expand Down
28 changes: 12 additions & 16 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,7 +884,7 @@ def conv1d_transpose(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
strides: Union[int, Tuple[int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.Shape, ivy.NativeShape]] = None,
Expand All @@ -903,9 +903,8 @@ def conv1d_transpose(
strides
The stride of the sliding window for each dimension of input.
padding
either the string ‘SAME’ (padding with zeros evenly), the string ‘VALID’ (no
padding), or a sequence of n (low, high) integer pairs that give the padding to
apply before and after each spatial dimension.
Either ‘SAME’ (padding so that the output's shape is the same as the
input's), or ‘VALID’ (padding so that the output's shape is `output_shape`).
output_shape
Shape of the output (Default value = None)
data_format
Expand Down Expand Up @@ -1145,7 +1144,7 @@ def conv2d_transpose(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
strides: Union[int, Tuple[int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.Shape, ivy.NativeShape]] = None,
Expand All @@ -1164,9 +1163,8 @@ def conv2d_transpose(
strides
The stride of the sliding window for each dimension of input.
padding
either the string ‘SAME’ (padding with zeros evenly), the string ‘VALID’ (no
padding), or a sequence of n (low, high) integer pairs that give the padding to
apply before and after each spatial dimension.
Either ‘SAME’ (padding so that the output's shape is the same as the
input's), or ‘VALID’ (padding so that the output's shape is `output_shape`).
output_shape
Shape of the output (Default value = None)
data_format
Expand Down Expand Up @@ -1518,7 +1516,7 @@ def conv3d_transpose(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
strides: Union[int, Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
output_shape: Optional[Union[ivy.Shape, ivy.NativeShape]] = None,
Expand All @@ -1537,9 +1535,8 @@ def conv3d_transpose(
strides
The stride of the sliding window for each dimension of input.
padding
either the string ‘SAME’ (padding with zeros evenly), the string ‘VALID’ (no
padding), or a sequence of n (low, high) integer pairs that give the padding to
apply before and after each spatial dimension.
Either ‘SAME’ (padding so that the output's shape is the same as the
input's), or ‘VALID’ (padding so that the output's shape is `output_shape`).
output_shape
Shape of the output (Default value = None)
data_format
Expand Down Expand Up @@ -1706,7 +1703,7 @@ def conv_general_transpose(
x: Union[ivy.Array, ivy.NativeArray],
filters: Union[ivy.Array, ivy.NativeArray],
strides: Union[int, Tuple[int], Tuple[int, int], Tuple[int, int, int]],
padding: Union[str, Sequence[Tuple[int, int]]],
padding: str,
/,
*,
dims: Optional[int] = 2,
Expand All @@ -1731,9 +1728,8 @@ def conv_general_transpose(
strides
The stride of the sliding window for each dimension of input.
padding
either the string ‘SAME’ (padding with zeros evenly), the string ‘VALID’ (no
padding), or a sequence of n (low, high) integer pairs that give the padding to
apply before and after each spatial dimension.
Either ‘SAME’ (padding so that the output's shape is the same as the
input's), or ‘VALID’ (padding so that the output's shape is `output_shape`).
dims
Either 1, 2, or 3 corresponding to 1-D, 2-D, and 3-D convolution.
output_shape
Expand Down
27 changes: 14 additions & 13 deletions ivy_tests/test_ivy/test_functional/test_nn/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,19 +322,6 @@ def x_and_filters(
):
if not isinstance(dim, int):
dim = draw(dim)
padding = draw(
st.one_of(
st.lists(
st.tuples(
st.integers(min_value=0, max_value=3),
st.integers(min_value=0, max_value=3),
),
min_size=dim,
max_size=dim,
),
st.sampled_from(["SAME", "VALID"]),
)
)
batch_size = draw(st.integers(1, 5))
filter_shape = draw(
helpers.get_shape(
Expand Down Expand Up @@ -374,6 +361,7 @@ def x_and_filters(
full_strides = [strides] * dim if isinstance(strides, int) else strides
full_dilations = [dilations] * dim if isinstance(dilations, int) else dilations
if transpose:
padding = draw(st.sampled_from(["SAME", "VALID"]))
x_dim = draw(
helpers.get_shape(
min_num_dims=dim, max_num_dims=dim, min_dim_size=1, max_dim_size=5
Expand All @@ -394,6 +382,19 @@ def x_and_filters(
else:
output_shape = None
else:
padding = draw(
st.one_of(
st.lists(
st.tuples(
st.integers(min_value=0, max_value=3),
st.integers(min_value=0, max_value=3),
),
min_size=dim,
max_size=dim,
),
st.sampled_from(["SAME", "VALID"]),
)
)
x_dim = []
for i in range(dim):
min_x = filter_shape[i] + (filter_shape[i] - 1) * (full_dilations[i] - 1)
Expand Down

0 comments on commit 37efbf0

Please sign in to comment.