Skip to content

Commit

Permalink
fix mnist_with_visdom + update logging API (pytorch#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
alykhantejani authored Feb 9, 2018
1 parent daa9b80 commit 6b4c0e8
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 106 deletions.
61 changes: 39 additions & 22 deletions examples/mnist_with_visdom.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@
from torch.optim import SGD
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize
import visdom

from ignite.trainer import Trainer, TrainingEvents
from ignite.handlers.logging import log_training_simple_moving_average
try:
import visdom
except ImportError:
raise RuntimeError("No visdom package is found. Please install it with command: \n pip install visdom")

from ignite.trainer import Trainer
from ignite.evaluator import Evaluator
from ignite.engine import Events
from ignite.handlers.evaluate import Evaluate
from ignite.handlers.logging import log_simple_moving_average
import numpy as np


Expand Down Expand Up @@ -46,9 +53,10 @@ def get_plot_training_loss_handler(vis, plot_every):
def plot_training_loss_to_visdom(trainer):
if trainer.current_iteration % plot_every == 0:
vis.line(X=np.array([trainer.current_iteration]),
Y=np.array([trainer.training_history.simple_moving_average(window_size=100)]),
Y=np.array([trainer.history.simple_moving_average(window_size=100)]),
win=train_loss_plot_window,
update='append')

return plot_training_loss_to_visdom


Expand All @@ -60,24 +68,26 @@ def get_plot_validation_accuracy_handler(vis):
title='Validation Accuracy')
)

def plot_val_accuracy_to_visdom(trainer):
accuracy = sum([accuracy for (loss, accuracy) in trainer.validation_history])
accuracy = (accuracy * 100.) / len(trainer.validation_data.dataset)
def plot_val_accuracy_to_visdom(evaluator, trainer):
accuracy = sum([accuracy for (loss, accuracy) in evaluator.history])
accuracy = (accuracy * 100.) / len(evaluator.dataloader.dataset)
vis.line(X=np.array([trainer.current_epoch]),
Y=np.array([accuracy]),
win=val_accuracy_plot_window,
update='append')

return plot_val_accuracy_to_visdom


def get_log_validation_loss_and_accuracy_handler(logger):
def log_validation_loss_and_accuracy(trainer):
avg_loss = np.mean([loss for (loss, accuracy) in trainer.validation_history])
accuracy = sum([accuracy for (loss, accuracy) in trainer.validation_history])
def log_validation_loss_and_accuracy(evaluator):
avg_loss = np.mean([loss for (loss, accuracy) in evaluator.history])
accuracy = sum([accuracy for (loss, accuracy) in evaluator.history])
logger('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
avg_loss, accuracy, len(trainer.validation_data.dataset),
(accuracy * 100.) / len(trainer.validation_data.dataset)
avg_loss, accuracy, len(evaluator.dataloader.dataset),
(accuracy * 100.) / len(evaluator.dataloader.dataset)
))

return log_validation_loss_and_accuracy


Expand Down Expand Up @@ -109,28 +119,35 @@ def training_update_function(batch):

def validation_inference_function(batch):
model.eval()
data, target = Variable(batch[0]), Variable(batch[1])
data, target = Variable(batch[0], volatile=True), Variable(batch[1])
output = model(data)
loss = F.nll_loss(output, target, size_average=False).data[0]
pred = output.data.max(1, keepdim=True)[1]
correct = pred.eq(target.data.view_as(pred)).sum()
return loss, correct

trainer = Trainer(train_loader, training_update_function, val_loader, validation_inference_function)
trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
log_training_simple_moving_average,
trainer = Trainer(training_update_function)
evaluator = Evaluator(validation_inference_function)
run_evaluation = Evaluate(evaluator, val_loader, epoch_interval=1)

# trainer event handlers
trainer.add_event_handler(Events.ITERATION_COMPLETED,
log_simple_moving_average,
window_size=100,
metric_name="NLL",
should_log=lambda trainer: trainer.current_iteration % log_interval == 0,
logger=logger)

