Skip to content

Commit

Permalink
Merge pull request meta-llama#900 from flu0r1ne/main
Browse files Browse the repository at this point in the history
Fix key-value caching for seqlen != 1 (Issue meta-llama#899)
  • Loading branch information
ruanslv authored Nov 14, 2023
2 parents 4835a30 + cd0719d commit ef351e9
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,12 +289,12 @@ def forward(
values = self.cache_v[:bsz, : start_pos + seqlen]

# repeat k/v heads if n_kv_heads < n_heads
keys = repeat_kv(keys, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
keys = repeat_kv(keys, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)
values = repeat_kv(values, self.n_rep) # (bs, cache_len + seqlen, n_local_heads, head_dim)

xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores + mask # (bs, n_local_heads, seqlen, cache_len + seqlen)
Expand Down Expand Up @@ -474,9 +474,19 @@ def forward(self, tokens: torch.Tensor, start_pos: int):
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=tokens.device
(seqlen, seqlen), float("-inf"), device=tokens.device
)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(h)

mask = torch.triu(mask, diagonal=1)

# When performing key-value caching, we compute the attention scores
# only for the new sequence. Thus, the matrix of scores is of size
# (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
# j > cache_len + i, since row i corresponds to token cache_len + i.
mask = torch.hstack([
torch.zeros((seqlen, start_pos), device=tokens.device),
mask
]).type_as(h)

for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
Expand Down

0 comments on commit ef351e9

Please sign in to comment.