Skip to content

Commit

Permalink
make crop size dependent on img size
Browse files Browse the repository at this point in the history
  • Loading branch information
untitled-author committed Jan 14, 2022
1 parent 867d6fa commit 2264731
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions datasets/ssl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.utils.data import sampler, DataLoader
from torch.utils.data.sampler import BatchSampler
import torch.distributed as dist
from datasets.DistributedProxySampler import DistributedProxySampler
# from datasets.DistributedProxySampler import DistributedProxySampler

import gc
import sys
Expand Down Expand Up @@ -180,7 +180,7 @@ def get_lb_test_data(self):
def get_transform(mean, std, crop_size, train=True):
if train:
return transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.RandomCrop(crop_size, padding=4, padding_mode='reflect'),
transforms.RandomCrop(crop_size, padding=int(crop_size * 0.125), padding_mode='reflect'),
transforms.ToTensor(),
transforms.Normalize(mean, std)])
else:
Expand All @@ -194,6 +194,11 @@ class SSL_Dataset:
separates labeled and unlabeled data,
and return BasicDataset: torch.utils.data.Dataset (see datasets.dataset.py)
"""
dataset2crop_size = {'STL10': 96,
'SVHN': 32,
'CIFAR10': 32,
'CIFAR100': 32,
'IMAGENET': 224}

def __init__(self,
args,
Expand All @@ -216,7 +221,8 @@ def __init__(self,
self.train = train
self.num_classes = num_classes
self.data_dir = data_dir
crop_size = 96 if self.name.upper() == 'STL10' else 224 if self.name.upper() == 'IMAGENET' else 32
crop_size = self.dataset2crop_size[self.name.upper()]
# crop_size = 96 if self.name.upper() == 'STL10' else 224 if self.name.upper() == 'IMAGENET' else 32
self.transform = get_transform(mean[name], std[name], crop_size, train)

def get_data(self, svhn_extra=True):
Expand Down

0 comments on commit 2264731

Please sign in to comment.