trainer.add_event_handler(TrainingEvents.TRAINING_ITERATION_COMPLETED,
trainer.add_event_handler(Events.ITERATION_COMPLETED,
get_plot_training_loss_handler(vis, plot_every=log_interval))
trainer.add_event_handler(Events.EPOCH_COMPLETED, run_evaluation)

# evaluator event handlers
evaluator.add_event_handler(Events.STARTED, lambda evaluator: evaluator.history.clear())
evaluator.add_event_handler(Events.COMPLETED, get_log_validation_loss_and_accuracy_handler(logger))
evaluator.add_event_handler(Events.COMPLETED, get_plot_validation_accuracy_handler(vis), trainer)

trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED, get_log_validation_loss_and_accuracy_handler(logger))
trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED, get_plot_validation_accuracy_handler(vis))
trainer.add_event_handler(TrainingEvents.VALIDATION_COMPLETED, lambda trainer: trainer.validation_history.clear())
trainer.run(max_epochs=epochs, validate_every_epoch=True)
# kick everything off
trainer.run(train_loader, max_epochs=epochs)


if __name__ == "__main__":
Expand Down
6 changes: 3 additions & 3 deletions ignite/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,11 +106,11 @@ def terminate(self):
self._logger.info("Terminate signaled. Engine will stop after current iteration is finished")
self.should_terminate = True

def _run_once_on_dataset(self, dataset):
self.dataset = dataset
def _run_once_on_dataset(self, dataloader):
self.dataloader = dataloader
try:
start_time = time.time()
for batch in dataset:
for batch in dataloader:
self.current_iteration += 1
self._fire_event(Events.ITERATION_STARTED)
step_result = self._process_function(batch)
Expand Down
6 changes: 3 additions & 3 deletions ignite/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ def add_event_handler(self, event_name, handler, *args, **kwargs):

super(Evaluator, self).add_event_handler(event_name, handler, *args, **kwargs)

def run(self, dataset):
self.dataset = dataset
def run(self, data):
self.dataloader = data
self.current_iteration = 0
self._fire_event(Events.STARTED)
hours, mins, secs = self._run_once_on_dataset(dataset)
hours, mins, secs = self._run_once_on_dataset(data)
self._logger.info("Evaluation Complete. Time taken: %02d:%02d:%02d", hours, mins, secs)
self._fire_event(Events.COMPLETED)

Expand Down
127 changes: 50 additions & 77 deletions ignite/handlers/logging.py
Original file line number Diff line number Diff line change
@@ -1,79 +1,52 @@
from __future__ import print_function


def log_training_simple_moving_average(trainer, window_size, history_transform=lambda x: x,
should_log=lambda trainer: True, metric_name="", logger=print):
if should_log(trainer):
iterations_per_epoch = len(trainer.training_data)
current_iteration = trainer.current_iteration % iterations_per_epoch
log_str = "Training Epoch[{}/{}] Iteration[{}/{} ({:.2f}%)]\t{}Simple Moving Average: {:.4f}" \
.format(trainer.current_epoch, trainer.max_epochs, current_iteration,
iterations_per_epoch, (100. * current_iteration) / iterations_per_epoch,
metric_name + " ",
trainer.training_history.simple_moving_average(window_size, history_transform))
logger(log_str)


def log_validation_simple_moving_average(trainer, window_size, history_transform=lambda x: x,
should_log=lambda trainer: True, metric_name="", logger=print):
if should_log(trainer):
total_iterations = len(trainer.validation_data)
current_iteration = trainer.current_iteration % total_iterations
log_str = "Validation Iteration[{}/{} ({:.2f}%)]\t{}Simple Moving Average: {:.4f}" \
.format(current_iteration, total_iterations,
(100. * current_iteration) / total_iterations,
metric_name + " ",
trainer.validation_history.simple_moving_average(window_size, history_transform))
logger(log_str)


def log_training_weighted_moving_average(trainer, window_size, weights, history_transform=lambda x: x,
should_log=lambda trainer: True, metric_name="", logger=print):
if should_log(trainer):
iterations_per_epoch = len(trainer.training_data)
current_iteration = trainer.current_iteration % iterations_per_epoch
log_str = "Training Epoch[{}/{}] Iteration[{}/{} ({:.2f}%}]\t{}Weighted Moving Average: {:.4f}" \
.format(trainer.current_epoch, trainer.max_epochs, current_iteration,
iterations_per_epoch, (100. * current_iteration) / iterations_per_epoch,
metric_name + " ",
trainer.training_history.weighted_moving_average(window_size, weights, history_transform))
logger(log_str)


