Standalone Product Key Memory module for augmenting Transformer models
$ pip install product-key-memory
Replace the feedforwards in a Transformer with the following
import torch
from product_key_memory import PKM
pkm = PKM(
dim = 512,
heads = 8,
num_keys = 512, # number of subkeys, # values will be num_keys ^ 2
topk = 10, # the top number of subkeys to select
share_kv = False # share key/values across heads
)
x = torch.randn(1, 1024, 512)
values = pkm(x) # (1, 1024, 512)
To give different learning rates to the value parameters of the product-key-memory network, use the following helper function.
from torch.optim import Adam
from product_key_memory import fetch_pkm_value_parameters
pkm_parameters = fetch_pkm_value_parameters(model)
optim = Adam([
{'params': model.parameters()},
{'params': pkm_parameters, 'lr': 1e-2}
], lr=1e-3)
Special thanks go to Aran for encouraging me to look into this, and to Madison May for his educational blog post, which helped me understand this better.
@misc{lample2019large,
title = {Large Memory Layers with Product Keys},
author = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
year = {2019},
eprint = {1907.05242},
archivePrefix = {arXiv}
}