Skip to content

Commit

Permalink
Add pruning handler for ignite.
Browse files Browse the repository at this point in the history
  • Loading branch information
toshihikoyanase committed Sep 30, 2019
1 parent 1806beb commit 92d0b91
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 1 deletion.
2 changes: 2 additions & 0 deletions optuna/integration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
'chainer': ['ChainerPruningExtension'],
'chainermn': ['ChainerMNStudy'],
'cma': ['CmaEsSampler'],
'ignite': ["IgnitePruningHandler"],
'keras': ['KerasPruningCallback'],
'lightgbm': ['LightGBMPruningCallback'],
'sklearn': ['OptunaSearchCV'],
Expand All @@ -28,6 +29,7 @@
from optuna.integration.chainer import ChainerPruningExtension # NOQA
from optuna.integration.chainermn import ChainerMNStudy # NOQA
from optuna.integration.cma import CmaEsSampler # NOQA
from optuna.integration.ignite import IgnitePruningHandler # NOQA
from optuna.integration.keras import KerasPruningCallback # NOQA
from optuna.integration.lightgbm import LightGBMPruningCallback # NOQA
from optuna.integration.mxnet import MXNetPruningCallback # NOQA
Expand Down
44 changes: 44 additions & 0 deletions optuna/integration/ignite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import optuna
from optuna import type_checking

if type_checking.TYPE_CHECKING:
from optuna.trial import Trial # NOQA

try:
from ignite.engine import Engine # NOQA
_available = True
except ImportError as e:
_import_error = e
# IgnitePruningHandler is disabled because pytorch-ignite is not available.
_available = False


class IgnitePruningHandler(object):

def __init__(self, trial, metric, trainer):
# type: (Trial, str, Engine) -> None

self.trial = trial
self.metric = metric
self.trainer = trainer

def __call__(self, engine):
# type: (Engine) -> None

score = engine.state.metrics[self.metric]
self.trial.report(score, engine.state.epoch)
if self.trial.should_prune():
self.trainer.terminate()
message = "Trial was pruned at {} epoch.".format(engine.state.epoch)
raise optuna.structs.TrialPruned(message)


def _check_pytorch_ignite_availability():
# type: () -> None

if not _available:
raise ImportError(
'PyTorch Ignite is not available. Please install PyTorch Ignite to use this feature. '
'PyTorch Ignite can be installed by executing `$ pip install pytorch-ignite`. '
'For further information, please refer to the installation guide of PyTorch Ignite. '
'(The actual import error is as follows: ' + str(_import_error) + ')')
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def get_extras_require():
'testing': [
'bokeh', 'chainer>=5.0.0', 'cma', 'keras', 'lightgbm', 'mock',
'mpi4py', 'mxnet', 'pandas', 'plotly>=4.0.0', 'pytest', 'scikit-optimize',
'tensorflow', 'tensorflow-datasets', 'xgboost', 'scikit-learn>=0.19.0',
'tensorflow', 'tensorflow-datasets', 'xgboost', 'scikit-learn>=0.19.0', 'torch',
'pytorch-ignite'
],
'example': [
'chainer', 'keras', 'catboost', 'lightgbm', 'scikit-learn',
Expand Down
40 changes: 40 additions & 0 deletions tests/integration_tests/test_ignite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from ignite.engine import Engine
from mock import Mock
from mock import patch
import pytest

import optuna
from optuna.testing.integration import create_running_trial
from optuna.testing.integration import DeterministicPruner
from optuna import type_checking

if type_checking.TYPE_CHECKING:
from typing import Iterable # NOQA


def test__ignite_pruning_handler():
# type: () -> None

def update(engine, batch):
# type: (Engine, Iterable) -> None

pass

trainer = Engine(update)

# The pruner is activated.
study = optuna.create_study(pruner=DeterministicPruner(True))
trial = create_running_trial(study, 1.0)

handler = optuna.integration.IgnitePruningHandler(trial, 'accuracy', trainer)
with patch.object(trainer, 'state', epoch=Mock(return_value=1), metrics={'accuracy': 1}):
with pytest.raises(optuna.structs.TrialPruned):
handler(trainer)

# # The pruner is not activated.
study = optuna.create_study(pruner=DeterministicPruner(False))
trial = create_running_trial(study, 1.0)

handler = optuna.integration.IgnitePruningHandler(trial, 'accuracy', trainer)
with patch.object(trainer, 'state', epoch=Mock(return_value=1), metrics={'accuracy': 1}):
handler(trainer)

0 comments on commit 92d0b91

Please sign in to comment.