Skip to content

Commit

Permalink
Example of MNIST with tqdm (pytorch#253)
Browse files Browse the repository at this point in the history
* Example of MNIST with tqdm

* Example of MNIST with tqdm--fixing flake8

* Change name from mnist-with-tqdm to mnist.  Adding metrics to tqdm bar

* Removing comments and __future__

* Delete mnist-with-tqdm.py

* Merge functions log_training_loss_start and log_training_loss_end into log_training_loss(ITERATION_COMPLETED)

* Simplified example--tqdm only for iterations

* Tqdm bar defined outside of the handler

* Change from pdbar to pbar, initial value set to zero
  • Loading branch information
rave78 authored and alykhantejani committed Sep 19, 2018
1 parent 2de993f commit 7fb9679
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
1 change: 1 addition & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ install:
# Examples dependencies
- pip install visdom torchvision tensorboardX
- pip install gym
- pip install tqdm

script:
- py.test --cov ignite --cov-report term-missing
Expand Down
29 changes: 22 additions & 7 deletions examples/mnist/mnist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from __future__ import print_function
from argparse import ArgumentParser

from torch import nn
Expand All @@ -12,6 +11,8 @@
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import CategoricalAccuracy, Loss

from tqdm import tqdm


class Net(nn.Module):
def __init__(self):
Expand Down Expand Up @@ -58,32 +59,46 @@ def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval):
'nll': Loss(F.nll_loss)},
device=device)

desc = "ITERATION - loss: {:.2f}"
pbar = tqdm(
initial=0, leave=False, total=len(train_loader),
desc=desc.format(0)
)

@trainer.on(Events.ITERATION_COMPLETED)
def log_training_loss(engine):
iter = (engine.state.iteration - 1) % len(train_loader) + 1

if iter % log_interval == 0:
print("Epoch[{}] Iteration[{}/{}] Loss: {:.2f}"
"".format(engine.state.epoch, iter, len(train_loader), engine.state.output))
pbar.desc = desc.format(engine.state.output)
pbar.update(log_interval)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(engine):
pbar.refresh()
evaluator.run(train_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll))
tqdm.write(
"Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll)
)

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(engine):
evaluator.run(val_loader)
metrics = evaluator.state.metrics
avg_accuracy = metrics['accuracy']
avg_nll = metrics['nll']
print("Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll))
tqdm.write(
"Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}"
.format(engine.state.epoch, avg_accuracy, avg_nll))

pbar.n = pbar.last_print_n = 0

trainer.run(train_loader, max_epochs=epochs)
pbar.close()


if __name__ == "__main__":
Expand Down

0 comments on commit 7fb9679

Please sign in to comment.