forked from CMU-CREATE-Lab/deep-smoke-machine
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbase_learner.py
201 lines (184 loc) · 8.16 KB
/
base_learner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import os
import logging
import absl.logging
import logging.handlers
from util import check_and_create_dir
from collections import OrderedDict
from torchvision.transforms import Compose
from video_transforms import RandomResizedCrop, RandomHorizontalFlip, ColorJitter, RandomPerspective, RandomErasing, Resize, Normalize, ToTensor
class RequestFormatter(logging.Formatter):
def format(self, record):
return super().format(record)
class Reshape(nn.Module):
def __init__(self, shape):
super().__init__()
self.shape = shape
def forward(self, input):
"""
Reshapes the input according to the shape saved in the view data structure.
"""
batch_size = input.size(0)
shape = (batch_size, *self.shape)
out = input.reshape(shape)
return out
"""
Base PyTorch learners
Usage:
from base_pytorch_learner import BasePyTorchLearner
class Learner(BasePyTorchLearner):
def __init__(self):
super().__init__()
self.create_logger(log_path="../log/Learner.log")
def fit(self, Xt, Yt, Xv=None, Yv=None):
pass
def predict(self, X):
pass
"""
class BaseLearner(ABC):
def __init__(self, use_cuda=None):
self.logger = None
if use_cuda is None:
if torch.cuda.is_available:
self.use_cuda = True
else:
self.use_cuda = False
else:
if use_cuda is True and torch.cuda.is_available:
self.use_cuda = True
else:
self.use_cuda = False
# Train the model
# Output: None
@abstractmethod
def fit(self):
pass
# Test the model
# Output: None
@abstractmethod
def test(self):
pass
# Save model
def save(self, model, out_path):
if model is not None and out_path is not None:
self.log("Save model weights to " + out_path)
try:
state_dict = model.module.state_dict() # nn.DataParallel model
except AttributeError:
state_dict = model.state_dict() # single GPU model
check_and_create_dir(out_path)
torch.save(state_dict, out_path)
# Load model
def load(self, model, in_path, ignore_fc=False, fill_dim=False):
if model is not None and in_path is not None:
self.log("Load model weights from " + in_path)
sd_loaded = torch.load(in_path)
if "state_dict" in sd_loaded:
sd_loaded = sd_loaded["state_dict"]
sd_model = model.state_dict()
replace_dict = []
for k, v in sd_loaded.items():
if k not in sd_model and k.replace(".net", "") in sd_model:
print("Load after remove .net: ", k)
replace_dict.append((k, k.replace(".net", "")))
for k, v in sd_model.items():
if k not in sd_loaded and k.replace(".net", "") in sd_loaded:
print("Load after adding .net: ", k)
replace_dict.append((k.replace(".net", ""), k))
for k, k_new in replace_dict:
sd_loaded[k_new] = sd_loaded.pop(k)
keys1 = set(list(sd_loaded.keys()))
keys2 = set(list(sd_model.keys()))
set_diff = (keys1 - keys2) | (keys2 - keys1)
#print('#### Notice: keys that failed to load: {}'.format(set_diff))
if ignore_fc:
print("Ignore fully connected layer weights")
sd_loaded = {k: v for k, v in sd_model.items() if "fc" not in k}
if fill_dim:
# Note that this only works for the Inception-v1 I3D model
print("Auto-fill the mismatched dimension for the i3d model...")
for name, param in model.named_parameters():
if param.requires_grad:
if param.data.size() != sd_loaded[name].size():
print("\t Found dimension mismatch for:", name)
ds = param.data.size()
ls = sd_loaded[name].size()
print("\t\t Desired data size:", param.data.size())
print("\t\t Loaded data size:", sd_loaded[name].size())
for i in range(len(ds)):
diff = ds[i] - ls[i]
if diff > 0:
print("\t\t\t Desired dimension %d is larger than the loaded dimension" % i)
m = sd_loaded[name].mean(i).unsqueeze(i)
print("\t\t\t Compute the missing dimension to have size:", m.size())
sd_loaded[name] = torch.cat([sd_loaded[name], m], i)
print("\t\t\t Loaded data are filled to have size:", sd_loaded[name].size())
sd_model.update(sd_loaded)
try:
model.load_state_dict(sd_model)
except:
self.log("Weights were from nn.DataParallel or DistributedDataParallel...")
self.log("Remove 'module.' prefix from state_dict keys...")
new_state_dict = OrderedDict()
for k, v in sd_model.items():
new_state_dict[k.replace("module.", "")] = v
model.load_state_dict(new_state_dict)
# Log information
def log(self, msg, lv="i"):
print(msg)
if self.logger is not None:
if lv == "i":
self.logger.info(msg)
elif lv == "w":
self.logger.warning(msg)
elif lv == "e":
self.logger.error(msg)
# Data augmentation pipeline
def get_transform(self, mode, phase=None, image_size=224):
if mode == "rgb": # three channels (r, g, and b)
mean = (127.5, 127.5, 127.5)
std = (127.5, 127.5, 127.5)
elif mode == "flow": # two channels (x and y)
mean = (127.5, 127.5)
std = (127.5, 127.5)
elif mode == "rgbd": # four channels (r, g, b, and dark channel)
mean = (127.5, 127.5, 127.5, 127.5)
std = (127.5, 127.5, 127.5, 127.5)
else:
return None
nm = Normalize(mean=mean, std=std) # same as (img/255)*2-1
tt = ToTensor()
if phase == "train":
# Deals with small camera shifts, zoom changes, and rotations due to wind or maintenance
rrc = RandomResizedCrop(image_size, scale=(0.9, 1), ratio=(3./4., 4./3.))
rp = RandomPerspective(anglex=3, angley=3, anglez=3, shear=3)
# Improve generalization
rhf = RandomHorizontalFlip(p=0.5)
# Deal with dirts, ants, or spiders on the camera lense
re = RandomErasing(p=0.5, scale=(0.003, 0.01), ratio=(0.3, 3.3), value=0)
if mode == "rgb" or mode == "rgbd":
# Color jitter deals with different lighting and weather conditions
cj = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=(-0.1, 0.1), gamma=0.3)
return Compose([cj, rrc, rp, rhf, tt, nm, re, re])
elif mode == "flow":
return Compose([rrc, rp, rhf, tt, nm, re, re])
else:
return Compose([Resize(image_size), tt, nm])
# Create a logger
def create_logger(self, log_path=None):
if log_path is None:
return None
check_and_create_dir(log_path)
handler = logging.handlers.RotatingFileHandler(log_path, mode="a", maxBytes=100000000, backupCount=200)
logging.root.removeHandler(absl.logging._absl_handler) # this removes duplicated logging
absl.logging._warn_preinit_stderr = False # this removes duplicated logging
formatter = RequestFormatter("[%(asctime)s] %(levelname)s: %(message)s")
handler.setFormatter(formatter)
logger = logging.getLogger(log_path)
logger.setLevel(logging.INFO)
for hdlr in logger.handlers[:]:
logger.removeHandler(hdlr) # remove old handlers
logger.addHandler(handler)
self.logger = logger