Skip to content

Commit

Permalink
[hk] Expose hk.DepthwiseConv2D and change default padding to SAME
Browse files Browse the repository at this point in the history
Also renames the home file to depthwise_conv.

PiperOrigin-RevId: 294754337
Change-Id: Ic47e9f8bfac83095f3ddde8a08c42e5e6ac3fe4d
  • Loading branch information
trevorcai authored and copybara-github committed Feb 12, 2020
1 parent 0092e1f commit 62b19e6
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 29 deletions.
1 change: 1 addition & 0 deletions haiku/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ py_library(
"//haiku/_src:bias",
"//haiku/_src:conv",
"//haiku/_src:data_structures",
"//haiku/_src:depthwise_conv",
"//haiku/_src:embed",
"//haiku/_src:initializers",
"//haiku/_src:layer_norm",
Expand Down
2 changes: 2 additions & 0 deletions haiku/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from haiku._src.conv import Conv2DTranspose
from haiku._src.conv import Conv3D
from haiku._src.conv import Conv3DTranspose
from haiku._src.depthwise_conv import DepthwiseConv2D
from haiku._src.embed import Embed
from haiku._src.embed import EmbedLookupStyle
from haiku._src.layer_norm import InstanceNorm
Expand Down Expand Up @@ -92,6 +93,7 @@
"Conv3D",
"Conv3DTranspose",
"DeepRNN",
"DepthwiseConv2D",
"EMAParamsTree",
"Embed",
"EmbedLookupStyle",
Expand Down
20 changes: 5 additions & 15 deletions haiku/_src/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ hk_py_library(
":analytics",
":data_structures",
":typing",
":utils",
# pip: jax
# pip: numpy
],
Expand All @@ -44,7 +43,6 @@ hk_py_library(
":module",
":utils",
# pip: jax
# pip: numpy
],
)

Expand All @@ -67,10 +65,8 @@ hk_py_library(
srcs = ["module.py"],
deps = [
":base",
":data_structures",
":utils",
# pip: jax
# pip: numpy
],
)

Expand Down Expand Up @@ -139,25 +135,24 @@ hk_py_test(
)

hk_py_library(
name = "depthwiseconv",
srcs = ["depthwiseconv.py"],
name = "depthwise_conv",
srcs = ["depthwise_conv.py"],
deps = [
":base",
":initializers",
":module",
":pad",
":utils",
# pip: jax
# pip: numpy
],
)

hk_py_test(
name = "depthwiseconv_test",
srcs = ["depthwiseconv_test.py"],
name = "depthwise_conv_test",
srcs = ["depthwise_conv_test.py"],
deps = [
":base",
":depthwiseconv",
":depthwise_conv",
":initializers",
# pip: absl/testing:absltest
# pip: absl/testing:parameterized
Expand Down Expand Up @@ -284,9 +279,6 @@ hk_py_test(
hk_py_library(
name = "utils",
srcs = ["utils.py"],
deps = [
# pip: jax
],
)

hk_py_test(
Expand Down Expand Up @@ -386,7 +378,6 @@ hk_py_library(
":initializers",
":module",
# pip: jax
# pip: tree
],
)

Expand Down Expand Up @@ -464,7 +455,6 @@ hk_py_library(
":basic",
":initializers",
":module",
# pip: enum
# pip: jax
],
)
Expand Down
4 changes: 2 additions & 2 deletions haiku/_src/depthwiseconv.py → haiku/_src/depthwise_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def __init__(self,
channel_multiplier,
kernel_shape,
stride=1,
padding="VALID",
padding="SAME",
with_bias=True,
w_init=None,
b_init=None,
data_format="channels_last",
data_format="NHWC",
name=None):
"""Construct a 2D Depthwise Convolution.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for haiku._src.depthwiseconv."""
"""Tests for haiku._src.depthwise_conv."""

from absl.testing import absltest
from absl.testing import parameterized
from haiku._src import base
from haiku._src import depthwiseconv
from haiku._src import depthwise_conv
from haiku._src import initializers
from jax import random
import jax.numpy as jnp
Expand All @@ -44,11 +44,11 @@ def f():
data[0, :, :, 1] += 1
data[0, :, :, 2] += 2
data = jnp.array(data)
net = depthwiseconv.DepthwiseConv2D(
net = depthwise_conv.DepthwiseConv2D(
channel_multiplier=1,
kernel_shape=3,
stride=1,
padding="Valid",
padding="VALID",
with_bias=with_bias,
data_format="channels_last",
**create_constant_initializers(1.0, 1.0, with_bias))
Expand Down Expand Up @@ -76,11 +76,11 @@ def f():
data[0, :, :, 1] += 1
data[0, :, :, 2] += 2
data = jnp.array(data)
net = depthwiseconv.DepthwiseConv2D(
net = depthwise_conv.DepthwiseConv2D(
channel_multiplier=1,
kernel_shape=3,
stride=1,
padding="Same",
padding="SAME",
with_bias=with_bias,
data_format="channels_last",
**create_constant_initializers(1.0, 0.0, with_bias))
Expand All @@ -97,11 +97,11 @@ def f():
data[0, :, :, 1] += 1
data[0, :, :, 2] += 2
data = jnp.array(data)
net = depthwiseconv.DepthwiseConv2D(
net = depthwise_conv.DepthwiseConv2D(
channel_multiplier=3,
kernel_shape=3,
stride=1,
padding="Valid",
padding="VALID",
with_bias=with_bias,
data_format="channels_last",
**create_constant_initializers(1.0, 0.0, with_bias))
Expand All @@ -118,11 +118,11 @@ def f():
data[0, 1, :, :] += 1
data[0, 2, :, :] += 2
data = jnp.array(data)
net = depthwiseconv.DepthwiseConv2D(
net = depthwise_conv.DepthwiseConv2D(
channel_multiplier=1,
kernel_shape=3,
stride=1,
padding="Valid",
padding="VALID",
with_bias=with_bias,
data_format="channels_first",
**create_constant_initializers(1.0, 1.0, with_bias))
Expand All @@ -147,11 +147,11 @@ def f():
data[0, 1, :, :] += 1
data[0, 2, :, :] += 2
data = jnp.array(data)
net = depthwiseconv.DepthwiseConv2D(
net = depthwise_conv.DepthwiseConv2D(
channel_multiplier=9,
kernel_shape=3,
stride=1,
padding="Valid",
padding="VALID",
with_bias=with_bias,
data_format="channels_first",
**create_constant_initializers(1.0, 0.0, with_bias))
Expand Down
4 changes: 4 additions & 0 deletions haiku/_src/integration/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ class ModuleDescriptor(NamedTuple):
name="Conv3DTranspose",
create=lambda: hk.Conv3DTranspose(3, 3),
shape=(BATCH_SIZE, 2, 2, 2, 2)),
ModuleDescriptor(
name="DepthwiseConv2D",
create=lambda: hk.DepthwiseConv2D(1, 3),
shape=(BATCH_SIZE, 2, 2, 2)),
)


Expand Down

0 comments on commit 62b19e6

Please sign in to comment.