Skip to content

Commit

Permalink
Add torch.no_grad to test
Browse files Browse the repository at this point in the history
  • Loading branch information
kuangliu authored May 3, 2018
1 parent 7f16208 commit b1cf0f1
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,18 +108,19 @@ def test(epoch):
test_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)

test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))
with torch.no_grad():
for batch_idx, (inputs, targets) in enumerate(testloader):
inputs, targets = inputs.to(device), targets.to(device)
outputs = net(inputs)
loss = criterion(outputs, targets)

test_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()

progress_bar(batch_idx, len(testloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (test_loss/(batch_idx+1), 100.*correct/total, correct, total))

# Save checkpoint.
acc = 100.*correct/total
Expand Down

0 comments on commit b1cf0f1

Please sign in to comment.