Skip to content

Commit

Permalink
NAS documents general improvements (v2.8) (microsoft#4942)
Browse files Browse the repository at this point in the history
  • Loading branch information
ultmaster authored Jun 20, 2022
1 parent b99e268 commit 8978659
Show file tree
Hide file tree
Showing 11 changed files with 108 additions and 18 deletions.
41 changes: 40 additions & 1 deletion docs/source/nas/evaluator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ A model evaluator is for training and validating each generated model. They are
Customize Evaluator with Any Function
-------------------------------------

The simplest way to customize a new evaluator is with :class:`FunctionalEvaluator <nni.retiarii.evaluator.FunctionalEvaluator>`, which is very easy when training code is already available. Users only need to write a fit function that wraps everything, which usually includes training, validating and testing of a single model. This function takes one positional arguments (``model_cls``) and possible keyword arguments. The keyword arguments (other than ``model_cls``) are fed to :class:`FunctionalEvaluator <nni.retiarii.evaluator.FunctionalEvaluator>` as its initialization parameters (note that they will be :doc:`serialized <./serialization>`). In this way, users get everything under their control, but expose less information to the framework and as a result, further optimizations like :ref:`CGO <cgo-execution-engine>` might be not feasible. An example is as belows:
The simplest way to customize a new evaluator is with :class:`~nni.retiarii.evaluator.FunctionalEvaluator`, which is very easy when training code is already available. Users only need to write a fit function that wraps everything, which usually includes training, validating and testing of a single model. This function takes one positional arguments (``model_cls``) and possible keyword arguments. The keyword arguments (other than ``model_cls``) are fed to :class:`~nni.retiarii.evaluator.FunctionalEvaluator` as its initialization parameters (note that they will be :doc:`serialized <./serialization>`). In this way, users get everything under their control, but expose less information to the framework and as a result, further optimizations like :ref:`CGO <cgo-execution-engine>` might be not feasible. An example is as belows:

.. code-block:: python
Expand Down Expand Up @@ -42,6 +42,41 @@ The simplest way to customize a new evaluator is with :class:`FunctionalEvaluato
If the conversion is successful, the model will be able to be visualized with powerful tools `Netron <https://netron.app/>`__.

Use Evaluators to Train and Evaluate Models
-------------------------------------------

Users can use evaluators to train or evaluate a single, concrete architecture. This is very useful when:

* Debugging your evaluator against a baseline model.
* Fully train, validate and test your model after the search process is complete.

The usage is shown below:

.. code-block:: python
# Class definition of single model, for example, ResNet.
class SingleModel(nn.Module):
def __init__(): # Can't have init parameters here.
...
# Use a callable returning a model
evaluator.evaluate(SingleModel)
# Or initialize the model beforehand
evaluator.evaluate(SingleModel())
The underlying implementation of :meth:`~nni.retiarii.Evaluator.evaluate` depends on concrete evaluator that you used.
For example, if :class:`~nni.retiarii.evaluator.FunctionalEvaluator` is used, it will run your customized fit function.
If lightning evaluators like :class:`nni.retiarii.evaluator.pytorch.Classification` are used, it will invoke the ``trainer.fit()`` of Lightning.

To evaluate an architecture that is exported from experiment (i.e., from :meth:`~nni.retiarii.experiment.pytorch.RetiariiExperiment.export_top_models`), use :func:`nni.retiarii.fixed_arch` to instantiate the exported model::

with fixed_arch(exported_model):
model = ModelSpace()
# Then use evaluator.evaluate
evaluator.evaluate(model)

.. tip:: There is a way to port the trained checkpoint of super-net produced by one-shot strategies, to the concrete chosen architecture, thanks to :func:`nni.retiarii.utils.original_state_dict_hooks`. This is helpful in implementing recent multi-stage NAS algorithms like `SPOS <https://arxiv.org/abs/1904.00420>`__.

.. _lightning-evaluator:

Evaluators with PyTorch-Lightning
Expand Down Expand Up @@ -134,6 +169,10 @@ An example is as follows:
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_loss'].item())
.. note::

If you are trying to use your customized evaluator with one-shot strategy, bear in mind that your defined methods will be reassembled into another LightningModule, which might result in extra constraints when writing the LightningModule. For example, your validation step could appear else where (e.g., in ``training_step``). This prohibits you from returning arbitrary object in ``validation_step``.

Then, users need to wrap everything (including LightningModule, trainer and dataloaders) into a :class:`nni.retiarii.evaluator.pytorch.Lightning` object, and pass this object into a Retiarii experiment.

.. code-block:: python
Expand Down
10 changes: 9 additions & 1 deletion docs/source/nas/exploration_strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,15 @@ Starting from v2.8, the usage of one-shot strategies are much alike to multi-tri
import nni.retiarii.strategy as strategy
import nni.retiarii.evaluator.pytorch.lightning as pl
evaluator = pl.Classification(...)
evaluator = pl.Classification(
# Need to use `pl.DataLoader` instead of `torch.utils.data.DataLoader` here,
# or use `nni.trace` to wrap `torch.utils.data.DataLoader`.
train_dataloaders=pl.DataLoader(train_dataset, batch_size=100),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=100),
# Other keyword arguments passed to pytorch_lightning.Trainer.
max_epochs=10,
gpus=1,
)
exploration_strategy = strategy.DARTS()
exp_config.execution_engine = 'oneshot'
Expand Down
5 changes: 5 additions & 0 deletions docs/source/reference/nas/evaluator.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,14 @@ Utilities
Customization
-------------

.. autoclass:: nni.retiarii.Evaluator
:members:

.. autoclass:: nni.retiarii.evaluator.pytorch.Lightning
:members:

.. autoclass:: nni.retiarii.evaluator.pytorch.LightningModule
:members:

