SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning (ICLR 2024)
SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning
By
Hongjun Wang,
Sagar Vaze, and
Kai Han.
[05.2024] We update the results of SPTNet with DINOv2 on CUB, please check our latest version in Arxiv
All | Old | New | |
---|---|---|---|
CUB (DINO) | 65.8 | 68.8 | 65.1 |
CUB (DINOv2) | 76.3 | 79.5 | 74.6 |
First, you need to clone the SPTNet repository from GitHub. Open your terminal and run the following command:
git clone https://github.com/Visual-AI/SPTNet.git
cd SPTNet
We recommend setting up a conda environment for the project:
conda create --name=spt python=3.9
conda activate spt
pip install -r requirements.txt
Set paths to datasets and desired log directories in config.py
We use generic object recognition datasets, including CIFAR-10/100 and ImageNet-100/1K:
We also use fine-grained benchmarks (CUB, Stanford-cars, FGVC-aircraft, Herbarium-19). You can find the datasets in:
Download the checkpints of SPTNet for different datasets and put them in the ``checkpoints'' folder (only used during evaluation).
Eval the model
CUDA_VISIBLE_DEVICES=0 python eval.py \
--dataset_name 'aircraft' \
--pretrained_model_path ./checkpoints/fgvc/dinoB16_best.pt \
--prompt_type 'all' \ # switch to 'patch' for 'cifar10' and 'cifar100'
--eval_funcs 'v2' \
To reproduce all main results in the paper, just change the name (dataset_name
) and its corresponding path (pretrained_model_path
) to the pretrained model you downloaded from the above link.
Train the model:
CUDA_VISIBLE_DEVICES=0 python train_spt.py \
--dataset_name 'aircraft' \
--batch_size 128 \
--grad_from_block 11 \
--epochs 1000 \
--num_workers 8 \
--use_ssb_splits \
--sup_weight 0.35 \
--weight_decay 5e-4 \
--transform 'imagenet' \
--lr 1 \
--lr2 0.05 \
--prompt_size 1 \
--freq_rep_learn 20 \
--pretrained_model_path ${YOUR_OWN_PRETRAINED_PATH} \
--prompt_type 'all' \
--eval_funcs 'v2' \
--warmup_teacher_temp 0.07 \
--teacher_temp 0.04 \
--warmup_teacher_temp_epochs 10 \
--memax_weight 1 \
--model_path ${YOUR_OWN_SAVE_DIR}
Just be aware to change the name (dataset_name
) and its corresponding path (pretrained_model_path
) to the pretrained model. Our SPTNet method is adaptable to various pretrained models, allowing for the modification of the architecture by changing the pretrained_model_path
. This feature enables quick adoption of the state-of-the-art (SOTA) method. Our default settings utilize the SimGCD method.
Generic results:
All | Old | New | |
---|---|---|---|
CIFAR-10 | 97.3 | 95.0 | 98.6 |
CIFAR-100 | 81.3 | 84.3 | 75.6 |
ImageNet-100 | 85.4 | 93.2 | 81.4 |
Fine-grained results:
All | Old | New | |
---|---|---|---|
CUB | 65.8 | 68.8 | 65.1 |
Stanford Cars | 59.0 | 79.2 | 49.3 |
FGVC-Aircraft | 59.3 | 61.8 | 58.1 |
Herbarium19 | 43.4 | 58.7 | 35.2 |
If you find this repo useful for your research, please consider citing our paper:
@inproceedings{wang2024sptnet,
author = {Wang, Hongjun and Vaze, Sagar and Han, Kai},
title = {SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning},
booktitle = {International Conference on Learning Representations (ICLR)},
year = {2024}
}