-
Notifications
You must be signed in to change notification settings - Fork 17
/
train.py
241 lines (181 loc) · 7.33 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# encoding = utf-8
"""
基于Torch-Template的一个baseline。
如何添加新的模型:
① 复制network目录下的Default文件夹,改成另外一个名字(比如MyNet)。
② 在network/__init__.py中import你的Model并且在models = {}中添加它。
from MyNet.Model import Model as MyNet
models = {
'default': Default,
'MyNet': MyNet,
}
③ 尝试 python train.py --model MyNet 看能否成功运行
File Structure:
AliProducts
├── train.py :Train and evaluation loop, errors and outputs visualization (Powered by TensorBoard)
├── eval.py :Evaluation and test (with visualization)
├── test.py :Test
│
├── clear.py :Clear cache, be CAREFUL to use it
│
├── run_log.txt :Record your command logs (except --tag cache)
│
├── network
│ ├── __init__.py :Declare all models here so that `--model` can work properly
│ ├── Default
│ │ ├── Model.py :Define default model, losses and parameter updating procedure
│ │ └── res101.py
│ └── MyNet
│ ├── Model.py :Define your model, losses and parameter updating procedure
│ └── mynet.py
│
├── options
│ └── options.py :Define options
│
│
├── dataloader/ :Define Dataloaders
│ ├── __init__.py :imports all dataloaders in dataloaders.py
│ ├── dataloaders.py :Define all dataloaders here
│ └── products.py :Custom Dataset
│
├── checkpoints/<tag> :Trained checkpoints
├── logs/<tag> :Logs and TensorBoard event files
└── results/<tag> :Test results
Datasets:
datasets
├── train
│ ├── 00001
│ ├── 00002
│ └── .....
├── val
│ ├── 00001
│ ├── 00002
│ └── .....
├── train.json
├── val.json
└── product_tree.json
Usage:
#### Train
python train.py --tag train_1 --epochs 500 -b 8 --gpu 1
#### Resume or Fine Tune
python train.py --load checkpoints/train_1 --which-epoch 500
#### Evaluation
python eval.py --tag eval_1 --model MyNet --load checkpoints/MyNet --which-epoch 499
#### Test
python test.py --tag test_1
#### Clear
python clear.py [--tag cache] # (DO NOT use this command unless you know what you are doing.)
License: MIT
"""
import os
import pdb
import time
import numpy as np
from collections.abc import Iterable
import torch
from torch import optim
from torch.autograd import Variable
import dataloader as dl
from options import opt
from network import get_model
from eval import evaluate
from utils import *
# from torch_template.utils.torch_utils import create_summary_writer, write_meters_loss, LR_Scheduler
from utils.torch_utils import create_summary_writer, write_meters_loss, write_image
# from utils.send_sms import send_notification
import misc_utils as utils
######################
# Paths
######################
save_root = os.path.join(opt.checkpoint_dir, opt.tag)
log_root = os.path.join(opt.log_dir, opt.tag)
utils.try_make_dir(save_root)
utils.try_make_dir(log_root)
train_dataloader = dl.train_dataloader
val_dataloader = dl.val_dataloader
# init log
logger = init_log(training=True)
######################
# Init model
######################
Model = get_model(opt.model)
model = Model(opt)
# if len(opt.gpu_ids):
# model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
model = model.to(device=opt.device)
if opt.load:
load_epoch = model.load(opt.load)
start_epoch = load_epoch + 1 if opt.resume and not opt.sampler else 1
else:
start_epoch = 1
model.train()
# Start training
print('Start training...')
start_step = (start_epoch - 1) * len(train_dataloader)
global_step = start_step
total_steps = opt.epochs * len(train_dataloader)
start = time.time()
#####################
# 定义scheduler
#####################
optimizer = model.optimizer
scheduler = model.scheduler
######################
# Summary_writer
######################
writer = create_summary_writer(log_root)
start_time = time.time()
######################
# Train loop
######################
try:
eval_result = ''
for epoch in range(start_epoch, opt.epochs + 1):
for iteration, data in enumerate(train_dataloader):
global_step += 1
rate = (global_step - start_step) / (time.time() - start)
remaining = (total_steps - global_step) / rate
img, label = data['input'], data['label'] # ['label'], data['image'] #
img_var = Variable(img, requires_grad=False).to(device=opt.device)
label_var = Variable(label, requires_grad=False).to(device=opt.device)
##############################
# Update parameters
##############################
update = model.update(img_var, label_var)
predicted = update.get('predicted')
pre_msg = 'Epoch:%d' % epoch
msg = f'lr:{round(scheduler.get_lr()[0], 6) : .6f} (loss) {str(model.avg_meters)} ETA: {utils.format_time(remaining)}'
utils.progress_bar(iteration, len(train_dataloader), pre_msg, msg)
# print(pre_msg, msg)
if global_step % 1000 == 999:
write_meters_loss(writer, 'train', model.avg_meters, global_step)
mini_freq = 100000
if opt.sampler and global_step % mini_freq == 0:
print()
mini_epoch = global_step // mini_freq + 1
model.save(mini_epoch)
eval_result = evaluate(model, val_dataloader, mini_epoch, writer, logger)
logger.info(f'Train epoch: {epoch}, lr: {round(scheduler.get_lr()[0], 6) : .6f}, (loss) ' + str(model.avg_meters))
if epoch % opt.save_freq == 0 or epoch == opt.epochs: # 最后一个epoch要保存一下
model.save(epoch)
####################
# Validation
####################
if epoch % opt.eval_freq == 0:
model.eval()
eval_result = evaluate(model, val_dataloader, epoch, writer, logger)
model.train()
if scheduler is not None:
scheduler.step()
# send_notification([opt.tag[:12], '', '', eval_result])
if opt.tag != 'cache':
with open('run_log.txt', 'a') as f:
f.writelines(' Accuracy:' + eval_result + '\n')
except Exception as e:
# if not opt.debug: # debug模式不会发短信 12是短信模板字数限制
# send_notification([opt.tag[:12], str(e)[:12]], template='error')
if opt.tag != 'cache':
with open('run_log.txt', 'a') as f:
f.writelines(' Error: ' + str(e)[:120] + '\n')
# print(e)
raise Exception('Error') # 再引起一个异常,这样才能打印之前的trace back信息