forked from facebookresearch/xformers
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[feat] Adding Metaformer support (facebookresearch#294)
- Loading branch information
1 parent
4a09585
commit 90a1ec7
Showing
14 changed files
with
491 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
Hierarchical Transformers | ||
========================= | ||
|
||
The original Transformer proposal processes ("transforms") sequences of tokens, across possibly many layers. Crucially, the number of tokens is unchanged cross the depth of the model, and this prove to be really efficient in many domains. | ||
|
||
It seems that some domains could however benefit from an architecture more typical from CNN, where there's a tradeoff across the depth of the model in between the spatial extent (ie: number of tokens) and their expressiveness (ie: the model or embedding dimension). These architectures are handled in xformers, through the "patch_embedding" element, which translates the sequence of tokens from one layer to another. | ||
|
||
A small helper is provided to make it easier to generate matching configurations, as follows. We present in this example a truncated version of a small Metaformer_. | ||
|
||
.. _Metaformer: https://arxiv.org/abs/2111.11418v1 | ||
|
||
.. code-block:: python | ||
from xformers.factory import xFormer, xFormerConfig | ||
from xformers.helpers.hierarchical_configs import ( | ||
BasicLayerConfig, | ||
get_hierarchical_configuration, | ||
) | ||
base_hierarchical_configs = [ | ||
BasicLayerConfig( | ||
embedding=64, # the dimensions just have to match along the layers | ||
attention_mechanism="scaled_dot_product", # anything you like | ||
patch_size=7, | ||
stride=4, | ||
padding=2, | ||
seq_len=image_size * image_size // 16, | ||
), | ||
BasicLayerConfig( | ||
embedding=128, | ||
attention_mechanism="scaled_dot_product", | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 64, | ||
), | ||
BasicLayerConfig( | ||
embedding=320, | ||
attention_mechanism="scaled_dot_product", | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 256, | ||
), | ||
] | ||
# Fill in the gaps in the config | ||
xformer_config = get_hierarchical_configuration( | ||
base_hierarchical_configs, | ||
layernorm_style="pre", | ||
use_rotary_embeddings=False, | ||
mlp_multiplier=4, | ||
dim_head=32, | ||
) | ||
config = xFormerConfig(xformer_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ Tutorials | |
pytorch_encoder | ||
reversible | ||
triton | ||
hierarchical |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import pytorch_lightning as pl | ||
import torch | ||
from pl_bolts.datamodules import CIFAR10DataModule | ||
from pl_bolts.transforms.dataset_normalizations import cifar10_normalization | ||
from torch import nn | ||
from torchmetrics import Accuracy | ||
from torchvision import transforms | ||
|
||
from examples.microViT import Classifier, VisionTransformer | ||
from xformers.factory import xFormer, xFormerConfig | ||
from xformers.helpers.hierarchical_configs import ( | ||
BasicLayerConfig, | ||
get_hierarchical_configuration, | ||
) | ||
|
||
|
||
class MetaVisionTransformer(VisionTransformer): | ||
def __init__( | ||
self, | ||
steps, | ||
learning_rate=5e-3, | ||
betas=(0.9, 0.99), | ||
weight_decay=0.03, | ||
image_size=32, | ||
num_classes=10, | ||
dim=384, | ||
attention="scaled_dot_product", | ||
layer_norm_style="pre", | ||
use_rotary_embeddings=True, | ||
linear_warmup_ratio=0.1, | ||
classifier=Classifier.GAP, | ||
): | ||
|
||
super(VisionTransformer, self).__init__() | ||
|
||
# all the inputs are saved under self.hparams (hyperparams) | ||
self.save_hyperparameters() | ||
|
||
# Generate the skeleton of our hierarchical Transformer | ||
|
||
# This is a small poolformer configuration, adapted to the small CIFAR10 pictures (32x32) | ||
# Any other related config would work, | ||
# and the attention mechanisms don't have to be the same across layers | ||
base_hierarchical_configs = [ | ||
BasicLayerConfig( | ||
embedding=64, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 4, | ||
), | ||
BasicLayerConfig( | ||
embedding=128, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 16, | ||
), | ||
BasicLayerConfig( | ||
embedding=320, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 64, | ||
), | ||
BasicLayerConfig( | ||
embedding=512, | ||
attention_mechanism=attention, | ||
patch_size=3, | ||
stride=2, | ||
padding=1, | ||
seq_len=image_size * image_size // 256, | ||
), | ||
] | ||
|
||
# Fill in the gaps in the config | ||
xformer_config = get_hierarchical_configuration( | ||
base_hierarchical_configs, | ||
layernorm_style=layer_norm_style, | ||
use_rotary_embeddings=use_rotary_embeddings, | ||
mlp_multiplier=4, | ||
dim_head=32, | ||
) | ||
|
||
# Now instantiate the metaformer trunk | ||
config = xFormerConfig(xformer_config) | ||
print(config) | ||
self.trunk = xFormer.from_config(config) | ||
print(self.trunk) | ||
|
||
# The classifier head | ||
dim = base_hierarchical_configs[-1].embedding | ||
self.ln = nn.LayerNorm(dim) | ||
self.head = nn.Linear(dim, num_classes) | ||
self.criterion = torch.nn.CrossEntropyLoss() | ||
self.val_accuracy = Accuracy() | ||
|
||
def forward(self, x): | ||
x = self.trunk(x) | ||
x = self.ln(x) | ||
|
||
if self.hparams.classifier == Classifier.TOKEN: | ||
x = x[:, 0] # only consider the token, we're classifying anyway | ||
elif self.hparams.classifier == Classifier.GAP: | ||
x = x.mean(dim=1) # mean over sequence len | ||
|
||
x = self.head(x) | ||
return x | ||
|
||
|
||
if __name__ == "__main__": | ||
pl.seed_everything(42) | ||
|
||
# Adjust batch depending on the available memory on your machine. | ||
# You can also use reversible layers to save memory | ||
REF_BATCH = 512 | ||
BATCH = 512 # lower if not enough GPU memory | ||
|
||
MAX_EPOCHS = 50 | ||
NUM_WORKERS = 4 | ||
GPUS = 1 | ||
|
||
train_transforms = transforms.Compose( | ||
[ | ||
transforms.RandomCrop(32, padding=4), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.ToTensor(), | ||
cifar10_normalization(), | ||
] | ||
) | ||
|
||
test_transforms = transforms.Compose( | ||
[ | ||
transforms.ToTensor(), | ||
cifar10_normalization(), | ||
] | ||
) | ||
|
||
# We'll use a datamodule here, which already handles dataset/dataloader/sampler | ||
# See https://pytorchlightning.github.io/lightning-tutorials/notebooks/lightning_examples/cifar10-baseline.html | ||
# for a full tutorial | ||
dm = CIFAR10DataModule( | ||
data_dir="data", | ||
batch_size=BATCH, | ||
num_workers=NUM_WORKERS, | ||
pin_memory=True, | ||
) | ||
dm.train_transforms = train_transforms | ||
dm.test_transforms = test_transforms | ||
dm.val_transforms = test_transforms | ||
|
||
image_size = dm.size(-1) # 32 for CIFAR | ||
num_classes = dm.num_classes # 10 for CIFAR | ||
|
||
# compute total number of steps | ||
batch_size = BATCH * GPUS | ||
steps = dm.num_samples // REF_BATCH * MAX_EPOCHS | ||
lm = MetaVisionTransformer( | ||
steps=steps, | ||
image_size=image_size, | ||
num_classes=num_classes, | ||
attention="scaled_dot_product", | ||
layer_norm_style="pre", | ||
use_rotary_embeddings=True, | ||
) | ||
trainer = pl.Trainer( | ||
gpus=GPUS, | ||
max_epochs=MAX_EPOCHS, | ||
precision=16, | ||
accumulate_grad_batches=REF_BATCH // BATCH, | ||
) | ||
trainer.fit(lm, dm) | ||
|
||
# check the training | ||
trainer.test(lm, datamodule=dm) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.