This project is a PyTorch implementation of a Vision Transformer (ViT) model, inspired by the architecture outlined in "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2021). ViT reimagines the use of Transformers in image classification by applying attention mechanisms to sequences of image patches.
- Project Structure
- Features
- Installation
- Usage
- Configuration
- Model Overview
- Relation to ViT Paper
- Training and Evaluation
- License
- ViT.py: Contains the full implementation of the Vision Transformer model, including configurations, model architecture, and helper functions.
- Configurable Model: Define the transformer parameters, image dimensions, patch sizes, and number of classes.
- Self-Attention Mechanism: Implements multi-head self-attention to process image patches.
- Training on Multiple Devices: Supports training on CPU, CUDA (GPU), and MPS (Mac GPU).
Ensure you have Python 3.7+ and PyTorch installed.
-
Clone the repository:
git clone <repository_url> cd vision-transformer-pytorch
-
Install dependencies:
pip install torch torchvision transformers numpy datasets
-
Configuration: Modify the
VITConfig
class to set parameters such as image size, patch size, embedding size, number of heads, layers, and classes. -
Device Selection: The script automatically detects the available device (CPU, CUDA, or MPS).
-
Training: Use the provided dataset loader to train the model on your image dataset.
All configurable parameters are in the VITConfig
class:
n_emb
: Embedding size for each patch.image_size
: Input image size (height and width should match).patch_size
: Size of each image patch.n_heads
: Number of attention heads.n_layers
: Number of transformer encoder layers.num_classes
: Number of classes for classification.
Example:
config = VITConfig(
n_emb=768,
image_size=224,
n_heads=12,
patch_size=16,
n_layers=12,
num_classes=10
)
The Vision Transformer consists of the following components:
- Patch Embedding: Splits the image into patches and embeds them into the transformer input dimension.
- Transformer Encoder: Applies multiple self-attention layers with residual connections.
- Classification Head: Maps the transformer output to the number of classes.
This implementation follows the methods outlined in "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2021). Key aspects include:
-
Patch Embedding and Image Representation: The model treats each image as a sequence of patches (16x16 pixels each), embedding these patches linearly before feeding them into the Transformer. This is represented in the
VITConfig
class inViT.py
through parameters likepatch_size
andnum_patches
, reflecting the paper’s method of treating images as sequences. -
Transformer Encoder and Self-Attention: The core of the ViT model leverages a Transformer encoder with self-attention and MLPs to process the sequence of patches. The
SelfAttention
class inViT.py
implements this mechanism, applying global context across patches, as described in the ViT model. -
Classification Token: Similar to BERT’s [CLS] token, a classification token is added to the input sequence to serve as an image representation for the final classification. The code includes this token in its processing sequence, and the final layer classifies based on this token, consistent with the paper’s structure.
-
Model Configuration and Parameters: The
VITConfig
class allows modification of transformer parameters, enabling experimentation with model variants (Base, Large, Huge) as outlined in the paper. -
Position Embedding: Position embeddings are added to retain spatial information in the sequence, ensuring that positional context is included across patches in the Transformer. The code reflects this approach by incorporating position embeddings into patch sequences, aligning with the paper’s structure.
- Load Data: Use the
torchvision.datasets
andtorch.utils.data.DataLoader
modules to load your image dataset. - Train: Define a training loop with an optimizer and a loss function, such as cross-entropy.
- Evaluate: After training, evaluate the model on a validation/test set to assess performance.
This project is open-source and available under the MIT License.