Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
解决模型评估过程出现显存爆炸
  • Loading branch information
ice-tong authored Mar 14, 2020
1 parent 0df12e1 commit 0ef099d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def train():
optimizer.step()

acc = calculat_acc(output, target)
acc_history.append(acc)
loss_history.append(loss)
acc_history.append(float(acc))
loss_history.append(float(loss))
print('train_loss: {:.4}|train_acc: {:.4}'.format(
torch.mean(torch.Tensor(loss_history)),
torch.mean(torch.Tensor(acc_history)),
Expand All @@ -99,7 +99,7 @@ def train():
output = cnn(img)

acc = calculat_acc(output, target)
acc_history.append(acc)
acc_history.append(float(acc))
loss_history.append(float(loss))
print('test_loss: {:.4}|test_acc: {:.4}'.format(
torch.mean(torch.Tensor(loss_history)),
Expand All @@ -110,4 +110,4 @@ def train():

if __name__=="__main__":
train()
pass
pass

0 comments on commit 0ef099d

Please sign in to comment.