Skip to content

Commit

Permalink
Support dropout and rngs in new efficient attention code
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 293654213
  • Loading branch information
trax-robot authored and copybara-github committed Feb 6, 2020
1 parent 84cdad4 commit e7a599b
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 46 deletions.
2 changes: 1 addition & 1 deletion trax/configs/reformer_enwik8.gin
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ train.checkpoints_at = \

# Parameters for SelfAttention:
# ==============================================================================
SelfAttention.attention_dropout = 0.2
# SelfAttention.chunk_len: see top
SelfAttention.n_chunks_after = 0
# SelfAttention.n_chunks_before: see top
# TODO(kitaev): attention dropout, when implemented

# Parameters for LSHSelfAttention:
# ==============================================================================
Expand Down
101 changes: 67 additions & 34 deletions trax/layers/research/efficient_attention_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ def attend(
return out, dots_logsumexp


def apply_broadcasted_dropout(vecs, dropout_rate, rng):
"""Apply dropout, broadcasted across all but the last dimension of `vecs`."""
if dropout_rate > 0.0:
assert rng is not None
keep_prob = jax.lax.tie_in(vecs, 1.0 - dropout_rate)
keep = jax.random.bernoulli(rng, keep_prob, (vecs.shape[-1],))
multiplier = keep.astype(vecs.dtype) / jax.lax.tie_in(keep, keep_prob)
return vecs * multiplier
else:
return vecs


def permute_via_gather(val, permutation, inverse_permutation, axis=0):
"""Permutation helper for LSH attention."""
def permute_impl(val):
Expand Down Expand Up @@ -335,15 +347,19 @@ def forward_with_state(self, inputs, weights, state, rng=None):
inputs: Layer inputs (subclasses may use different inputs)
weights: Layer weights
state: Complete state of the layer
rng: PRNG key
rng: PRNG key. Note that the RNG is shared across all examples and heads.
This sharing is useful to reduce memory usage for dropout (all dropout
instances are automatically broadcasted across the batch and head
dimensions). Attention types that need separate random numbers for each
example and head may store their own RNG in the model state.
Returns:
A tuple (output, new_state).
"""
if not self.use_reference_code:
# By default, an efficient, batched implementation is used.
output, new_state, _, _ = self.forward_and_or_backward(
inputs, weights, state, compute_output=True, update_state=True)
inputs, weights, state, rng, compute_output=True, update_state=True)
return output, new_state

# The reference implementation below provides a more readable overview of
Expand All @@ -365,7 +381,7 @@ def forward_with_state(self, inputs, weights, state, rng=None):
lambda s: s[example_idx * self.n_heads + head_idx], state)
# pylint: enable=cell-var-from-loop
single_out, single_new_state = self.forward_unbatched(
*single_inputs, weights=single_weights, state=single_state,
*single_inputs, weights=single_weights, rng=rng, state=single_state,
update_state=True)
new_state.append(single_new_state)
output_accum[example_idx] = output_accum[example_idx] + single_out
Expand All @@ -388,12 +404,12 @@ def backward(self, inputs, output, grad, weights, state, new_state, rng=None,
assert not self.use_reference_code
del output, state, kwargs
_, _, inputs_grad, weights_grad = self.forward_and_or_backward(
inputs, weights, new_state, output_grad=grad,
inputs, weights, new_state, rng, output_grad=grad,
compute_output=False, update_state=False)
return inputs_grad, weights_grad

def forward_and_or_backward(
self, inputs, weights, state, output_grad=None,
self, inputs, weights, state, rng, output_grad=None,
compute_output=True, update_state=True):
"""Performs batched forward and/or backward passes.
Expand All @@ -405,6 +421,7 @@ def forward_and_or_backward(
inputs: inputs to the attention layer
weights: weights for the attention layer
state: state of the attention layer
rng: PRNG key for the layer (shared across all examples and heads)
output_grad: gradient of the loss wrt the output of the layer, or None.
This function performs the backward pass iff `output_grad` is not None.
compute_output: bool: whether to return the output of the forward pass
Expand Down Expand Up @@ -442,9 +459,6 @@ def forward_and_or_backward(
for example, head in zip(examples, heads):
backward(example, head)
"""
# TODO(kitaev): support non-differentiable inputs (for enc-dec attn masking)
# TODO(kitaev): support RNGs (needed for dropout and LSH). Currently LSH
# hacks around this by storing an RNG in its state
# TODO(kitaev): profile ~4% speed drop compared to previous implementation
# in some conditions. Other conditions (e.g. the enwik8 model) appear
# to have the same overall training speed.
Expand Down Expand Up @@ -522,7 +536,7 @@ def run_inner(idx, loop_val):

def forward_fn(i_h, w_h):
return self.forward_unbatched(
*i_h, weights=w_h, state=jax.lax.stop_gradient(s_h),
*i_h, weights=w_h, state=jax.lax.stop_gradient(s_h), rng=rng,
update_state=update_state)

if compute_grad:
Expand Down Expand Up @@ -559,7 +573,7 @@ def run_inner(idx, loop_val):
s_mh = jax.tree_map(lambda s: s[state_range], state)
def forward_unbatched(i_h, w_h, s_h):
return self.forward_unbatched(
*i_h, weights=w_h, state=s_h, update_state=update_state)
*i_h, weights=w_h, state=s_h, rng=rng, update_state=update_state)
def forward_fn(i_mh, w_mh):
o_mh, new_s_mh = jax.vmap(
forward_unbatched, in_axes=(None, 0, 0), out_axes=0)(
Expand Down Expand Up @@ -588,7 +602,7 @@ def forward_fn(i_mh, w_mh):
def forward_single_example(i_x, w_all, s_x):
def forward_unbatched(i_h, w_h, s_h):
return self.forward_unbatched(
*i_h, weights=w_h, state=s_h, update_state=update_state)
*i_h, weights=w_h, state=s_h, rng=rng, update_state=update_state)
o_x, s_x = jax.vmap(
forward_unbatched, in_axes=(None, 0, 0), out_axes=(0, 0))(
i_x, w_all, s_x)
Expand Down Expand Up @@ -682,6 +696,7 @@ def __init__(self,
chunk_len=None, n_chunks_before=0, n_chunks_after=0,
mode='train',
attention_dropout=0.0,
output_dropout=0.0,
n_parallel_heads=None,
use_python_loop=False,
use_reference_code=False,
Expand All @@ -707,6 +722,7 @@ def __init__(self,
is never a strict no-op.
mode: 'train' or 'eval'
attention_dropout: Dropout probability for attention mask.
output_dropout: Dropout probability for the layer output.
n_parallel_heads: see EfficientAttentionBase. This option controls the
trade-off between parallelism and memory usage.
use_python_loop: For testing/debugging (see EfficientAttentionBase)
Expand All @@ -728,9 +744,12 @@ def __init__(self,
self.n_chunks_before = n_chunks_before
self.n_chunks_after = n_chunks_after
self.mode = mode
self.attention_dropout = attention_dropout
if self.attention_dropout != 0.0:
raise NotImplementedError('RNG support not implemented yet.')
if mode == 'train':
self.attention_dropout = attention_dropout
self.output_dropout = output_dropout
else:
self.attention_dropout = 0.0
self.output_dropout = 0.0

def _kernel_initializer(self, shape, rng):
# Attention uses Glorot uniform initalization with respect to the *total*
Expand All @@ -756,8 +775,10 @@ def create_weights_unbatched(self, input_signature, rng):
else:
return (w_q, w_k, w_v, w_o)

def forward_unbatched(self, x, mask=None, *, weights, state, update_state):
def forward_unbatched(self, x, mask=None, *,
weights, state, rng, update_state):
del update_state
attend_rng, output_rng = jax.random.split(rng)
if self.share_qk:
w_q, w_v, w_o = weights
else:
Expand Down Expand Up @@ -787,10 +808,11 @@ def forward_unbatched(self, x, mask=None, *, weights, state, update_state):
n_chunks_before=self.n_chunks_before,
n_chunks_after=self.n_chunks_after,
mask_fn=mask_fn, q_info=q_info, kv_info=kv_info,
dropout=self.attention_dropout, rng=None, # TODO(kitaev): support RNG
dropout=self.attention_dropout, rng=attend_rng,
)

out = np.matmul(o, w_o)
out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
return out, state


Expand All @@ -805,6 +827,7 @@ def __init__(self,
n_buckets=256,
mode='train',
attention_dropout=0.0,
output_dropout=0.0,
n_parallel_heads=1,
use_python_loop=False,
use_reference_code=False,
Expand All @@ -816,20 +839,25 @@ def __init__(self,
chunk_len=chunk_len,
n_chunks_before=n_chunks_before, n_chunks_after=n_chunks_after,
mode=mode,
attention_dropout=0.0, # Base class does not support dropout yet.
attention_dropout=attention_dropout,
output_dropout=output_dropout,
n_parallel_heads=n_parallel_heads,
use_python_loop=use_python_loop,
use_reference_code=use_reference_code,
)
self.n_hashes = n_hashes
self.n_buckets = n_buckets
self.attention_dropout = attention_dropout

def create_state_unbatched(self, input_signature, rng):
if isinstance(input_signature, (tuple, list)):
input_signature = input_signature[0]
buckets = np.zeros(self.n_hashes * input_signature.shape[0], dtype=np.int32)
# TODO(kitaev): storing RNG in the state is a HACK.
# The `rng` argument passed to forward_unbatched is shared across all
# examples and heads. This facilitates using broadcasted dropout, which
# saves memory and hasn't been shown to hurt model quality. Even though the
# same sharing is likely to be safe when selecting random hash functions
# for LSH, we haven't run experiments to demonstrate this. To be on the safe
# side we include a per-head RNG in the state for the purpose of doing LSH.
return (buckets, rng)

def hash_vectors(self, vecs, rng):
Expand Down Expand Up @@ -878,22 +906,20 @@ def hash_vectors(self, vecs, rng):

return buckets

def forward_unbatched(self, x, *, weights, state, update_state):
def forward_unbatched(self, x, *, weights, state, rng, update_state):
attend_rng, output_rng = jax.random.split(rng)
w_q, w_v, w_o = weights

q = np.matmul(x, w_q)
v = np.matmul(x, w_v)

if update_state:
_, old_rng = state
rng = jax.random.fold_in(old_rng, 0)
hash_rng = jax.random.fold_in(rng, 1)
buckets = self.hash_vectors(q, hash_rng)
state = (buckets, rng)
_, old_hash_rng = state
hash_rng, hash_subrng = jax.random.split(old_hash_rng)
buckets = self.hash_vectors(q, hash_subrng)
state = (buckets, hash_rng)
else:
buckets, rng = state

rng = jax.random.fold_in(rng, 2)
buckets, _ = state

seqlen = x.shape[0]
assert int(buckets.shape[0]) == self.n_hashes * seqlen
Expand Down Expand Up @@ -923,7 +949,7 @@ def forward_unbatched(self, x, *, weights, state, update_state):
n_chunks_before=self.n_chunks_before,
n_chunks_after=self.n_chunks_after,
mask_fn=mask_fn, q_info=q_info,
dropout=self.attention_dropout, rng=rng,
dropout=self.attention_dropout, rng=attend_rng,
)

# np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would
Expand All @@ -939,6 +965,7 @@ def forward_unbatched(self, x, *, weights, state, update_state):

assert o.shape == (seqlen, w_v.shape[-1])
out = np.matmul(o, w_o)
out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
return out, state


Expand All @@ -950,6 +977,7 @@ def __init__(self,
masked=True,
mode='train',
attention_dropout=0.0,
output_dropout=0.0,
n_parallel_heads=None,
use_python_loop=False,
use_reference_code=False,
Expand All @@ -965,9 +993,12 @@ def __init__(self,
self.d_v = d_v
self.masked = masked
self.mode = mode
self.attention_dropout = attention_dropout
if self.attention_dropout != 0.0:
raise NotImplementedError('RNG support not implemented yet.')
if mode == 'train':
self.attention_dropout = attention_dropout
self.output_dropout = output_dropout
else:
self.attention_dropout = 0.0
self.output_dropout = 0.0

def _kernel_initializer(self, shape, rng):
# Attention uses Glorot uniform initalization with respect to the *total*
Expand All @@ -989,8 +1020,9 @@ def create_weights_unbatched(self, input_signature, rng):
return (w_q, w_k, w_v, w_o)

def forward_unbatched(self, q_antecedent, kv_antecedent, mask=None, *,
weights, state, update_state):
weights, state, rng, update_state):
del update_state
attend_rng, output_rng = jax.random.split(rng)
w_q, w_k, w_v, w_o = weights

q = np.matmul(q_antecedent, w_q)
Expand All @@ -1014,8 +1046,9 @@ def mask_fn(dots, q_info, kv_info):
o, _ = attend(
q, k, v,
mask_fn=mask_fn, q_info=q_info, kv_info=kv_info,
dropout=self.attention_dropout, rng=None, # TODO(kitaev): support RNG
dropout=self.attention_dropout, rng=attend_rng,
)

out = np.matmul(o, w_o)
out = apply_broadcasted_dropout(out, self.output_dropout, output_rng)
return out, state
4 changes: 2 additions & 2 deletions trax/layers/research/efficient_attention_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def test_batching_self_attention(self):
common_kwargs = dict(
n_heads=6, d_qk=7, d_v=17, share_qk=False, causal=True,
chunk_len=5, n_chunks_before=1, n_chunks_after=0,
attention_dropout=0.0, mode='train',
attention_dropout=0.2, output_dropout=0.1, mode='train',
)
test_kwargs = []
for n_parallel_heads in [1, 3, 6, 12]:
Expand All @@ -117,7 +117,7 @@ def test_batching_lsh_self_attention(self):
n_heads=6, d_qk=7, d_v=17, causal=True,
chunk_len=5, n_chunks_before=1, n_chunks_after=0,
n_hashes=2, n_buckets=4,
attention_dropout=0.2, mode='train',
attention_dropout=0.2, output_dropout=0.1, mode='train',
)
test_kwargs = []
for n_parallel_heads in [1, 3, 6, 12]:
Expand Down
15 changes: 6 additions & 9 deletions trax/models/reformer/reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ def call_compute_residual(x, weights):
inputs = _inputs_from_stack(self.attention_layer, stack)
(residual, _, attn_inputs_ct, attn_weights_ct
) = self.attention_layer.forward_and_or_backward(
inputs, weights[1], new_state[1], output_grad=accumulator_output_ct,
inputs, weights[1], new_state[1], rngs[1],
output_grad=accumulator_output_ct,
compute_output=True, update_state=False)
stack_ct = _outputs_onto_stack(
self.attention_layer, attn_inputs_ct, stack_ct,
Expand Down Expand Up @@ -678,11 +679,10 @@ def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
else:
attention = attention_type(
n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value,
share_qk=share_qk, causal=True, mode=mode)
share_qk=share_qk, causal=True, output_dropout=dropout, mode=mode)
attention_half_residual = ReversibleHalfResidualV2(
tl.LayerNorm(),
attention_layer=attention,
# TODO(kitaev): add output dropout to attention layer.
)

if ff_use_sru:
Expand Down Expand Up @@ -962,12 +962,11 @@ def EncoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, mode):
attention = tl.SelfAttention(
n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
masked=True,
attention_dropout=0.0, # TODO(kitaev): attention dropout
attention_dropout=dropout, output_dropout=dropout,
mode=mode)
attention_half_residual = ReversibleHalfResidualV2(
tl.LayerNorm(),
attention_layer=attention,
# TODO(kitaev): add output dropout to attention layer. rate=dropout
)

# TODO(kitaev): Switch to FeedForward with BroadcastedDropout?
Expand Down Expand Up @@ -999,23 +998,21 @@ def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, mode):
"""
enc_dec_attention = tl.EncDecAttention(
n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
attention_dropout=0.0, # TODO(kitaev): attention dropout
attention_dropout=dropout, output_dropout=dropout,
mode=mode)
enc_dec_attention_half_residual = ReversibleHalfResidualV2(
tl.LayerNorm(),
attention_layer=enc_dec_attention,
# TODO(kitaev): add output dropout to attention layer. rate=dropout
)

causal_attention = tl.SelfAttention(
n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
causal=True,
attention_dropout=0.0, # TODO(kitaev): attention dropout
attention_dropout=dropout, output_dropout=dropout,
mode=mode)
causal_attention_half_residual = ReversibleHalfResidualV2(
tl.LayerNorm(),
attention_layer=causal_attention,
# TODO(kitaev): add output dropout to attention layer. rate=dropout
)

feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation, mode)
Expand Down

0 comments on commit e7a599b

Please sign in to comment.