Skip to content

Commit

Permalink
Create hk.transform_with_state(f).
Browse files Browse the repository at this point in the history
Equivalent to transform `hk.transform(f, state=True, apply_rng=True)`. The long
term plan is to just have `hk.transform` and `hk.transform_with_state` (both
will imply `apply_rng=True`).

I've renamed the returned values to `hk.Transformed` and
`hk.TransformedWithState` so they are a bit shorter for type hints.

PiperOrigin-RevId: 295744930
Change-Id: I8fba9eb61d163fdb17f24c39fdecd64d0862fd45
  • Loading branch information
tomhennigan authored and copybara-github committed Feb 18, 2020
1 parent 6e77ce7 commit d347a54
Show file tree
Hide file tree
Showing 9 changed files with 108 additions and 60 deletions.
9 changes: 8 additions & 1 deletion haiku/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from haiku._src.base import PRNGSequence
from haiku._src.base import set_state
from haiku._src.base import transform
from haiku._src.base import TransformedPair
from haiku._src.base import transform_with_state
from haiku._src.base import Transformed
from haiku._src.base import TransformedWithState
from haiku._src.base import with_rng
from haiku._src.base import without_state
from haiku._src.basic import BatchApply
Expand Down Expand Up @@ -80,6 +82,8 @@
from haiku._src.typing import Params
from haiku._src.typing import State

TransformedPair = Transformed # TODO(tomhennigan) Remove deprecated alias.

__version__ = "0.0.1a0"

