Skip to content

Commit

Permalink
[Oryx] Add AddN expression and addition rewrite rules
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 396895039
  • Loading branch information
sharadmv authored and tensorflower-gardener committed Sep 15, 2021
1 parent 075b22f commit dc45507
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 4 deletions.
32 changes: 32 additions & 0 deletions spinoffs/oryx/oryx/experimental/autoconj/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ py_library(
srcs = ["canonicalize.py"],
srcs_version = "PY3",
deps = [
":addn",
":einsum",
# jax dep,
"//oryx/experimental/matching:jax_rewrite",
Expand All @@ -45,6 +46,18 @@ py_library(
],
)

# pytype_strict
py_library(
name = "addn",
srcs = ["addn.py"],
srcs_version = "PY3",
deps = [
# jax dep,
"//oryx/experimental/matching:jax_rewrite",
"//oryx/experimental/matching:matcher",
],
)

# py_strict
py_test(
name = "einsum_test",
Expand All @@ -70,6 +83,7 @@ py_test(
python_version = "PY3",
srcs_version = "PY3",
deps = [
":addn",
":canonicalize",
":einsum",
# absl/testing:absltest dep,
Expand All @@ -79,3 +93,21 @@ py_test(
"//oryx/internal:test_util",
],
)

# py_strict
py_test(
name = "addn_test",
srcs = ["addn_test.py"],
python_version = "PY3",
srcs_version = "PY3",
deps = [
":addn",
# absl/testing:absltest dep,
# jax dep,
# numpy dep,
"//oryx/experimental/matching:jax_rewrite",
"//oryx/experimental/matching:matcher",
"//oryx/experimental/matching:rules",
"//oryx/internal:test_util",
],
)
115 changes: 115 additions & 0 deletions spinoffs/oryx/oryx/experimental/autoconj/addn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright 2021 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Contains the `AddN` expression.
The `AddN` expression represents a sum of operands. JAX only has a binary
`add` primitive, meaning a sequence of adds is represented as an expression
tree of `add` primitives. In `autoconj`, we'd like to roll all the `add`s into
a single expression to simplify rewrite rules and to represent a canonicalized
density function. Thus we use `AddN` to represent a flat sum of operands.
"""
import dataclasses
import functools
import operator

from typing import Any, Dict, Iterator, Tuple, Union

import jax
import jax.numpy as jnp

from oryx.experimental.matching import jax_rewrite as jr
from oryx.experimental.matching import matcher

__all__ = [
'AddN',
]

Bindings = matcher.Bindings
Continuation = matcher.Continuation
Expr = matcher.Expr
Pattern = matcher.Pattern
Success = matcher.Success


@dataclasses.dataclass(frozen=True)
class AddN(jr.JaxExpression):
"""Adds several children expressions.
JAX's `add` primitive is binary so adding several terms must be represented
as a tree of `add`s. `AddN` is a "flat" expression representation of adding
several subexpressions which is more convenient for pattern matching and
term rewriting.
Attributes:
operands: A tuple of expressions to be added together when evaluating
the `AddN` expression.
"""
operands: Union[Pattern, Tuple[Any, ...]]

@functools.lru_cache(None)
def shape_dtype(self) -> jax.ShapeDtypeStruct:
"""Computes the shape and dtype of the result of this `AddN`.
Returns:
A `jax.ShapeDtypeStruct` object describing the shape and dtype of the
`AddN`.
"""
operand_shape_dtypes = tuple(
jax.ShapeDtypeStruct(operand.shape, operand.dtype)
for operand in self.operands)

def _eval_fun(*args):
return functools.reduce(operator.add, args)

return jax.eval_shape(_eval_fun, *operand_shape_dtypes)

@property
def shape(self) -> Tuple[int, ...]:
return self.shape_dtype().shape

@property
def dtype(self) -> jnp.dtype:
return self.shape_dtype().dtype

# Matching methods

def match(self, expr: Expr, bindings: Bindings,
succeed: Continuation) -> Success:
"""Matches the formula and operands of an `AddN`."""
if not isinstance(expr, AddN):
return
yield from matcher.matcher(self.operands)(expr.operands, bindings, succeed)

# Rules methods

def tree_map(self, fn) -> 'AddN':
"""Maps a function across the operands of an `AddN`."""
return AddN(tuple(map(fn, self.operands)))

def tree_children(self) -> Iterator[Any]:
"""Returns an iterator over the operands of an `AddN`."""
yield from self.operands

# JAX rewriting methods

def evaluate(self, env: Dict[str, Any]) -> Any:
"""Evaluates an `AddN` in an environment."""
operands = jr.evaluate(self.operands, env)
return functools.reduce(operator.add, operands)

# Builtin methods

def __str__(self) -> str:
return f'(addn {self.operands})'
79 changes: 79 additions & 0 deletions spinoffs/oryx/oryx/experimental/autoconj/addn_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright 2021 The TensorFlow Probability Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tests for tensorflow_probability.spinoffs.oryx.experimental.autoconj.addn."""
from absl.testing import absltest

import jax.numpy as jnp

import numpy as np

from oryx.experimental.autoconj import addn
from oryx.experimental.matching import jax_rewrite as jr
from oryx.experimental.matching import matcher
from oryx.experimental.matching import rules
from oryx.internal import test_util

Var = matcher.Var
Segment = matcher.Segment
JaxVar = jr.JaxVar
AddN = addn.AddN


class AddNTest(test_util.TestCase):

def test_can_match_addn_components(self):
x = JaxVar('x', (5,), jnp.float32)
op = AddN((x, x))
pattern = AddN((matcher.Segment('args'),))
self.assertDictEqual(
matcher.match(pattern, op), {
'args': (x, x)
})

def test_can_replace_addn_operands(self):
x = JaxVar('x', (5,), jnp.float32)
y = JaxVar('y', (5,), jnp.float32)
z = JaxVar('y', (5,), jnp.float32)
op = AddN((x, y))
pattern = AddN((matcher.Segment('args'),))
def replace_with_z(args):
del args
return AddN((z, z))
replace_rule = rules.make_rule(pattern, replace_with_z)
replaced_op = replace_rule(op)
self.assertEqual(replaced_op, AddN((z, z)))

def test_addn_correctly_infers_shape_and_dtype(self):
x = JaxVar('x', (5, 2), jnp.float32)
y = JaxVar('y', (5, 2), jnp.float32)
op = AddN((x, y))
self.assertEqual(op.dtype, jnp.float32)
self.assertTupleEqual(op.shape, (5, 2))

def test_addn_evaluates_to_correct_value(self):
x = JaxVar('x', (5, 2), jnp.float32)
y = JaxVar('y', (5, 2), jnp.float32)
z = JaxVar('z', (5, 2), jnp.float32)
op = AddN((x, y, z))
x_val = jnp.arange(10.).reshape((5, 2))
y_val = jnp.arange(10., 20.).reshape((5, 2))
z_val = jnp.arange(20., 30.).reshape((5, 2))
np.testing.assert_allclose(
op.evaluate(dict(x=x_val, y=y_val, z=z_val)),
x_val + y_val + z_val)


if __name__ == '__main__':
absltest.main()
54 changes: 52 additions & 2 deletions spinoffs/oryx/oryx/experimental/autoconj/canonicalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,31 @@
Because JAX does not have an einsum primitive, some rules are dedicated to
converting existing linear function primitives in JAX (like dot products,
transposes, etc.) into into `Einsum`s.
## Rewrites
Here we broadly explain the canonicalization rules.
* `Einsum` rewrite rules
Linear functions can be expressed as an `Einsum`. An `Einsum` rewrite pass takes
several JAX primitives (`transpose`, `broadcast`, `dot_general`, etc.) and
converts them into `Einsum`s. It also combines nested `Einsum`s into a single
`Einsum`.
* `AddN` rewrite rules
With the eventual goal of rewriting a density into "sum of einsums", we want
to turn nested additions into a single flat addition operation. The `AddN`
rewrite pass turns nested additions into a single `AddN`.
"""
import itertools as it

from typing import Any, Callable

from jax import lax

from oryx.experimental.autoconj import addn
from oryx.experimental.autoconj import einsum
from oryx.experimental.matching import jax_rewrite as jr
from oryx.experimental.matching import matcher
Expand All @@ -43,7 +61,7 @@
'canonicalize',
]


AddN = addn.AddN
Einsum = einsum.Einsum
Var = matcher.Var
Segment = matcher.Segment
Expand Down Expand Up @@ -136,12 +154,44 @@ def compose_einsums(parent_formula, left_args, child_formula, child_args,
right_args)


CANONICALIZATION_RULES = (
REWRITE_TO_EINSUM_RULES = (
transpose_as_einsum,
dot_as_einsum,
squeeze_as_einsum,
reduce_sum_as_einsum,
compose_einsums,
)


def _add(x, y):
return Primitive(lax.add_p, (x, y), Params())


_add_to_addn_pattern = _add(Var('x'), Var('y'))


@register_rule(_add_to_addn_pattern)
def add_to_addn(x, y):
return AddN((x, y))


_addn_of_addn_pattern = AddN(
(Segment('args1'), AddN(Var('child_args')), Segment('args2')))


@register_rule(_addn_of_addn_pattern)
def addn_of_addn_to_addn(args1, child_args, args2):
return AddN((*args1, *child_args, *args2))


REWRITE_TO_ADDN_RULES = (
add_to_addn,
addn_of_addn_to_addn,
)

CANONICALIZATION_RULES = (
*REWRITE_TO_EINSUM_RULES,
*REWRITE_TO_ADDN_RULES,
)

canonicalize = rules.term_rewriter(*CANONICALIZATION_RULES)
Loading

0 comments on commit dc45507

Please sign in to comment.