Skip to content

Commit

Permalink
format.sh
Browse files Browse the repository at this point in the history
  • Loading branch information
tohmae committed Oct 3, 2021
1 parent 92ad9fc commit 768fcb2
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions optuna/integration/pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import warnings

from packaging import version
import sqlalchemy

import optuna
from optuna.storages._cached_storage import _CachedStorage
from packaging import version


with optuna._imports.try_import() as _imports:
import pytorch_lightning as pl
Expand Down Expand Up @@ -49,9 +50,7 @@ def on_init_start(self, trainer: Trainer):
self.is_ddp_backend = trainer.accelerator_connector.distributed_backend is not None
if self.is_ddp_backend is True:
if version.parse(pl.__version__) < version.parse("1.4.0"):
raise ValueError(
"PyTorch Lightning>=1.4.0 is required in DDP."
)
raise ValueError("PyTorch Lightning>=1.4.0 is required in DDP.")
if not isinstance(self._trial.study._storage, _CachedStorage):
raise ValueError(
":class:`~optuna.integration.PyTorchLightningPruningCallback`"
Expand Down Expand Up @@ -80,8 +79,6 @@ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> Non
self._trial.report(current_score, step=epoch)
except sqlalchemy.exc.IntegrityError:
pass
except:
raise

if self._trial.should_prune():
trainer.should_stop = True
Expand Down

0 comments on commit 768fcb2

Please sign in to comment.