Skip to content

Commit

Permalink
FunMCMC: Add Simple Dual Averaging implementation.
Browse files Browse the repository at this point in the history
As far as I can tell, this algorithm is implemented correctly (cross checked with publications as well as tfp.mcmc.DualAveragingSSA), but it's pretty clear that it's extremely sensitive to the Lipschitz constant of the function being optimized, and is *extremely* slow to converge. It probably makes sense for baselines, but I wouldn't ever reach for it over Adam or other modern subgradient descent algorithms.

PiperOrigin-RevId: 293081652
  • Loading branch information
SiegeLordEx authored and tensorflower-gardener committed Feb 4, 2020
1 parent da77d8a commit deaa8ab
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 10 deletions.
8 changes: 8 additions & 0 deletions discussion/fun_mcmc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,10 @@
from discussion.fun_mcmc.fun_mcmc_lib import RunningVarianceState
from discussion.fun_mcmc.fun_mcmc_lib import ruth4_step
from discussion.fun_mcmc.fun_mcmc_lib import sign_adaptation
from discussion.fun_mcmc.fun_mcmc_lib import simple_dual_averages_init
from discussion.fun_mcmc.fun_mcmc_lib import simple_dual_averages_step
from discussion.fun_mcmc.fun_mcmc_lib import SimpleDualAveragesExtra
from discussion.fun_mcmc.fun_mcmc_lib import SimpleDualAveragesState
from discussion.fun_mcmc.fun_mcmc_lib import spliting_integrator_step
from discussion.fun_mcmc.fun_mcmc_lib import State
from discussion.fun_mcmc.fun_mcmc_lib import trace
Expand Down Expand Up @@ -147,6 +151,10 @@
'ruth4_step',
'set_backend',
'sign_adaptation',
'simple_dual_averages_init',
'simple_dual_averages_step',
'SimpleDualAveragesExtra',
'SimpleDualAveragesState',
'spliting_integrator_step',
'State',
'TENSORFLOW',
Expand Down
126 changes: 118 additions & 8 deletions discussion/fun_mcmc/fun_mcmc_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@
'call_fn',
'call_potential_fn',
'call_potential_fn_with_grads',
'call_transition_operator',
'call_transport_map',
'call_transport_map_with_ldj',
'call_transition_operator',
'gaussian_momentum_sample',
'gradient_descent_step',
'GradientDescentExtra',
Expand Down Expand Up @@ -107,6 +107,10 @@
'RunningVarianceState',
'ruth4_step',
'sign_adaptation',
'simple_dual_averages_init',
'simple_dual_averages_step',
'SimpleDualAveragesExtra',
'SimpleDualAveragesState',
'spliting_integrator_step',
'State',
'trace',
Expand Down Expand Up @@ -451,6 +455,7 @@ def call_transport_map_with_ldj(
extra: Second output of `fn`.
ldj: Log-det jacobian of `fn`.
"""

def wrapper(args):
return call_transport_map(fn, args)

Expand Down Expand Up @@ -574,8 +579,7 @@ def wrapper(*args, **kwargs):
state, map_extra, ldj = call_transport_map_with_ldj(
transport_map_fn, transformed_state)
else:
state, map_extra = call_transport_map(transport_map_fn,
transformed_state)
state, map_extra = call_transport_map(transport_map_fn, transformed_state)

potential, extra = call_potential_fn(potential_fn, state)

Expand Down Expand Up @@ -1538,7 +1542,6 @@ def _one_part(state, g, learning_rate):
RandomWalkMetropolisState = collections.namedtuple(
'RandomWalkMetropolisState', 'state, target_log_prob, state_extra')


RandomWalkMetropolisExtra = collections.namedtuple(
'RandomWalkMetropolisExtra',
'is_accepted, log_accept_ratio, proposal_extra, proposed_rwm_state')
Expand Down Expand Up @@ -1757,15 +1760,18 @@ def running_covariance_init(shape: IntTensor,
num_points=util.map_tree(lambda _: tf.zeros([], tf.int32), dtype),
mean=util.map_tree_up_to(dtype, tf.zeros, shape, dtype),
covariance=util.map_tree_up_to(
dtype, lambda shape, dtype: tf.zeros( # pylint: disable=g-long-lambda
dtype,
lambda shape, dtype: tf.zeros( # pylint: disable=g-long-lambda
tf.concat(
[
tf.convert_to_tensor(shape),
tf.convert_to_tensor(shape[-1:]),
],
axis=0,
),
dtype=dtype), shape, dtype),
dtype=dtype),
shape,
dtype),
)


Expand Down Expand Up @@ -1973,8 +1979,7 @@ def potential_scale_reduction_init(shape,
# We are wrapping running variance so that the user doesn't get the chance to
# set the reduction axis, which would break the assumptions of
# `potential_scale_reduction_extract`.
return PotentialScaleReductionState(
*running_variance_init(shape, dtype))
return PotentialScaleReductionState(*running_variance_init(shape, dtype))


def potential_scale_reduction_step(
Expand Down Expand Up @@ -2273,3 +2278,108 @@ def inner_grad_fn(*_):
return grad_wrapper(*util.flatten_tree((args, kwargs)))

return loss_fn


SimpleDualAveragesState = collections.namedtuple(
'SimpleDualAveragesState', 'state, step, grad_running_mean_state')
SimpleDualAveragesExtra = collections.namedtuple('SimpleDualAveragesExtra',
'loss, loss_extra, grads')


def simple_dual_averages_init(
state: FloatNest,
grad_mean_smoothing_steps: IntNest = 0,
) -> SimpleDualAveragesState:
"""Initialize Simple Dual Averages state.
Note that the `state` argument only affects the initial value read from the
state, it has no effect on any other step of the algorithm. Typically, you'd
set this to the same value as `shrink_point`.
Args:
state: The state of the problem.
grad_mean_smoothing_steps: Smoothes out the initial gradient running mean.
For some algorithms it improves stability to make this non-zero.
Returns:
sda_state: `SimpleDualAveragesState`.
"""
grad_rms = running_mean_init(
util.map_tree(lambda s: s.shape, state),
util.map_tree(lambda s: s.dtype, state))
grad_rms = grad_rms._replace(
num_points=util.map_tree(lambda _: grad_mean_smoothing_steps,
grad_rms.num_points))

return SimpleDualAveragesState(
state=state,
# The algorithm assumes this starts at 1.
step=1,
grad_running_mean_state=grad_rms,
)


def simple_dual_averages_step(
sda_state: SimpleDualAveragesState,
loss_fn: PotentialFn,
shrink_weight: FloatNest,
shrink_point: State = 0.,
) -> Tuple[SimpleDualAveragesState, SimpleDualAveragesExtra]:
"""Performs one step of the Simple Dual Averages algorithm [1].
This function implements equation 3.4 from [1], with the following choices:
```none
d(x) = 0.5 * (x - shrink_point)**2
mu_k = shrink_weight / step**0.5
```
Strictly speaking, this algorithm only applies to convex problems. The
`loss_fn` need not have true gradients: sub-gradients are sufficient. The
sequence of `state` is not actually convergent. To get a convergent sequence,
you can compute a running mean of `state` (e.g. using `running_mean_step`),
although that is not the sole choice.
Args:
sda_state: `SimpleDualAveragesState`.
loss_fn: A function whose output will be minimized.
shrink_weight: Weight of the shrinkage term. Must broadcast with `state`.
shrink_point: Where the algorithm initially shrinks `state` to. Must
broadcast with `state`.
Returns:
sda_state: `SimpleDualAveragesState`.
sda_extra: `SimpleDualAveragesExtra`.
#### References
[1]: Nesterov, Y. (2009). Primal-dual subgradient methods for convex problems.
Mathematical Programming, 120(1), 221-259.
"""
state = sda_state.state
step = sda_state.step
step_f = tf.cast(step, tf.float32)
shrink_point = maybe_broadcast_structure(shrink_point, state)
shrink_weight = maybe_broadcast_structure(shrink_weight, state)

loss, loss_extra, grads = call_potential_fn_with_grads(loss_fn, state)

grad_rms, _ = running_mean_step(sda_state.grad_running_mean_state, grads)

def _one_part(shrink_point, shrink_weight, grad_running_mean):
return shrink_point - tf.sqrt(step_f) / shrink_weight * grad_running_mean

state = util.map_tree(_one_part, shrink_point, shrink_weight, grad_rms.mean)

sda_state = SimpleDualAveragesState(
state=state,
step=step + 1,
grad_running_mean_state=grad_rms,
)
sda_extra = SimpleDualAveragesExtra(
loss=loss,
loss_extra=loss_extra,
grads=grads,
)

return sda_state, sda_extra
23 changes: 23 additions & 0 deletions discussion/fun_mcmc/fun_mcmc_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,29 @@ def loss_fn(x, y):
self.assertAllClose(2., y[-1], atol=1e-3)
self.assertAllClose(0., loss[-1], atol=1e-3)

def testSimpleDualAverages(self):

def loss_fn(x, y):
return tf.square(x - 1.) + tf.square(y - 2.), []

def kernel(sda_state, rms_state):
sda_state, _ = fun_mcmc.simple_dual_averages_step(sda_state, loss_fn, 1.)
rms_state, _ = fun_mcmc.running_mean_step(rms_state, sda_state.state)
return (sda_state, rms_state), rms_state.mean

_, (x, y) = fun_mcmc.trace(
(
fun_mcmc.simple_dual_averages_init([tf.zeros([]),
tf.zeros([])]),
fun_mcmc.running_mean_init([[], []], [tf.float32, tf.float32]),
),
kernel,
num_steps=1000,
)

self.assertAllClose(1., x[-1], atol=1e-1)
self.assertAllClose(2., y[-1], atol=1e-1)

def testRandomWalkMetropolis(self):
num_steps = 1000
state = tf.ones([16], dtype=tf.int32)
Expand Down
5 changes: 3 additions & 2 deletions discussion/fun_mcmc/tf_on_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,14 +156,15 @@ def _range(*args, **kwargs):
_impl_np()(np.einsum)
_impl_np()(np.float32)
_impl_np()(np.int32)
_impl_np()(np.maximum)
_impl_np()(np.minimum)
_impl_np()(np.ones)
_impl_np()(np.reshape)
_impl_np()(np.shape)
_impl_np()(np.sqrt)
_impl_np()(np.where)
_impl_np()(np.zeros)
_impl_np()(np.zeros_like)
_impl_np()(np.maximum)
_impl_np()(np.minimum)
_impl_np(['math'])(np.log)
_impl_np(['math'])(np.sqrt)
_impl_np(['math'], name='pow')(np.power)
Expand Down

0 comments on commit deaa8ab

Please sign in to comment.