forked from CVMI-Lab/SimGCD
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1765452
Showing
29 changed files
with
2,444 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2022 Xin Wen | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
# SimGCD: A Simple Parametric Classification Baseline for Generalized Category Discovery | ||
|
||
This repo contains code for our paper: [A Simple Parametric Classification Baseline for Generalized Category Discovery](https://arxiv.org/abs/xxxx.xxxxx). | ||
|
||
 | ||
|
||
Generalized category discovery (GCD) is a problem setting where the goal is to discover novel categories within an unlabelled dataset using the knowledge learned from a set of labelled samples. | ||
Recent works in GCD argue that a non-parametric classifier formed using semi-supervised $k$-means can outperform strong baselines which use parametric classifiers as it can alleviate the over-fitting to seen categories in the labelled set. | ||
|
||
In this paper, we revisit the reason that makes previous parametric classifiers fail to recognise new classes for GCD. | ||
By investigating the design choices of parametric classifiers from the perspective of model architecture, representation learning, and classifier learning, we conclude that the less discriminative representations and unreliable pseudo-labelling strategy are key factors that make parametric classifiers lag behind non-parametric ones. | ||
Motivated by our investigation, we present a simple yet effective parametric classification baseline that outperforms the previous best methods by a large margin on multiple popular GCD benchmarks. | ||
We hope the investigations and the simple baseline can serve as a cornerstone to facilitate future studies. | ||
|
||
## Running | ||
|
||
### Dependencies | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Config | ||
|
||
Set paths to datasets and desired log directories in ```config.py``` | ||
|
||
|
||
### Datasets | ||
|
||
We use fine-grained benchmarks in this paper, including: | ||
|
||
* [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6) | ||
|
||
We also use generic object recognition datasets, including: | ||
|
||
* [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet](https://image-net.org/download.php) | ||
|
||
|
||
### Scripts | ||
|
||
**Train the model**: | ||
|
||
``` | ||
bash scripts/run_${DATASET_NAME}.sh | ||
``` | ||
|
||
Then check the results starting with `Metrics with best model on test set:` in the logs. | ||
This means the model is picked according to its performance on the test set, and then evaluated on the unlabelled instances of the train set. | ||
|
||
## Results | ||
Our results in three independent runs: | ||
|
||
| Dataset | All | Old | New | | ||
|:-------------: |:--------: |:--------: |:--------: | | ||
| CIFAR10 | 93.2±0.4 | 82.0±1.2 | 98.9±0.0 | | ||
| CIFAR100 | 78.1±0.8 | 77.6±1.5 | 78.0±2.5 | | ||
| ImageNet-100 | 82.4±0.9 | 90.7±0.6 | 78.3±1.2 | | ||
| CUB | 60.3±0.1 | 65.6±0.9 | 57.7±0.4 | | ||
| Stanford Cars | 46.8±1.8 | 64.9±1.3 | 38.0±2.1 | | ||
| FGVC-Aircraft | 48.8±2.2 | 51.0±2.2 | 47.8±2.7 | | ||
| Herbarium 19 | 43.3±0.3 | 57.9±0.5 | 35.3±0.2 | | ||
|
||
## Citing this work | ||
|
||
If you find this repo useful for your research, please consider citing our paper: | ||
|
||
``` | ||
@article{wen2022simgcd, | ||
title={A Simple Parametric Classification Baseline for Generalized Category Discovery}, | ||
author={Wen, Xin and Zhao, Bingchen and Qi, Xiaojuan}, | ||
journal={arXiv preprint arXiv:2211.xxxxx}, | ||
year={2022} | ||
} | ||
``` | ||
|
||
## Acknowledgements | ||
|
||
The codebase is largely built on this repo: https://github.com/sgvaze/generalized-category-discovery. | ||
|
||
## License | ||
|
||
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# ----------------- | ||
# DATASET ROOTS | ||
# ----------------- | ||
cifar_10_root = '${DATASET_DIR}/cifar10' | ||
cifar_100_root = '${DATASET_DIR}/cifar100' | ||
cub_root = '${DATASET_DIR}/cub' | ||
aircraft_root = '${DATASET_DIR}/fgvc-aircraft-2013b' | ||
car_root = '${DATASET_DIR}/cars' | ||
herbarium_dataroot = '${DATASET_DIR}/herbarium_19' | ||
imagenet_root = '${DATASET_DIR}/ImageNet' | ||
|
||
# OSR Split dir | ||
osr_split_dir = 'data/ssb_splits' | ||
|
||
# ----------------- | ||
# OTHER PATHS | ||
# ----------------- | ||
exp_root = 'dev_outputs' # All logs and checkpoints will be saved here |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from torchvision import transforms | ||
|
||
import torch | ||
|
||
def get_transform(transform_type='imagenet', image_size=32, args=None): | ||
|
||
if transform_type == 'imagenet': | ||
|
||
mean = (0.485, 0.456, 0.406) | ||
std = (0.229, 0.224, 0.225) | ||
interpolation = args.interpolation | ||
crop_pct = args.crop_pct | ||
|
||
train_transform = transforms.Compose([ | ||
transforms.Resize(int(image_size / crop_pct), interpolation), | ||
transforms.RandomCrop(image_size), | ||
transforms.RandomHorizontalFlip(p=0.5), | ||
transforms.ColorJitter(), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=torch.tensor(mean), | ||
std=torch.tensor(std)) | ||
]) | ||
|
||
test_transform = transforms.Compose([ | ||
transforms.Resize(int(image_size / crop_pct), interpolation), | ||
transforms.CenterCrop(image_size), | ||
transforms.ToTensor(), | ||
transforms.Normalize( | ||
mean=torch.tensor(mean), | ||
std=torch.tensor(std)) | ||
]) | ||
|
||
else: | ||
|
||
raise NotImplementedError | ||
|
||
return (train_transform, test_transform) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,195 @@ | ||
from torchvision.datasets import CIFAR10, CIFAR100 | ||
from copy import deepcopy | ||
import numpy as np | ||
|
||
from data.data_utils import subsample_instances | ||
from config import cifar_10_root, cifar_100_root | ||
|
||
|
||
class CustomCIFAR10(CIFAR10): | ||
|
||
def __init__(self, *args, **kwargs): | ||
|
||
super(CustomCIFAR10, self).__init__(*args, **kwargs) | ||
|
||
self.uq_idxs = np.array(range(len(self))) | ||
|
||
def __getitem__(self, item): | ||
|
||
img, label = super().__getitem__(item) | ||
uq_idx = self.uq_idxs[item] | ||
|
||
return img, label, uq_idx | ||
|
||
def __len__(self): | ||
return len(self.targets) | ||
|
||
|
||
class CustomCIFAR100(CIFAR100): | ||
|
||
def __init__(self, *args, **kwargs): | ||
super(CustomCIFAR100, self).__init__(*args, **kwargs) | ||
|
||
self.uq_idxs = np.array(range(len(self))) | ||
|
||
def __getitem__(self, item): | ||
img, label = super().__getitem__(item) | ||
uq_idx = self.uq_idxs[item] | ||
|
||
return img, label, uq_idx | ||
|
||
def __len__(self): | ||
return len(self.targets) | ||
|
||
|
||
def subsample_dataset(dataset, idxs): | ||
|
||
# Allow for setting in which all empty set of indices is passed | ||
|
||
if len(idxs) > 0: | ||
|
||
dataset.data = dataset.data[idxs] | ||
dataset.targets = np.array(dataset.targets)[idxs].tolist() | ||
dataset.uq_idxs = dataset.uq_idxs[idxs] | ||
|
||
return dataset | ||
|
||
else: | ||
|
||
return None | ||
|
||
|
||
def subsample_classes(dataset, include_classes=(0, 1, 8, 9)): | ||
|
||
cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes] | ||
|
||
target_xform_dict = {} | ||
for i, k in enumerate(include_classes): | ||
target_xform_dict[k] = i | ||
|
||
dataset = subsample_dataset(dataset, cls_idxs) | ||
|
||
# dataset.target_transform = lambda x: target_xform_dict[x] | ||
|
||
return dataset | ||
|
||
|
||
def get_train_val_indices(train_dataset, val_split=0.2): | ||
|
||
train_classes = np.unique(train_dataset.targets) | ||
|
||
# Get train/test indices | ||
train_idxs = [] | ||
val_idxs = [] | ||
for cls in train_classes: | ||
|
||
cls_idxs = np.where(train_dataset.targets == cls)[0] | ||
|
||
v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),)) | ||
t_ = [x for x in cls_idxs if x not in v_] | ||
|
||
train_idxs.extend(t_) | ||
val_idxs.extend(v_) | ||
|
||
return train_idxs, val_idxs | ||
|
||
|
||
def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9), | ||
prop_train_labels=0.8, split_train_val=False, seed=0): | ||
|
||
np.random.seed(seed) | ||
|
||
# Init entire training set | ||
whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True) | ||
|
||
# Get labelled training set which has subsampled classes, then subsample some indices from that | ||
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) | ||
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) | ||
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) | ||
|
||
# Split into training and validation sets | ||
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) | ||
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) | ||
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) | ||
val_dataset_labelled_split.transform = test_transform | ||
|
||
# Get unlabelled data | ||
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) | ||
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) | ||
|
||
# Get test set for all classes | ||
test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False) | ||
|
||
# Either split train into train and val or use test set as val | ||
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled | ||
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None | ||
|
||
all_datasets = { | ||
'train_labelled': train_dataset_labelled, | ||
'train_unlabelled': train_dataset_unlabelled, | ||
'val': val_dataset_labelled, | ||
'test': test_dataset, | ||
} | ||
|
||
return all_datasets | ||
|
||
|
||
def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80), | ||
prop_train_labels=0.8, split_train_val=False, seed=0): | ||
|
||
np.random.seed(seed) | ||
|
||
# Init entire training set | ||
whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True) | ||
|
||
# Get labelled training set which has subsampled classes, then subsample some indices from that | ||
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes) | ||
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels) | ||
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices) | ||
|
||
# Split into training and validation sets | ||
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled) | ||
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs) | ||
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs) | ||
val_dataset_labelled_split.transform = test_transform | ||
|
||
# Get unlabelled data | ||
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs) | ||
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices))) | ||
|
||
# Get test set for all classes | ||
test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False) | ||
|
||
# Either split train into train and val or use test set as val | ||
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled | ||
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None | ||
|
||
all_datasets = { | ||
'train_labelled': train_dataset_labelled, | ||
'train_unlabelled': train_dataset_unlabelled, | ||
'val': val_dataset_labelled, | ||
'test': test_dataset, | ||
} | ||
|
||
return all_datasets | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
x = get_cifar_100_datasets(None, None, split_train_val=False, | ||
train_classes=range(80), prop_train_labels=0.5) | ||
|
||
print('Printing lens...') | ||
for k, v in x.items(): | ||
if v is not None: | ||
print(f'{k}: {len(v)}') | ||
|
||
print('Printing labelled and unlabelled overlap...') | ||
print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs))) | ||
print('Printing total instances in train...') | ||
print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs))) | ||
|
||
print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}') | ||
print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}') | ||
print(f'Len labelled set: {len(x["train_labelled"])}') | ||
print(f'Len unlabelled set: {len(x["train_unlabelled"])}') |
Oops, something went wrong.