Skip to content

Commit

Permalink
fixes for optim API
Browse files Browse the repository at this point in the history
  • Loading branch information
soumith committed Nov 8, 2016
1 parent 067fef5 commit 4f2a9b1
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions mnist/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(self, x):
criterion = nn.NLLLoss()

# Training settings
BATCH_SIZE = 150
BATCH_SIZE = 64
TEST_BATCH_SIZE = 1000
NUM_EPOCHS = 2

Expand All @@ -74,10 +74,13 @@ def train(epoch):
batch_data = Variable(batch_data_t, requires_grad=False)
batch_targets = Variable(batch_targets_t, requires_grad=False)
for i in range(0, training_data.size(0), BATCH_SIZE):
optimizer.zero_grad()
batch_data.data[:] = training_data[i:i+BATCH_SIZE]
batch_targets.data[:] = training_labels[i:i+BATCH_SIZE]
loss = optimizer.step(lambda: criterion(model(batch_data), batch_targets))
loss = criterion(model(batch_data), batch_targets)
loss.backward()
loss = loss.data[0]
optimizer.step()
print('Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.4f}'.format(epoch,
i+BATCH_SIZE, training_data.size(0),
float(i+BATCH_SIZE)/training_data.size(0)*100, loss))
Expand Down

0 comments on commit 4f2a9b1

Please sign in to comment.