Skip to content

Commit

Permalink
Autograd Doc for Complex Numbers (#41012)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch/pytorch#41012

Test Plan: Imported from OSS

Differential Revision: D22476911

Pulled By: anjali411

fbshipit-source-id: 7da20cb4312a0465272bebe053520d9911475828
  • Loading branch information
anjali411 authored and facebook-github-bot committed Jul 10, 2020
1 parent e568b3f commit db38487
Showing 1 changed file with 79 additions and 0 deletions.
79 changes: 79 additions & 0 deletions docs/source/notes/autograd.rst
Original file line number Diff line number Diff line change
Expand Up @@ -210,3 +210,82 @@ No thread safety on C++ hooks
Autograd relies on the user to write thread safe C++ hooks. If you want the hook
to be correctly applied in multithreading environment, you will need to write
proper thread locking code to ensure the hooks are thread safe.

Autograd for Complex Numbers
^^^^^^^^^^^^^^^^^^^^^^^^^^^^

**What notion of complex derivative does PyTorch use?**
*******************************************************

PyTorch follows `JAX's <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
convention for autograd for Complex Numbers.

Suppose we have a function :math:`F: ℂ → ℂ` which we can decompose into functions u and v
which compute the real and imaginary parts of the function:

.. code::
def F(z):
x, y = real(z), imag(z)
return u(x, y) + v(x, y) * 1j
where :math:`1j` is a unit imaginary number.

We define the :math:`JVP` for function :math:`F` at :math:`(x, y)` applied to a tangent
vector :math:`c+dj \in C` as:

.. math:: \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix}

where

.. math::
J = \begin{bmatrix}
\frac{\partial u(x, y)}{\partial x} & \frac{\partial u(x, y)}{\partial y}\\
\frac{\partial v(x, y)}{\partial x} & \frac{\partial v(x, y)}{\partial y} \end{bmatrix} \\
This is similar to the definition of the JVP for a function defined from :math:`R^2 → R^2`, and the multiplication
with :math:`[1, 1j]^T` is used to identify the result as a complex number.

We define the :math:`VJP` of :math:`F` at :math:`(x, y)` for a cotangent vector :math:`c+dj \in C` as:

.. math:: \begin{bmatrix} c & -d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}

In PyTorch, the `VJP` is mostly what we care about, as it is the computation performed when we do backward
mode automatic differentiation. Notice that d and :math:`1j` are negated in the formula above. Please look at
the `JAX docs <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#Complex-numbers-and-differentiation>`_
to get explanation for the negative signs in the formula.

**What happens if I call backward() on a complex scalar?**
*******************************************************************************

The gradient for a complex function is computed assuming the input function is a holomorphic function.
This is because for general :math:`ℂ → ℂ` functions, the Jacobian has 4 real-valued degrees of freedom
(as in the `2x2` Jacobian matrix above), so we can’t hope to represent all of them with in a complex number.
However, for holomorphic functions, the gradient can be fully represented with complex numbers due to the
Cauchy-Riemann equations that ensure that `2x2` Jacobians have the special form of a scale-and-rotate
matrix in the complex plane, i.e. the action of a single complex number under multiplication. And so, we can
obtain that gradient using backward which is just a call to `vjp` with covector `1.0`.

The net effect of this assumption is that the partial derivatives of the imaginary part of the function
(:math:`v(x, y)` above) are discarded for :func:`torch.autograd.backward` on a complex scalar
(e.g., this is equivalent to dropping the imaginary part of the loss before performing a backwards).

For any other desired behavior, you can specify the covector `grad_output` in :func:`torch.autograd.backward` call accordingly.

**How are the JVP and VJP defined for cross-domain functions?**
***************************************************************

Based on formulas above and the behavior we expect to see (going from :math:`ℂ → ℝ^2 → ℂ` should be an identity),
we use the formula given below for cross-domain functions.

The :math:`JVP` and :math:`VJP` for a :math:`f1: ℂ → ℝ^2` are defined as:

.. math:: JVP = J * \begin{bmatrix} c \\ d \end{bmatrix}

.. math:: VJP = \begin{bmatrix} c & d \end{bmatrix} * J * \begin{bmatrix} 1 \\ -1j \end{bmatrix}

The :math:`JVP` and :math:`VJP` for a :math:`f1: ℝ^2 → ℂ` are defined as:

.. math:: JVP = \begin{bmatrix} 1 & 1j \end{bmatrix} * J * \begin{bmatrix} c \\ d \end{bmatrix} \\ \\

.. math:: VJP = \begin{bmatrix} c & -d \end{bmatrix} * J

0 comments on commit db38487

Please sign in to comment.