forked from YunseokJANG/l2l-da
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar_loader.py
176 lines (132 loc) · 6.77 KB
/
cifar_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
""" Code to build a cifar10 data loader """
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import cifar10.cifar_resnets as cifar_resnets
import cifar10.wide_resnets as wide_resnets
import utils.pytorch_utils as utils
import config
import os
import re
###############################################################################
# PARSE CONFIGS #
###############################################################################
DEFAULT_DATASETS_DIR = config.DEFAULT_DATASETS_DIR
RESNET_WEIGHT_PATH = config.MODEL_PATH
DEFAULT_BATCH_SIZE = config.DEFAULT_BATCH_SIZE
DEFAULT_WORKERS = config.DEFAULT_WORKERS
CIFAR10_MEANS = config.CIFAR10_MEANS
CIFAR10_STDS = config.CIFAR10_STDS
WIDE_CIFAR10_MEANS = config.WIDE_CIFAR10_MEANS
WIDE_CIFAR10_STDS = config.WIDE_CIFAR10_STDS
###############################################################################
# END PARSE CONFIGS #
###############################################################################
##############################################################################
# #
# MODEL LOADER #
# #
##############################################################################
def load_pretrained_cifar_resnet(flavor=32,
return_normalizer=False,
manual_gpu=None):
""" Helper fxn to initialize/load the pretrained cifar resnet
"""
# Resolve load path
valid_flavor_numbers = [110, 1202, 20, 32, 44, 56]
assert flavor in valid_flavor_numbers
weight_path = os.path.join(RESNET_WEIGHT_PATH,
'cifar10_resnet%s.th' % flavor)
# Resolve CPU/GPU stuff
if manual_gpu is not None:
use_gpu = manual_gpu
else:
use_gpu = utils.use_gpu()
if use_gpu:
map_location = None
else:
map_location = (lambda s, l: s)
# need to modify the resnet state dict to be proper
# TODO: LOAD THESE INTO MODEL ZOO
bad_state_dict = torch.load(weight_path, map_location=map_location)
correct_state_dict = {re.sub(r'^module\.', '', k): v for k, v in
bad_state_dict['state_dict'].items()}
classifier_net = eval("cifar_resnets.resnet%s" % flavor)()
classifier_net.load_state_dict(correct_state_dict)
if return_normalizer:
normalizer = utils.DifferentiableNormalize(mean=CIFAR10_MEANS,
std=CIFAR10_STDS)
return classifier_net, normalizer
return classifier_net
def load_pretrained_cifar_wide_resnet(use_gpu=False, return_normalizer=False):
""" Helper fxn to initialize/load a pretrained 28x10 CIFAR resnet """
weight_path = os.path.join(RESNET_WEIGHT_PATH,
'cifar10_wide-resnet28x10.th')
state_dict = torch.load(weight_path)
classifier_net = wide_resnets.Wide_ResNet(28, 10, 0, 10)
classifier_net.load_state_dict(state_dict)
if return_normalizer:
normalizer = utils.DifferentiableNormalize(mean=WIDE_CIFAR10_MEANS,
std=WIDE_CIFAR10_STDS)
return classifier_net, normalizer
return classifier_net
##############################################################################
# #
# DATA LOADER #
# #
##############################################################################
def load_cifar_data(train_or_val, extra_args=None, dataset_dir=None,
normalize=False, batch_size=None, manual_gpu=None,
shuffle=True, no_transform=False):
""" Builds a CIFAR10 data loader for either training or evaluation of
CIFAR10 data. See the 'DEFAULTS' section in the fxn for default args
ARGS:
train_or_val: string - one of 'train' or 'val' for whether we should
load training or validation datap
extra_args: dict - if not None is the kwargs to be passed to DataLoader
constructor
dataset_dir: string - if not None is a directory to load the data from
normalize: boolean - if True, we normalize the data by subtracting out
means and dividing by standard devs
manual_gpu : boolean or None- if None, we use the GPU if we can
else, we use the GPU iff this is True
shuffle: boolean - if True, we load the data in a shuffled order
no_transform: boolean - if True, we don't do any random cropping/
reflections of the data
"""
##################################################################
# DEFAULTS #
##################################################################
# dataset directory
dataset_dir = dataset_dir or DEFAULT_DATASETS_DIR
batch_size = batch_size or DEFAULT_BATCH_SIZE
# Extra arguments for DataLoader constructor
if manual_gpu is not None:
use_gpu = manual_gpu
else:
use_gpu = utils.use_gpu()
constructor_kwargs = {'batch_size': batch_size,
'shuffle': shuffle,
'num_workers': DEFAULT_WORKERS,
'pin_memory': use_gpu}
constructor_kwargs.update(extra_args or {})
# transform chain
transform_list = []
if no_transform is False:
transform_list.extend([transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4)])
transform_list.append(transforms.ToTensor())
if normalize:
normalizer = transforms.Normalize(mean=CIFAR10_MEANS,
std=CIFAR10_STDS)
transform_list.append(normalizer)
transform_chain = transforms.Compose(transform_list)
# train_or_val validation
assert train_or_val in ['train', 'val']
##################################################################
# Build DataLoader #
##################################################################
return torch.utils.data.DataLoader(
datasets.CIFAR10(root=dataset_dir, train=train_or_val=='train',
transform=transform_chain, download=True),
**constructor_kwargs)