Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 465060096
  • Loading branch information
tensorflower-gardener committed Aug 3, 2022
1 parent c2ddc0b commit 0028cbe
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 19 deletions.
105 changes: 89 additions & 16 deletions official/nlp/modeling/layers/kernel_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,69 @@ def split_tensor_into_chunks(tensor, axis, chunk_length):
return tf.reshape(tensor, new_shape)


def rectangular_window_sum(tensor, window_length):
"""Summarizes tensor elements over a sliding rectangular window.
Sums elements of the input tensor of shape [B, T', C', H, dim]
across a rectangular window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the rectangular window.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
tensor_cumsum = tf.cumsum(tensor, axis=-4)
tensor_winsum = tensor_cumsum - tf.pad(
tensor_cumsum,
[[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length]
return tensor_winsum


def weighted_window_sum(tensor, window_length, window_weights):
"""Summarizes tensor elements over a sliding weighted window.
Computes a weighted sum of elements of the input tensor of shape [B,
T', C', H, dim] across a window sliding along the dimension T'.
Args:
tensor: Tensor of shape `[B, T', C', H, dim]`.
window_length: The length of the window.
window_weights: Tensor of shape [window_length] containing window
weights.
Returns:
A tensor of shape [B, T', C', H, dim] containing sums over the
window.
"""
# Flatten the last three dimensions of the [B, T', C', H, dim] shape
# into a single channels dimension.
tensor_shape = tf.shape(tensor)
tensor_2d = tf.reshape(tensor, [tensor_shape[0], tensor_shape[1], 1, -1])

# Apply the same weights to all channels.
conv_filter = tf.tile(
tf.reshape(window_weights, [-1, 1, 1, 1]),
multiples=[1, 1, tf.shape(tensor_2d)[-1], 1])
tensor_winsum_2d = tf.nn.depthwise_conv2d(
tensor_2d,
conv_filter,
strides=[1, 1, 1, 1],
padding=[[0, 0], [window_length - 1, 0], [0, 0], [0, 0]])

# Unflatten the channels dimension into the original shape.
tensor_winsum = tf.reshape(tensor_winsum_2d, tensor_shape)
return tensor_winsum


def causal_windowed_performer_attention(query_matrix,
key_matrix,
value_matrix,
chunk_length,
window_length,
window_decay=None,
padding=None):
"""Applies windowed causal kernel attention with query, key, value tensors.
Expand Down Expand Up @@ -133,19 +191,22 @@ def causal_windowed_performer_attention(query_matrix,
or right respectively and make it divisible by chunk_length.
Args:
query_matrix: Kernel query `Tensor` of shape `[B, T, N, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T, N, dim]`.
value_matrix: Value `Tensor` of shape `[B, T, N, out_dim]`.
query_matrix: Kernel query `Tensor` of shape `[B, T, H, dim]`.
key_matrix: Kernel key `Tensor` of shape `[B, T, H, dim]`.
value_matrix: Value `Tensor` of shape `[B, T, H, out_dim]`.
chunk_length: Length of each chunk in tokens.
window_length: Length of attention window in chunks.
window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this factor
before summation.
padding: Pad the query, value and key input tensors across the
axis from either left or right if padding is set to "left" or
"right"; apply no padding if padding is set to None. In the
latter case, the axis dimension of the query, value and key
input tensors must be divisible by the chunk_length.
Returns:
Window causal performer attention of shape `[B, T, N, out_dim]`.
Window causal performer attention of shape `[B, T, H, out_dim]`.
"""
old_shape = tf.shape(value_matrix)

Expand All @@ -164,19 +225,26 @@ def causal_windowed_performer_attention(query_matrix,
value_matrix, -3,
chunk_length) # [-1, T//chunk_length, chunk_length, N, out_dim]

kp_v = tf.einsum("BNCHD,BNCHO->BNHDO", chunked_key_matrix,
kp_v = tf.einsum("BTCHD,BTCHO->BTHDO", chunked_key_matrix,
chunked_value_matrix)
kp_v_cumsum = tf.cumsum(kp_v, axis=-4)
kp_v_winsum = kp_v_cumsum - tf.pad(
kp_v_cumsum,
[[0, 0], [window_length, 0], [0, 0], [0, 0], [0, 0]])[:, :-window_length]
numerator = tf.einsum("BNCHD,BNHDO->BNCHO", chunked_query_matrix, kp_v_winsum)

k_sum = tf.reduce_sum(chunked_key_matrix, axis=-3)
k_cumsum = tf.cumsum(k_sum, axis=-3)
k_winsum = k_cumsum - tf.pad(k_cumsum, [[0, 0], [window_length, 0], [0, 0],
[0, 0]])[:, :-window_length]
denominator = tf.einsum("BNCHD,BNHD->BNCH", chunked_query_matrix, k_winsum)
k_sum = tf.math.reduce_sum(chunked_key_matrix, axis=-3, keepdims=True)

if window_decay is None:
kp_v_winsum = rectangular_window_sum(kp_v, window_length)
k_winsum = rectangular_window_sum(k_sum, window_length)
else:
# Compute exponentially decaying weights.
decaying_weights = tf.math.pow(
tf.convert_to_tensor(window_decay, dtype=value_matrix.dtype),
tf.range(window_length - 1, -1, delta=-1, dtype=value_matrix.dtype))
kp_v_winsum = weighted_window_sum(kp_v, window_length, decaying_weights)
k_winsum = weighted_window_sum(k_sum, window_length, decaying_weights)

numerator = tf.einsum("BTCHD,BTHDO->BTCHO", chunked_query_matrix, kp_v_winsum)

k_winsum = tf.squeeze(k_winsum, -3)
denominator = tf.einsum("BTCHD,BTHD->BTCH", chunked_query_matrix, k_winsum)
denominator = tf.expand_dims(denominator, -1) + _NUMERIC_STABLER

attention = numerator / denominator
Expand Down Expand Up @@ -351,7 +419,6 @@ def expplus(data_orig,
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = tf.expand_dims(diag_omega, axis=0)
diag_omega = a_coeff * diag_omega
#

if numerical_renormalizer:
if is_query:
Expand Down Expand Up @@ -454,6 +521,7 @@ def __init__(self,
use_causal_windowed=False,
causal_chunk_length=1,
causal_window_length=3,
causal_window_decay=None,
causal_padding=None,
**kwargs):
r"""Constructor of KernelAttention.
Expand Down Expand Up @@ -485,6 +553,9 @@ def __init__(self,
causal_windowed_performer_attention function docstring for more details.
causal_chunk_length: Length of each chunk in tokens.
causal_window_length: Length of attention window in chunks.
causal_window_decay: Float window decay factor or `None`. If set,
exponentially decay past attention window values by this
factor before summation.
causal_padding: Pad the query, value and key input tensors
across the axis from either left or right if padding is set to
"left" or "right"; apply no padding if padding is set to None.
Expand Down Expand Up @@ -524,6 +595,7 @@ def __init__(self,
self.use_causal_windowed = use_causal_windowed
self.causal_chunk_length = causal_chunk_length
self.causal_window_length = causal_window_length
self.causal_window_decay = causal_window_decay
self.causal_padding = causal_padding
if self.use_causal_windowed and self._is_short_seq:
raise ValueError(
Expand Down Expand Up @@ -608,6 +680,7 @@ def _compute_attention(self,
query_prime, key_prime, value,
chunk_length=self.causal_chunk_length,
window_length=self.causal_window_length,
window_decay=self.causal_window_decay,
padding=self.causal_padding)
else:
kv = tf.einsum("BSNH,BSND->BNDH", key_prime, value)
Expand Down
28 changes: 25 additions & 3 deletions official/nlp/modeling/layers/kernel_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ def test_attention_projection(
self.assertEqual(output.shape, [batch_size, seq_length, key_dim])

@parameterized.parameters(
itertools.product(_FEATURE_TRANSFORM, [127], _TRAINING, [True, False],
[0], [None, "left", "right"]))
itertools.product(["relu", "exp"], [127], _TRAINING, [True, False],
[0], [None, 0.97], [None, "left", "right"]))
def test_causal_windowed_attention_projection(
self, feature_transform, num_random_features, training, redraw,
begin_kernel, causal_padding):
begin_kernel, causal_window_decay, causal_padding):
num_heads = 12
key_dim = 64
seq_length = 1024
Expand All @@ -81,6 +81,7 @@ def test_causal_windowed_attention_projection(
use_causal_windowed=True,
causal_chunk_length=8,
causal_window_length=3,
causal_window_decay=causal_window_decay,
causal_padding=causal_padding)
query = tf.random.normal(
shape=(batch_size, seq_length, key_dim))
Expand Down Expand Up @@ -175,5 +176,26 @@ def test_config(self):
# If the serialization was successful, the new config should match the old.
self.assertAllEqual(test_layer.get_config(), new_layer.get_config())

def test_rectangular_window_sum(self):
x = tf.ones([2, 5, 2, 2, 2])
winsum = attention.rectangular_window_sum(x, 3)
self.assertEqual(winsum.shape, x.shape)
self.assertAllClose(
tf.tile(
tf.reshape([1., 2., 3., 3., 3.], [1, -1, 1, 1, 1]),
[2, 1, 2, 2, 2]),
winsum)

def test_weighted_window_sum(self):
x = tf.ones([2, 5, 2, 2, 2])
winsum = attention.weighted_window_sum(x, 3, [0.01, 0.1, 1.])
self.assertEqual(winsum.shape, x.shape)
self.assertAllClose(
tf.tile(
tf.reshape([1., 1.1, 1.11, 1.11, 1.11], [1, -1, 1, 1, 1]),
[2, 1, 2, 2, 2]),
winsum)


if __name__ == "__main__":
tf.test.main()

0 comments on commit 0028cbe

Please sign in to comment.