Skip to content

Commit

Permalink
fix bugs in generalization error calculation (pytorch#179)
Browse files Browse the repository at this point in the history
  • Loading branch information
chao1224 authored and soumith committed Jul 7, 2017
1 parent cab5705 commit 1b26501
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions mnist_hogwild/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,11 @@ def test_epoch(model, data_loader):
for data, target in data_loader:
data, target = Variable(data, volatile=True), Variable(target)
output = model(data)
test_loss += F.nll_loss(output, target).data[0]
test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()

test_loss = test_loss
test_loss /= len(data_loader) # loss function already averages over batch size
test_loss /= len(test_loader.dataset)
print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
test_loss, correct, len(data_loader.dataset),
100. * correct / len(data_loader.dataset)))

0 comments on commit 1b26501

Please sign in to comment.