def log_validation_weighted_moving_average(trainer, window_size, weights, history_transform=lambda x: x,
should_log=lambda trainer: True, metric_name="", logger=print):
if should_log(trainer):
total_iterations = len(trainer.validation_data)
current_iteration = trainer.current_iteration % total_iterations
log_str = "Validation Iteration[{}/{} ({:.2f}%)]\t{}Weighted Moving Average: {:.4f}" \
.format(current_iteration, total_iterations,
(100. * current_iteration) / total_iterations,
metric_name + " ",
trainer.validation_history.weighted_moving_average(window_size, weights, history_transform))
logger(log_str)


def log_training_exponential_moving_average(trainer, window_size, alpha, history_transform=lambda x: x,
should_log=lambda trainer: True, metric_name="", logger=print):
if should_log(trainer):
iterations_per_epoch = len(trainer.training_data)
current_iteration = trainer.current_iteration % iterations_per_epoch
log_str = "Training Epoch[{}/{}] Iteration[{}/{} ({:.2f}%)]\t{}Exponential Moving Average: {:.4f}" \
.format(trainer.current_epoch, trainer.max_epochs, current_iteration,
iterations_per_epoch, (100. * current_iteration) / iterations_per_epoch,
metric_name + " ",
trainer.training_history.exponential_moving_average(window_size, alpha, history_transform))
logger(log_str)


def log_validation_exponential_moving_average(trainer, window_size, alpha, history_transform=lambda x: x,
should_log=lambda trainer: True, metric_name="", logger=print):
if should_log(trainer):
total_iterations = len(trainer.validation_data)
current_iteration = trainer.current_iteration % total_iterations
log_str = "Validation Iteration[{}/{} ({:.2f}%)]\t{}Exponential Moving Average: {:.4f}" \
.format(trainer.current_validation_iteration, total_iterations,
(100. * current_iteration) / total_iterations,
metric_name + " ",
trainer.validation_history.exponential_moving_average(window_size, alpha, history_transform))
logger(log_str)
from functools import partial

from ignite.evaluator import Evaluator
from ignite.trainer import Trainer
from ignite.history import History


def _log_engine_history_average(engine, metric_name, msg_avg_type, history_avg_fn, logger):
total_iterations = len(engine.dataloader)
current_iteration = (engine.current_iteration - 1) % total_iterations + 1
history_average = history_avg_fn(engine.history)
msg_prefix = ""

if isinstance(engine, Trainer):
msg_prefix = "Training Epoch[{}/{}] ".format(engine.current_epoch, engine.max_epochs)
elif isinstance(engine, Evaluator):
msg_prefix = "Evaluation "

log_str = "{}Iteration[{}/{} ({:.2f}%)]\t{} {}: {:.4f}" \
.format(msg_prefix,
current_iteration, total_iterations, (100. * current_iteration) / total_iterations,
metric_name, msg_avg_type, history_average)
logger(log_str)


def log_simple_moving_average(engine, window_size, history_transform=lambda x: x,
should_log=lambda engine: True, metric_name="", logger=print):
if should_log(engine):
_log_engine_history_average(engine, metric_name, "Simple Moving Average",
partial(History.simple_moving_average, window_size=window_size,
transform=history_transform),
logger)


def log_weighted_moving_average(engine, window_size, weights, history_transform=lambda x: x,
should_log=lambda engine: True, metric_name="", logger=print):
if should_log(engine):
_log_engine_history_average(engine, metric_name, "Weighted Moving Average",
partial(History.weighted_moving_average, window_size=window_size,
weights=weights, transform=history_transform),
logger)


def log_exponential_moving_average(engine, window_size, alpha, history_transform=lambda x: x,
should_log=lambda trainer: True, metric_name="", logger=print):
if should_log(engine):
_log_engine_history_average(engine, metric_name, "Exponential Moving Average",
partial(History.exponential_moving_average, window_size=window_size,
alpha=alpha, transform=history_transform),
logger)
2 changes: 1 addition & 1 deletion ignite/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def run(self, training_data, max_epochs=1):
-------
None
"""
self.dataset = training_data
self.dataloader = training_data
self.current_iteration = 0
self.current_epoch = 0

Expand Down

0 comments on commit 6b4c0e8

Please sign in to comment.