Skip to content
forked from Visual-AI/SPTNet

The official repository for ICLR2024 paper "SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning"

License

Notifications You must be signed in to change notification settings

whj363636/SPTNet

 
 

Repository files navigation

SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning (ICLR 2024)

SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning
By Hongjun Wang, Sagar Vaze, and Kai Han.

teaser

Update

[05.2024] We update the results of SPTNet with DINOv2 on CUB, please check our latest version in Arxiv

All Old New
CUB (DINO) 65.8 68.8 65.1
CUB (DINOv2) 76.3 79.5 74.6

Prerequisite 🛠️

First, you need to clone the SPTNet repository from GitHub. Open your terminal and run the following command:

git clone https://github.com/Visual-AI/SPTNet.git
cd SPTNet

We recommend setting up a conda environment for the project:

conda create --name=spt python=3.9
conda activate spt
pip install -r requirements.txt

Running 🏃

Config

Set paths to datasets and desired log directories in config.py

Datasets

We use generic object recognition datasets, including CIFAR-10/100 and ImageNet-100/1K:

We also use fine-grained benchmarks (CUB, Stanford-cars, FGVC-aircraft, Herbarium-19). You can find the datasets in:

Checkpoints

Download the checkpints of SPTNet for different datasets and put them in the ``checkpoints'' folder (only used during evaluation).

Scripts

Eval the model

CUDA_VISIBLE_DEVICES=0 python eval.py \
    --dataset_name 'aircraft' \
    --pretrained_model_path ./checkpoints/fgvc/dinoB16_best.pt \
    --prompt_type 'all' \ # switch to 'patch' for 'cifar10' and 'cifar100'
    --eval_funcs 'v2' \

To reproduce all main results in the paper, just change the name (dataset_name) and its corresponding path (pretrained_model_path) to the pretrained model you downloaded from the above link.

Train the model:

CUDA_VISIBLE_DEVICES=0 python train_spt.py \
    --dataset_name 'aircraft' \
    --batch_size 128 \
    --grad_from_block 11 \
    --epochs 1000 \
    --num_workers 8 \
    --use_ssb_splits \
    --sup_weight 0.35 \
    --weight_decay 5e-4 \
    --transform 'imagenet' \
    --lr 1 \
    --lr2 0.05 \
    --prompt_size 1 \
    --freq_rep_learn 20 \
    --pretrained_model_path ${YOUR_OWN_PRETRAINED_PATH} \
    --prompt_type 'all' \
    --eval_funcs 'v2' \
    --warmup_teacher_temp 0.07 \
    --teacher_temp 0.04 \
    --warmup_teacher_temp_epochs 10 \
    --memax_weight 1 \
    --model_path ${YOUR_OWN_SAVE_DIR}

Just be aware to change the name (dataset_name) and its corresponding path (pretrained_model_path) to the pretrained model. Our SPTNet method is adaptable to various pretrained models, allowing for the modification of the architecture by changing the pretrained_model_path. This feature enables quick adoption of the state-of-the-art (SOTA) method. Our default settings utilize the SimGCD method.

Results

Generic results:

All Old New
CIFAR-10 97.3 95.0 98.6
CIFAR-100 81.3 84.3 75.6
ImageNet-100 85.4 93.2 81.4

Fine-grained results:

All Old New
CUB 65.8 68.8 65.1
Stanford Cars 59.0 79.2 49.3
FGVC-Aircraft 59.3 61.8 58.1
Herbarium19 43.4 58.7 35.2

Citing this work

If you find this repo useful for your research, please consider citing our paper:

@inproceedings{wang2024sptnet,
    author    = {Wang, Hongjun and Vaze, Sagar and Han, Kai},
    title     = {SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning},
    booktitle = {International Conference on Learning Representations (ICLR)},
    year      = {2024}
}

About

The official repository for ICLR2024 paper "SPTNet: An Efficient Alternative Framework for Generalized Category Discovery with Spatial Prompt Tuning"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 100.0%