Skip to content

Commit

Permalink
accuracy now an average over log_interval batches
Browse files Browse the repository at this point in the history
  • Loading branch information
bmccann authored and soumith committed Mar 14, 2017
1 parent 39bb701 commit 0e77a0b
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions OpenNMT/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,8 @@ def trainEpoch(epoch):
# shuffle mini batch order
batchOrder = torch.randperm(len(trainData))

total_loss, report_loss = 0, 0
total_words, report_tgt_words, report_src_words = 0, 0, 0
total_num_correct = 0
total_loss, total_words, total_num_correct = 0
report_loss, report_tgt_words, report_src_words, report_num_correct = 0
start = time.time()
for i in range(len(trainData)):

Expand All @@ -206,23 +205,24 @@ def trainEpoch(epoch):
# update the parameters
optim.step()

report_loss += loss
total_num_correct += num_correct
total_loss += loss
num_words = targets.data.ne(onmt.Constants.PAD).sum()
total_words += num_words
report_loss += loss
report_num_correct += num_correct
report_tgt_words += num_words
report_src_words += batch[0].data.ne(onmt.Constants.PAD).sum()
total_loss += loss
total_num_correct += num_correct
total_words += num_words
if i % opt.log_interval == -1 % opt.log_interval:
print("Epoch %2d, %5d/%5d; acc: %6.2f; ppl: %6.2f; %3.0f src tok/s; %3.0f tgt tok/s; %6.0f s elapsed" %
(epoch, i, len(trainData),
num_correct / num_words * 100,
report_num_correct / report_tgt_words * 100,
math.exp(report_loss / report_tgt_words),
report_src_words/(time.time()-start),
report_tgt_words/(time.time()-start),
time.time()-start_time))

report_loss = report_tgt_words = report_src_words = 0
report_loss = report_tgt_words = report_src_words = report_num_correct = 0
start = time.time()

return total_loss / total_words, total_num_correct / total_words
Expand Down

0 comments on commit 0e77a0b

Please sign in to comment.