Skip to content

Commit

Permalink
release 1 conformer block
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 4, 2021
1 parent 9453375 commit a70f5d9
Show file tree
Hide file tree
Showing 4 changed files with 149 additions and 4 deletions.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,27 @@ x = torch.randn(1, 1024, 512)
x = layer(x) + x
```

1 Conformer Block

```python
import torch
from conformer import ConformerBlock

block = ConformerBlock(
dim = 512,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.
)

x = torch.randn(1, 1024, 512)
block(x) # (1, 1024, 512)
```
## Citations

```bibtex
Expand Down
2 changes: 1 addition & 1 deletion conformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from conformer.conformer import ConformerConvModule
from conformer.conformer import ConformerConvModule, ConformerBlock
127 changes: 125 additions & 2 deletions conformer/conformer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,17 @@
import torch
from torch import nn
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange

# helper functions

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

def calc_same_padding(kernel_size):
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
Expand Down Expand Up @@ -42,7 +50,86 @@ def forward(self, x):
x = F.pad(x, self.padding)
return self.conv(x)

# main class
# attention, feedforward, and conv module

class Scale(nn.Module):
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale

def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)

def forward(self, x, **kwargs):
x = self.norm(x)
return self.fn(x, **kwargs)

class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
self.heads= heads
self.scale = dim_head ** -0.5
self.to_q = nn.Linear(dim, inner_dim, bias = False)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

self.dropout = nn.Dropout(dropout)

def forward(self, x, context = None, mask = None, context_mask = None):
device, h, has_context = x.device, self.heads, exists(context)
context = default(context, x)

q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))

dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

if exists(mask) or exists(context_mask):
mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))
mask_value = -torch.finfo(dots.dtype).max
mask = mask[:, None, :, None] * context_mask[:, None, None, :]
dots.masked_fill_(~mask, mask_value)

attn = dots.softmax(dim = -1)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class FeedForward(nn.Module):
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult),
Swish(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
nn.Dropout(dropout)
)

def forward(self, x):
return self.net(x)

class ConformerConvModule(nn.Module):
def __init__(
Expand Down Expand Up @@ -72,3 +159,39 @@ def __init__(

def forward(self, x):
return self.net(x)

# Conformer Block

class ConformerBlock(nn.Module):
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.
):
super().__init__()
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
self.conv = ConformerConvModule(dim = dim, causal = False, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)

self.attn = PreNorm(dim, self.attn)
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))

self.post_norm = nn.LayerNorm(dim)

def forward(self, x, mask = None):
x = self.ff1(x) + x
x = self.attn(x, mask = mask) + x
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
setup(
name = 'conformer',
packages = find_packages(),
version = '0.1.0',
version = '0.2.0',
license='MIT',
description = 'The convolutional module from the Conformer paper',
author = 'Phil Wang',
author_email = '[email protected]',
url = 'https://github.com/lucidrains/conformer',
keywords = ['transformers', 'artificial intelligence', 'transformer'],
install_requires=[
'einops',
'torch'
],
classifiers=[
Expand Down

0 comments on commit a70f5d9

Please sign in to comment.