Skip to content
forked from THU-MIG/RepViT

RepViT: Revisiting Mobile CNN From ViT Perspective

Notifications You must be signed in to change notification settings

henrywoo/RepViT

 
 

Repository files navigation

Official PyTorch implementation of RepViT, from the following paper:

RepViT: Revisiting Mobile CNN From ViT Perspective.
Ao Wang, Hui Chen, Zijia Lin, Hengjun Pu, and Guiguang Ding
[arXiv]


Models are trained on ImageNet-1K and deployed on iPhone 12 with Core ML Tools to get latency.

Abstract Recently, lightweight Vision Transformers (ViTs) demonstrate superior performance and lower latency compared with lightweight Convolutional Neural Networks (CNNs) on resource-constrained mobile devices. This improvement is usually attributed to the multi-head self-attention module, which enables the model to learn global representations. However, the architectural disparities between lightweight ViTs and lightweight CNNs have not been adequately examined. In this study, we revisit the efficient design of lightweight CNNs and emphasize their potential for mobile devices. We incrementally enhance the mobile-friendliness of a standard lightweight CNN, specifically MobileNetV3, by integrating the efficient architectural choices of lightweight ViTs. This ends up with a new family of pure lightweight CNNs, namely RepViT. Extensive experiments show that RepViT outperforms existing state-of-the-art lightweight ViTs and exhibits favorable latency in various vision tasks. On ImageNet, RepViT achieves over 80\% top-1 accuracy with nearly 1ms latency on an iPhone 12, which is the first time for a lightweight model, to the best of our knowledge. Our largest model, RepViT-M3, obtains 81.4\% accuracy with only 1.3ms latency.

Classification on ImageNet-1K

Models

Model Top-1 (300) #params MACs Latency Ckpt Core ML Log
RepViT-M1 78.5 5.1M 0.8G 0.9ms M1 M1 M1
RepViT-M2 80.6 8.2M 1.3G 1.1ms M2 M2 M2
RepViT-M3 81.4 10.1M 1.9G 1.3ms M3 M3 M3

Tips: Convert a training-time RepViT into the inference-time structure

from timm.models import create_model
import utils

model = create_model('repvit_m1')
utils.replace_batchnorm(model)

Latency Measurement

The latency reported in RepViT for iPhone 12 (iOS 16) uses the benchmark tool from XCode 14. For example, here is a latency measurement of RepViT-M1:

Tips: export the model to Core ML model

python export_coreml.py --model repvit_m1 --ckpt pretrain/repvit_m1_distill_300.pth

Tips: measure the throughput on GPU

python speed_gpu.py --model repvit_m1

Result On Tesla V100:

12:07 $ python speed_gpu.py --model repvit_m1
repvit_m1 cuda:0 3767.5837485638444 images/s @ batch size 2048

ImageNet

Prerequisites

conda virtual environment is recommended.

conda create -n repvit python=3.8
pip install -r requirements.txt

Data preparation

Download and extract ImageNet train and val images from http://image-net.org/. The training and validation data are expected to be in the train folder and val folder respectively:

|-- /path/to/imagenet/
    |-- train
    |-- val

Training

To train RepViT-M1 on an 8-GPU machine:

python -m torch.distributed.launch --nproc_per_node=8 --master_port 12346 --use_env main.py --model repvit_m1 --data-path ~/imagenet --dist-eval

Tips: specify your data path and model name!

Testing

For example, to test RepViT-M1:

python main.py --eval --model repvit_m3 --resume pretrain/repvit_m3_distill_300.pth --data-path ~/imagenet

Downstream Tasks

Object Detection and Instance Segmentation
Semantic Segmentation

Acknowledgement

Classification (ImageNet) code base is partly built with LeViT, PoolFormer and EfficientFormer.

The detection and segmentation pipeline is from MMCV (MMDetection and MMSegmentation).

Thanks for the great implementations!

Citation

If our code or models help your work, please cite our paper:

@misc{wang2023repvit,
      title={RepViT: Revisiting Mobile CNN From ViT Perspective}, 
      author={Ao Wang and Hui Chen and Zijia Lin and Hengjun Pu and Guiguang Ding},
      year={2023},
      eprint={2307.09283},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

from timm.models import create_model from hiq.vis import print_model model = create_model('repvit_m1') print_model(model)

About

RepViT: Revisiting Mobile CNN From ViT Perspective

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 98.8%
  • Shell 1.2%