diff --git a/mnist_hogwild/train.py b/mnist_hogwild/train.py index 990a2e9aa7..211285fb7d 100644 --- a/mnist_hogwild/train.py +++ b/mnist_hogwild/train.py @@ -51,12 +51,11 @@ def test_epoch(model, data_loader): for data, target in data_loader: data, target = Variable(data, volatile=True), Variable(target) output = model(data) - test_loss += F.nll_loss(output, target).data[0] + test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss pred = output.data.max(1)[1] # get the index of the max log-probability correct += pred.eq(target.data).cpu().sum() - test_loss = test_loss - test_loss /= len(data_loader) # loss function already averages over batch size + test_loss /= len(test_loader.dataset) print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format( test_loss, correct, len(data_loader.dataset), 100. * correct / len(data_loader.dataset)))