Cross-graph Optimization (experimental)
---------------------------------------
Expand Down
5 changes: 5 additions & 0 deletions nni/retiarii/evaluator/pytorch/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ class Lightning(Evaluator):
in trainer. Two hooks are added at the end of validation epoch and the end of ``fit``, respectively. The metric name
and type depend on the specific task.
.. warning::
The Lightning evaluator are stateful. If you try to use a previous Lightning evaluator,
please note that the inner ``lightning_module`` and ``trainer`` will be reused.
Parameters
----------
lightning_module
Expand Down
3 changes: 3 additions & 0 deletions nni/retiarii/experiment/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,9 @@ def export_top_models(self, top_k: int = 1, optimize_mode: str = 'maximize', for
For one-shot algorithms, only top-1 is supported. For others, ``optimize_mode`` and ``formatter`` are
available for customization.
The concrete behavior of export depends on each strategy.
See the documentation of each strategy for detailed specifications.
Parameters
----------
top_k : int
Expand Down
10 changes: 9 additions & 1 deletion nni/retiarii/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid

__all__ = ['Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']
__all__ = ['Evaluator', 'Model', 'ModelStatus', 'Graph', 'Node', 'Edge', 'Mutation', 'IllegalGraphError', 'MetricData']


MetricData = Any
Expand All @@ -43,6 +43,13 @@ class Evaluator(abc.ABC):
For example, functional evaluator might directly import the function and call the function.
"""

def evaluate(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
"""To run evaluation of a model. The model could be either a concrete model or a callable returning a model.
The concrete implementation of evaluate depends on the implementation of ``_execute()`` in sub-class.
"""
return self._execute(model_cls)

def __repr__(self):
items = ', '.join(['%s=%r' % (k, v) for k, v in self.__dict__.items()])
return f'{self.__class__.__name__}({items})'
Expand Down Expand Up @@ -355,6 +362,7 @@ def add_node(self, name, operation_or_type, parameters=None): # type: ignore

@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...

@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
Expand Down
4 changes: 2 additions & 2 deletions nni/retiarii/oneshot/pytorch/base_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules.
They are invoked in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
Expand Down Expand Up @@ -189,7 +189,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
it means the hook suggests to
keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`no_default_hook`.
An example of mutation hook is given in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list.
Expand Down
8 changes: 5 additions & 3 deletions nni/retiarii/oneshot/pytorch/differentiable.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class DartsLightningModule(BaseOneShotLightningModule):
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet.
The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet.
.. versionadded:: 2.8
Expand Down Expand Up @@ -186,8 +187,9 @@ class GumbelDartsLightningModule(DartsLightningModule):
See `FBNet <https://arxiv.org/abs/1812.03443>`__ and `SNAS <https://arxiv.org/abs/1812.09926>`__.
This is a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
Essentially, it tries to mimick the behavior of sampling one path on forward by gradually
cool down the temperature, aiming to bridge the gap between differentiable architecture weights and
discretization of architectures.
.. versionadded:: 2.8
Expand Down
4 changes: 3 additions & 1 deletion nni/retiarii/oneshot/pytorch/enas.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ class ReinforceController(nn.Module):
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect will appear.
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Expand Down
35 changes: 26 additions & 9 deletions nni/retiarii/oneshot/pytorch/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""Experimental version of sampling-based one-shot implementation."""

from __future__ import annotations
import warnings
from typing import Any

import pytorch_lightning as pl
Expand Down Expand Up @@ -76,6 +77,18 @@ def training_step(self, batch, batch_idx):
self.resample()
return self.model.training_step(batch, batch_idx)

def export(self) -> dict[str, Any]:
"""
Export of Random one-shot. It will return an arbitrary architecture.
"""
warnings.warn(
'Direct export from RandomOneShot returns an arbitrary architecture. '
'Sampling the best architecture from this trained supernet is another search process. '
'Users need to do another search based on the checkpoint of the one-shot strategy.',
UserWarning
)
return super().export()


class EnasLightningModule(RandomSamplingLightningModule):
_enas_note = """
Expand All @@ -86,8 +99,10 @@ class EnasLightningModule(RandomSamplingLightningModule):
- Firstly, training model parameters.
- Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward.
ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``.
See explanation of ``reward_metric_name`` for details.
.. note::
ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``.
See explanation of ``reward_metric_name`` for details.
The supported mutation primitives of ENAS are:
Expand All @@ -105,22 +120,24 @@ class EnasLightningModule(RandomSamplingLightningModule):
{{module_params}}
{base_params}
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
Optional kwargs that will be passed to :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController`.
entropy_weight : float
Weight of sample entropy loss.
Weight of sample entropy loss in RL.
skip_weight : float
Weight of skip penalty loss.
Weight of skip penalty loss. See :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController` for details.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
Decay factor of reward baseline, which is used to normalize the reward in RL.
At each step, the new reward baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
Number of steps for which the gradients will be accumulated,
before updating the weights of RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
If there are multiple, it will find the metric with key name ``reward_metric_name``,
which is "default" by default.
If there are multiple, by default, it will find the metric with key name ``default``.
If reward_metric_name is specified, it will find reward_metric_name.
Otherwise it raises an exception indicating multiple metrics are found.
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
Expand Down
1 change: 1 addition & 0 deletions nni/retiarii/oneshot/pytorch/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def run(self, base_model: Model, applied_mutators):
evaluator.trainer.fit(self.model, train_loader, val_loader)

def export_top_models(self, top_k: int = 1) -> list[Any]:
"""The behavior of export top models in strategy depends on the implementation of inner one-shot module."""
if self.model is None:
raise RuntimeError('One-shot strategy needs to be run before export.')
if top_k != 1:
Expand Down

0 comments on commit 8978659

Please sign in to comment.