Skip to content

Commit

Permalink
Open source vision transformers (facebookresearch#646)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#646

Open source Vision Transformers from https://arxiv.org/abs/2010.11929

Reviewed By: vreis

Differential Revision: D24840754

fbshipit-source-id: b5bbe1fd77aca2730c36edd472fa52d2bedf4b61
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Nov 10, 2020
1 parent 1ba6b03 commit 7c58f6d
Show file tree
Hide file tree
Showing 6 changed files with 617 additions and 2 deletions.
2 changes: 2 additions & 0 deletions classy_vision/heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,15 @@ def build_head(config):
from .fully_connected_head import FullyConnectedHead # isort:skip
from .fully_convolutional_linear_head import FullyConvolutionalLinearHead # isort:skip
from .identity_head import IdentityHead # isort:skip
from .vision_transformer_head import VisionTransformerHead # isort:skip


__all__ = [
"ClassyHead",
"FullyConnectedHead",
"FullyConvolutionalLinearHead",
"IdentityHead",
"VisionTransformerHead",
"build_head",
"register_head",
]
60 changes: 60 additions & 0 deletions classy_vision/heads/vision_transformer_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Vision Transformer head implementation from https://arxiv.org/abs/2010.11929.
References:
https://github.com/google-research/vision_transformer
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
"""

import copy
from collections import OrderedDict

import torch.nn as nn
from classy_vision.heads import ClassyHead, register_head

from ..models.lecun_normal_init import lecun_normal_init


@register_head("vision_transformer_head")
class VisionTransformerHead(ClassyHead):
def __init__(
self,
in_plane,
num_classes,
hidden_dim=None,
):
super().__init__()
if hidden_dim is None:
layers = [("head", nn.Linear(in_plane, num_classes))]
else:
layers = [
("pre_logits", nn.Linear(in_plane, hidden_dim)),
("act", nn.Tanh()),
("head", nn.Linear(hidden_dim, num_classes)),
]
self.layers = nn.Sequential(OrderedDict(layers))
self.init_weights()

def init_weights(self):
if hasattr(self.layers, "pre_logits"):
lecun_normal_init(
self.layers.pre_logits.weight, fan_in=self.layers.pre_logits.in_features
)
nn.init.zeros_(self.layers.pre_logits.bias)
nn.init.zeros_(self.layers.head.weight)
nn.init.zeros_(self.layers.head.bias)

@classmethod
def from_config(cls, config):
config = copy.deepcopy(config)
config.pop("unique_id")
return cls(**config)

def forward(self, x):
return self.layers(x)
8 changes: 6 additions & 2 deletions classy_vision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,17 +96,17 @@ def build_model(config):
) # isort:skip
from .densenet import DenseNet # isort:skip
from .efficientnet import EfficientNet # isort:skip
from .lecun_normal_init import lecun_normal_init # isort:skip
from .mlp import MLP # isort:skip
from .regnet import RegNet # isort:skip
from .resnet import ResNet # isort:skip
from .resnext import ResNeXt # isort:skip
from .resnext3d import ResNeXt3D # isort:skip
from .squeeze_and_excitation_layer import SqueezeAndExcitationLayer # isort:skip
from .vision_transformer import VisionTransformer # isort:skip


__all__ = [
"build_model",
"register_model",
"ClassyBlock",
"ClassyModel",
"ClassyModelHeadExecutorWrapper",
Expand All @@ -119,4 +119,8 @@ def build_model(config):
"ResNeXt",
"ResNeXt3D",
"SqueezeAndExcitationLayer",
"VisionTransformer",
"build_model",
"lecun_normal_init",
"register_model",
]
11 changes: 11 additions & 0 deletions classy_vision/models/lecun_normal_init.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch.nn as nn
import math


def lecun_normal_init(tensor, fan_in):
nn.init.trunc_normal_(tensor, std=math.sqrt(1 / fan_in))
Loading

0 comments on commit 7c58f6d

Please sign in to comment.