Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
xwen99 committed Nov 19, 2022
0 parents commit 1765452
Show file tree
Hide file tree
Showing 29 changed files with 2,444 additions and 0 deletions.
21 changes: 21 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
MIT License

Copyright (c) 2022 Xin Wen

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
82 changes: 82 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SimGCD: A Simple Parametric Classification Baseline for Generalized Category Discovery

This repo contains code for our paper: [A Simple Parametric Classification Baseline for Generalized Category Discovery](https://arxiv.org/abs/xxxx.xxxxx).

![teaser](assets/teaser.jpg)

Generalized category discovery (GCD) is a problem setting where the goal is to discover novel categories within an unlabelled dataset using the knowledge learned from a set of labelled samples.
Recent works in GCD argue that a non-parametric classifier formed using semi-supervised $k$-means can outperform strong baselines which use parametric classifiers as it can alleviate the over-fitting to seen categories in the labelled set.

In this paper, we revisit the reason that makes previous parametric classifiers fail to recognise new classes for GCD.
By investigating the design choices of parametric classifiers from the perspective of model architecture, representation learning, and classifier learning, we conclude that the less discriminative representations and unreliable pseudo-labelling strategy are key factors that make parametric classifiers lag behind non-parametric ones.
Motivated by our investigation, we present a simple yet effective parametric classification baseline that outperforms the previous best methods by a large margin on multiple popular GCD benchmarks.
We hope the investigations and the simple baseline can serve as a cornerstone to facilitate future studies.

## Running

### Dependencies

```
pip install -r requirements.txt
```

### Config

Set paths to datasets and desired log directories in ```config.py```


### Datasets

We use fine-grained benchmarks in this paper, including:

* [The Semantic Shift Benchmark (SSB)](https://github.com/sgvaze/osr_closed_set_all_you_need#ssb) and [Herbarium19](https://www.kaggle.com/c/herbarium-2019-fgvc6)

We also use generic object recognition datasets, including:

* [CIFAR-10/100](https://pytorch.org/vision/stable/datasets.html) and [ImageNet](https://image-net.org/download.php)


### Scripts

**Train the model**:

```
bash scripts/run_${DATASET_NAME}.sh
```

Then check the results starting with `Metrics with best model on test set:` in the logs.
This means the model is picked according to its performance on the test set, and then evaluated on the unlabelled instances of the train set.

## Results
Our results in three independent runs:

| Dataset | All | Old | New |
|:-------------: |:--------: |:--------: |:--------: |
| CIFAR10 | 93.2±0.4 | 82.0±1.2 | 98.9±0.0 |
| CIFAR100 | 78.1±0.8 | 77.6±1.5 | 78.0±2.5 |
| ImageNet-100 | 82.4±0.9 | 90.7±0.6 | 78.3±1.2 |
| CUB | 60.3±0.1 | 65.6±0.9 | 57.7±0.4 |
| Stanford Cars | 46.8±1.8 | 64.9±1.3 | 38.0±2.1 |
| FGVC-Aircraft | 48.8±2.2 | 51.0±2.2 | 47.8±2.7 |
| Herbarium 19 | 43.3±0.3 | 57.9±0.5 | 35.3±0.2 |

## Citing this work

If you find this repo useful for your research, please consider citing our paper:

```
@article{wen2022simgcd,
title={A Simple Parametric Classification Baseline for Generalized Category Discovery},
author={Wen, Xin and Zhao, Bingchen and Qi, Xiaojuan},
journal={arXiv preprint arXiv:2211.xxxxx},
year={2022}
}
```

## Acknowledgements

The codebase is largely built on this repo: https://github.com/sgvaze/generalized-category-discovery.

## License

This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
Binary file added assets/teaser.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
18 changes: 18 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# -----------------
# DATASET ROOTS
# -----------------
cifar_10_root = '${DATASET_DIR}/cifar10'
cifar_100_root = '${DATASET_DIR}/cifar100'
cub_root = '${DATASET_DIR}/cub'
aircraft_root = '${DATASET_DIR}/fgvc-aircraft-2013b'
car_root = '${DATASET_DIR}/cars'
herbarium_dataroot = '${DATASET_DIR}/herbarium_19'
imagenet_root = '${DATASET_DIR}/ImageNet'

# OSR Split dir
osr_split_dir = 'data/ssb_splits'

# -----------------
# OTHER PATHS
# -----------------
exp_root = 'dev_outputs' # All logs and checkpoints will be saved here
38 changes: 38 additions & 0 deletions data/augmentations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from torchvision import transforms

import torch

def get_transform(transform_type='imagenet', image_size=32, args=None):

if transform_type == 'imagenet':

mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)
interpolation = args.interpolation
crop_pct = args.crop_pct

train_transform = transforms.Compose([
transforms.Resize(int(image_size / crop_pct), interpolation),
transforms.RandomCrop(image_size),
transforms.RandomHorizontalFlip(p=0.5),
transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
])

test_transform = transforms.Compose([
transforms.Resize(int(image_size / crop_pct), interpolation),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize(
mean=torch.tensor(mean),
std=torch.tensor(std))
])

else:

raise NotImplementedError

return (train_transform, test_transform)
195 changes: 195 additions & 0 deletions data/cifar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from torchvision.datasets import CIFAR10, CIFAR100
from copy import deepcopy
import numpy as np

from data.data_utils import subsample_instances
from config import cifar_10_root, cifar_100_root


class CustomCIFAR10(CIFAR10):

def __init__(self, *args, **kwargs):

super(CustomCIFAR10, self).__init__(*args, **kwargs)

self.uq_idxs = np.array(range(len(self)))

def __getitem__(self, item):

img, label = super().__getitem__(item)
uq_idx = self.uq_idxs[item]

return img, label, uq_idx

def __len__(self):
return len(self.targets)


class CustomCIFAR100(CIFAR100):

def __init__(self, *args, **kwargs):
super(CustomCIFAR100, self).__init__(*args, **kwargs)

self.uq_idxs = np.array(range(len(self)))

def __getitem__(self, item):
img, label = super().__getitem__(item)
uq_idx = self.uq_idxs[item]

return img, label, uq_idx

def __len__(self):
return len(self.targets)


def subsample_dataset(dataset, idxs):

# Allow for setting in which all empty set of indices is passed

if len(idxs) > 0:

dataset.data = dataset.data[idxs]
dataset.targets = np.array(dataset.targets)[idxs].tolist()
dataset.uq_idxs = dataset.uq_idxs[idxs]

return dataset

else:

return None


def subsample_classes(dataset, include_classes=(0, 1, 8, 9)):

cls_idxs = [x for x, t in enumerate(dataset.targets) if t in include_classes]

target_xform_dict = {}
for i, k in enumerate(include_classes):
target_xform_dict[k] = i

dataset = subsample_dataset(dataset, cls_idxs)

# dataset.target_transform = lambda x: target_xform_dict[x]

return dataset


def get_train_val_indices(train_dataset, val_split=0.2):

train_classes = np.unique(train_dataset.targets)

# Get train/test indices
train_idxs = []
val_idxs = []
for cls in train_classes:

cls_idxs = np.where(train_dataset.targets == cls)[0]

v_ = np.random.choice(cls_idxs, replace=False, size=((int(val_split * len(cls_idxs))),))
t_ = [x for x in cls_idxs if x not in v_]

train_idxs.extend(t_)
val_idxs.extend(v_)

return train_idxs, val_idxs


def get_cifar_10_datasets(train_transform, test_transform, train_classes=(0, 1, 8, 9),
prop_train_labels=0.8, split_train_val=False, seed=0):

np.random.seed(seed)

# Init entire training set
whole_training_set = CustomCIFAR10(root=cifar_10_root, transform=train_transform, train=True)

# Get labelled training set which has subsampled classes, then subsample some indices from that
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)

# Split into training and validation sets
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
val_dataset_labelled_split.transform = test_transform

# Get unlabelled data
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))

