-
Notifications
You must be signed in to change notification settings - Fork 91
/
Copy pathbase.py
168 lines (132 loc) · 5.38 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import os
import torch
import torch.optim as optim
import torch.nn as nn
from abc import abstractmethod
from tensorboardX import SummaryWriter
class BaseTrainer(object):
"""Base trainer that provides common training behavior.
All customized trainer should be subclass of this class.
"""
def __init__(self, cfg):
self.cfg = cfg
self.log_dir = cfg.log_dir
self.model_dir = cfg.model_dir
self.clock = TrainClock()
self.batch_size = cfg.batch_size
# build network
self.build_net(cfg)
# set loss function
self.set_loss_function()
# set optimizer
self.set_optimizer(cfg)
# set tensorboard writer
self.train_tb = SummaryWriter(os.path.join(self.log_dir, 'train.events'))
self.val_tb = SummaryWriter(os.path.join(self.log_dir, 'val.events'))
@abstractmethod
def build_net(self, cfg):
raise NotImplementedError
def set_loss_function(self):
"""set loss function used in training"""
pass
def set_optimizer(self, cfg):
"""set optimizer and lr scheduler used in training"""
self.optimizer = optim.Adam(self.net.parameters(), cfg.lr)
self.scheduler = optim.lr_scheduler.StepLR(self.optimizer, cfg.lr_step_size)
def save_ckpt(self, name=None):
"""save checkpoint during training for future restore"""
if name is None:
save_path = os.path.join(self.model_dir, "ckpt_epoch{}.pth".format(self.clock.epoch))
print("Saving checkpoint epoch {}...".format(self.clock.epoch))
else:
save_path = os.path.join(self.model_dir, "{}.pth".format(name))
if isinstance(self.net, nn.DataParallel):
model_state_dict = self.net.module.cpu().state_dict()
else:
model_state_dict = self.net.cpu().state_dict()
torch.save({
'clock': self.clock.make_checkpoint(),
'model_state_dict': model_state_dict,
'optimizer_state_dict': self.optimizer.state_dict(),
'scheduler_state_dict': self.scheduler.state_dict(),
}, save_path)
self.net.cuda()
def load_ckpt(self, name=None):
"""load checkpoint from saved checkpoint"""
name = name if name == 'latest' else "ckpt_epoch{}".format(name)
load_path = os.path.join(self.model_dir, "{}.pth".format(name))
if not os.path.exists(load_path):
raise ValueError("Checkpoint {} not exists.".format(load_path))
checkpoint = torch.load(load_path)
print("Loading checkpoint from {} ...".format(load_path))
if isinstance(self.net, nn.DataParallel):
self.net.module.load_state_dict(checkpoint['model_state_dict'])
else:
self.net.load_state_dict(checkpoint['model_state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
self.clock.restore_checkpoint(checkpoint['clock'])
@abstractmethod
def forward(self, data):
"""forward logic for your network"""
"""should return network outputs, losses(dict)"""
raise NotImplementedError
def update_network(self, loss_dict):
"""update network by back propagation"""
loss = sum(loss_dict.values())
self.optimizer.zero_grad()
loss.backward()
if self.cfg.grad_clip is not None:
nn.utils.clip_grad_norm_(self.net.parameters(), self.cfg.grad_clip)
self.optimizer.step()
def update_learning_rate(self):
"""record and update learning rate"""
self.train_tb.add_scalar('learning_rate', self.optimizer.param_groups[-1]['lr'], self.clock.epoch)
self.scheduler.step()
def record_losses(self, loss_dict, mode='train'):
"""record loss to tensorboard"""
losses_values = {k: v.item() for k, v in loss_dict.items()}
tb = self.train_tb if mode == 'train' else self.val_tb
for k, v in losses_values.items():
tb.add_scalar(k, v, self.clock.step)
def train_func(self, data):
"""one step of training"""
self.net.train()
outputs, losses = self.forward(data)
self.update_network(losses)
if self.clock.step % 10 == 0:
self.record_losses(losses, 'train')
return outputs, losses
def val_func(self, data):
"""one step of validation"""
self.net.eval()
with torch.no_grad():
outputs, losses = self.forward(data)
self.record_losses(losses, 'validation')
return outputs, losses
def visualize_batch(self, data, tb, **kwargs):
"""write visualization results to tensorboard writer"""
raise NotImplementedError
class TrainClock(object):
""" Clock object to track epoch and step during training
"""
def __init__(self):
self.epoch = 1
self.minibatch = 0
self.step = 0
def tick(self):
self.minibatch += 1
self.step += 1
def tock(self):
self.epoch += 1
self.minibatch = 0
def make_checkpoint(self):
return {
'epoch': self.epoch,
'minibatch': self.minibatch,
'step': self.step
}
def restore_checkpoint(self, clock_dict):
self.epoch = clock_dict['epoch']
self.minibatch = clock_dict['minibatch']
self.step = clock_dict['step']