Skip to content

Commit

Permalink
offer an option to concat the values across heads (with head dimensio…
Browse files Browse the repository at this point in the history
…n) and then project out, like multi-head attention
  • Loading branch information
lucidrains committed Jun 30, 2023
1 parent f0be2db commit ddcb10f
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 6 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ Special thanks go to <a href="https://github.com/AranKomat">Aran</a> for encoura
## Todo

- [x] offer stochasticity with annealed gumbel noise. seen dramatic effects in vector-quantization setting
- [x] offer a way for smaller value dimensions + concat and linear combination of heads (like multi-head attention)

- [ ] get caught up on latest literature on product key memories, if any
- [ ] offer a way for smaller value dimensions + concat and linear combination of heads (like multi-head attention)
- [ ] instead of additive scores, try multiplicative using coordinate descent routing

## Citations
Expand Down
37 changes: 32 additions & 5 deletions product_key_memory/product_key_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from torch import nn, einsum

from einops import rearrange
from einops.layers.torch import Rearrange, Reduce

from colt5_attention import topk as coor_descent_topk

Expand Down Expand Up @@ -83,7 +84,8 @@ def __init__(
attn_dropout = 0.,
use_layernorm = True,
pre_layernorm = False,
differentiable_topk = False
differentiable_topk = False,
concat_values_and_combine = False
):
super().__init__()
self.topk = topk
Expand All @@ -106,10 +108,32 @@ def __init__(
else:
self.norm = MaskedBatchNorm1D(nn.BatchNorm1d(dim_head))

# keys

self.keys = nn.Parameter(torch.zeros(heads, num_keys, 2, dim_head))
self.values = nn.EmbeddingBag(num_keys ** 2, dim, mode = 'sum')
init_(self.keys)
init_(self.values.weight)

# values

self.concat_values_and_combine = concat_values_and_combine

if concat_values_and_combine:
values = nn.Embedding(num_keys ** 2, dim_head)

self.values = nn.Sequential(
values,
Reduce('b (h k) d -> b h d', 'sum', h = heads),
Rearrange('b n d -> b (n d)'),
nn.Linear(dim_head * heads, dim, bias = False)
)
else:
values = nn.EmbeddingBag(num_keys ** 2, dim, mode = 'sum')
self.values = values


init_(values.weight)

# dropouts

self.input_dropout = nn.Dropout(input_dropout)
self.query_dropout = nn.Dropout(query_dropout)
Expand Down Expand Up @@ -192,7 +216,10 @@ def forward(

# aggregate

out = self.values(value_indices, per_sample_weights=attn)
out = self.value_dropout(out)
if self.concat_values_and_combine:
out = self.values(value_indices)
else:
out = self.values(value_indices, per_sample_weights = attn)

out = self.value_dropout(out)
return rearrange(out, '(b t) d -> b t d', b = b)

0 comments on commit ddcb10f

Please sign in to comment.