This repository contains the training code of ViT introduced in our work: "Oscillation-free Quantization for Low-bit Vision Transformers" which has been accepted for ICML 2023.
In this work, we discusses the issue of weight oscillation in quantization-aware training and how it negatively affects model performance. The learnable scaling factor, commonly used in quantization, was found to worsen weight oscillation. The study proposes three techniques to address this issue: statistical weight quantization (StatsQ), confidence-guided annealing (CGA), and query-key reparameterization (QKR). These techniques were tested on the ViT model and were found to improve quantization robustness and accuracy. The proposed 2-bit DeiT-T/DeiT-S algorithms outperform the previous state-of-the-art by 9.8% and 7.7%, respectively.
- numpy==1.22.3
- torch==2.0.0
- torchvision==0.15.1
- timm=0.5.4
- pyyaml
Please replace "/your/miniconda3/envs/ofq/lib/python3.8/site-packages/timm/data/dataset_factory.py" with "timm_fix_imagenet_loading_bugs/dataset_factory.py" as with the original code there is a "TypeError: init() got an unexpected keyword argument 'download'" error.
- Pretrained models will be automatically downloaded for you if set args.pretrained to True.
-
Examples of training scripts, finetuning scripts (CGA) are provided under "train_scripts/" and evaluation scripts are under "eval_scripts/" (please use the exact same batch size (batch_size * world_size) as provided in the evaluation scripts to reproduce the results reported in the paper).
-
Please modified the data path to your own dataset address
Models | #Bits | Top-1 Accuracy (Model Link) | eval script |
---|---|---|---|
DeiT-T | 32-32 | 72.02 | ------- |
OFQ DeiT-T | 2-2 | 64.33 | eval_scripts/deit_t/w2a2.sh |
OFQ DeiT-T | 3-3 | 72.72 | eval_scripts/deit_t/w3a3.sh |
OFQ DeiT-T | 4-4 | 75.46 | eval_scripts/deit_t/w4a4.sh |
DeiT-S | 32-32 | 79.9 | ------- |
OFQ DeiT-S | 2-2 | 75.72 | eval_scripts/deit_s/w2a2.sh |
OFQ DeiT-S | 3-3 | 79.57 | eval_scripts/deit_s/w3a3.sh |
OFQ DeiT-S | 4-4 | 81.10 | eval_scripts/deit_s/w4a4.sh |
Swin-T | 32-32 | 81.2 | ------- |
OFQ Swin-T | 2-2 | 78.52 | eval_scripts/swin_t/w2a2.sh |
OFQ Swin-T | 3-3 | 81.09 | eval_scripts/swin_t/w3a3.sh |
OFQ Swin-T | 4-4 | 81.88 | eval_scripts/swin_t/w4a4.sh |
The original code is borrowed from DeiT.
If you find our code useful for your research, please consider citing:
@misc{https://doi.org/10.48550/arxiv.2302.02210,
doi = {10.48550/ARXIV.2302.02210},
url = {https://arxiv.org/abs/2302.02210},
author = {Liu, Shih-Yang and Liu, Zechun and Cheng, Kwang-Ting},
keywords = {Computer Vision and Pattern Recognition (cs.CV), Artificial Intelligence (cs.AI), Hardware Architecture (cs.AR), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences},
title = {Oscillation-free Quantization for Low-bit Vision Transformers},
publisher = {arXiv},
year = {2023},
copyright = {arXiv.org perpetual, non-exclusive license}
}
Shih-Yang Liu, HKUST (sliuau at connect.ust.hk)