Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer does not save checkpoints every n steps as it should #23

Merged
merged 5 commits into from
Mar 21, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improved callbacks docs
  • Loading branch information
Alberto Gasparin committed Mar 19, 2023
commit d3b91c8908fe9112c9a080d3f493c49046d1513a
4 changes: 4 additions & 0 deletions docs/source/references/prob_model/fit_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ Posterior fitting configuration
===============================
This section describes :class:`~fortuna.prob_model.fit_config.base.FitConfig`,
an object that configures the posterior fitting process. It is made of several objects:

- :class:`~fortuna.prob_model.fit_config.optimizer.FitOptimizer`: to configure the optimization process;
- :class:`~fortuna.prob_model.fit_config.checkpointer.FitCheckpointer`: to save and restore checkpoints;
- :class:`~fortuna.prob_model.fit_config.monitor.FitMonitor`: to monitor the process and trigger early stopping;
- :class:`~fortuna.prob_model.fit_config.processor.FitProcessor`: to decide how and where the computation is processed.
- List[:class:`~fortuna.prob_model.fit_config.callbacks.Callback`]: to allow the user to perform custom actions at different stages of the training process.

.. _fit_config:

Expand All @@ -18,3 +20,5 @@ an object that configures the posterior fitting process. It is made of several o
.. autoclass:: fortuna.prob_model.fit_config.checkpointer.FitCheckpointer

.. autoclass:: fortuna.prob_model.fit_config.processor.FitProcessor

.. autoclass:: fortuna.prob_model.fit_config.callbacks.Callback
2 changes: 1 addition & 1 deletion fortuna/prob_model/fit_config/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from fortuna.prob_model.fit_config.monitor import FitMonitor
from fortuna.prob_model.fit_config.optimizer import FitOptimizer
from fortuna.prob_model.fit_config.processor import FitProcessor
from fortuna.training.callbacks import Callback
from fortuna.prob_model.fit_config.callbacks import Callback


class FitConfig:
Expand Down
67 changes: 67 additions & 0 deletions fortuna/prob_model/fit_config/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from fortuna.training.train_state import TrainState


class Callback:
"""
Base class to define new callback functions. To define a new callback, create a child of this class and
override the relevant methods.

Example
-------
The following is a custom callback that prints the number of model's parameters at the start of each epoch.

.. code-block:: python

class CountParamsCallback(Callback):
def training_epoch_start(self, state: TrainState) -> TrainState:
params, unravel = ravel_pytree(state.params)
logger.info(f"num params: {len(params)}")
return state
"""
def training_epoch_start(self, state: TrainState) -> TrainState:
"""
Called at the beginning of every training epoch

Parameters
----------
state: TrainState
The training state

Returns
-------
TrainState
The (possibly updated) training state
"""
return state

def training_epoch_end(self, state: TrainState) -> TrainState:
"""
Called at the end of every training epoch

Parameters
----------
state: TrainState
The training state

Returns
-------
TrainState
The (possibly updated) training state
"""
return state

def training_step_end(self, state: TrainState) -> TrainState:
"""
Called after every minibatch update

Parameters
----------
state: TrainState
The training state

Returns
-------
TrainState
The (possibly updated) training state
"""
return state
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Callable, Dict, Optional, Tuple, Union, List

import jax.numpy as jnp
import numpy as np
from flax.core import FrozenDict
from jax import random, vmap
from jax._src.prng import PRNGKeyArray
Expand All @@ -14,7 +13,7 @@
from fortuna.distribution.base import Distribution
from fortuna.prob_model.posterior.posterior_trainer import PosteriorTrainerABC
from fortuna.prob_model.posterior.state import PosteriorState
from fortuna.training.callbacks import Callback
from fortuna.prob_model.fit_config.callbacks import Callback
from fortuna.typing import Array, Batch, CalibMutable, CalibParams, Params, Mutable


Expand Down
2 changes: 1 addition & 1 deletion fortuna/prob_model/posterior/swag/swag_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from fortuna.prob_model.posterior.map.map_trainer import MAPTrainer
from fortuna.prob_model.posterior.swag.swag_state import SWAGState
from fortuna.training.callbacks import Callback
from fortuna.prob_model.fit_config.callbacks import Callback
from fortuna.training.trainer import JittedMixin, MultiDeviceMixin
from fortuna.typing import Array, Batch

Expand Down
19 changes: 0 additions & 19 deletions fortuna/training/callbacks.py

This file was deleted.

2 changes: 1 addition & 1 deletion fortuna/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tqdm.std import tqdm as TqdmDecorator

from fortuna.data.loader import DataLoader
from fortuna.training.callbacks import Callback
from fortuna.prob_model.fit_config.callbacks import Callback
from fortuna.training.mixin import (InputValidatorMixin,
WithCheckpointingMixin,
WithEarlyStoppingMixin)
Expand Down
5 changes: 2 additions & 3 deletions tests/fortuna/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
from flax.core import FrozenDict
from jax import numpy as jnp
from jax._src.prng import PRNGKeyArray
from optax._src.base import GradientTransformation, PyTree
from optax._src.base import PyTree

from fortuna.prob_model.joint.state import JointState
from fortuna.training.callbacks import Callback
from fortuna.prob_model.fit_config.callbacks import Callback
from fortuna.training.train_state import TrainState
from fortuna.training.trainer import TrainerABC
from fortuna.typing import Params, Batch, Mutable, CalibParams, CalibMutable
Expand Down