Official implementation of NeurIPS 2024 "Visual Fourier Prompt Tuning"
Contact me: runjia.tech | [email protected] | [email protected]
(👉Under construction! You can currently check here for commands. There are several redundancies in the current version, and the commands/instructions are not perfectly ready for formal release. I will gradually update it! Please stay tuned.)
2024/12/07: Our code is publicly available now! Thank you for your attention and patience!
2024/12/02: Our homepage is available now (slides and video are included)! Check it out to see more details.
2024/11/14: Our preliminary key code is now available on GitHub.
If you are just interested in the key implementation in our paper, you can simply take out this part of the code.
# Visual Prompts
x = torch.cat(( x[:, :1, :],
prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
x[:, 1:, :]), dim=1)
# Visual Fourier Prompts (Fourier percentage equals 1.0)
x = torch.cat(( x[:, :1, :],
torch.fft.fft(torch.fft.fft(
prompt_dropout(prompt_proj(prompt_embeddings).expand(B, -1, -1)),
dim=-1),dim=-2).real,
x[:, 1:, :]), dim=1)
Our code implementation is based on VPT and E2VPT. I have also included part of the ViT VFPT implementation code (originally located at src/models/vit_prompt/vit_fourier.py in the main root directory ./vit_VFPT.py for your convenience.
Our code implementation is based on loss-landscape.
For the heatmap, our code implementation is based on gradcam. The attention map is simply obtained from the attention layer and visualized using Matplotlib.
The documentation below is copied and modified from VPT and E2VPT. Thanks for their effort.
See env_setup.sh
-
src/configs
: handles config parameters for the experiments.- 👉
src/config/config.py
: main config setups for experiments and explanation for each of them.
- 👉
-
src/data
: loading and setup input datasets. Thesrc/data/vtab_datasets
are borrowed from -
src/engine
: main training and eval actions here. -
src/models
: handles backbone archs and heads for different fine-tuning protocols-
👉
src/models/vit_prompt
: a folder contains the same backbones invit_backbones
folder, specified for VPT. This folder should contain the same file names as those invit_backbones
-
👉
src/models/vit_models.py
: main model for transformer-based models ❗️Note❗️: Current version only support ViT, Swin and ViT with mae, moco-v3 -
src/models/build_model.py
: main action here to utilize the config and build the model to train / eval.
-
-
src/solver
: optimization, losses and learning rate schedules. -
src/utils
: helper functions for io, loggings, training, visualizations. -
👉
train.py
: call this one for training and eval a model with a specified transfer type. -
👉
tune_fgvc.py
: call this one for tuning learning rate and weight decay for a model with a specified transfer type. We used this script for FGVC tasks. -
👉
tune_vtab.py
: call this one for tuning vtab tasks: use 800/200 split to find the best lr and wd, and use the best lr/wd for the final runs -
launch.py
: contains functions used to launch the job.
- 🔥VFPT related:
- MODEL.PROMPT_FOURIER.FOURIER_DIMENSION: all, sequence or hidden
- MODEL.PROMPT_FOURIER.FOURIER_PERCENTAGE: 0.0 to 1.0
- MODEL.PROMPT_FOURIER.FOURIER_LOCATION: append, prepend or random
- VPT related:
- MODEL.PROMPT.NUM_TOKENS: prompt length
- MODEL.PROMPT.DEEP: deep or shallow prompt
- Fine-tuning method specification:
- MODEL.TRANSFER_TYPE
- Vision backbones:
- DATA.FEATURE: specify which representation to use
- MODEL.TYPE: the general backbone type, e.g., "vit" or "swin"
- MODEL.MODEL_ROOT: folder with pre-trained model checkpoints
- Optimization related:
- SOLVER.BASE_LR: learning rate for the experiment
- SOLVER.WEIGHT_DECAY: weight decay value for the experiment
- DATA.BATCH_SIZE
- Datasets related:
- DATA.NAME
- DATA.DATAPATH: where you put the datasets
- DATA.NUMBER_CLASSES
- Others:
- RUN_N_TIMES: ensure only run once in case for duplicated submision, not used during vtab runs
- OUTPUT_DIR: output dir of the final model and logs
- MODEL.SAVE_CKPT: if set to
True
, will save model ckpts and final output of both val and test set
See Table 8 in the Appendix for dataset details.
-
Fine-Grained Visual Classification tasks (FGVC): The datasets can be downloaded following the official links. We split the training data if the public validation set is not available. The splitted dataset can be found here: Dropbox, Google Drive.
-
Visual Task Adaptation Benchmark (VTAB): see
VTAB_SETUP.md
for detailed instructions and tips.
Download and place the pre-trained Transformer-based backbones to MODEL.MODEL_ROOT
(ConvNeXt-Base and ResNet50 would be automatically downloaded via the links in the code). Note that you also need to rename the downloaded ViT-B/16 ckpt from ViT-B_16.npz
to imagenet21k_ViT-B_16.npz
.
See Table 9 in the Appendix for more details about pre-trained backbones.
Pre-trained Backbone | Pre-trained Objective | Link | md5sum |
---|---|---|---|
ViT-B/16 | Supervised | link | d9715d |
ViT-B/16 | MoCo v3 | link | 8f39ce |
ViT-B/16 | MAE | link | 8cad7c |
Swin-B | Supervised | link | bf9cc1 |
# Training of VFPT
python tune_vtab.py \
--train-type "prompt" \
--config-file ./configs/prompt/prompt_fourier/Natural/caltech101_forVPT.yaml \
MODEL.PROMPT_FOURIER.DEEP "True" \
MODEL.PROMPT_FOURIER.NUM_TOKENS "10" \
MODEL.PROMPT_FOURIER.DROPOUT "0.10" \
MODEL.PROMPT_FOURIER.FOURIER_PERCENTAGE "1.0" \
OUTPUT_DIR "./output/" \
DATA.BATCH_SIZE "64"
# Landscape Visualization of VFPT
# Replace {} with the correct args
python ./ls_plot_surface.py \
--lr {lr} \
--wd {wd} \
--train-type "prompt" \
--x=-1:1:51 \
--y=-1:1:51 \
--config-file {config} \
MODEL.PROMPT_FOURIER.DEEP "True" \
MODEL.PROMPT_FOURIER.DROPOUT "0.10" \
MODEL.PROMPT_FOURIER.FOURIER_PERCENTAGE "{percentage}" \
MODEL.PROMPT_FOURIER.NUM_TOKENS "{num}" \
DATA.BATCH_SIZE "{batch_size}" \
OUTPUT_DIR "{output_dir}"
# GradCam Visualization of VFPT
PORT=20000 python \
tune_vtab_AS.py \
--train-type "prompt" \
--config-file ./configs/prompt/prompt_fourier/Natural/caltech101_forVPT.yaml \
MODEL.PROMPT_FOURIER.DEEP "True" \
MODEL.PROMPT_FOURIER.NUM_TOKENS "10" \
MODEL.PROMPT_FOURIER.DROPOUT "0.10" \
MODEL.PROMPT_FOURIER.FOURIER_PERCENTAGE "1.0" \
OUTPUT_DIR "./attn" \
DATA.BATCH_SIZE "64" \
ATTRIBUTION_TYPE "general" \
ATTRIBUTION_INTEGRATED_METHOD "pytorch_gradcam"
If you find our work helpful in your research, please cite it as:
@inproceedings{zeng2024visual,
title={Visual Fourier Prompt Tuning},
author={Zeng, Runjia and Han, Cheng and Wang, Qifan and Wu, Chunshu and Geng, Tong and Huang, Lifu and Wu, Ying Nian and Liu, Dongfang},
booktitle={NeurIPS},
year={2024}
}
The majority of VFPT is licensed under the CC-BY-NC 4.0 license (see LICENSE for details). Portions of the project are available under separate license terms: GitHub - google-research/task_adaptation and huggingface/transformers are licensed under the Apache 2.0 license; Swin-Transformer, ConvNeXt and ViT-pytorch are licensed under the MIT license; and MoCo-v3 and MAE are licensed under the Attribution-NonCommercial 4.0 International license.