Skip to content

Commit

Permalink
support multi-process training on ImageNet-1K
Browse files Browse the repository at this point in the history
  • Loading branch information
xwen99 committed Mar 15, 2023
1 parent 2391ad1 commit 4051103
Show file tree
Hide file tree
Showing 7 changed files with 491 additions and 48 deletions.
9 changes: 8 additions & 1 deletion data/get_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from data.cifar import get_cifar_10_datasets, get_cifar_100_datasets
from data.herbarium_19 import get_herbarium_datasets
from data.stanford_cars import get_scars_datasets
from data.imagenet import get_imagenet_100_datasets
from data.imagenet import get_imagenet_100_datasets, get_imagenet_1k_datasets
from data.cub import get_cub_datasets
from data.fgvc_aircraft import get_aircraft_datasets

Expand All @@ -18,6 +18,7 @@
'cifar10': get_cifar_10_datasets,
'cifar100': get_cifar_100_datasets,
'imagenet_100': get_imagenet_100_datasets,
'imagenet_1k': get_imagenet_1k_datasets,
'herbarium_19': get_herbarium_datasets,
'cub': get_cub_datasets,
'aircraft': get_aircraft_datasets,
Expand Down Expand Up @@ -106,6 +107,12 @@ def get_class_splits(args):
args.train_classes = range(50)
args.unlabeled_classes = range(50, 100)

elif args.dataset_name == 'imagenet_1k':

args.image_size = 224
args.train_classes = range(500)
args.unlabeled_classes = range(500, 1000)

elif args.dataset_name == 'scars':

args.image_size = 224
Expand Down
42 changes: 42 additions & 0 deletions data/imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,48 @@ def get_imagenet_100_datasets(train_transform, test_transform, train_classes=ran

return all_datasets


def get_imagenet_1k_datasets(train_transform, test_transform, train_classes=range(500),
prop_train_labels=0.5, split_train_val=False, seed=0):

np.random.seed(seed)

# Init entire training set
whole_training_set = ImageNetBase(root=os.path.join(imagenet_root, 'train'), transform=train_transform)

# Get labelled training set which has subsampled classes, then subsample some indices from that
train_dataset_labelled = subsample_classes(deepcopy(whole_training_set), include_classes=train_classes)
subsample_indices = subsample_instances(train_dataset_labelled, prop_indices_to_subsample=prop_train_labels)
train_dataset_labelled = subsample_dataset(train_dataset_labelled, subsample_indices)

# Split into training and validation sets
train_idxs, val_idxs = get_train_val_indices(train_dataset_labelled)
train_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), train_idxs)
val_dataset_labelled_split = subsample_dataset(deepcopy(train_dataset_labelled), val_idxs)
val_dataset_labelled_split.transform = test_transform

# Get unlabelled data
unlabelled_indices = set(whole_training_set.uq_idxs) - set(train_dataset_labelled.uq_idxs)
train_dataset_unlabelled = subsample_dataset(deepcopy(whole_training_set), np.array(list(unlabelled_indices)))

# Get test set for all classes
test_dataset = ImageNetBase(root=os.path.join(imagenet_root, 'val'), transform=test_transform)

# Either split train into train and val or use test set as val
train_dataset_labelled = train_dataset_labelled_split if split_train_val else train_dataset_labelled
val_dataset_labelled = val_dataset_labelled_split if split_train_val else None

all_datasets = {
'train_labelled': train_dataset_labelled,
'train_unlabelled': train_dataset_unlabelled,
'val': val_dataset_labelled,
'test': test_dataset,
}

return all_datasets



if __name__ == '__main__':

x = get_imagenet_100_datasets(None, None, split_train_val=False,
Expand Down
41 changes: 40 additions & 1 deletion model/loss.py → model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,46 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn import functional as F

class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True,
nlayers=3, hidden_dim=2048, bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
elif nlayers != 0:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False

def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)

def forward(self, x):
x_proj = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
# x = x.detach()
logits = self.last_layer(x)
return x_proj, logits


class ContrastiveLearningViewGenerator(object):
"""Take two random crops of one image as the query and key."""
Expand Down
41 changes: 1 addition & 40 deletions model/simgcd.py → train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,46 +14,7 @@
from util.general_utils import AverageMeter, init_experiment
from util.cluster_and_log_utils import log_accs_from_preds
from config import exp_root
from model.loss import info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups


class DINOHead(nn.Module):
def __init__(self, in_dim, out_dim, use_bn=False, norm_last_layer=True,
nlayers=3, hidden_dim=2048, bottleneck_dim=256):
super().__init__()
nlayers = max(nlayers, 1)
if nlayers == 1:
self.mlp = nn.Linear(in_dim, bottleneck_dim)
elif nlayers != 0:
layers = [nn.Linear(in_dim, hidden_dim)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim))
self.mlp = nn.Sequential(*layers)
self.apply(self._init_weights)
self.last_layer = nn.utils.weight_norm(nn.Linear(in_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
if norm_last_layer:
self.last_layer.weight_g.requires_grad = False

def _init_weights(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)

def forward(self, x):
x_proj = self.mlp(x)
x = nn.functional.normalize(x, dim=-1, p=2)
# x = x.detach()
logits = self.last_layer(x)
return x_proj, logits
from model import DINOHead, info_nce_logits, SupConLoss, DistillLoss, ContrastiveLearningViewGenerator, get_params_groups


def train(student, train_loader, test_loader, unlabelled_train_loader, args):
Expand Down
Loading

0 comments on commit 4051103

Please sign in to comment.