forked from pytorch/examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
108 lines (92 loc) · 4.29 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import time
import glob
import torch
import torch.optim as O
import torch.nn as nn
from torchtext import data
from torchtext import datasets
from model import SNLIClassifier
from util import get_args
args = get_args()
torch.cuda.set_device(args.gpu)
inputs = data.Field(lower=args.lower)
answers = data.Field(sequential=False)
train, dev, test = datasets.SNLI.splits(inputs, answers)
inputs.build_vocab(train, dev, test)
if args.word_vectors:
if os.path.isfile(args.vector_cache):
inputs.vocab.vectors = torch.load(args.vector_cache)
else:
inputs.vocab.load_vectors(wv_dir=args.data_cache, wv_type=args.word_vectors, wv_dim=args.d_embed)
os.makedirs(os.path.dirname(args.vector_cache), exist_ok=True)
torch.save(inputs.vocab.vectors, args.vector_cache)
answers.build_vocab(train)
train_iter, dev_iter, test_iter = data.BucketIterator.splits(
(train, dev, test), batch_size=args.batch_size, device=args.gpu)
config = args
config.n_embed = len(inputs.vocab)
config.d_out = len(answers.vocab)
config.n_cells = config.n_layers
if config.birnn:
config.n_cells *= 2
if args.resume_snapshot:
model = torch.load(args.resume_snapshot, map_location=lambda storage, locatoin: storage.cuda(args.gpu))
else:
model = SNLIClassifier(config)
if args.word_vectors:
model.embed.weight.data = inputs.vocab.vectors
model.cuda()
criterion = nn.CrossEntropyLoss()
opt = O.Adam(model.parameters(), lr=args.lr)
iterations = 0
start = time.time()
best_dev_acc = -1
train_iter.repeat = False
header = ' Time Epoch Iteration Progress (%Epoch) Loss Dev/Loss Accuracy Dev/Accuracy'
dev_log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
log_template = ' '.join('{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{},{:12.4f},{}'.split(','))
os.makedirs(args.save_path, exist_ok=True)
print(header)
for epoch in range(args.epochs):
train_iter.init_epoch()
n_correct, n_total = 0, 0
for batch_idx, batch in enumerate(train_iter):
model.train(); opt.zero_grad()
iterations += 1
answer = model(batch)
n_correct += (torch.max(answer, 1)[1].view(batch.label.size()).data == batch.label.data).sum()
n_total += batch.batch_size
train_acc = 100. * n_correct/n_total
loss = criterion(answer, batch.label)
loss.backward(); opt.step()
if iterations % args.save_every == 0:
snapshot_prefix = os.path.join(args.save_path, 'snapshot')
snapshot_path = snapshot_prefix + '_acc_{:.4f}_loss_{:.6f}_iter_{}_model.pt'.format(train_acc, loss.data[0], iterations)
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)
if iterations % args.dev_every == 0:
model.eval(); dev_iter.init_epoch()
n_dev_correct, dev_loss = 0, 0
for dev_batch_idx, dev_batch in enumerate(dev_iter):
answer = model(dev_batch)
n_dev_correct += (torch.max(answer, 1)[1].view(dev_batch.label.size()).data == dev_batch.label.data).sum()
dev_loss = criterion(answer, dev_batch.label)
dev_acc = 100. * n_dev_correct / len(dev)
print(dev_log_template.format(time.time()-start,
epoch, iterations, 1+batch_idx, len(train_iter),
100. * (1+batch_idx) / len(train_iter), loss.data[0], dev_loss.data[0], train_acc, dev_acc))
if dev_acc > best_dev_acc:
best_dev_acc = dev_acc
snapshot_prefix = os.path.join(args.save_path, 'best_snapshot')
snapshot_path = snapshot_prefix + '_devacc_{}_devloss_{}__iter_{}_model.pt'.format(dev_acc, dev_loss.data[0], iterations)
torch.save(model, snapshot_path)
for f in glob.glob(snapshot_prefix + '*'):
if f != snapshot_path:
os.remove(f)
elif iterations % args.log_every == 0:
print(log_template.format(time.time()-start,
epoch, iterations, 1+batch_idx, len(train_iter),
100. * (1+batch_idx) / len(train_iter), loss.data[0], ' '*8, n_correct/n_total*100, ' '*12))