Skip to content

Latest commit

 

History

History
61 lines (44 loc) · 1.61 KB

README.md

File metadata and controls

61 lines (44 loc) · 1.61 KB

Product Key Memory

Standalone Product Key Memory module for augmenting Transformer models

Install

$ pip install product-key-memory

Usage

Replace the feedforwards in a Transformer with the following

import torch
from product_key_memory import PKM

pkm = PKM(
    dim = 512,
    heads = 8,
    num_keys = 512,       # number of subkeys, # values will be num_keys ^ 2
    topk = 10,            # the top number of subkeys to select
    share_kv = False      # share key/values across heads
)

x = torch.randn(1, 1024, 512)
values = pkm(x) # (1, 1024, 512)

Learning Rates

To give different learning rates to the value parameters of the product-key-memory network, use the following helper function.

from torch.optim import Adam
from product_key_memory import fetch_pkm_value_parameters

pkm_parameters = fetch_pkm_value_parameters(model)

optim = Adam([
    {'params': model.parameters()},
    {'params': pkm_parameters, 'lr': 1e-2}
], lr=1e-3)

Appreciation

Special thanks go to Aran for encouraging me to look into this, and to Madison May for his educational blog post, which helped me understand this better.

Citations

@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}