Skip to content

Commit

Permalink
[feat] Adding Metaformer support (facebookresearch#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored May 9, 2022
1 parent 4a09585 commit 90a1ec7
Show file tree
Hide file tree
Showing 14 changed files with 491 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- MLP benchmark
- Move all triton kernels to triton v2 [#272]
- Mem efficient attention, BW pass [#281]
- Metaformer support [#294]

## [0.0.10] - 2022-03-14
### Fixed
Expand Down
57 changes: 57 additions & 0 deletions HOWTO.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Let's present here a couple of code snippets on how to solve a couple of questio
- [Intro](#intro)
- [Transformer](#transformer)
- [In practice](#in-practice)
- [Hierarchical Transformers](#hierarchical-transformers)


## Understanding the dimension conventions

Expand Down Expand Up @@ -749,3 +751,58 @@ class xFormerStackConfig:
[2]: Kitaev, N., Kaiser, Ł., & Levskaya, A. (2020). Reformer: The Efficient Transformer.

[3]: Vaswani et al., Attention is all you need, 2017


### Hierarchical Transformers

The original Transformer proposal processes ("transforms") sequences of tokens, across possibly many layers. Crucially, the number of tokens is unchanged cross the depth of the model, and this prove to be really efficient in many domains.

It seems that some domains could however benefit from an architecture more typical from CNN, where there's a tradeoff across the depth of the model in between the spatial extent (ie: number of tokens) and their expressiveness (ie: the model or embedding dimension). These architectures are handled in xformers, through the "patch_embedding" element, which translates the sequence of tokens from one layer to another.

A small helper is provided to make it easier to generate matching configurations, as follows. We present in this example a truncated version of a small [Metaformer](https://arxiv.org/abs/2111.11418v1).

```python
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)


base_hierarchical_configs = [
BasicLayerConfig(
embedding=64, # the dimensions just have to match along the layers
attention_mechanism="scaled_dot_product", # anything you like
patch_size=7,
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
),
BasicLayerConfig(
embedding=128,
attention_mechanism="scaled_dot_product",
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
),
BasicLayerConfig(
embedding=320,
attention_mechanism="scaled_dot_product",
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
),
]

# Fill in the gaps in the config
xformer_config = get_hierarchical_configuration(
base_hierarchical_configs,
layernorm_style="pre",
use_rotary_embeddings=False,
mlp_multiplier=4,
dim_head=32,
)
config = xFormerConfig(xformer_config)
```
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ Patrick et al., 2021](https://arxiv.org/abs/2106.05392)*
2. transformer block benchmark
3. [LRA](xformers/benchmarks/LRA/README.md), with SLURM suppot
4. Programatic and sweep friendly layer and model construction
1. Compatible with hierarchical Transformers, like Swin or Metaformer
5. Hackable
1. Not using monolithic CUDA kernels, composable building blocks
2. Using [Triton](https://triton-lang.org/) for some optimized parts, explicit, pythonic and user-accessible
Expand Down
Binary file added docs/assets/metaformer.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
56 changes: 56 additions & 0 deletions docs/source/tutorials/hierarchical.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
Hierarchical Transformers
=========================

The original Transformer proposal processes ("transforms") sequences of tokens, across possibly many layers. Crucially, the number of tokens is unchanged cross the depth of the model, and this prove to be really efficient in many domains.

It seems that some domains could however benefit from an architecture more typical from CNN, where there's a tradeoff across the depth of the model in between the spatial extent (ie: number of tokens) and their expressiveness (ie: the model or embedding dimension). These architectures are handled in xformers, through the "patch_embedding" element, which translates the sequence of tokens from one layer to another.

A small helper is provided to make it easier to generate matching configurations, as follows. We present in this example a truncated version of a small Metaformer_.

.. _Metaformer: https://arxiv.org/abs/2111.11418v1

.. code-block:: python
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64, # the dimensions just have to match along the layers
attention_mechanism="scaled_dot_product", # anything you like
patch_size=7,
stride=4,
padding=2,
seq_len=image_size * image_size // 16,
),
BasicLayerConfig(
embedding=128,
attention_mechanism="scaled_dot_product",
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
),
BasicLayerConfig(
embedding=320,
attention_mechanism="scaled_dot_product",
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
),
]
# Fill in the gaps in the config
xformer_config = get_hierarchical_configuration(
base_hierarchical_configs,
layernorm_style="pre",
use_rotary_embeddings=False,
mlp_multiplier=4,
dim_head=32,
)
config = xFormerConfig(xformer_config)
1 change: 1 addition & 0 deletions docs/source/tutorials/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ Tutorials
pytorch_encoder
reversible
triton
hierarchical
7 changes: 7 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,10 @@ If your current machine does not expose enough RAM and the example reports an `O
This is meant to be an easy introduction to using xformers in practice, mirroring closely [this Pytorch Lightning](https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html) tutorial. The default settings are close to this tutorial, which trains a 11M parameters ResNet on the CIFAR dataset, we train a 10.6M ViT on the same dataset. The ViT configuration is not optimal for CIFAR, since the pictures have a very small size to begin with and information is probably lost given the patches. Nevertheless you should be able to reach about 80% accuracy within about an hour on a single GPU.

![Example curves](../docs/assets/microViT.png)


### MicroMetaformer

This is very close to the MicroViT example above, but illustrating the use of a hierarchical Transformer ([Metaformer](https://arxiv.org/pdf/2111.11418.pdf)) this time, through a helper function which generates the required configuration given the pooling parameters. The suggested configuration is about 6.6M parameters big (half of a ResNet18) and trains to about 86% top-1 Cifar10 within minutes.

![Example curves](../docs/assets/metaformer.png)
184 changes: 184 additions & 0 deletions examples/cifarMetaformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.


import pytorch_lightning as pl
import torch
from pl_bolts.datamodules import CIFAR10DataModule
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization
from torch import nn
from torchmetrics import Accuracy
from torchvision import transforms

from examples.microViT import Classifier, VisionTransformer
from xformers.factory import xFormer, xFormerConfig
from xformers.helpers.hierarchical_configs import (
BasicLayerConfig,
get_hierarchical_configuration,
)


class MetaVisionTransformer(VisionTransformer):
def __init__(
self,
steps,
learning_rate=5e-3,
betas=(0.9, 0.99),
weight_decay=0.03,
image_size=32,
num_classes=10,
dim=384,
attention="scaled_dot_product",
layer_norm_style="pre",
use_rotary_embeddings=True,
linear_warmup_ratio=0.1,
classifier=Classifier.GAP,
):

super(VisionTransformer, self).__init__()

# all the inputs are saved under self.hparams (hyperparams)
self.save_hyperparameters()

# Generate the skeleton of our hierarchical Transformer

# This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32)
# Any other related config would work,
# and the attention mechanisms don't have to be the same across layers
base_hierarchical_configs = [
BasicLayerConfig(
embedding=64,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 4,
),
BasicLayerConfig(
embedding=128,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 16,
),
BasicLayerConfig(
embedding=320,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 64,
),
BasicLayerConfig(
embedding=512,
attention_mechanism=attention,
patch_size=3,
stride=2,
padding=1,
seq_len=image_size * image_size // 256,
),
]

# Fill in the gaps in the config
xformer_config = get_hierarchical_configuration(
base_hierarchical_configs,
layernorm_style=layer_norm_style,
use_rotary_embeddings=use_rotary_embeddings,
mlp_multiplier=4,
dim_head=32,
)

# Now instantiate the metaformer trunk
config = xFormerConfig(xformer_config)
print(config)
self.trunk = xFormer.from_config(config)
print(self.trunk)

# The classifier head
dim = base_hierarchical_configs[-1].embedding
self.ln = nn.LayerNorm(dim)
self.head = nn.Linear(dim, num_classes)
self.criterion = torch.nn.CrossEntropyLoss()
self.val_accuracy = Accuracy()

def forward(self, x):
x = self.trunk(x)
x = self.ln(x)

if self.hparams.classifier == Classifier.TOKEN:
x = x[:, 0] # only consider the token, we're classifying anyway
elif self.hparams.classifier == Classifier.GAP:
x = x.mean(dim=1) # mean over sequence len

x = self.head(x)
return x


if __name__ == "__main__":
pl.seed_everything(42)

# Adjust batch depending on the available memory on your machine.
# You can also use reversible layers to save memory
REF_BATCH = 512
BATCH = 512 # lower if not enough GPU memory

MAX_EPOCHS = 50
NUM_WORKERS = 4
GPUS = 1

train_transforms = transforms.Compose(
[
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
cifar10_normalization(),
]
)

test_transforms = transforms.Compose(
[
transforms.ToTensor(),
cifar10_normalization(),
]
)

# We'll use a datamodule here, which already handles dataset/dataloader/sampler
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html
# for a full tutorial
dm = CIFAR10DataModule(
data_dir="data",
batch_size=BATCH,
num_workers=NUM_WORKERS,
pin_memory=True,
)
dm.train_transforms = train_transforms
dm.test_transforms = test_transforms
dm.val_transforms = test_transforms

image_size = dm.size(-1) # 32 for CIFAR
num_classes = dm.num_classes # 10 for CIFAR

# compute total number of steps
batch_size = BATCH * GPUS
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS
lm = MetaVisionTransformer(
steps=steps,
image_size=image_size,
num_classes=num_classes,
attention="scaled_dot_product",
layer_norm_style="pre",
use_rotary_embeddings=True,
)
trainer = pl.Trainer(
gpus=GPUS,
max_epochs=MAX_EPOCHS,
precision=16,
accumulate_grad_batches=REF_BATCH // BATCH,
)
trainer.fit(lm, dm)

# check the training
trainer.test(lm, datamodule=dm)
2 changes: 1 addition & 1 deletion examples/microViT.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def training_step(self, batch, _):
"train_loss": loss.mean(),
"learning_rate": self.lr_schedulers().get_last_lr()[0],
},
step=trainer.global_step,
step=self.global_step,
)

return loss
Expand Down
7 changes: 7 additions & 0 deletions tests/test_attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import math
from typing import Tuple

import pytest
Expand Down Expand Up @@ -96,6 +97,9 @@ def test_order_invariance(
device: torch.device,
):

if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "poolling":
pytest.skip(f"{attention_name} requires squared sequence lengths")

torch.manual_seed(42)

multi_head = _get_multihead(
Expand Down Expand Up @@ -282,6 +286,9 @@ def test_broadcast_batch_dimension(
device: torch.device,
batch_sizes: Tuple[int, int, int],
):
if int(math.sqrt(SEQ)) ** 2 != SEQ and attention_name == "poolling":
pytest.skip(f"{attention_name} requires squared sequence lengths")

Q_BATCH, K_BATCH, V_BATCH = batch_sizes
multi_head = _get_multihead(attention_name, 0.0, 0.0, False, heads, device)

Expand Down
Loading

0 comments on commit 90a1ec7

Please sign in to comment.