# Get test set for all classes
test_dataset = CustomCIFAR10(root=cifar_10_root, transform=test_transform, train=False)

# Either split train into train and val or use test set as val
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None

all_datasets = {
'train_labelled': train_dataset_labelled,
'train_unlabelled': train_dataset_unlabelled,
'val': val_dataset_labelled,
'test': test_dataset,
}

return all_datasets


def get_cifar_100_datasets(train_transform, test_transform, train_classes=range(80),
prop_train_labels=0.8, split_train_val=False, seed=0):

np.random.seed(seed)

# Init entire training set
whole_training_set = CustomCIFAR100(root=cifar_100_root, transform=train_transform, train=True)

# Get labelled training set which has subsampled classes, then subsample some indices from that
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)

# Split into training and validation sets
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
val_dataset_labelled_split.transform = test_transform

# Get unlabelled data
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))

# Get test set for all classes
test_dataset = CustomCIFAR100(root=cifar_100_root, transform=test_transform, train=False)

# Either split train into train and val or use test set as val
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None

all_datasets = {
'train_labelled': train_dataset_labelled,
'train_unlabelled': train_dataset_unlabelled,
'val': val_dataset_labelled,
'test': test_dataset,
}

return all_datasets


if __name__ == '__main__':

x = get_cifar_100_datasets(None, None, split_train_val=False,
train_classes=range(80), prop_train_labels=0.5)

print('Printing lens...')
for k, v in x.items():
if v is not None:
print(f'{k}: {len(v)}')

print('Printing labelled and unlabelled overlap...')
print(set.intersection(set(x['train_labelled'].uq_idxs), set(x['train_unlabelled'].uq_idxs)))
print('Printing total instances in train...')
print(len(set(x['train_labelled'].uq_idxs)) + len(set(x['train_unlabelled'].uq_idxs)))

print(f'Num Labelled Classes: {len(set(x["train_labelled"].targets))}')
print(f'Num Unabelled Classes: {len(set(x["train_unlabelled"].targets))}')
print(f'Len labelled set: {len(x["train_labelled"])}')
print(f'Len unlabelled set: {len(x["train_unlabelled"])}')
Loading

0 comments on commit 1765452

Please sign in to comment.