Skip to content

Commit

Permalink
cache value to avoid tensor conversion in torch (keras-team#654)
Browse files Browse the repository at this point in the history
Co-authored-by: Haifeng Jin <[email protected]>
  • Loading branch information
2 people authored and fchollet committed Aug 1, 2023
1 parent 871f03e commit 710c933
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion keras_core/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np

from keras_core import backend
from keras_core import constraints
from keras_core import initializers
from keras_core import ops
Expand Down Expand Up @@ -115,6 +116,8 @@ def __init__(
self.supports_masking = True
self._num_heads = num_heads
self._key_dim = key_dim
# Cache 1.0 / math.sqrt(self._key_dim).
self._inverse_sqrt_key_dim = None
self._value_dim = value_dim if value_dim else key_dim
self._dropout = dropout
self._use_bias = use_bias
Expand Down Expand Up @@ -311,6 +314,9 @@ def _build_attention(self, rank):
)
self._softmax = Softmax(axis=norm_axes)
self._dropout_layer = Dropout(rate=self._dropout)
self._inverse_sqrt_key_dim = backend.convert_to_tensor(
1.0 / math.sqrt(float(self._key_dim))
)

def _masked_softmax(self, attention_scores, attention_mask=None):
# Normalize the attention scores to probabilities.
Expand Down Expand Up @@ -355,7 +361,7 @@ def _compute_attention(
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
query = ops.multiply(query, 1.0 / math.sqrt(float(self._key_dim)))
query = ops.multiply(query, self._inverse_sqrt_key_dim)

# Take the dot product between "query" and "key" to get the raw
# attention scores.
Expand Down

0 comments on commit 710c933

Please sign in to comment.