forked from keon/seq2seq
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
115 lines (100 loc) · 4.1 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import math
import argparse
import torch
from torch import optim
from torch.autograd import Variable
from torch.nn.utils import clip_grad_norm_
from torch.nn import functional as F
from model import Encoder, Decoder, Seq2Seq
from utils import load_dataset
def parse_arguments():
p = argparse.ArgumentParser(description='Hyperparams')
p.add_argument('-epochs', type=int, default=100,
help='number of epochs for train')
p.add_argument('-batch_size', type=int, default=32,
help='number of epochs for train')
p.add_argument('-lr', type=float, default=0.0001,
help='initial learning rate')
p.add_argument('-grad_clip', type=float, default=10.0,
help='in case of gradient explosion')
return p.parse_args()
def evaluate(model, val_iter, vocab_size, DE, EN):
with torch.no_grad():
model.eval()
pad = EN.vocab.stoi['<pad>']
total_loss = 0
for b, batch in enumerate(val_iter):
src, len_src = batch.src
trg, len_trg = batch.trg
src = src.data.cuda()
trg = trg.data.cuda()
output = model(src, trg, teacher_forcing_ratio=0.0)
loss = F.nll_loss(output[1:].view(-1, vocab_size),
trg[1:].contiguous().view(-1),
ignore_index=pad)
total_loss += loss.data.item()
return total_loss / len(val_iter)
def train(e, model, optimizer, train_iter, vocab_size, grad_clip, DE, EN):
model.train()
total_loss = 0
pad = EN.vocab.stoi['<pad>']
for b, batch in enumerate(train_iter):
src, len_src = batch.src
trg, len_trg = batch.trg
src, trg = src.cuda(), trg.cuda()
optimizer.zero_grad()
output = model(src, trg)
loss = F.nll_loss(output[1:].view(-1, vocab_size),
trg[1:].contiguous().view(-1),
ignore_index=pad)
loss.backward()
clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
total_loss += loss.data.item()
if b % 100 == 0 and b != 0:
total_loss = total_loss / 100
print("[%d][loss:%5.2f][pp:%5.2f]" %
(b, total_loss, math.exp(total_loss)))
total_loss = 0
def main():
args = parse_arguments()
hidden_size = 512
embed_size = 256
assert torch.cuda.is_available()
print("[!] preparing dataset...")
train_iter, val_iter, test_iter, DE, EN = load_dataset(args.batch_size)
de_size, en_size = len(DE.vocab), len(EN.vocab)
print("[TRAIN]:%d (dataset:%d)\t[TEST]:%d (dataset:%d)"
% (len(train_iter), len(train_iter.dataset),
len(test_iter), len(test_iter.dataset)))
print("[DE_vocab]:%d [en_vocab]:%d" % (de_size, en_size))
print("[!] Instantiating models...")
encoder = Encoder(de_size, embed_size, hidden_size,
n_layers=2, dropout=0.5)
decoder = Decoder(embed_size, hidden_size, en_size,
n_layers=1, dropout=0.5)
seq2seq = Seq2Seq(encoder, decoder).cuda()
optimizer = optim.Adam(seq2seq.parameters(), lr=args.lr)
print(seq2seq)
best_val_loss = None
for e in range(1, args.epochs+1):
train(e, seq2seq, optimizer, train_iter,
en_size, args.grad_clip, DE, EN)
val_loss = evaluate(seq2seq, val_iter, en_size, DE, EN)
print("[Epoch:%d] val_loss:%5.3f | val_pp:%5.2fS"
% (e, val_loss, math.exp(val_loss)))
# Save the model if the validation loss is the best we've seen so far.
if not best_val_loss or val_loss < best_val_loss:
print("[!] saving model...")
if not os.path.isdir(".save"):
os.makedirs(".save")
torch.save(seq2seq.state_dict(), './.save/seq2seq_%d.pt' % (e))
best_val_loss = val_loss
test_loss = evaluate(seq2seq, test_iter, en_size, DE, EN)
print("[TEST] loss:%5.2f" % test_loss)
if __name__ == "__main__":
try:
main()
except KeyboardInterrupt as e:
print("[STOP]", e)