Skip to content

WAMAWAMA/Pytorch-Vision-Transformers-CIFAR10

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PYTORCH VISION TRANSFORMERS

Pytorch implementation of Vision transformer from An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. Trainable on CIFAR10 only (for now).

Based on Pytorch Lightning.

SUMMARY

Last Commit Changes Log

  • Initial Commit

Installation

(Back to top)

Clone repo:

git clone https://github.com/the-dharma-bum/Pytorch-Vision-Transformers-CIFAR10

Install dependancies by running:

pip install -r requirements.txt

Usage

(Back to top)

All hyperparameters can be configured in config.py One can setup a training config inside this file and then simply run:

python main.py

This command accepts a huge number of parameters. Run

python main.py -h

to see them all, or refer to documentation de Pytorch Lightning.

Some usefull parameters:

  • --gpus n: launch training on n gpus
  • --distributed_backend ddp : use DistributedDataParallel as multi gpus training backend.
  • --fast_dev_run True : launch a training loop (train, eval, test) on a single batch. Use it to debug.

If fast_dev_run doesn't suit your debugging need (for instance if you wanna see what's happening between two epochs), you can use:

  • --limit_train_batches i --limit_val_batches j --max_epochs k

    i,j,k being of course three integers of your choice.

Fastai Integration

(Back to top)

It's very easy to integrate a Lighning code into the Fastai training environnement.

One must define

  • a model (see model.py):
import config as cfg
from model import LightningModel

model = LightningModel(cfg.TrainConfig())
  • a datamodule:
from pl_bolts.datamodules import CIFAR10DataModule

dm = CIFAR10DataModule(args, kwargs)

Using this datamodule, two fastai DataLoaders can be defined like this:

from fastai.vision.all import DataLoaders

data = Dataloaders(dm.train_dataloader(), dm.val_dataloader()).cuda()

Then a Learner can be defined and used like a standart Fastai code, for instance:

learn = Learner(data, model, loss_func=F.cross_entropy, opt_func=Adam, metrics=accuracy)
learn.fit_one_cycle(1, 0.001)

This makes every fastai training fonctionalites availables (callbacks, transforms, visualizations ...).

About

A small test of Vision Transformers on CIFAR10 in Pytorch.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 51.4%
  • Jupyter Notebook 48.6%