-
Notifications
You must be signed in to change notification settings - Fork 1
/
train_eval.py
121 lines (110 loc) · 4.91 KB
/
train_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#coding: UTF-8
import numpy as np
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from sklearn import metrics
import time
from utils import get_time_dif
from mindspore.common.initializer import initializer, XavierNormal, Normal, HeNormal, Constant
from mindspore.train import Model, CheckpointConfig, ModelCheckpoint, LossMonitor
class MyNet(nn.Cell):
def __init__(self):
super(MyNet, self).__init__()
# The following implements mindspore.nn.Cell.get_parameters() with MindSpore.
net = MyNet()
# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
for name, w in net.trainable_params():
if exclude not in name:
if 'weight' in name:
if method == 'xavier':
XavierNormal(w)
elif method == 'kaiming':
HeNormal(w)
else:
Normal(w)
elif 'bias' in name:
Constant(w, 0)
else:
pass
def train(config, model, train_iter, dev_iter, test_iter):
start_time = time.time()
model = Model(net)
optimizer = nn.Adam(net.trainable_params(), learning_rate=0.001)
# 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
# scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
total_batch = 0 # 记录进行到多少batch
dev_best_loss = float('inf')
last_improve = 0 # 记录上次验证集loss下降的batch数
flag = False # 记录是否很久没有效果提升
for epoch in range(config.num_epochs):
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
# scheduler.step() # 学习率衰减
for i, (trains, labels) in enumerate(train_iter):
outputs = model(trains)
model.zero_grad()
loss = ops.cross_entropy(outputs, labels)
grads = mindspore.grad(loss)
optimizer(grads)
if total_batch % 100 == 0:
# 每多少轮输出在训练集和验证集上的效果
true = labels.data.cpu()
predic = ops.max(outputs.data, 1)[1].cpu()
train_acc = metrics.accuracy_score(true, predic)
dev_acc, dev_loss = evaluate(config, model, dev_iter)
if dev_loss < dev_best_loss:
dev_best_loss = dev_loss
mindspore.save_checkpoint(model.state_dict(), config.save_path)
improve = '*'
last_improve = total_batch
else:
improve = ''
time_dif = get_time_dif(start_time)
msg = 'Iter: {0:>6}, Train Loss: {1:>5.2}, Train Acc: {2:>6.2%}, Val Loss: {3:>5.2}, Val Acc: {4:>6.2%}, Time: {5} {6}'
print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))
model.train()
total_batch += 1
if total_batch - last_improve > config.require_improvement:
# 验证集loss超过1000batch没下降,结束训练
print("No optimization for a long time, auto-stopping...")
flag = True
break
if flag:
break
test(config, model, test_iter)
def test(config, model, test_iter):
# test
param_dict = mindspore.load_checkpoint(config.save_path)
mindspore.load_param_into_net(model, param_dict)
model.eval()
start_time = time.time()
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
print(msg.format(test_loss, test_acc))
print("Precision, Recall and F1-Score...")
print(test_report)
print("Confusion Matrix...")
print(test_confusion)
time_dif = get_time_dif(start_time)
print("Time usage:", time_dif)
def evaluate(config, model, data_iter, test=False):
model.eval()
loss_total = 0
predict_all = np.array([], dtype=int)
labels_all = np.array([], dtype=int)
with ops.stop_gradient():
for texts, labels in data_iter:
outputs = model(texts)
loss = ops.cross_entropy(outputs, labels)
loss_total += loss
labels = labels.data.cpu().numpy()
predic = ops.max(outputs.data, 1)[1].cpu().numpy()
labels_all = np.append(labels_all, labels)
predict_all = np.append(predict_all, predic)
acc = metrics.accuracy_score(labels_all, predict_all)
if test:
report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
confusion = metrics.confusion_matrix(labels_all, predict_all)
return acc, loss_total / len(data_iter), report, confusion
return acc, loss_total / len(data_iter)