forked from optuna/optuna
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1806beb
commit 92d0b91
Showing
4 changed files
with
88 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import optuna | ||
from optuna import type_checking | ||
|
||
if type_checking.TYPE_CHECKING: | ||
from optuna.trial import Trial # NOQA | ||
|
||
try: | ||
from ignite.engine import Engine # NOQA | ||
_available = True | ||
except ImportError as e: | ||
_import_error = e | ||
# IgnitePruningHandler is disabled because pytorch-ignite is not available. | ||
_available = False | ||
|
||
|
||
class IgnitePruningHandler(object): | ||
|
||
def __init__(self, trial, metric, trainer): | ||
# type: (Trial, str, Engine) -> None | ||
|
||
self.trial = trial | ||
self.metric = metric | ||
self.trainer = trainer | ||
|
||
def __call__(self, engine): | ||
# type: (Engine) -> None | ||
|
||
score = engine.state.metrics[self.metric] | ||
self.trial.report(score, engine.state.epoch) | ||
if self.trial.should_prune(): | ||
self.trainer.terminate() | ||
message = "Trial was pruned at {} epoch.".format(engine.state.epoch) | ||
raise optuna.structs.TrialPruned(message) | ||
|
||
|
||
def _check_pytorch_ignite_availability(): | ||
# type: () -> None | ||
|
||
if not _available: | ||
raise ImportError( | ||
'PyTorch Ignite is not available. Please install PyTorch Ignite to use this feature. ' | ||
'PyTorch Ignite can be installed by executing `$ pip install pytorch-ignite`. ' | ||
'For further information, please refer to the installation guide of PyTorch Ignite. ' | ||
'(The actual import error is as follows: ' + str(_import_error) + ')') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from ignite.engine import Engine | ||
from mock import Mock | ||
from mock import patch | ||
import pytest | ||
|
||
import optuna | ||
from optuna.testing.integration import create_running_trial | ||
from optuna.testing.integration import DeterministicPruner | ||
from optuna import type_checking | ||
|
||
if type_checking.TYPE_CHECKING: | ||
from typing import Iterable # NOQA | ||
|
||
|
||
def test__ignite_pruning_handler(): | ||
# type: () -> None | ||
|
||
def update(engine, batch): | ||
# type: (Engine, Iterable) -> None | ||
|
||
pass | ||
|
||
trainer = Engine(update) | ||
|
||
# The pruner is activated. | ||
study = optuna.create_study(pruner=DeterministicPruner(True)) | ||
trial = create_running_trial(study, 1.0) | ||
|
||
handler = optuna.integration.IgnitePruningHandler(trial, 'accuracy', trainer) | ||
with patch.object(trainer, 'state', epoch=Mock(return_value=1), metrics={'accuracy': 1}): | ||
with pytest.raises(optuna.structs.TrialPruned): | ||
handler(trainer) | ||
|
||
# # The pruner is not activated. | ||
study = optuna.create_study(pruner=DeterministicPruner(False)) | ||
trial = create_running_trial(study, 1.0) | ||
|
||
handler = optuna.integration.IgnitePruningHandler(trial, 'accuracy', trainer) | ||
with patch.object(trainer, 'state', epoch=Mock(return_value=1), metrics={'accuracy': 1}): | ||
handler(trainer) |