In this repository we implement implicit forms of common gradient descent algorithms available in pyTorch 1.0 and compare performance to traditional (explicit) SGD methods.
For a detailed discussion of the theoretical advantages of implicit gradient descent, see this paper. To paraphrase, a key finding is that whereas in explicit methods the learning rate schedule needs to be carefully adjusted to balance statistical efficiency and numerical stability, with implicit methods the stability constraint effectively vanishes. Effectively, any learning rate (or sequence thereof) yields a stable procedure allowing for higher rates which yields faster convergence. On the other hand as we shall see computing the update is computationally expensive and in many cases intractable, so in this repo we experiment to find situations in which the trade-off is beneficial.
A little background: an update rule for gradient descent typically looks something like:
The implicit update is instead:
where here we choose use a first-order expansion to approximate
where denotes the Hessian of F. Now we can simply rearrange to find our update rule:
In our investigation we restrict our attention to compositions of affine functions, bounded non-linearities and ReLU (MLP, CNN and RNN) and so we can approximate well the Hessian without computing second-order partials of loss w.r.t parameters (see background.pdf for details). This enables us to utilize pyTorch auto-differentiation tools, and the increase in workload from explicit method is due exclusively to needed to invert a square matrix of size MxM where M is the number of parameters. In addition, rather than straight-forward SGD update we also implement alternative update-rules (such as implicit ADAM).
The implicit update requires the inversion of the MxM matrix where M is the number of parameters, making it impractical for modern convolution of architectures which regularly comprised by millions (or more) of parameters. On the other hand we are interested in identifying if and when (e.g for what M?) the increase in speed of training afforded by higher allowable learning rates can off-set the computational cost of inverting such a matrix.
While training deep networks from scratch with such methods seems impractical, we'd like to investigate potential applications into transfer learning (see our repo on neural splicing for example & discussion) where weights from a pre-trained model are loaded into a slightly modified architecture where only a small fraction of the parameters are initialized randomly and trained to begin with (on the order of tens of thousands). We hope this approach can accelerate fine-tuning in these cases.