__all__ = (
Expand All @@ -101,6 +105,8 @@
"ExponentialMovingAverage",
"Flatten",
"GRU",
"Transformed",
"TransformedWithState",
"TransformedPair",
"InstanceNorm",
"LSTM",
Expand Down Expand Up @@ -143,6 +149,7 @@
"transparent",
"to_module",
"transform",
"transform_with_state",
"value_and_grad",
"VanillaRNN",
"with_rng",
Expand Down
94 changes: 71 additions & 23 deletions haiku/_src/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import functools
from typing import (Any, Callable, Iterator, MutableMapping, NamedTuple,
Optional, Text, Tuple, Set, TypeVar, Union)
import warnings

from haiku._src import analytics
from haiku._src import data_structures
Expand All @@ -37,11 +38,25 @@
T = TypeVar("T")

ModuleState = namedtuple("ModuleState", ("module", "method_name"))
TransformedPair = namedtuple("TransformedPair", ("init", "apply"))
StatePair = namedtuple("StatePair", ("initial", "current"))
MutableParams = MutableMapping[Text, MutableMapping[ParamName, jnp.ndarray]]
MutableState = MutableMapping[Text, MutableMapping[Text, StatePair]]

# TODO(tomhennigan) Add types to attributes once state/apply_rng are removed.
Transformed = namedtuple("Transformed", ("init", "apply"))


class TransformedWithState(NamedTuple):
"""Holds a pair of pure functions."""

# TODO(tomhennigan): Use protocols to describe *args when we are 3.8+.
# https://www.python.org/dev/peps/pep-0544/#callback-protocols

# Args: [Optional[PRNGKey], ...]
init: Callable[..., Tuple[Params, State]]

# Args: [Params, State, Optional[PRNGKey], ...]
apply: Callable[..., Tuple[Any, State]]

# TODO(tomhennigan) Should creator_stack be part of frame?
frame_stack = ThreadLocalStack() # type: ThreadLocalStack["Frame"]
Expand Down Expand Up @@ -330,14 +345,17 @@ def apply_fn(
return apply_fn


def without_state(f: TransformedPair) -> TransformedPair:
# TODO(tomhennigan) Should be TransformedWithState -> Transformed.
def without_state(
f: Union[Transformed, TransformedWithState],
) -> Transformed:
"""Wraps a transformed tuple and ignores state in/out.
>>> def f(x):
... mod = hk.Linear(10)
... return mod(x)
>>> f = hk.without_state(hk.transform(f, apply_rng=True, state=True))
>>> f = hk.without_state(hk.transform_with_state(f))
>>> rng = jax.random.PRNGKey(42)
>>> x = jnp.zeros([1, 1])
Expand Down Expand Up @@ -365,22 +383,25 @@ def apply_fn(params, *args, **kwargs):
raise ValueError("Function wrapped with `hk.without_state` used state.")
return out

return TransformedPair(init=init_fn, apply=apply_fn)
return Transformed(init=init_fn, apply=apply_fn)


def without_apply_rng(f: TransformedPair) -> TransformedPair:
def without_apply_rng(
f: Union[Transformed, TransformedWithState],
) -> Transformed:

def apply_fn(params, state, *args, **kwargs):
return f.apply(params, state, None, *args, **kwargs)

return TransformedPair(init=f.init, apply=apply_fn)
return Transformed(init=f.init, apply=apply_fn)


# TODO(tomhennigan) Remove apply_rng and state.
def transform(
f,
apply_rng=False,
state=False,
) -> TransformedPair:
) -> Transformed:
"""Transforms a function using Haiku modules into a pair of pure functions.
The first thing to do is to define a `Module`. A module encapsulates some
Expand Down Expand Up @@ -428,40 +449,67 @@ def transform(
>>> f.apply(new_params, 2)
DeviceArray(9., dtype=float32)
It is possible for the transformed function to maintain internal state (e.g.
for a module like `BatchNorm` that may want to maintain a moving average) see
:func:`get_state`, :func:`set_state`:
If your transformed function needs to maintain internal state (e.g. moving
averages in batch norm) then see :func:`transform_with_state`.
Args:
f: A function closing over `Module` instances.
apply_rng: Whether `apply` should accept `rng` as an argument.
state: *Deprecated:* use `hk.transform_with_state`.
Returns:
A named tuple with `init` and `apply` pure functions.
"""
analytics.log_once("transform")

if state:
warnings.warn(
"Prefer using hk.transform_with_state(f) vs. passing state=True.",
DeprecationWarning)

if apply_rng:
warnings.warn("Apply_rng will soon be removed and defaulted to True",
DeprecationWarning)

pair = transform_with_state(f) # type: Transformed
if not apply_rng:
pair = without_apply_rng(pair)
if not state:
pair = without_state(pair)
return pair


def transform_with_state(f) -> TransformedWithState:
"""Transforms a function using Haiku modules into a pair of pure functions.
See :func:`transform` for general details on Haiku transformations.
This function is equivalent to :func:`transform`, however it allows you to
maintain and update internal state (e.g. moving averages in batch norm) via
:func:`get_state` and :func:`set_state`.
>>> def f():
... counter = hk.get_state("counter", shape=[], dtype=jnp.int32,
... init=jnp.zeros)
... hk.set_state("counter", counter + 1)
... return counter
>>> f = hk.transform(f, state=True)
>>> f = hk.transform_with_state(f)
>>> params, state = f.init(None)
>>> for _ in range(10):
... counter, state = f.apply(params, state)
... counter, state = f.apply(params, state, None)
>>> counter
DeviceArray(9, dtype=int32)
Args:
f: A function closing over `Module` instances.
apply_rng: Whether `apply` should accept `rng` as an argument.
state: Whether the resulting functions should accept state as input and
and output.
Returns:
A named tuple with `init` and `apply` properties. object if `f` is not None.
A named tuple with `init` and `apply` properties.
"""
analytics.log_once("transform")
pair = TransformedPair(mk_init_fn(f), mk_apply_fn(f))
if not apply_rng:
pair = without_apply_rng(pair)
if not state:
pair = without_state(pair)
return pair
analytics.log_once("transform_with_state")
return TransformedWithState(mk_init_fn(f), mk_apply_fn(f))


class PRNGSequence(Iterator[PRNGKey]):
Expand Down
39 changes: 16 additions & 23 deletions haiku/_src/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,55 +152,50 @@ def _logging_creator(next_creator, name, shape, dtype, init):

self.assertEqual(log, ["a", "b", "c"])

@parameterized.parameters(True, False)
def test_rng_arg(self, rng):
init_fn, apply_fn = base.transform(
lambda: None, apply_rng=rng, state=True)
def test_argspec(self):
init_fn, apply_fn = base.transform_with_state(lambda: None)
init_fn_spec = inspect.getfullargspec(init_fn)
apply_fn_spec = inspect.getfullargspec(apply_fn)

self.assertEqual(init_fn_spec.args, ["rng"])
rng_args = ["rng"] if rng else []
self.assertEqual(apply_fn_spec.args, ["params", "state"] + rng_args)
self.assertEqual(apply_fn_spec.args, ["params", "state", "rng"])

def test_get_state_no_init_raises(self):
init_fn, apply_fn = base.transform(lambda: base.get_state("i"), state=True)
init_fn, apply_fn = base.transform_with_state(lambda: base.get_state("i"))
with self.assertRaisesRegex(ValueError, "set an init function"):
init_fn(None)
state = params = {"~": {}}
with self.assertRaisesRegex(ValueError, "set an init function"):
apply_fn(params, state)
apply_fn(params, state, None)

def test_get_state_no_shape_raises(self):
init_fn, apply_fn = base.transform(
lambda: base.get_state("i", init=jnp.zeros),
state=True,
apply_rng=False)
init_fn, apply_fn = base.transform_with_state(
lambda: base.get_state("i", init=jnp.zeros))
with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
init_fn(None)
state = params = {"~": {}}
with self.assertRaisesRegex(ValueError, "provide shape and dtype"):
apply_fn(params, state)
apply_fn(params, state, None)

def test_get_state_no_init(self):
_, apply_fn = base.transform(lambda: base.get_state("i"), state=True)
_, apply_fn = base.transform_with_state(lambda: base.get_state("i"))
for i in range(10):
state_in = {"~": {"i": i}}
_, state_out = apply_fn({}, state_in)
_, state_out = apply_fn({}, state_in, None)
self.assertEqual(state_in, state_out)

def test_set_then_get(self):
def net():
base.set_state("i", 1)
return base.get_state("i")

init_fn, apply_fn = base.transform(net, state=True)
init_fn, apply_fn = base.transform_with_state(net)
params, state = init_fn(None)
self.assertEqual(state, {"~": {"i": 1}})

for i in range(10):
state_in = {"~": {"i": i}}
y, state_out = apply_fn(params, state_in)
y, state_out = apply_fn(params, state_in, None)
self.assertEqual(y, 1)
self.assertEqual(state_out, {"~": {"i": 1}})

Expand All @@ -211,19 +206,18 @@ def f():
base.set_state("count", count + 1)
return count

init_fn, apply_fn = base.transform(f, state=True)
init_fn, apply_fn = base.transform_with_state(f)
params, state = init_fn(None)
self.assertEqual(state, {"~": {"count": 0}})
_, state = apply_fn(params, state)
_, state = apply_fn(params, state, None)
self.assertEqual(state, {"~": {"count": 10}})

def test_without_state(self):
def f():
w = base.get_parameter("w", [], init=jnp.zeros)
return w

init_fn, apply_fn = base.without_state(
base.transform(f, apply_rng=True, state=True))
init_fn, apply_fn = base.without_state(base.transform_with_state(f))
params = init_fn(None)
out = apply_fn(params, None)
self.assertEqual(out, 0)
Expand All @@ -235,8 +229,7 @@ def f():
base.set_state("count", count + 1)
return count

init_fn, _ = base.without_state(
base.transform(f, apply_rng=True, state=True))
init_fn, _ = base.without_state(base.transform_with_state(f))

with self.assertRaisesRegex(ValueError, "without_state.*used state"):
init_fn(None)
Expand Down
2 changes: 1 addition & 1 deletion haiku/_src/integration/bfloat16_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def g(x):
mod = module_fn()
return mod(x)

init_fn, apply_fn = hk.transform(g, apply_rng=True, state=True)
init_fn, apply_fn = hk.transform_with_state(g)

# Create state in f32 to start.
# NOTE: We need to do this since some initializers (e.g. random.uniform) do
Expand Down
4 changes: 2 additions & 2 deletions haiku/_src/integration/hk_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def g(x, jit=False):
mod = hk.jit(mod)
return mod(x)

f = hk.transform(g, state=True, apply_rng=True)
f = hk.transform_with_state(g)

assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

Expand Down Expand Up @@ -89,7 +89,7 @@ def g(x, remat=False):
mod = hk.remat(mod)
return jnp.mean(mod(x))

f = hk.transform(g, state=True, apply_rng=True)
f = hk.transform_with_state(g)

assert_allclose = functools.partial(np.testing.assert_allclose, atol=1e-5)

Expand Down
4 changes: 2 additions & 2 deletions haiku/_src/integration/jax_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def test_jit(
def g(x):
return module_fn()(x)

f = hk.transform(g, state=True, apply_rng=True)
f = hk.transform_with_state(g)

atol = CUSTOM_ATOL.get(module_type(module_fn), DEFAULT_ATOL)
assert_allclose = functools.partial(np.testing.assert_allclose, atol=atol)
Expand Down Expand Up @@ -91,7 +91,7 @@ def test_vmap(
def g(x):
return module_fn()(x)

f = hk.transform(g, state=True, apply_rng=True)
f = hk.transform_with_state(g)

# Ensure application under vmap is the same.
params, state = f.init(rng, sample)
Expand Down
4 changes: 2 additions & 2 deletions haiku/_src/lift.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class LiftingModule(module.Module):
and state from the outer transform's dictionaries.
Must be called inside hk.transform, and be passed the `init` member of a
`hk.TransformedPair`.
`hk.Transformed`.
Currently, the given `init_fn` must not use state.
"""
Expand All @@ -60,7 +60,7 @@ def __init__(self, init_fn, name=None):
"""Initializes the LiftingModule.
Args:
init_fn: The init_fn from a hk.TransformedPair. Requires state=True.
init_fn: The init_fn from a hk.Transformed. Requires state=True.
name: Module name.
"""
if name is None:
Expand Down
8 changes: 4 additions & 4 deletions haiku/_src/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,22 +146,22 @@ def net():
self.assertEmpty(log)

def test_stateful_module(self):
init_fn, apply_fn = base.transform(lambda: CountingModule()(), state=True) # pylint: disable=unnecessary-lambda
init_fn, apply_fn = base.transform_with_state(lambda: CountingModule()()) # pylint: disable=unnecessary-lambda
params, state = init_fn(None)
self.assertEqual(state, {"counting_module": {"count": 0}})
_, state = apply_fn(params, state)
_, state = apply_fn(params, state, None)
self.assertEqual(state, {"counting_module": {"count": 10}})

def test_without_state(self):
init_fn, apply_fn = base.without_state(
base.transform(lambda: ScalarModule()(), apply_rng=True, state=True)) # pylint: disable=unnecessary-lambda
base.transform_with_state(lambda: ScalarModule()())) # pylint: disable=unnecessary-lambda
params = init_fn(None)
out = apply_fn(params, None)
self.assertEqual(out, 0)

def test_without_state_raises_if_state_used(self):
init_fn, _ = base.without_state(
base.transform(lambda: CountingModule()(), apply_rng=True, state=True)) # pylint: disable=unnecessary-lambda
base.transform_with_state(lambda: CountingModule()())) # pylint: disable=unnecessary-lambda
with self.assertRaisesRegex(ValueError, "without_state.*used state"):
init_fn(None)

Expand Down
Loading

0 comments on commit d347a54

Please sign in to comment.