Skip to content

Commit

Permalink
release h-transformers-1d for non-causal case only
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jul 29, 2021
1 parent 629723d commit d778f10
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 45 deletions.
32 changes: 31 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,39 @@
<img src="./h-transformer.png" width="300px"></img>

## H-Transformer-1D (wip)
## H-Transformer-1D

Implementation of <a href="https://arxiv.org/abs/2107.11906">H-Transformer-1D</a>, Transformer using hierarchical Attention for sequence learning with subquadratic costs.

For now, the H-Transformer will only act as a long-context encoder

## Install

```bash
$ pip install h-transformer-1d
```

## Usage

```python
import torch
from h_transformer_1d import HTransformer1D

model = HTransformer1D(
num_tokens = 256, # number of tokens
dim = 512, # dimension
depth = 2, # depth
max_seq_len = 1024, # maximum sequence length
heads = 8, # heads
dim_head = 64, # dimension per head
block_size = 128 # block size
)

x = torch.randint(0, 256, (1, 1024))
mask = torch.ones((1, 1024)).bool()

logits = model(x, mask = mask) # (1, 1024, 256)
```

## Citations

```bibtex
Expand Down
105 changes: 62 additions & 43 deletions h_transformer_1d/h_transformer_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,24 @@
def exists(val):
return val is not None

def masked_aggregate(tensor, mask = None, dim = -1, average = True):
if not exists(mask):
fn = torch.sum if not average else torch.mean
return fn(tensor, dim = dim)

diff_len = len(tensor.shape) - len(mask.shape)
mask = mask[(..., *((None,) * diff_len))]
tensor = tensor.masked_fill(~mask, 0.)

total_el = mask.sum(dim = dim)
agg = tensor.sum(dim = dim)

if average:
agg = agg / total_el.clamp(min = 1.)

agg.masked_fill_(total_el == 0, 0.)
return agg

# helper classes

class PreNorm(nn.Module):
Expand Down Expand Up @@ -44,15 +62,13 @@ def __init__(
self,
dim,
*,
causal = False,
heads = 8,
dim_head = 64,
block_size = 16,
eps = 1e-8
):
super().__init__()
self.eps = eps
self.causal = causal
self.heads = heads
self.scale = dim_head ** -0.5
self.block_size = block_size
Expand All @@ -62,7 +78,7 @@ def __init__(
self.to_out = nn.Linear(inner_dim, dim)

def forward(self, x, mask = None):
b, n, h, device, bsz, causal, eps = *x.shape[:2], self.heads, x.device, self.block_size, self.causal, self.eps
b, n, h, device, bsz, eps = *x.shape[:2], self.heads, x.device, self.block_size, self.eps

# derive queries, keys, values

Expand All @@ -82,32 +98,33 @@ def forward(self, x, mask = None):

# coarsening

coarsened_qkvs = [(q, k, v)]
qkvs = [(q, k, v, mask)]

for level in range(num_levels):
q = reduce(q, 'b (n r) d -> b n d', 'mean', r = 2)
k = reduce(k, 'b (n r) d -> b n d', 'mean', r = 2)
v = reduce(v, 'b (n r) d -> b n d', 'sum', r = 2)
q, k, v = map(lambda t: rearrange(t, 'b (n r) d -> b n r d', r = 2), (q, k, v))

coarsened_qkvs.append((q, k, v))
if exists(mask):
mask = rearrange(mask, 'b (n r) -> b n r', r = 2)

*coarsened_qkvs, top_level_qkvs = reversed(coarsened_qkvs)
# masked mean for queries and keys, but not values

# half-attention function
q = masked_aggregate(q, mask, dim = 2)
k = masked_aggregate(k, mask, dim = 2)
v = masked_aggregate(v, mask, dim = 2, average = False)

def calculate_Y_and_A(q, k, v, mask_A = False, remove_right_off_diagonals = False):
if remove_right_off_diagonals:
q, k, v = map(lambda t: rearrange(t, 'b (n r) z d -> b n r z d', r = 2), (q, k, v))
q, k, v = map(lambda t: t[:, :, 1], (q, k, v))
if exists(mask):
mask = torch.any(mask, dim = 2)

coarsened_qkvs = (q, k, v, mask)
qkvs.append(coarsened_qkvs)

# half-attention function

def calculate_Y_and_A(q, k, v, mask = None):
S = einsum('... i d, ... j d -> ... i j', q, k)

if mask_A:
device = q.device
n = S.shape[-1]
if exists(mask):
mask_value = -torch.finfo(S.dtype).max
mask = torch.ones((n, n), device = device).triu(1).bool()
mask = rearrange(mask, 'i j -> () () i j')
S = S.masked_fill(mask, mask_value)

S = S - torch.amax(S, dim = -1, keepdim = True)
Expand All @@ -117,47 +134,50 @@ def calculate_Y_and_A(q, k, v, mask_A = False, remove_right_off_diagonals = Fals

A = A.sum(dim = -1)

if remove_right_off_diagonals:
y = rearrange(y, 'b n z d -> b n () z d')
y = F.pad(y, (0, 0, 0, 0, 1, 0), value = 0.)
y = rearrange(y, 'b n r z d -> b (n r) z d')

A = rearrange(A, 'b n z -> b n () z')
A = F.pad(A, (0, 0, 1, 0), value = 0.)
A = rearrange(A, 'b n r z -> b (n r) z')

y = rearrange(y, 'b ... n d -> b (... n) d')
A = rearrange(A, 'b ... i -> b (... i)')
return y, A

def flip_every_two(t):
t = rearrange(t, 'b (n r) ... -> b n r ...', r = 2)
t = torch.flip(t, dims = (2,)) # so we pay attention to the off-diagonal blocks in the attention matrix
t = rearrange(t, 'b n r ... -> b (n r) ...')
return t

to_blocks = lambda t: rearrange(t, 'b (n z) ... -> b n z ...', z = bsz)

# calculate Ys, as in the paper

to_blocks = lambda t: rearrange(t, 'b (n z) d -> b n z d', z = bsz)
Ys = []

for ind, (q, k, v) in enumerate(coarsened_qkvs):
for ind, (q, k, v, mask) in enumerate(reversed(qkvs)):
is_last = ind == (len(qkvs) - 1)

q, k, v = map(to_blocks, (q, k, v))

k = rearrange(k, 'b (n r) z d -> b n r z d', r = 2)
k = torch.flip(k, dims = (2,)) # so we pay attention to the off-diagonal blocks in the attention matrix
k = rearrange(k, 'b n r z d -> b (n r) z d')
# generate the mask for S

S_mask = None
if exists(mask):
mask = to_blocks(mask)
q_mask = mask
k_mask = flip_every_two(mask) if is_last else mask
S_mask = rearrange(q_mask, '... n -> ... n ()') * rearrange(k_mask, '... n -> ... () n')

v = rearrange(v, 'b (n r) z d -> b n r z d', r = 2)
v = torch.flip(v, dims = (2,))
v = rearrange(v, 'b n r z d -> b (n r) z d')
# flip keys and values to capture the off-diagonals

coarsened_Y = calculate_Y_and_A(q, k, v, remove_right_off_diagonals = causal)
Ys.append(coarsened_Y)
if not is_last:
k, v = map(flip_every_two, (k, v))

top_level_Y = calculate_Y_and_A(*map(to_blocks, top_level_qkvs), mask_A = causal)
Ys.append(top_level_Y)
Y_level = calculate_Y_and_A(q, k, v, mask = S_mask)
Ys.append(Y_level)

# interpolate

Y = 0
A = 0

for Y_level, A_level in Ys:
for Y_level, A_level in Ys[-2:]:
if torch.is_tensor(Y):
Y = repeat(Y, 'b n d -> b (n r) d', r = 2)

Expand Down Expand Up @@ -189,7 +209,6 @@ def __init__(
max_seq_len,
heads = 8,
dim_head = 64,
causal = False,
ff_mult = 4,
block_size = 128 # this is the Nr in the paper - Nb = (max_seq_len / tokens_per_block)
):
Expand All @@ -202,7 +221,7 @@ def __init__(

for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, HAttention1D(dim, causal = causal, dim_head = dim_head, heads = heads, block_size = block_size)),
PreNorm(dim, HAttention1D(dim, dim_head = dim_head, heads = heads, block_size = block_size)),
PreNorm(dim, FeedForward(dim, mult = ff_mult))
]))

Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
],
install_requires=[
'einops>=0.3',
'rotary-embedding-torch',
'torch>=1.6'
],
classifiers=[
Expand Down

0 comments on commit d778f10

Please sign in to comment.