diff --git a/.conda_env.yml b/.conda_env.yml index 52d2e3fc3..9b0396e70 100644 --- a/.conda_env.yml +++ b/.conda_env.yml @@ -3,7 +3,7 @@ channels: - pytorch - defaults dependencies: - - pip=21.2.4 - - python=3.8.5 + - python=3.9.16 + - pip=23.1.2 - pip: - - -e .[lint,test,doc] + - -e .[lint,test,docs] diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index af35589c8..edcee9bb3 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -15,10 +15,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -30,10 +30,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -45,10 +45,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -61,10 +61,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -77,10 +77,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -92,10 +92,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip @@ -107,10 +107,10 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 - - name: Set up Python 3.8 + - name: Set up Python 3.9 uses: actions/setup-python@v1 with: - python-version: 3.8 + python-version: 3.9 - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.github/workflows/python-publish.yml b/.github/workflows/python-publish.yml index 48c52e137..d4bebe1e9 100644 --- a/.github/workflows/python-publish.yml +++ b/.github/workflows/python-publish.yml @@ -20,11 +20,11 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install setuptools wheel twine + pip install --upgrade twine build - name: Build and publish env: - TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} - TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + TWINE_USERNAME: __token__ + TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} run: | - python setup.py sdist bdist_wheel + python -m build twine upload dist/* diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 44af53e7c..d947771b1 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -19,14 +19,9 @@ jobs: strategy: matrix: - python-version: [3.8, 3.9] + python-version: [3.9] pytorch-version: - - "==1.9.1" - - "==1.10.1" - - "==1.11.0" - - "==1.12.1" - - "==1.13.1" - - "==2.0.1" + - "==2.2.0" - "" # latest steps: - uses: actions/checkout@v1 @@ -46,10 +41,3 @@ jobs: if: contains('refs/heads/master refs/heads/development refs/heads/release', github.ref) != 1 run: | make test-light - - - name: Test coveralls - python ${{ matrix.python-version }} - run: coveralls --service=github - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - flag-name: run-${{ matrix.python-version }} - parallel: true diff --git a/.readthedocs.yml b/.readthedocs.yml index 32cd568e7..931339f6f 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -6,11 +6,14 @@ version: 2 sphinx: configuration: docs_src/rtd/conf.py +build: + os: ubuntu-22.04 + tools: + python: "3.9" + python: - version: 3.8 install: - - method: pip - path: . - extra_requirements: - - docs - system_packages: true + - method: pip + path: . + extra_requirements: + - docs diff --git a/backpack/__init__.py b/backpack/__init__.py index 85b5aa578..671600141 100644 --- a/backpack/__init__.py +++ b/backpack/__init__.py @@ -1,4 +1,5 @@ """BackPACK.""" + from inspect import isclass from types import TracebackType from typing import Callable, Optional, Tuple, Type, Union diff --git a/backpack/context.py b/backpack/context.py index 39d19ebf7..969428c2b 100644 --- a/backpack/context.py +++ b/backpack/context.py @@ -1,4 +1,5 @@ """Context class for BackPACK.""" + from typing import Callable, Iterable, List, Tuple, Type from torch.nn import Module diff --git a/backpack/core/derivatives/adaptive_avg_pool_nd.py b/backpack/core/derivatives/adaptive_avg_pool_nd.py index f6f4d3ad1..1ff760149 100644 --- a/backpack/core/derivatives/adaptive_avg_pool_nd.py +++ b/backpack/core/derivatives/adaptive_avg_pool_nd.py @@ -1,12 +1,11 @@ """Implements the derivatives for AdaptiveAvgPool.""" + from typing import List, Tuple, Union -from warnings import warn from torch import Size from torch.nn import AdaptiveAvgPool1d, AdaptiveAvgPool2d, AdaptiveAvgPool3d from backpack.core.derivatives.avgpoolnd import AvgPoolNDDerivatives -from backpack.utils import ADAPTIVE_AVG_POOL_BUG class AdaptiveAvgPoolNDDerivatives(AvgPoolNDDerivatives): @@ -28,14 +27,6 @@ def check_parameters( Raises: NotImplementedError: if the given shapes do not match """ - if ADAPTIVE_AVG_POOL_BUG and module.input0.is_cuda and (self.N == 3): - warn( - "Be careful when computing gradients of AdaptiveAvgPool3d. " - "There is a bug using autograd.grad on cuda with AdaptiveAvgPool3d. " - "https://discuss.pytorch.org/t/bug-report-autograd-grad-adaptiveavgpool3d-cuda/124614 " # noqa: B950 - "BackPACK derivatives are correct." - ) - shape_input: Size = module.input0.shape shape_output: Size = module.output.shape diff --git a/backpack/core/derivatives/avgpoolnd.py b/backpack/core/derivatives/avgpoolnd.py index b51d21480..8f2c02261 100644 --- a/backpack/core/derivatives/avgpoolnd.py +++ b/backpack/core/derivatives/avgpoolnd.py @@ -3,6 +3,7 @@ Average pooling can be expressed as convolution over grouped channels with a constant kernel. """ + from typing import Any, List, Tuple from einops import rearrange diff --git a/backpack/core/derivatives/basederivatives.py b/backpack/core/derivatives/basederivatives.py index 2a55ac88d..c2c38e824 100644 --- a/backpack/core/derivatives/basederivatives.py +++ b/backpack/core/derivatives/basederivatives.py @@ -1,4 +1,5 @@ """Base classes for more flexible Jacobians and second-order information.""" + import warnings from abc import ABC from typing import Callable, List, Tuple diff --git a/backpack/core/derivatives/batchnorm_nd.py b/backpack/core/derivatives/batchnorm_nd.py index ba1810526..9c030e8e8 100644 --- a/backpack/core/derivatives/batchnorm_nd.py +++ b/backpack/core/derivatives/batchnorm_nd.py @@ -1,4 +1,5 @@ """Contains derivatives for BatchNorm.""" + from typing import List, Tuple, Union from torch import Size, Tensor, einsum diff --git a/backpack/core/derivatives/conv_transposend.py b/backpack/core/derivatives/conv_transposend.py index 3f7fb40e1..a9666e7e2 100644 --- a/backpack/core/derivatives/conv_transposend.py +++ b/backpack/core/derivatives/conv_transposend.py @@ -1,4 +1,5 @@ """Partial derivatives for ``torch.nn.ConvTranspose{1,2,3}d``.""" + from typing import List, Tuple, Union from einops import rearrange @@ -7,19 +8,13 @@ from torch.nn import ConvTranspose1d, ConvTranspose2d, ConvTranspose3d, Module from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION_AT_LEAST_1_13 -from backpack.utils.conv import get_conv_function +from backpack.utils.conv import _grad_input_padding, get_conv_function from backpack.utils.conv_transpose import ( get_conv_transpose_function, unfold_by_conv_transpose, ) from backpack.utils.subsampling import subsample -if TORCH_VERSION_AT_LEAST_1_13: - from backpack.utils.conv import _grad_input_padding -else: - from torch.nn.grad import _grad_input_padding - class ConvTransposeNDDerivatives(BaseParameterDerivatives): """Base class for partial derivatives of transpose convolution.""" diff --git a/backpack/core/derivatives/convnd.py b/backpack/core/derivatives/convnd.py index c3d428694..dabbd2ed6 100644 --- a/backpack/core/derivatives/convnd.py +++ b/backpack/core/derivatives/convnd.py @@ -7,16 +7,10 @@ from torch.nn import Conv1d, Conv2d, Conv3d, Module from backpack.core.derivatives.basederivatives import BaseParameterDerivatives -from backpack.utils import TORCH_VERSION_AT_LEAST_1_13 -from backpack.utils.conv import get_conv_function, unfold_by_conv +from backpack.utils.conv import _grad_input_padding, get_conv_function, unfold_by_conv from backpack.utils.conv_transpose import get_conv_transpose_function from backpack.utils.subsampling import subsample -if TORCH_VERSION_AT_LEAST_1_13: - from backpack.utils.conv import _grad_input_padding -else: - from torch.nn.grad import _grad_input_padding - class weight_jac_t_save_memory: """Choose algorithm to apply transposed convolution weight Jacobian.""" diff --git a/backpack/core/derivatives/crossentropyloss.py b/backpack/core/derivatives/crossentropyloss.py index 5c8bc2768..212db2110 100644 --- a/backpack/core/derivatives/crossentropyloss.py +++ b/backpack/core/derivatives/crossentropyloss.py @@ -1,4 +1,5 @@ """Partial derivatives for cross-entropy loss.""" + from math import sqrt from typing import Callable, Dict, List, Tuple diff --git a/backpack/core/derivatives/dropout.py b/backpack/core/derivatives/dropout.py index b32aef49a..becff4440 100644 --- a/backpack/core/derivatives/dropout.py +++ b/backpack/core/derivatives/dropout.py @@ -1,4 +1,5 @@ """Partial derivatives for the dropout layer.""" + from typing import List, Tuple from torch import Tensor, eq, ones_like diff --git a/backpack/core/derivatives/elu.py b/backpack/core/derivatives/elu.py index 5d3778223..a3d1185f8 100644 --- a/backpack/core/derivatives/elu.py +++ b/backpack/core/derivatives/elu.py @@ -1,4 +1,5 @@ """Partial derivatives for the ELU activation function.""" + from typing import List, Tuple from torch import Tensor, exp, le, ones_like, zeros_like diff --git a/backpack/core/derivatives/embedding.py b/backpack/core/derivatives/embedding.py index acb191be8..52f2e25b6 100644 --- a/backpack/core/derivatives/embedding.py +++ b/backpack/core/derivatives/embedding.py @@ -1,4 +1,5 @@ """Derivatives for Embedding.""" + from typing import List, Tuple from torch import Tensor, einsum, zeros diff --git a/backpack/core/derivatives/flatten.py b/backpack/core/derivatives/flatten.py index aac7f7992..ce216e2e0 100644 --- a/backpack/core/derivatives/flatten.py +++ b/backpack/core/derivatives/flatten.py @@ -1,4 +1,5 @@ """Partial derivatives of the flatten layer.""" + from typing import List, Tuple from torch import Tensor diff --git a/backpack/core/derivatives/leakyrelu.py b/backpack/core/derivatives/leakyrelu.py index 7cb0dfa1e..640ffcfcd 100644 --- a/backpack/core/derivatives/leakyrelu.py +++ b/backpack/core/derivatives/leakyrelu.py @@ -1,4 +1,5 @@ """Partial derivatives for the leaky ReLU layer.""" + from typing import List, Tuple from torch import Tensor, gt diff --git a/backpack/core/derivatives/linear.py b/backpack/core/derivatives/linear.py index a3156f927..59953fb86 100644 --- a/backpack/core/derivatives/linear.py +++ b/backpack/core/derivatives/linear.py @@ -1,4 +1,5 @@ """Contains partial derivatives for the ``torch.nn.Linear`` layer.""" + from typing import List, Tuple from torch import Size, Tensor, einsum diff --git a/backpack/core/derivatives/logsigmoid.py b/backpack/core/derivatives/logsigmoid.py index 917784010..d2da50675 100644 --- a/backpack/core/derivatives/logsigmoid.py +++ b/backpack/core/derivatives/logsigmoid.py @@ -1,4 +1,5 @@ """Contains partial derivatives for the ``torch.nn.LogSigmoid`` layer.""" + from typing import List, Tuple from torch import Tensor, exp diff --git a/backpack/core/derivatives/lstm.py b/backpack/core/derivatives/lstm.py index def7e80da..e93081f75 100644 --- a/backpack/core/derivatives/lstm.py +++ b/backpack/core/derivatives/lstm.py @@ -1,4 +1,5 @@ """Partial derivatives for nn.LSTM.""" + from typing import List, Tuple from torch import Tensor, cat, einsum, sigmoid, tanh, zeros diff --git a/backpack/core/derivatives/nll_base.py b/backpack/core/derivatives/nll_base.py index f2c5dd607..a6b84add0 100644 --- a/backpack/core/derivatives/nll_base.py +++ b/backpack/core/derivatives/nll_base.py @@ -1,4 +1,5 @@ """Partial derivative bases for NLL losses.""" + from math import sqrt from typing import List, Tuple diff --git a/backpack/core/derivatives/permute.py b/backpack/core/derivatives/permute.py index 396803876..8207e2ab2 100644 --- a/backpack/core/derivatives/permute.py +++ b/backpack/core/derivatives/permute.py @@ -1,4 +1,5 @@ """Module containing derivatives of Permute.""" + from typing import List, Tuple from torch import Tensor, argsort diff --git a/backpack/core/derivatives/relu.py b/backpack/core/derivatives/relu.py index 18dab75fa..1bce54e0e 100644 --- a/backpack/core/derivatives/relu.py +++ b/backpack/core/derivatives/relu.py @@ -1,4 +1,5 @@ """Partial derivatives for the ReLU activation function.""" + from typing import List, Tuple from torch import Tensor, gt diff --git a/backpack/core/derivatives/rnn.py b/backpack/core/derivatives/rnn.py index 792eda640..c80a62d3c 100644 --- a/backpack/core/derivatives/rnn.py +++ b/backpack/core/derivatives/rnn.py @@ -1,4 +1,5 @@ """Partial derivatives for the torch.nn.RNN layer.""" + from typing import List, Tuple from torch import Tensor, cat, einsum, zeros diff --git a/backpack/core/derivatives/scale_module.py b/backpack/core/derivatives/scale_module.py index 9965a204c..9f65abfab 100644 --- a/backpack/core/derivatives/scale_module.py +++ b/backpack/core/derivatives/scale_module.py @@ -1,4 +1,5 @@ """Derivatives of ScaleModule (implies Identity).""" + from typing import List, Tuple, Union from torch import Tensor diff --git a/backpack/core/derivatives/selu.py b/backpack/core/derivatives/selu.py index b6e1c6852..3d7b91a6c 100644 --- a/backpack/core/derivatives/selu.py +++ b/backpack/core/derivatives/selu.py @@ -1,4 +1,5 @@ """Partial derivatives for the SELU activation function.""" + from typing import List, Tuple from torch import Tensor, exp, le, ones_like, zeros_like diff --git a/backpack/core/derivatives/shape_check.py b/backpack/core/derivatives/shape_check.py index d9d998ae3..2ea506913 100644 --- a/backpack/core/derivatives/shape_check.py +++ b/backpack/core/derivatives/shape_check.py @@ -2,6 +2,7 @@ Helpers to check input and output sizes of Jacobian-matrix products. """ + import functools from typing import Any, Callable diff --git a/backpack/core/derivatives/sigmoid.py b/backpack/core/derivatives/sigmoid.py index f03573e57..042b6fbfa 100644 --- a/backpack/core/derivatives/sigmoid.py +++ b/backpack/core/derivatives/sigmoid.py @@ -1,4 +1,5 @@ """Partial derivatives for the Sigmoid activation function.""" + from typing import List, Tuple from torch import Tensor diff --git a/backpack/core/derivatives/slicing.py b/backpack/core/derivatives/slicing.py index 06c78e6f9..7fbc42af0 100644 --- a/backpack/core/derivatives/slicing.py +++ b/backpack/core/derivatives/slicing.py @@ -1,4 +1,5 @@ """Contains derivatives of slicing operation.""" + from typing import List, Tuple from torch import Tensor, zeros diff --git a/backpack/core/derivatives/sum_module.py b/backpack/core/derivatives/sum_module.py index e8383fe6b..da1b29e53 100644 --- a/backpack/core/derivatives/sum_module.py +++ b/backpack/core/derivatives/sum_module.py @@ -1,4 +1,5 @@ """Contains derivatives for SumModule.""" + from typing import List, Tuple from torch import Tensor diff --git a/backpack/core/derivatives/tanh.py b/backpack/core/derivatives/tanh.py index 2e8bd004b..240347299 100644 --- a/backpack/core/derivatives/tanh.py +++ b/backpack/core/derivatives/tanh.py @@ -1,4 +1,5 @@ """Partial derivatives for the Tanh activation function.""" + from typing import List, Tuple from torch import Tensor diff --git a/backpack/core/derivatives/zeropad2d.py b/backpack/core/derivatives/zeropad2d.py index 07af6c95e..13a588877 100644 --- a/backpack/core/derivatives/zeropad2d.py +++ b/backpack/core/derivatives/zeropad2d.py @@ -1,4 +1,5 @@ """Partial derivatives for the ZeroPad2d function.""" + from typing import List, Tuple from einops import rearrange diff --git a/backpack/custom_module/branching.py b/backpack/custom_module/branching.py index 109b88c87..201d17a73 100644 --- a/backpack/custom_module/branching.py +++ b/backpack/custom_module/branching.py @@ -1,4 +1,5 @@ """Emulating branching with modules.""" + from typing import Any, OrderedDict, Tuple, Union from torch import Tensor diff --git a/backpack/custom_module/graph_utils.py b/backpack/custom_module/graph_utils.py index 62bc0ea03..51b6a71d3 100644 --- a/backpack/custom_module/graph_utils.py +++ b/backpack/custom_module/graph_utils.py @@ -1,4 +1,5 @@ """Transformation tools to make graph BackPACK compatible.""" + from copy import deepcopy from typing import Tuple, Union from warnings import warn diff --git a/backpack/custom_module/permute.py b/backpack/custom_module/permute.py index 3213ccd54..54bc8d216 100644 --- a/backpack/custom_module/permute.py +++ b/backpack/custom_module/permute.py @@ -1,4 +1,5 @@ """Module containing Permute module.""" + from typing import Any from torch import Tensor diff --git a/backpack/custom_module/reduce_tuple.py b/backpack/custom_module/reduce_tuple.py index 02fa9f5cc..94b087d1d 100644 --- a/backpack/custom_module/reduce_tuple.py +++ b/backpack/custom_module/reduce_tuple.py @@ -1,4 +1,5 @@ """Module containing ReduceTuple module.""" + from typing import Union from torch import Tensor diff --git a/backpack/custom_module/scale_module.py b/backpack/custom_module/scale_module.py index 2ee03e1a3..d5021fee6 100644 --- a/backpack/custom_module/scale_module.py +++ b/backpack/custom_module/scale_module.py @@ -1,4 +1,5 @@ """Contains ScaleModule.""" + from torch import Tensor from torch.nn import Module diff --git a/backpack/extensions/backprop_extension.py b/backpack/extensions/backprop_extension.py index 84e12d254..1cb605e2b 100644 --- a/backpack/extensions/backprop_extension.py +++ b/backpack/extensions/backprop_extension.py @@ -1,4 +1,5 @@ """Implements the backpropagation mechanism.""" + from __future__ import annotations import abc diff --git a/backpack/extensions/curvmatprod/__init__.py b/backpack/extensions/curvmatprod/__init__.py index d197bd275..ba20fb28a 100644 --- a/backpack/extensions/curvmatprod/__init__.py +++ b/backpack/extensions/curvmatprod/__init__.py @@ -20,7 +20,6 @@ by Felix Dangel, Stefan Harmeling, Philipp Hennig, 2020. """ - from .ggnmp import GGNMP from .hmp import HMP from .pchmp import PCHMP diff --git a/backpack/extensions/firstorder/base.py b/backpack/extensions/firstorder/base.py index b3529df3f..d4c28edb5 100644 --- a/backpack/extensions/firstorder/base.py +++ b/backpack/extensions/firstorder/base.py @@ -1,4 +1,5 @@ """Base class for first order extensions.""" + from typing import Dict, List, Type from torch.nn import Module diff --git a/backpack/extensions/firstorder/batch_grad/__init__.py b/backpack/extensions/firstorder/batch_grad/__init__.py index d67bc1d4d..cf39dca2f 100644 --- a/backpack/extensions/firstorder/batch_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_grad/__init__.py @@ -2,6 +2,7 @@ It defines the module extension for each module. """ + from typing import List from torch.nn import ( diff --git a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py index bd8e75a0d..06abc2d15 100644 --- a/backpack/extensions/firstorder/batch_grad/batch_grad_base.py +++ b/backpack/extensions/firstorder/batch_grad/batch_grad_base.py @@ -1,4 +1,5 @@ """Calculates the batch_grad derivative.""" + from __future__ import annotations from typing import TYPE_CHECKING, Callable, List, Tuple diff --git a/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py index 83759b0ae..045826413 100644 --- a/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py +++ b/backpack/extensions/firstorder/batch_grad/batchnorm_nd.py @@ -1,4 +1,5 @@ """Contains grad_batch extension for BatchNorm.""" + from typing import Tuple, Union from torch import Tensor diff --git a/backpack/extensions/firstorder/batch_grad/embedding.py b/backpack/extensions/firstorder/batch_grad/embedding.py index 35b41f7b0..4476452a4 100644 --- a/backpack/extensions/firstorder/batch_grad/embedding.py +++ b/backpack/extensions/firstorder/batch_grad/embedding.py @@ -1,4 +1,5 @@ """BatchGrad extension for Embedding.""" + from backpack.core.derivatives.embedding import EmbeddingDerivatives from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase diff --git a/backpack/extensions/firstorder/batch_grad/rnn.py b/backpack/extensions/firstorder/batch_grad/rnn.py index 9b92f2642..e092f3c44 100644 --- a/backpack/extensions/firstorder/batch_grad/rnn.py +++ b/backpack/extensions/firstorder/batch_grad/rnn.py @@ -1,4 +1,5 @@ """Contains BatchGradRNN.""" + from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.firstorder.batch_grad.batch_grad_base import BatchGradBase diff --git a/backpack/extensions/firstorder/batch_l2_grad/__init__.py b/backpack/extensions/firstorder/batch_l2_grad/__init__.py index 8be80f08e..c3c7dbd6d 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/__init__.py +++ b/backpack/extensions/firstorder/batch_l2_grad/__init__.py @@ -3,6 +3,7 @@ Defines the backpropagation extension. Within it, define the extension for each module. """ + from torch.nn import ( LSTM, RNN, diff --git a/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py b/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py index f7b4f79dd..cce324dfb 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py +++ b/backpack/extensions/firstorder/batch_l2_grad/batch_l2_base.py @@ -1,4 +1,5 @@ """Contains Base class for batch_l2_grad.""" + from __future__ import annotations from typing import TYPE_CHECKING, Callable, List, Tuple diff --git a/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py index 9e1941804..5857df97d 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py +++ b/backpack/extensions/firstorder/batch_l2_grad/batchnorm_nd.py @@ -1,4 +1,5 @@ """Contains batch_l2 extension for BatchNorm.""" + from typing import Tuple, Union from torch import Tensor diff --git a/backpack/extensions/firstorder/batch_l2_grad/convnd.py b/backpack/extensions/firstorder/batch_l2_grad/convnd.py index 991eb96e2..c5230bb30 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/convnd.py +++ b/backpack/extensions/firstorder/batch_l2_grad/convnd.py @@ -1,4 +1,5 @@ """batch_l2 extension for Conv.""" + from torch import einsum from backpack.core.derivatives.conv1d import Conv1DDerivatives diff --git a/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py b/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py index 3c54be1f5..4cbf3a991 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py +++ b/backpack/extensions/firstorder/batch_l2_grad/convtransposend.py @@ -1,4 +1,5 @@ """batch_l2 extension for ConvTranspose.""" + from torch import einsum from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives diff --git a/backpack/extensions/firstorder/batch_l2_grad/embedding.py b/backpack/extensions/firstorder/batch_l2_grad/embedding.py index eca2b10cb..87697a077 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/embedding.py +++ b/backpack/extensions/firstorder/batch_l2_grad/embedding.py @@ -1,4 +1,5 @@ """BatchL2 extension for Embedding.""" + from backpack.core.derivatives.embedding import EmbeddingDerivatives from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base diff --git a/backpack/extensions/firstorder/batch_l2_grad/linear.py b/backpack/extensions/firstorder/batch_l2_grad/linear.py index 89da03276..51cb05e3e 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/linear.py +++ b/backpack/extensions/firstorder/batch_l2_grad/linear.py @@ -1,4 +1,5 @@ """Contains batch_l2 extension for Linear.""" + from __future__ import annotations from typing import TYPE_CHECKING, Tuple diff --git a/backpack/extensions/firstorder/batch_l2_grad/rnn.py b/backpack/extensions/firstorder/batch_l2_grad/rnn.py index dbb1a1644..c70de5924 100644 --- a/backpack/extensions/firstorder/batch_l2_grad/rnn.py +++ b/backpack/extensions/firstorder/batch_l2_grad/rnn.py @@ -1,4 +1,5 @@ """Contains BatchL2RNN.""" + from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.firstorder.batch_l2_grad.batch_l2_base import BatchL2Base diff --git a/backpack/extensions/firstorder/gradient/__init__.py b/backpack/extensions/firstorder/gradient/__init__.py index 89c7cff43..a6740058a 100644 --- a/backpack/extensions/firstorder/gradient/__init__.py +++ b/backpack/extensions/firstorder/gradient/__init__.py @@ -2,4 +2,5 @@ It calculates the same result as torch backward(). """ + # TODO: Rewrite variance to not need this extension diff --git a/backpack/extensions/firstorder/gradient/base.py b/backpack/extensions/firstorder/gradient/base.py index b9f198855..2cfbe3d26 100644 --- a/backpack/extensions/firstorder/gradient/base.py +++ b/backpack/extensions/firstorder/gradient/base.py @@ -1,4 +1,5 @@ """Calculates the gradient.""" + from backpack.extensions.firstorder.base import FirstOrderModuleExtension diff --git a/backpack/extensions/firstorder/gradient/batchnorm_nd.py b/backpack/extensions/firstorder/gradient/batchnorm_nd.py index 5bacc2ad6..9141daa10 100644 --- a/backpack/extensions/firstorder/gradient/batchnorm_nd.py +++ b/backpack/extensions/firstorder/gradient/batchnorm_nd.py @@ -1,4 +1,5 @@ """Gradient extension for BatchNorm.""" + from typing import Tuple, Union from torch import Tensor diff --git a/backpack/extensions/firstorder/gradient/embedding.py b/backpack/extensions/firstorder/gradient/embedding.py index c394ae509..a323a03da 100644 --- a/backpack/extensions/firstorder/gradient/embedding.py +++ b/backpack/extensions/firstorder/gradient/embedding.py @@ -1,4 +1,5 @@ """Gradient extension for Embedding.""" + from backpack.core.derivatives.embedding import EmbeddingDerivatives from backpack.extensions.firstorder.gradient.base import GradBaseModule diff --git a/backpack/extensions/firstorder/gradient/rnn.py b/backpack/extensions/firstorder/gradient/rnn.py index 7ba76e626..604c8284a 100644 --- a/backpack/extensions/firstorder/gradient/rnn.py +++ b/backpack/extensions/firstorder/gradient/rnn.py @@ -1,4 +1,5 @@ """Contains GradRNN.""" + from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.firstorder.gradient.base import GradBaseModule diff --git a/backpack/extensions/firstorder/sum_grad_squared/__init__.py b/backpack/extensions/firstorder/sum_grad_squared/__init__.py index 76891cff6..f31466c7d 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/__init__.py +++ b/backpack/extensions/firstorder/sum_grad_squared/__init__.py @@ -2,6 +2,7 @@ Defines module extension for each module. """ + from torch.nn import ( LSTM, RNN, diff --git a/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py index 9ad99de07..649cf6e8a 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py +++ b/backpack/extensions/firstorder/sum_grad_squared/batchnorm_nd.py @@ -1,4 +1,5 @@ """SGS extension for BatchNorm.""" + from typing import Tuple, Union from torch import Tensor diff --git a/backpack/extensions/firstorder/sum_grad_squared/embedding.py b/backpack/extensions/firstorder/sum_grad_squared/embedding.py index 62f34e86b..4d5f56b27 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/embedding.py +++ b/backpack/extensions/firstorder/sum_grad_squared/embedding.py @@ -1,4 +1,5 @@ """SGS extension for Embedding.""" + from backpack.core.derivatives.embedding import EmbeddingDerivatives from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase diff --git a/backpack/extensions/firstorder/sum_grad_squared/rnn.py b/backpack/extensions/firstorder/sum_grad_squared/rnn.py index 129229144..57becd11b 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/rnn.py +++ b/backpack/extensions/firstorder/sum_grad_squared/rnn.py @@ -1,4 +1,5 @@ """Contains SGSRNN module.""" + from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.firstorder.sum_grad_squared.sgs_base import SGSBase diff --git a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py index 3e1d171ab..60950c71c 100644 --- a/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py +++ b/backpack/extensions/firstorder/sum_grad_squared/sgs_base.py @@ -1,4 +1,5 @@ """Contains SGSBase, the base module for sum_grad_squared extension.""" + from __future__ import annotations from typing import TYPE_CHECKING, Callable, List, Tuple diff --git a/backpack/extensions/firstorder/variance/__init__.py b/backpack/extensions/firstorder/variance/__init__.py index eeb90902f..d8f8772bc 100644 --- a/backpack/extensions/firstorder/variance/__init__.py +++ b/backpack/extensions/firstorder/variance/__init__.py @@ -2,6 +2,7 @@ Defines module extension for each module. """ + from torch.nn import ( LSTM, RNN, diff --git a/backpack/extensions/firstorder/variance/batchnorm_nd.py b/backpack/extensions/firstorder/variance/batchnorm_nd.py index d2b8512e5..fde59aa09 100644 --- a/backpack/extensions/firstorder/variance/batchnorm_nd.py +++ b/backpack/extensions/firstorder/variance/batchnorm_nd.py @@ -1,4 +1,5 @@ """Variance extension for BatchNorm.""" + from typing import Tuple, Union from torch import Tensor diff --git a/backpack/extensions/firstorder/variance/embedding.py b/backpack/extensions/firstorder/variance/embedding.py index 1b38472a6..7dfa9ccae 100644 --- a/backpack/extensions/firstorder/variance/embedding.py +++ b/backpack/extensions/firstorder/variance/embedding.py @@ -1,4 +1,5 @@ """Variance extension for Embedding.""" + from backpack.extensions.firstorder.gradient.embedding import GradEmbedding from backpack.extensions.firstorder.sum_grad_squared.embedding import SGSEmbedding from backpack.extensions.firstorder.variance.variance_base import VarianceBaseModule diff --git a/backpack/extensions/firstorder/variance/variance_base.py b/backpack/extensions/firstorder/variance/variance_base.py index b91aac935..b08d52bd0 100644 --- a/backpack/extensions/firstorder/variance/variance_base.py +++ b/backpack/extensions/firstorder/variance/variance_base.py @@ -1,4 +1,5 @@ """Contains VarianceBaseModule.""" + from __future__ import annotations from typing import TYPE_CHECKING, Callable, List, Tuple diff --git a/backpack/extensions/mat_to_mat_jac_base.py b/backpack/extensions/mat_to_mat_jac_base.py index 937d78b33..b5f1d3677 100644 --- a/backpack/extensions/mat_to_mat_jac_base.py +++ b/backpack/extensions/mat_to_mat_jac_base.py @@ -1,4 +1,5 @@ """Contains base class for second order extensions.""" + from typing import List, Tuple, Union from torch import Tensor diff --git a/backpack/extensions/module_extension.py b/backpack/extensions/module_extension.py index 96e4fdfd1..836031e5b 100644 --- a/backpack/extensions/module_extension.py +++ b/backpack/extensions/module_extension.py @@ -1,4 +1,5 @@ """Contains base class for BackPACK module extensions.""" + from __future__ import annotations from typing import TYPE_CHECKING, Any, List, Tuple diff --git a/backpack/extensions/saved_quantities.py b/backpack/extensions/saved_quantities.py index c5006fc2a..38b73bb3c 100644 --- a/backpack/extensions/saved_quantities.py +++ b/backpack/extensions/saved_quantities.py @@ -1,4 +1,5 @@ """Class for saving backpropagation quantities.""" + from typing import Any, Callable, Dict, Union from torch import Tensor diff --git a/backpack/extensions/secondorder/base.py b/backpack/extensions/secondorder/base.py index d65fa548f..a5bfa9ec7 100644 --- a/backpack/extensions/secondorder/base.py +++ b/backpack/extensions/secondorder/base.py @@ -1,4 +1,5 @@ """Contains base classes for second order extensions.""" + from backpack.extensions.backprop_extension import BackpropExtension diff --git a/backpack/extensions/secondorder/diag_ggn/__init__.py b/backpack/extensions/secondorder/diag_ggn/__init__.py index 02a4322ac..2fc55d649 100644 --- a/backpack/extensions/secondorder/diag_ggn/__init__.py +++ b/backpack/extensions/secondorder/diag_ggn/__init__.py @@ -8,6 +8,7 @@ BatchDiagGGNExact(BatchDiagGGN) BatchDiagGGNMC(BatchDiagGGN) """ + from torch import Tensor from torch.nn import ( ELU, diff --git a/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py b/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py index b2cfceb46..ab9f5d4f8 100644 --- a/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py +++ b/backpack/extensions/secondorder/diag_ggn/adaptive_avg_pool_nd.py @@ -1,4 +1,5 @@ """DiagGGN extension for AdaptiveAvgPool.""" + from backpack.core.derivatives.adaptive_avg_pool_nd import AdaptiveAvgPoolNDDerivatives from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule diff --git a/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py index c0aa7c29b..964342289 100644 --- a/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py +++ b/backpack/extensions/secondorder/diag_ggn/batchnorm_nd.py @@ -1,4 +1,5 @@ """DiagGGN extension for BatchNorm.""" + from typing import Tuple, Union from torch import Tensor diff --git a/backpack/extensions/secondorder/diag_ggn/custom_module.py b/backpack/extensions/secondorder/diag_ggn/custom_module.py index 293ed4281..04ea80043 100644 --- a/backpack/extensions/secondorder/diag_ggn/custom_module.py +++ b/backpack/extensions/secondorder/diag_ggn/custom_module.py @@ -1,4 +1,5 @@ """DiagGGN extensions for backpack's custom modules.""" + from backpack.core.derivatives.scale_module import ScaleModuleDerivatives from backpack.core.derivatives.sum_module import SumModuleDerivatives from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule diff --git a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py index 203b8ebd6..0ce986971 100644 --- a/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py +++ b/backpack/extensions/secondorder/diag_ggn/diag_ggn_base.py @@ -1,4 +1,5 @@ """Contains DiagGGN base class.""" + from typing import Callable, List, Tuple, Union from torch import Tensor diff --git a/backpack/extensions/secondorder/diag_ggn/embedding.py b/backpack/extensions/secondorder/diag_ggn/embedding.py index 1021b089b..211fe4cf4 100644 --- a/backpack/extensions/secondorder/diag_ggn/embedding.py +++ b/backpack/extensions/secondorder/diag_ggn/embedding.py @@ -1,4 +1,5 @@ """DiagGGN extension for Embedding.""" + from backpack.core.derivatives.embedding import EmbeddingDerivatives from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule diff --git a/backpack/extensions/secondorder/diag_ggn/permute.py b/backpack/extensions/secondorder/diag_ggn/permute.py index 7e7db118c..5ac8466be 100644 --- a/backpack/extensions/secondorder/diag_ggn/permute.py +++ b/backpack/extensions/secondorder/diag_ggn/permute.py @@ -1,4 +1,5 @@ """Module defining DiagGGNPermute.""" + from backpack.core.derivatives.permute import PermuteDerivatives from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule diff --git a/backpack/extensions/secondorder/diag_ggn/rnn.py b/backpack/extensions/secondorder/diag_ggn/rnn.py index 7c926c945..e588df396 100644 --- a/backpack/extensions/secondorder/diag_ggn/rnn.py +++ b/backpack/extensions/secondorder/diag_ggn/rnn.py @@ -1,4 +1,5 @@ """Module implementing GGN for RNN.""" + from backpack.core.derivatives.lstm import LSTMDerivatives from backpack.core.derivatives.rnn import RNNDerivatives from backpack.extensions.secondorder.diag_ggn.diag_ggn_base import DiagGGNBaseModule diff --git a/backpack/extensions/secondorder/diag_hessian/__init__.py b/backpack/extensions/secondorder/diag_hessian/__init__.py index 246d48e7b..5d4cfaeb2 100644 --- a/backpack/extensions/secondorder/diag_hessian/__init__.py +++ b/backpack/extensions/secondorder/diag_hessian/__init__.py @@ -3,9 +3,13 @@ - Hessian diagonal - Per-sample (individual) Hessian diagonal """ + from torch.nn import ( ELU, SELU, + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, AvgPool1d, AvgPool2d, AvgPool3d, @@ -38,6 +42,7 @@ from . import ( activations, + adaptive_avg_pool_nd, conv1d, conv2d, conv3d, @@ -81,6 +86,9 @@ def __init__(self): MaxPool3d: pooling.DiagHMaxPool3d(), AvgPool2d: pooling.DiagHAvgPool2d(), AvgPool3d: pooling.DiagHAvgPool3d(), + AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagHAdaptiveAvgPoolNd(1), + AdaptiveAvgPool2d: adaptive_avg_pool_nd.DiagHAdaptiveAvgPoolNd(2), + AdaptiveAvgPool3d: adaptive_avg_pool_nd.DiagHAdaptiveAvgPoolNd(3), ZeroPad2d: padding.DiagHZeroPad2d(), Conv1d: conv1d.DiagHConv1d(), Conv2d: conv2d.DiagHConv2d(), @@ -131,6 +139,9 @@ def __init__(self): MaxPool3d: pooling.DiagHMaxPool3d(), AvgPool2d: pooling.DiagHAvgPool2d(), AvgPool3d: pooling.DiagHAvgPool3d(), + AdaptiveAvgPool1d: adaptive_avg_pool_nd.DiagHAdaptiveAvgPoolNd(1), + AdaptiveAvgPool2d: adaptive_avg_pool_nd.DiagHAdaptiveAvgPoolNd(2), + AdaptiveAvgPool3d: adaptive_avg_pool_nd.DiagHAdaptiveAvgPoolNd(3), ZeroPad2d: padding.DiagHZeroPad2d(), Conv1d: conv1d.BatchDiagHConv1d(), Conv2d: conv2d.BatchDiagHConv2d(), diff --git a/backpack/extensions/secondorder/diag_hessian/adaptive_avg_pool_nd.py b/backpack/extensions/secondorder/diag_hessian/adaptive_avg_pool_nd.py new file mode 100644 index 000000000..44f10feaf --- /dev/null +++ b/backpack/extensions/secondorder/diag_hessian/adaptive_avg_pool_nd.py @@ -0,0 +1,16 @@ +"""DiagH extension for AdaptiveAvgPool.""" + +from backpack.core.derivatives.adaptive_avg_pool_nd import AdaptiveAvgPoolNDDerivatives +from backpack.extensions.secondorder.diag_hessian.diag_h_base import DiagHBaseModule + + +class DiagHAdaptiveAvgPoolNd(DiagHBaseModule): + """DiagH extension for AdaptiveAvgPool.""" + + def __init__(self, N: int): + """Initialization. + + Args: + N: number of free dimensions, e.g. use N=1 for AdaptiveAvgPool1d + """ + super().__init__(derivatives=AdaptiveAvgPoolNDDerivatives(N=N)) diff --git a/backpack/extensions/secondorder/diag_hessian/conv1d.py b/backpack/extensions/secondorder/diag_hessian/conv1d.py index 3f9a59e64..9d304c413 100644 --- a/backpack/extensions/secondorder/diag_hessian/conv1d.py +++ b/backpack/extensions/secondorder/diag_hessian/conv1d.py @@ -1,4 +1,5 @@ """Module extensions for diagonal Hessian properties of ``torch.nn.Conv1d``.""" + from backpack.core.derivatives.conv1d import Conv1DDerivatives from backpack.extensions.secondorder.diag_hessian.convnd import ( BatchDiagHConvND, diff --git a/backpack/extensions/secondorder/diag_hessian/conv2d.py b/backpack/extensions/secondorder/diag_hessian/conv2d.py index fe7d71a75..53501603b 100644 --- a/backpack/extensions/secondorder/diag_hessian/conv2d.py +++ b/backpack/extensions/secondorder/diag_hessian/conv2d.py @@ -1,4 +1,5 @@ """Module extensions for diagonal Hessian properties of ``torch.nn.Conv2d``.""" + from backpack.core.derivatives.conv2d import Conv2DDerivatives from backpack.extensions.secondorder.diag_hessian.convnd import ( BatchDiagHConvND, diff --git a/backpack/extensions/secondorder/diag_hessian/conv3d.py b/backpack/extensions/secondorder/diag_hessian/conv3d.py index 0b2c9e5eb..388e44c5f 100644 --- a/backpack/extensions/secondorder/diag_hessian/conv3d.py +++ b/backpack/extensions/secondorder/diag_hessian/conv3d.py @@ -1,4 +1,5 @@ """Module extensions for diagonal Hessian properties of ``torch.nn.Conv3d``.""" + from backpack.core.derivatives.conv3d import Conv3DDerivatives from backpack.extensions.secondorder.diag_hessian.convnd import ( BatchDiagHConvND, diff --git a/backpack/extensions/secondorder/hbp/custom_module.py b/backpack/extensions/secondorder/hbp/custom_module.py index ec8728c7e..04c52261f 100644 --- a/backpack/extensions/secondorder/hbp/custom_module.py +++ b/backpack/extensions/secondorder/hbp/custom_module.py @@ -1,4 +1,5 @@ """Module extensions for custom properties of HBPBaseModule.""" + from backpack.core.derivatives.scale_module import ScaleModuleDerivatives from backpack.core.derivatives.sum_module import SumModuleDerivatives from backpack.extensions.secondorder.hbp.hbpbase import HBPBaseModule diff --git a/backpack/extensions/secondorder/sqrt_ggn/activations.py b/backpack/extensions/secondorder/sqrt_ggn/activations.py index 3aaf8fff2..b6390913e 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/activations.py +++ b/backpack/extensions/secondorder/sqrt_ggn/activations.py @@ -1,4 +1,5 @@ """Contains extensions for activation layers used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.elu import ELUDerivatives from backpack.core.derivatives.leakyrelu import LeakyReLUDerivatives from backpack.core.derivatives.logsigmoid import LogSigmoidDerivatives diff --git a/backpack/extensions/secondorder/sqrt_ggn/base.py b/backpack/extensions/secondorder/sqrt_ggn/base.py index 425766f8e..6d4e37c6b 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/base.py +++ b/backpack/extensions/secondorder/sqrt_ggn/base.py @@ -1,4 +1,5 @@ """Contains base class for ``SqrtGGN{Exact, MC}`` module extensions.""" + from __future__ import annotations from typing import TYPE_CHECKING, Callable, List, Tuple, Union @@ -33,9 +34,7 @@ def __init__(self, derivatives: BaseDerivatives, params: List[str] = None): super().__init__(derivatives, params=params) - def _make_param_function( - self, param_str: str - ) -> Callable[ + def _make_param_function(self, param_str: str) -> Callable[ [Union[SqrtGGNExact, SqrtGGNMC], Module, Tuple[Tensor], Tuple[Tensor], Tensor], Tensor, ]: diff --git a/backpack/extensions/secondorder/sqrt_ggn/convnd.py b/backpack/extensions/secondorder/sqrt_ggn/convnd.py index 74a88651c..e81606258 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/convnd.py +++ b/backpack/extensions/secondorder/sqrt_ggn/convnd.py @@ -1,4 +1,5 @@ """Contains extensions for convolution layers used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.conv1d import Conv1DDerivatives from backpack.core.derivatives.conv2d import Conv2DDerivatives from backpack.core.derivatives.conv3d import Conv3DDerivatives diff --git a/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py b/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py index a18331976..eab96be33 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py +++ b/backpack/extensions/secondorder/sqrt_ggn/convtransposend.py @@ -1,4 +1,5 @@ """Contains transpose convolution layer extensions used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.conv_transpose1d import ConvTranspose1DDerivatives from backpack.core.derivatives.conv_transpose2d import ConvTranspose2DDerivatives from backpack.core.derivatives.conv_transpose3d import ConvTranspose3DDerivatives diff --git a/backpack/extensions/secondorder/sqrt_ggn/dropout.py b/backpack/extensions/secondorder/sqrt_ggn/dropout.py index 2f03b8aa9..097773307 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/dropout.py +++ b/backpack/extensions/secondorder/sqrt_ggn/dropout.py @@ -1,4 +1,5 @@ """Contains extensions for dropout layers used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.dropout import DropoutDerivatives from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule diff --git a/backpack/extensions/secondorder/sqrt_ggn/embedding.py b/backpack/extensions/secondorder/sqrt_ggn/embedding.py index 070ad217c..8d94a1194 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/embedding.py +++ b/backpack/extensions/secondorder/sqrt_ggn/embedding.py @@ -1,4 +1,5 @@ """Contains extension for the embedding layer used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.embedding import EmbeddingDerivatives from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule diff --git a/backpack/extensions/secondorder/sqrt_ggn/flatten.py b/backpack/extensions/secondorder/sqrt_ggn/flatten.py index 2a045c957..c564366b7 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/flatten.py +++ b/backpack/extensions/secondorder/sqrt_ggn/flatten.py @@ -1,4 +1,5 @@ """Contains extensions for the flatten layer used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.flatten import FlattenDerivatives from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule diff --git a/backpack/extensions/secondorder/sqrt_ggn/linear.py b/backpack/extensions/secondorder/sqrt_ggn/linear.py index 4aecca6f5..4e753cdca 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/linear.py +++ b/backpack/extensions/secondorder/sqrt_ggn/linear.py @@ -1,4 +1,5 @@ """Contains extension for the linear layer used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.linear import LinearDerivatives from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule diff --git a/backpack/extensions/secondorder/sqrt_ggn/losses.py b/backpack/extensions/secondorder/sqrt_ggn/losses.py index 6f0112e23..1e0cae136 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/losses.py +++ b/backpack/extensions/secondorder/sqrt_ggn/losses.py @@ -1,4 +1,5 @@ """Contains base class and extensions for losses used by ``SqrtGGN{Exact, MC}``.""" + from __future__ import annotations from typing import TYPE_CHECKING, Tuple, Union diff --git a/backpack/extensions/secondorder/sqrt_ggn/padding.py b/backpack/extensions/secondorder/sqrt_ggn/padding.py index 18574f685..703f17af4 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/padding.py +++ b/backpack/extensions/secondorder/sqrt_ggn/padding.py @@ -1,4 +1,5 @@ """Contains extensions for padding layers used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.zeropad2d import ZeroPad2dDerivatives from backpack.extensions.secondorder.sqrt_ggn.base import SqrtGGNBaseModule diff --git a/backpack/extensions/secondorder/sqrt_ggn/pooling.py b/backpack/extensions/secondorder/sqrt_ggn/pooling.py index e19cfba2a..d602600fc 100644 --- a/backpack/extensions/secondorder/sqrt_ggn/pooling.py +++ b/backpack/extensions/secondorder/sqrt_ggn/pooling.py @@ -1,4 +1,5 @@ """Contains extensions for pooling layers used by ``SqrtGGN{Exact, MC}``.""" + from backpack.core.derivatives.avgpool1d import AvgPool1DDerivatives from backpack.core.derivatives.avgpool2d import AvgPool2DDerivatives from backpack.core.derivatives.avgpool3d import AvgPool3DDerivatives diff --git a/backpack/hessianfree/ggnvp.py b/backpack/hessianfree/ggnvp.py index 165aff253..b40d29e9c 100644 --- a/backpack/hessianfree/ggnvp.py +++ b/backpack/hessianfree/ggnvp.py @@ -1,4 +1,5 @@ """Autodiff-only matrix-free multiplication by the generalized Gauss-Newton/Fisher.""" + from typing import List, Tuple from torch import Tensor diff --git a/backpack/hessianfree/hvp.py b/backpack/hessianfree/hvp.py index 548013ada..c7e7a1d9d 100644 --- a/backpack/hessianfree/hvp.py +++ b/backpack/hessianfree/hvp.py @@ -42,7 +42,9 @@ def hessian_vector_product(f, params, v, grad_params=None, detach=True): if grad_params is not None: df_dx = tuple(grad_params) else: - df_dx = torch.autograd.grad(f, params, create_graph=True, retain_graph=True) + df_dx = torch.autograd.grad( + f, params, create_graph=True, retain_graph=True, materialize_grads=True + ) Hv = R_op(df_dx, params, v) diff --git a/backpack/hessianfree/rop.py b/backpack/hessianfree/rop.py index 007b9e7ef..e50733cad 100644 --- a/backpack/hessianfree/rop.py +++ b/backpack/hessianfree/rop.py @@ -18,10 +18,17 @@ def R_op(ys, xs, vs, retain_graph=True, detach=True): create_graph=True, retain_graph=retain_graph, allow_unused=True, + materialize_grads=True, ) re = torch.autograd.grad( - gs, ws, grad_outputs=vs, create_graph=True, retain_graph=True, allow_unused=True + gs, + ws, + grad_outputs=vs, + create_graph=True, + retain_graph=True, + allow_unused=True, + materialize_grads=True, ) if detach: diff --git a/backpack/utils/__init__.py b/backpack/utils/__init__.py index 381b6c519..39f7fa2b1 100644 --- a/backpack/utils/__init__.py +++ b/backpack/utils/__init__.py @@ -1,9 +1 @@ """Contains utility functions.""" -from pkg_resources import get_distribution, packaging - -TORCH_VERSION = packaging.version.parse(get_distribution("torch").version) -TORCH_VERSION_AT_LEAST_1_9_1 = TORCH_VERSION >= packaging.version.parse("1.9.1") -TORCH_VERSION_AT_LEAST_2_0_0 = TORCH_VERSION >= packaging.version.parse("2.0.0") -TORCH_VERSION_AT_LEAST_1_13 = TORCH_VERSION >= packaging.version.parse("1.13") - -ADAPTIVE_AVG_POOL_BUG: bool = not TORCH_VERSION_AT_LEAST_2_0_0 diff --git a/backpack/utils/errors.py b/backpack/utils/errors.py index 690dc451b..5dd51e2bb 100644 --- a/backpack/utils/errors.py +++ b/backpack/utils/errors.py @@ -1,4 +1,5 @@ """Contains errors for BackPACK.""" + from typing import Union from warnings import warn diff --git a/backpack/utils/examples.py b/backpack/utils/examples.py index 3798ebde9..f946c9205 100644 --- a/backpack/utils/examples.py +++ b/backpack/utils/examples.py @@ -1,4 +1,5 @@ """Utility functions for examples.""" + from typing import Iterator, List, Tuple from torch import Tensor, stack, zeros diff --git a/backpack/utils/linear.py b/backpack/utils/linear.py index 6978545f6..6fe8b9c29 100644 --- a/backpack/utils/linear.py +++ b/backpack/utils/linear.py @@ -1,4 +1,5 @@ """Contains utility functions to extract the GGN diagonal for linear layers.""" + from torch import Tensor, einsum from torch.nn import Linear diff --git a/backpack/utils/module_classification.py b/backpack/utils/module_classification.py index c70e5247d..9be947cbc 100644 --- a/backpack/utils/module_classification.py +++ b/backpack/utils/module_classification.py @@ -1,4 +1,5 @@ """Contains util function for classification of modules.""" + from torch.fx import GraphModule from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, Module, MSELoss, Sequential from torch.nn.modules.loss import _Loss diff --git a/backpack/utils/subsampling.py b/backpack/utils/subsampling.py index 62d399f4c..75cb6e73c 100644 --- a/backpack/utils/subsampling.py +++ b/backpack/utils/subsampling.py @@ -1,4 +1,5 @@ """Utility functions to enable mini-batch subsampling in extensions.""" + from typing import List from torch import Tensor diff --git a/changelog.md b/changelog.md index da8d00832..57fcb6203 100644 --- a/changelog.md +++ b/changelog.md @@ -6,6 +6,35 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [1.7.0] - 2024-11-12 + +This version deprecates Python 3.8 and bumps the PyTorch dependency to +`>=2.2.0`, allowing several internal clean-ups. We also move from a +`setup.py`-based installation mechanism to `pyproject.toml`. + +### Added/New + +- Tutorial explaining how to write second-order extensions for custom layers + ([PR](https://github.com/f-dangel/backpack/pull/320), + [docs](https://docs.backpack.pt/en/master/use_cases/example_custom_module.html#second-order-extension)) +- Support `AdaptiveAvgPool{1,2,3}d` layers in `DiagHessian` and + `BatchDiagHessian` extensions + ([PR](https://github.com/f-dangel/backpack/pull/314)) + +### Fixed/Removed +- Fix RTD configuration after `system_packages` deprecation + ([PR1](https://github.com/f-dangel/backpack/pull/315), + [PR2](https://github.com/f-dangel/backpack/pull/323)) + +### Internal +- We will deprecate the `development` branch and from now on directly + merge new features and fixes into `master` +- Deprecate Python 3.8 and PyTorch <2.2.0 + ([PR1](https://github.com/f-dangel/backpack/pull/331), + [PR2](https://github.com/f-dangel/backpack/pull/333)) +- Move from `setup.py` to `pyproject.toml` + ([PR](https://github.com/f-dangel/backpack/pull/332)) + ## [1.6.0] - 2023-06-26 With this patch, BackPACK supports `torch==2.x` and deprecates `python3.7` @@ -460,7 +489,8 @@ co-authoring many PRs shipped in this release. Initial release -[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.6.0...HEAD +[Unreleased]: https://github.com/f-dangel/backpack/compare/v1.7.0...HEAD +[1.7.0]: https://github.com/f-dangel/backpack/compare/1.6.0...1.7.0 [1.6.0]: https://github.com/f-dangel/backpack/compare/1.5.2...1.6.0 [1.5.2]: https://github.com/f-dangel/backpack/compare/1.5.2...1.5.1 [1.5.1]: https://github.com/f-dangel/backpack/compare/1.5.1...1.5.0 diff --git a/docs_src/examples/basic_usage/example_all_in_one.py b/docs_src/examples/basic_usage/example_all_in_one.py index cb7aba42d..3f8e35915 100644 --- a/docs_src/examples/basic_usage/example_all_in_one.py +++ b/docs_src/examples/basic_usage/example_all_in_one.py @@ -2,8 +2,8 @@ Example using all extensions ============================== -Basic example showing how compute the gradient, -and and other quantities with BackPACK, +Basic example showing how to compute the gradient, +and other quantities with BackPACK, on a linear model for MNIST. """ diff --git a/docs_src/examples/use_cases/example_batched_jacobians.py b/docs_src/examples/use_cases/example_batched_jacobians.py index bf6a12ece..99fa9599e 100644 --- a/docs_src/examples/use_cases/example_batched_jacobians.py +++ b/docs_src/examples/use_cases/example_batched_jacobians.py @@ -26,6 +26,7 @@ Let's start by importing the required functionality and write a setup function to create our synthetic data. """ + import itertools from math import sqrt from typing import List, Tuple diff --git a/docs_src/examples/use_cases/example_custom_module.py b/docs_src/examples/use_cases/example_custom_module.py index e59b10c8d..1aa8bc2dd 100644 --- a/docs_src/examples/use_cases/example_custom_module.py +++ b/docs_src/examples/use_cases/example_custom_module.py @@ -1,19 +1,28 @@ """Custom module example ========================================= -This tutorial shows how to support a custom module in a simple fashion. -We focus on `BackPACK's first-order extensions `_. -They don't backpropagate additional information and thus require less functionality be -implemented. +This tutorial explains how to support new layers in BackPACK. + +We will write a custom module and show how to implement first-order extensions, +specifically :py:class:`BatchGrad `, and second-order +extensions, specifically :py:class:`DiagGGNExact `. Let's get the imports out of our way. """ # noqa: B950 +from typing import Tuple + import torch +from einops import einsum +from torch.nn.utils.convert_parameters import parameters_to_vector from backpack import backpack, extend from backpack.extensions import BatchGrad from backpack.extensions.firstorder.base import FirstOrderModuleExtension +from backpack.extensions.module_extension import ModuleExtension +from backpack.extensions.secondorder.diag_ggn import DiagGGNExact +from backpack.hessianfree.ggnvp import ggn_vector_product +from backpack.utils.convert_parameters import vector_to_parameter_list # make deterministic torch.manual_seed(0) @@ -32,42 +41,55 @@ class ScaleModule(torch.nn.Module): """Defines the module.""" - def __init__(self, weight=2.0): + def __init__(self, weight: float = 2.0): """Store scalar weight. Args: - weight(float, optional): Initial value for weight. Defaults to 2.0. + weight: Initial value for weight. Defaults to 2.0. """ super(ScaleModule, self).__init__() self.weight = torch.nn.Parameter(torch.tensor([weight])) - def forward(self, input): + def forward(self, input: torch.Tensor) -> torch.Tensor: """Defines forward pass. Args: - input(torch.Tensor): input + input: The layer input. Returns: - torch.Tensor: product of input and weight + Product of input and weight. """ return input * self.weight # %% -# You don't necessarily need to write a custom layer. Any PyTorch layer can be extended -# as described (it should be a :py:class:`torch.nn.Module `'s because -# BackPACK uses module hooks). +# We choose this custom simple layer as its related operations for backpropagation are +# easy to understand. Of course, you don't have to define a new layer if it already +# exists within :py:mod:`torch.nn`. +# +# It is important to understand though that BackPACK relies on module hooks and therefore +# can only be extended on the modular level: If your desired functionality is not a +# :py:class:`torch.nn.Module ` yet, you need to wrap it in a +# :py:class:`torch.nn.Module `. +# +# First-order extensions +# ---------------------- +# First we focus on `BackPACK's first-order extensions +# `_. +# They don't backpropagate additional information and thus require less functionality. # -# Custom module extension -# ----------------------- # Let's make BackPACK support computing individual gradients for ``ScaleModule``. # This is done by the :py:class:`BatchGrad ` extension. # To support the new module, we need to create a module extension that implements -# how individual gradients are extracted with respect to ``ScaleModule``'s parameter. +# how individual gradients are extracted with respect to ``ScaleModule``'s parameter +# called ``weight``. # # The module extension must implement methods named after the parameters passed to the -# constructor. Here it goes. +# constructor (in this case ``weight``). For a module with additional parametes, e.g. a +# ``bias``, an additional method named like the parameter has to be added. +# +# Here it goes. class ScaleModuleBatchGrad(FirstOrderModuleExtension): @@ -75,24 +97,36 @@ class ScaleModuleBatchGrad(FirstOrderModuleExtension): def __init__(self): """Store parameters for which individual gradients should be computed.""" - # specify parameter names super().__init__(params=["weight"]) - def weight(self, ext, module, g_inp, g_out, bpQuantities): + def weight( + self, + ext: BatchGrad, + module: ScaleModule, + g_inp: Tuple[torch.Tensor], + g_out: Tuple[torch.Tensor], + bpQuantities: None, + ) -> torch.Tensor: """Extract individual gradients for ScaleModule's ``weight`` parameter. Args: - ext(BatchGrad): extension that is used - module(ScaleModule): module that performed forward pass - g_inp(tuple[torch.Tensor]): input gradient tensors - g_out(tuple[torch.Tensor]): output gradient tensors - bpQuantities(None): additional quantities for second-order + ext: BackPACK extension that is used. + module: The module that performed forward pass. + g_inp: Input gradient tensors. + g_out: Output gradient tensors. + bpQuantities: The quantity backpropagated for the extension by BackPACK. + ``None`` for ``BatchGrad``. Returns: - torch.Tensor: individual gradients + The per-example gradients w.r.t. to the ``weight`` parameters. + Has shape ``[batch_size, *weight.shape]``. """ - show_useful = True + # The ``BatchGrad`` extension supports considering only a sub-set of + # data in the mini-batch. We will not account for this here for simplicity + # and therefore raise an exception if this feature is active. + assert ext.get_subsampling() is None + show_useful = True if show_useful: print("Useful quantities:") # output is saved under field output @@ -103,10 +137,16 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities): print("\tg_out[0].shape: ", g_out[0].shape) # actual computation - return (g_out[0] * module.input0).flatten(start_dim=1).sum(axis=1).unsqueeze(-1) + return einsum(g_out[0], module.input0, "batch d,batch d->batch").unsqueeze(-1) # %% +# +# Note that we have access to the layer's inputs and outputs from the forward pass, as +# they are stored by BackPACK. The computation itself basically +# computes vector-Jacobian-products of the incoming gradient with the layer's +# output-parameter Jacobian for each sample in the batch. +# # Lastly, we need to register the mapping between layer (``ScaleModule``) and layer # extension (``ScaleModuleBatchGrad``) in an instance of # :py:class:`BatchGrad `. @@ -121,8 +161,8 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities): # gradients with respect to ``ScaleModule``'s ``weight`` parameter. # %% -# Test custom module -# ------------------ +# Verifying first-order extensions +# -------------------------------- # Here, we verify the custom module extension on a small net with random inputs. # Let's create these. @@ -196,3 +236,437 @@ def weight(self, ext, module, g_inp, g_out, bpQuantities): "Individual gradients don't match:" + f"\n{grad_batch_autograd}\nvs.\n{grad_batch_backpack}" ) + +# %% +# Second-order extension +# ---------------------- +# Next, we focus on `BackPACK's second-order extensions +# `_. +# They backpropagate additional information and thus require more functionality to be +# implemented and a more in-depth understanding of BackPACK's internals and +# the quantity of interest. +# +# Let's make BackPACK support computing the exact diagonal of the generalized +# Gauss-Newton (GGN) matrix +# (:py:class:`DiagGGNExact `) for ``ScaleModule``. +# +# To do that, we need to write a module extension that implements how the exact +# GGN diagonal is computed for ``ScaleModule``'s parameter called ``weight``. +# Also, we need to implement how information is propagated from the layer's output +# to the layer's input. +# +# We need to understand the following details about +# :py:class:`DiagGGNExact `: +# +# 1. The extension backpropagates a matrix square root factorization of the loss +# function's Hessian w.r.t. its input via vector-Jacobian products. +# 2. To compute the GGN diagonal for a parameter, we need to multiply the incoming +# matrix square root of the GGN with the output-parameter Jacobian of the layer, +# then square it to obtain the GGN, and take its diagonal. +# +# These details vary between different second-order extensions and a good place to get +# started understanding their details is the BackPACK paper. +# +# We now describe the details for the GGN diagonal. +# +# Definition of the GGN +# ^^^^^^^^^^^^^^^^^^^^^ +# +# The GGN is calculated by multiplying the neural network's Jacobian (w.r.t. the +# parameters) with the Hessian of the loss function w.r.t. its prediction, +# +# .. math:: +# \mathbf{G}(\mathbf{\theta}) +# = +# (\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))^\top\; +# \nabla^2_{f_\mathbf{\theta}(x)} \ell (f_\mathbf{\theta}(x, y) \; +# (\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))\,. +# +# The Jacobian (left & right of RHS) is the matrix of all first-order derivatives +# of the function (neural network) w.r.t. the parameters. +# The Hessian (center of RHS) is the matrix of all second-order derivatives of the +# loss function w.r.t. the neural network's output. +# The GGN (LHS) is a matrix with dimension :math:`p \times p` where :math:`p` is the +# number of parameters. Note that in the presence of multiple data (a batch), the GGN +# is a sum/mean over per-sample GGNs. We will focus on the GGN for one sample, but +# also handle the parallel computation over all samples in the batch in the code. +# +# Our goal is to compute the diagonal of that matrix. To do that, we will re-write it +# in terms of a self-outer product as follows: Note that the loss function is convex. +# Let the neural network's prediction be +# :math:`f_\mathbf{\theta}(x) \in \mathbb{R}^C` where :math:`C` is the number of +# classes. Due to the convexity of :math:`\ell`, we can find a symmetric factorization +# of its Hessian, +# +# .. math:: +# \exists \mathbf{S} \in \mathbb{R}^{C \times C} +# \text{ s.t. } +# \mathbf{S} \mathbf{S}^\top +# = +# \nabla^2_{f_\mathbf{\theta}(x)} \ell (f_\mathbf{\theta}(x), y)\,. +# +# For our purposes, we will use a loss that is already supported within BackPACK, +# and there we don't need to be concerned how to compute this factorization. +# +# With that, we can define +# :math:`\mathbf{V}= (\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))^\top\;\mathbf{S}` +# and write the GGN as +# +# .. math:: +# \mathbf{G}(\mathbf{\theta}) = \mathbf{V} \mathbf{V}^\top\,. +# +# Instead of computing the GGN, we will compute :math:`\mathbf{V}` by backpropagating +# :math:`\mathbf{S}` via vector-Jacobian products, then square-and-take-the-diagonal +# to obtain the GGN's diagonal. +# +# Backpropagation for the GGN diagonal +# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +# To break down the multiplication with +# :math:`(\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))^\top` to the per-layer level, +# we will use the chain rule. +# +# Consider the following computation graph, where :math:`x = x^{(0)}`: +# +# .. image:: ../../images/comp_graph.jpg +# :width: 75% +# :align: center +# +# Each node in the graph represents a tensor. The arrows represent the flow of +# information and the computation associated with the incoming and outgoing tensors: +# :math:`f_{\mathbf{\theta}^{(k)}}^{(k)}(x^{(k)}) = x^{(k+1)}`. The intermediates +# correspond to the outputs of the neural network layers. +# +# The parameter vector :math:`\mathbf{\theta}` contains all NN parameters, flattened +# and concatenated over layers, +# +# .. math:: +# \mathbf{\theta} +# = +# \begin{pmatrix} +# \mathbf{\theta}^{(1)} +# \\ +# \mathbf{\theta}^{(2)} +# \\ +# \vdots +# \\ +# \mathbf{\theta}^{(l)} +# \end{pmatrix}\,. +# +# The Jacobian inherits this structure and is a stack of Jacobians of each layer, +# +# .. math:: +# (\mathbf{J}_\mathbf{\theta} f_\mathbf{\theta}(x))^\top +# = +# \begin{pmatrix} +# (\mathbf{J}_{\mathbf{\theta}^{(1)}} f_{\mathbf{\theta}}(x))^\top +# \\ +# (\mathbf{J}_{\mathbf{\theta}^{(2)}} f_{\mathbf{\theta}}(x))^\top +# \\ +# \vdots +# \\ +# (\mathbf{J}_{\mathbf{\theta}^{(l)}} f_\mathbf{\theta}(x))^\top +# \end{pmatrix}\,. +# +# The same holds for the matrix :math:`\mathbf{V}`, +# +# .. math:: +# \mathbf{V} +# = +# \begin{pmatrix} +# \mathbf{V}_{\mathbf{\theta}^{(1)}} +# \\ +# \mathbf{V}_{\mathbf{\theta}^{(2)}} +# \\ +# \vdots +# \\ +# \mathbf{V}_{\mathbf{\theta}^{(l)}} +# \end{pmatrix} +# = +# \begin{pmatrix} +# (\mathbf{J}_{\mathbf{\theta}^{(1)}} f_{\mathbf{\theta}}(x))^\top \mathbf{S} +# \\ +# (\mathbf{J}_{\mathbf{\theta}^{(2)}} f_{\mathbf{\theta}}(x))^\top \mathbf{S} +# \\ +# \vdots +# \\ +# (\mathbf{J}_{\mathbf{\theta}^{(l)}} f_\mathbf{\theta}(x))^\top \mathbf{S} +# \end{pmatrix}\,. +# +# With the chain rule recursions +# +# .. math:: +# (\mathbf{J}_{\mathbf{\theta}^{(k)}} f_{\mathbf{\theta}}(x))^\top +# = +# (\mathbf{J}_{\mathbf{\theta}^{(k)}} x^{(k)})^\top +# \;(\mathbf{J}_{x^{(k)}} f_{\mathbf{\theta}}(x))^\top +# +# and +# +# .. math:: +# (\mathbf{J}_{x^{(k-1)}} f_{\mathbf{\theta}}(x))^\top +# = +# (\mathbf{J}_{x^{(k-1)}} x^{(k)})^\top +# \;(\mathbf{J}_{x^{(k)}} f_{\mathbf{\theta}}(x))^\top +# +# we can identify the following recursions for the blocks of :math:`\mathbf{V}`: +# +# .. math:: +# \mathbf{V}_{\mathbf{\theta}^{(k)}} +# = +# (\mathbf{J}_{\mathbf{\theta}^{(k)}} x^{(k)})^\top +# \mathbf{V}_{x^{(k)}} +# +# and +# +# .. math:: +# \mathbf{V}_{x^{(k-1)}} +# = +# (\mathbf{J}_{x^{(k-1)}} x^{(k)})^\top +# \mathbf{V}_{x^{(k)}}\,. +# +# The above two recursions are the backpropagations performed by BackPACK's +# :py:class:`DiagGGNExact `. Layer :math:`k` +# receives the backpropagated quantity :math:`\mathbf{V}_{x^{(k)}}`, then +# (i) computes :math:`\mathbf{V}_{\mathbf{\theta}^{(k)}}`, then +# :math:`\mathrm{diag}(\mathbf{V}_{\mathbf{\theta}^{(k)}} +# \mathbf{V}_{\mathbf{\theta}^{(k)}}^\top)`, which is the GGN diagonal for +# the layer's parameters, and (ii) computes :math:`\mathbf{V}_{x^{(k-1)}}` +# which is sent to its parent layer :math:`k-1` which proceeds likewise. +# +# Implementation +# ^^^^^^^^^^^^^^ +# +# Now, let's create a module extension that specifies two methods: +# Step (i) from above is implemented by a function whose name +# matches the layer parameter's name (``weight`` in our case). Step (ii) +# is implemented by a function named ``backpropagate``. + + +class ScaleModuleDiagGGNExact(ModuleExtension): + """Backpropagation through ``ScaleModule`` for computing the GGN diagonal.""" + + def __init__(self): + """Store parameter names for which the GGN diagonal will be computed.""" + super().__init__(params=["weight"]) + + def backpropagate( + self, + ext: DiagGGNExact, + module: ScaleModule, + grad_inp: Tuple[torch.Tensor], + grad_out: Tuple[torch.Tensor], + bpQuantities: torch.Tensor, + ) -> torch.Tensor: + """Propagate GGN diagonal information from layer output to input. + + Args: + ext: The GGN diagonal extension. + module: Layer through which to perform backpropagation. + grad_inp: Input gradients. + grad_out:: Output gradients. + bpQuantities: Backpropagation information. For the GGN diagonal + this is a tensor V of shape ``[C, *module.output.shape]`` where + ``C`` is the neural network's output dimension and the layer's + output shape is typically something like ``[batch_size, D_out]``. + + Returns: + The GGN diagonal's backpropagated quantity V for the layer input. + Has shape ``[C, *layer.input0.shape]``. + """ + # The GGN diagonal extension supports considering only a sub-set of + # data in the mini-batch. We will not account for this here for simplicity + # and therefore raise an exception if this feature is active. + assert ext.get_subsampling() is None + + # Layer: + # - Input to the layer has shape ``[batch_size, D_in]`` + # - Output of the layer has shape ``[batch_size, D_out]`` + + # Loss function: + # - Neural networks prediction has shape ``[batch_size, C]`` + + # Quantity backpropagated by ``DiagGGNExact`` has shape + # ``[C, batch_size, D_out]`` imagine this as a set of ``C`` vectors + # which all have the same shape as the layer's output that represent + # the rows of the incoming V. + + # What we need to to do: + # - Take each of the C vectors + # - Multiply each of them with the layer's output-input Jacobian. + # The result of each VJP will have shape ``[batch_size, D_in]`` + # - Stack them back together into a tensor of shape + # ``[C, batch_size, D_in]`` that represents the outgoing V + + input0 = module.input0 + output = module.output + weight = module.weight + V_out = bpQuantities + + C = V_out.shape[0] + batch_size, D_in = input0.shape + assert V_out.shape == (C, *output.shape) + + show_useful = True + if show_useful: + print("backpropagate: Useful quantities:") + print(f" module.output.shape: {output.shape}") + print(f" module.input.shape: {input0.shape}") + print(f" V_out.shape: {V_out.shape}") + print(f" V_in.shape: {(C, *input0.shape)}") + + V_in = torch.zeros( + (C, batch_size, D_in), device=input0.device, dtype=input0.dtype + ) + + # forward pass computation performs: ``X * weight`` + # (``[batch_size, D_in] * [1] [batch_size, D_out=D_in]``) + for c in range(C): + V_in[c] = bpQuantities[c] * weight + # NOTE We could do this more efficiently with the following: + # V_in = V_out * weight + assert V_in.shape == (C, *input0.shape) + + return V_in + + def weight( + self, + ext: DiagGGNExact, + module: ScaleModule, + g_inp: Tuple[torch.Tensor], + g_out: Tuple[torch.Tensor], + bpQuantities: torch.Tensor, + ) -> torch.Tensor: + """Extract the GGN diagonal for the ``weight`` parameter. + + Args: + ext: The BackPACK extension. + module: Module through which to perform backpropagation. + grad_inp: Input gradients. + grad_out: Output gradients. + bpQuantities: Backpropagation information. For the GGN diagonal + this is a tensor V of shape ``[C, *module.output.shape]`` where + ``C`` is the neural network's output dimension and the layer's + output shape is typically something like ``[batch_size, D_out]``. + + Returns: + The GGN diagonal w.r.t. the layer's ``weight``. + Has shape ``[batch_size, *weight.shape]``. + """ + input0 = module.input0 + output = module.output + weight = module.weight + V_out = bpQuantities + + C = bpQuantities.shape[0] + assert V_out.shape == (C, *output.shape) + + show_useful = True + if show_useful: + print("weight: Useful quantities:") + print(f" module.output.shape {output.shape}") + print(f" module.input.shape {input0.shape}") + print(f" module.weight.shape {weight.shape}") + print(f" bpQuantities.shape {bpQuantities.shape}") + print(f" returned.shape {weight.shape}") + + # forward pass computation performs: ``X * weight`` + # (``[batch_size, D_in] * [1] = [batch_size, D_out]``) + V_theta = einsum(V_out, input0, "c batch d, batch d -> c batch") + # compute diag( V_theta @ V_theta^T ) + weight_ggn_diag = einsum(V_theta, V_theta, "c batch, c batch ->").unsqueeze(0) + + assert weight_ggn_diag.shape == weight.shape + return weight_ggn_diag + + +# %% +# After we have implemented the module extension we need to register the mapping +# between layer (``ScaleModule``) and layer extension (``ScaleModuleDiagGGNExact``) +# in an instance of :py:class:`DiagGGNExact `. + +extension = DiagGGNExact() +extension.set_module_extension(ScaleModule, ScaleModuleDiagGGNExact()) + +# %% +# We can then use this extension to compute the exact GGN diagonal for +# ``ScaleModule``s. +# +# +# Verifying second-order extensions +# --------------------------------- +# +# Here, we verify the custom module extension on a small net with random inputs. +# First, the setup: + +batch_size = 10 +input_size = 4 + +inputs = torch.randn(batch_size, input_size, device=device) +targets = torch.randint(0, 2, (batch_size,), device=device) + +reduction = ["mean", "sum"][1] + +my_module = ScaleModule().to(device) +lossfunc = torch.nn.CrossEntropyLoss(reduction=reduction).to(device) + +# %% +# As ground truth, we compute the GGN diagonal using GGN-vector products +# which exclusively rely on PyTorch's autodiff: +params = list(my_module.parameters()) +ggn_dim = sum(p.numel() for p in params) +diag_ggn_flat = torch.zeros(ggn_dim, device=inputs.device, dtype=inputs.dtype) + +outputs = my_module(inputs) +loss = lossfunc(outputs, targets) + +# compute GGN-vector products with all one-hot vectors +for d in range(ggn_dim): + # create unit vector d + e_d = torch.zeros(ggn_dim, device=inputs.device, dtype=inputs.dtype) + e_d[d] = 1.0 + # convert to list format + e_d = vector_to_parameter_list(e_d, params) + + # multiply GGN onto the unit vector -> get back column d of the GGN + ggn_e_d = ggn_vector_product(loss, outputs, my_module, e_d) + # flatten + ggn_e_d = parameters_to_vector(ggn_e_d) + + # extract the d-th entry (which is on the GGN's diagonal) + diag_ggn_flat[d] = ggn_e_d[d] + +print(f"Tr(GGN): {diag_ggn_flat.sum():.3f}") + +# %% +# Now we can use BackPACK to compute the GGN diagonal: + +my_module = extend(my_module) +lossfunc = extend(lossfunc) + +outputs = my_module(inputs) +loss = lossfunc(outputs, targets) + +with backpack(extension): + loss.backward() + +diag_ggn_flat_backpack = parameters_to_vector( + [p.diag_ggn_exact for p in my_module.parameters()] +) +print(f"Tr(GGN, BackPACK): {diag_ggn_flat_backpack.sum():.3f}") + +# %% +# +# Finally, let's compare the two results. + +match = torch.allclose(diag_ggn_flat, diag_ggn_flat_backpack) +print(f"Do manual and BackPACK GGN match? {match}") + +if not match: + raise AssertionError( + "Exact GGN diagonals do not match:" + + f"\n{diag_ggn_flat}\nvs.\n{diag_ggn_flat_backpack}" + ) + +# %% +# +# That's all for now. diff --git a/docs_src/examples/use_cases/example_resnet_all_in_one.py b/docs_src/examples/use_cases/example_resnet_all_in_one.py index 0b3398e3b..5d7bd2649 100644 --- a/docs_src/examples/use_cases/example_resnet_all_in_one.py +++ b/docs_src/examples/use_cases/example_resnet_all_in_one.py @@ -1,6 +1,7 @@ """Residual networks ==================== """ + # %% # There are three different approaches to using BackPACK with ResNets. # diff --git a/docs_src/examples/use_cases/example_rnn.py b/docs_src/examples/use_cases/example_rnn.py index 5e1f2a374..8459fa316 100644 --- a/docs_src/examples/use_cases/example_rnn.py +++ b/docs_src/examples/use_cases/example_rnn.py @@ -1,6 +1,7 @@ """Recurrent networks ==================== """ + # %% # There are two different approaches to using BackPACK with RNNs. # @@ -21,13 +22,9 @@ # # Not all extensions support RNNs (yet). Please create a feature request in the # repository if the extension you need is not supported. - -from pkg_resources import packaging - -# %% +# # Let's get the imports out of the way. from torch import ( - _C, allclose, cat, device, @@ -44,20 +41,11 @@ from backpack.custom_module.permute import Permute from backpack.custom_module.reduce_tuple import ReduceTuple from backpack.extensions import BatchGrad, DiagGGNExact -from backpack.utils import TORCH_VERSION from backpack.utils.examples import autograd_diag_ggn_exact manual_seed(0) DEVICE = device("cpu") # Verification via autograd only works on CPU -# %% -# -# .. note:: -# Due to `#99413 `_, we have to disable -# MKLDNN for PyTorch 2.0.1 to get the double-backward through LSTMs working. -if TORCH_VERSION == packaging.version.parse("2.0.1"): - _C._set_mkldnn_enabled(False) - # %% # For this demo, we will use the Tolstoi Char RNN from diff --git a/docs_src/images/comp_graph.jpg b/docs_src/images/comp_graph.jpg new file mode 100644 index 000000000..23d9be235 Binary files /dev/null and b/docs_src/images/comp_graph.jpg differ diff --git a/fully_documented.txt b/fully_documented.txt index f05271763..ebf4bd3ba 100644 --- a/fully_documented.txt +++ b/fully_documented.txt @@ -1,5 +1,3 @@ -setup.py - backpack/__init__.py backpack/context.py backpack/custom_module/ diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 000000000..bb920a114 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,94 @@ +# This file is used to configure the project. +# Read more about the various options under: +# https://packaging.python.org/en/latest/guides/writing-pyproject-toml +# https://setuptools.pypa.io/en/latest/userguide/pyproject_config.html +[build-system] +requires = ["setuptools >= 61.0", "setuptools_scm"] +build-backend = "setuptools.build_meta" +############################################################################### +# Main library # +############################################################################### +[project] +name = "backpack-for-pytorch" +authors = [ + { name = "Felix Dangel" }, + { name = "Frederik Kunstner" }, +] +urls = { Repository = "https://github.com/f-dangel/backpack" } +description = "BackPACK: Packing more into backprop" +readme = { file = "README.md", content-type = "text/markdown; charset=UTF-8; variant=GFM" } +license = { text = "MIT" } +# Add all kinds of additional classifiers as defined under +# https://pypi.python.org/pypi?%3Aaction=list_classifiers +classifiers = [ + "Development Status :: 4 - Beta", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", +] +dynamic = ["version"] +# Dependencies of the project: +dependencies = [ + "torch>=2.2.0", + "torchvision>=0.7.0", + "einops>=0.3.0,<1.0.0", + "unfoldNd>=0.2.0,<1.0.0", +] +# Require a specific Python version, e.g. Python 2.7 or >= 3.4 +requires-python = ">=3.9" +############################################################################### +# Development dependencies # +############################################################################### +[project.optional-dependencies] +# Dependencies needed to run the tests. +test = [ + "scipy", + "numpy<2", + "pytest>=4.5.0,<5.0.0", + "pytest-benchmark>=3.2.2,<4.0.0", + "pytest-optional-tests>=0.1.1", + "pytest-cov", + "coveralls", +] +# Dependencies needed for linting. +lint = [ + "black", + "flake8", + "mccabe", + "pycodestyle", + "pyflakes", + "pep8-naming", + "flake8-bugbear", + "flake8-comprehensions", + "flake8-tidy-imports", + "darglint", + "pydocstyle", + "isort", +] +# Dependencies needed to build/view the documentation. +docs = [ + "matplotlib", + "sphinx<7", + "sphinx-gallery", + "sphinx-rtd-theme", + "memory_profiler", + "tabulate", +] +############################################################################### +# Development tool configurations # +############################################################################### +[tool.setuptools] +packages = ["backpack"] +[tool.setuptools_scm] +[tool.isort] +profile = "black" +multi_line_output = 3 +include_trailing_comma = true +force_grid_wrap = 0 +use_parentheses = true +[tool.pydocstyle] +convention = "google" +match = '.*\.py' \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index 8be33bc68..940ff4910 100644 --- a/setup.cfg +++ b/setup.cfg @@ -2,96 +2,8 @@ # Read more about the various options under: # http://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files -############################################################################### -# Main library # -############################################################################### - -[metadata] -name = backpack-for-pytorch -author = Felix Dangel, Frederik Kunstner -url = https://github.com/f-dangel/backpack -description = BackPACK: Packing more into backprop -long_description = file: README.md -long_description_content_type = text/markdown; charset=UTF-8; variant=GFM -license = MIT -# Change if running only on Windows, Mac or Linux (comma-separated) -platforms = any -# Add all kinds of additional classifiers as defined under -# https://pypi.python.org/pypi?%3Aaction=list_classifiers -classifiers = - Development Status :: 4 - Beta - License :: OSI Approved :: MIT License - Operating System :: OS Independent - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - -[options] -zip_safe = False -packages = find: -include_package_data = True -setup_requires = - setuptools_scm -# Dependencies of the project (semicolon/line-separated): -install_requires = - torch >= 1.9.0 - torchvision >= 0.7.0 - einops >= 0.3.0, < 1.0.0 - unfoldNd >= 0.2.0, < 1.0.0 -# Require a specific Python version, e.g. Python 2.7 or >= 3.4 -python_requires = >=3.8 - -[options.packages.find] -exclude = test* - -############################################################################### -# Development dependencies # -############################################################################### - -[options.extras_require] -# Dependencies needed to run the tests (semicolon/line-separated) -test = - scipy - pytest >= 4.5.0, < 5.0.0 - pytest-benchmark >= 3.2.2, < 4.0.0 - pytest-optional-tests >= 0.1.1 - pytest-cov - coveralls - -# Dependencies needed for linting (semicolon/line-separated) -lint = - darglint - flake8 - mccabe - pycodestyle - pydocstyle - pyflakes - pep8-naming - flake8-bugbear - flake8-comprehensions - flake8-tidy-imports - black - isort - -# Dependencies needed to build/view the documentation (semicolon/line-separated) -docs = - matplotlib - sphinx-gallery - sphinx-rtd-theme - memory_profiler - tabulate - -############################################################################### -# Development tool configurations # -############################################################################### - -[isort] -profile=black -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -use_parentheses=True - +# Note: These tools do not yet support `pyproject.toml`, but these options +# should be moved there once support is added. [flake8] select = B,C,E,F,P,W,B9 max-line-length = 88 @@ -146,8 +58,4 @@ exclude = docs, build, .git, docs_src/rtd, docs_src/rtd_output, .eggs [darglint] docstring_style = google # short, long, full -strictness = full - -[pydocstyle] -convention = google -match = .*\.py \ No newline at end of file +strictness = full \ No newline at end of file diff --git a/setup.py b/setup.py deleted file mode 100644 index 3358418d6..000000000 --- a/setup.py +++ /dev/null @@ -1,18 +0,0 @@ -"""Setup file for BackPACK. - -Use ``setup.cfg`` for configuration. -""" -import sys - -from pkg_resources import VersionConflict, require -from setuptools import setup - -try: - require("setuptools>=38.3") -except VersionConflict: - print("Error: version of setuptools is too old (<38.3)!") - sys.exit(1) - - -if __name__ == "__main__": - setup(use_scm_version=True) diff --git a/test/adaptive_avg_pool/problem.py b/test/adaptive_avg_pool/problem.py index e96713ed4..9a3de1903 100644 --- a/test/adaptive_avg_pool/problem.py +++ b/test/adaptive_avg_pool/problem.py @@ -1,4 +1,5 @@ """Test problems for the AdaptiveAvgPool shape checker.""" + from __future__ import annotations import copy diff --git a/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py b/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py index 58fdaaebb..ff3982a0f 100644 --- a/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py +++ b/test/adaptive_avg_pool/settings_adaptive_avg_pool_nd.py @@ -1,4 +1,5 @@ """Settings to run test_adaptive_avg_pool_nd.""" + from typing import Any, Dict, List from torch import Size diff --git a/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py b/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py index eebf1931b..f628b010e 100644 --- a/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py +++ b/test/adaptive_avg_pool/test_adaptive_avg_pool_nd.py @@ -1,4 +1,5 @@ """Test the shape checker of AdaptiveAvgPoolNDDerivatives.""" + from test.adaptive_avg_pool.problem import AdaptiveAvgPoolProblem, make_test_problems from test.adaptive_avg_pool.settings_adaptive_avg_pool_nd import SETTINGS from typing import List diff --git a/test/conv2d_test.py b/test/conv2d_test.py index 0421c8697..57896fc93 100644 --- a/test/conv2d_test.py +++ b/test/conv2d_test.py @@ -5,6 +5,7 @@ Chellapilla: High Performance Convolutional Neural Networks for Document Processing (2007). """ + from random import choice, randint import pytest diff --git a/test/converter/converter_cases.py b/test/converter/converter_cases.py index 0715be0b5..34fcd1843 100644 --- a/test/converter/converter_cases.py +++ b/test/converter/converter_cases.py @@ -7,6 +7,7 @@ Network with multiply operation Network with add operation """ + import abc from typing import List, Type diff --git a/test/converter/resnet_cases.py b/test/converter/resnet_cases.py index 4d5e00de7..46b716a46 100644 --- a/test/converter/resnet_cases.py +++ b/test/converter/resnet_cases.py @@ -1,4 +1,5 @@ """Contains example ResNets to be used in tests.""" + from torch import flatten, tensor from torch.nn import ( AdaptiveAvgPool2d, diff --git a/test/converter/test_converter.py b/test/converter/test_converter.py index 4f7054e1e..090c14473 100644 --- a/test/converter/test_converter.py +++ b/test/converter/test_converter.py @@ -3,9 +3,9 @@ - whether converted network is equivalent to original network - whether DiagGGN runs without errors on new network """ + from test.converter.converter_cases import CONVERTER_MODULES, ConverterModule from test.core.derivatives.utils import classification_targets, regression_targets -from test.utils.skip_test import skip_torch_2_0_1_lstm from typing import Tuple from pytest import fixture @@ -32,7 +32,6 @@ def model_and_input(request) -> Tuple[Module, Tensor, Module]: """ manual_seed(0) model: ConverterModule = request.param() - skip_torch_2_0_1_lstm(model) inputs: Tensor = model.input_fn() loss_fn: Module = model.loss_fn() yield model, inputs, loss_fn diff --git a/test/core/derivatives/__init__.py b/test/core/derivatives/__init__.py index 36bc03bfc..acf983a38 100644 --- a/test/core/derivatives/__init__.py +++ b/test/core/derivatives/__init__.py @@ -1,4 +1,5 @@ """Test functionality of `backpack.core.derivatives` module.""" + from torch.nn import ( ELU, LSTM, diff --git a/test/core/derivatives/batch_norm_settings.py b/test/core/derivatives/batch_norm_settings.py index 7994e1716..d0a1c6fd3 100644 --- a/test/core/derivatives/batch_norm_settings.py +++ b/test/core/derivatives/batch_norm_settings.py @@ -12,6 +12,7 @@ "id_prefix" (str): Prefix to be included in the test name. "seed" (int): seed for the random number for torch.rand """ + from test.utils.evaluation_mode import initialize_batch_norm_eval from torch import rand diff --git a/test/core/derivatives/derivatives_test.py b/test/core/derivatives/derivatives_test.py index e5219ed2d..4cb34cebe 100644 --- a/test/core/derivatives/derivatives_test.py +++ b/test/core/derivatives/derivatives_test.py @@ -23,12 +23,10 @@ from test.core.derivatives.settings import SETTINGS from test.core.derivatives.slicing_settings import CUSTOM_SLICING_SETTINGS from test.utils.skip_test import ( - skip_adaptive_avg_pool3d_cuda, skip_batch_norm_train_mode_with_subsampling, skip_BCEWithLogitsLoss, skip_BCEWithLogitsLoss_non_binary_labels, skip_subsampling_conflict, - skip_torch_2_0_1_lstm, ) from typing import List, Union from warnings import warn @@ -112,9 +110,11 @@ def test_param_mjp( print(f"testing with save_memory={save_memory}") mat = rand_mat_like_output(V, problem, subsampling=subsampling) - with weight_jac_t_save_memory( - save_memory=save_memory - ) if test_save_memory else nullcontext(): + with ( + weight_jac_t_save_memory(save_memory=save_memory) + if test_save_memory + else nullcontext() + ): backpack_res = BackpackDerivatives(problem).param_mjp( param_str, mat, sum_batch, subsampling=subsampling ) @@ -138,7 +138,6 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: V: Number of vectorized Jacobian-vector products. Default: ``3``. """ problem.set_up() - skip_torch_2_0_1_lstm(problem.module) mat = rand(V, *problem.input_shape).to(problem.device) backpack_res = BackpackDerivatives(problem).jac_mat_prod(mat) @@ -167,21 +166,15 @@ def test_jac_mat_prod(problem: DerivativesTestProblem, V: int = 3) -> None: + CUSTOM_SLICING_MODULE_IDS, ) def test_jac_t_mat_prod( - problem: DerivativesTestProblem, - subsampling: Union[None, List[int]], - request, - V: int = 3, + problem: DerivativesTestProblem, subsampling: Union[None, List[int]], V: int = 3 ) -> None: """Test the transposed Jacobian-matrix product. Args: problem: Problem for derivative test. subsampling: Indices of active samples. - request: Pytest request, used for getting id. V: Number of vectorized transposed Jacobian-vector products. Default: ``3``. """ - skip_adaptive_avg_pool3d_cuda(request) - problem.set_up() skip_batch_norm_train_mode_with_subsampling(problem, subsampling) skip_subsampling_conflict(problem, subsampling) @@ -484,7 +477,7 @@ def test_sum_hessian_should_fail(problem): @mark.parametrize("problem", NO_LOSS_PROBLEMS, ids=NO_LOSS_IDS) -def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None: +def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem) -> None: """Test KFRA backpropagation. H_in → 1/N ∑ₙ Jₙ^T H_out Jₙ @@ -497,10 +490,7 @@ def test_ea_jac_t_mat_jac_prod(problem: DerivativesTestProblem, request) -> None Args: problem: Test case. - request: PyTest request, used to get test id. """ - skip_adaptive_avg_pool3d_cuda(request) - problem.set_up() out_features = problem.output_shape[1:].numel() mat = rand(out_features, out_features).to(problem.device) diff --git a/test/core/derivatives/embedding_settings.py b/test/core/derivatives/embedding_settings.py index e6f8b1486..53c34ab8c 100644 --- a/test/core/derivatives/embedding_settings.py +++ b/test/core/derivatives/embedding_settings.py @@ -1,4 +1,5 @@ """Settings for testing derivatives of Embedding.""" + from torch import randint from torch.nn import Embedding diff --git a/test/core/derivatives/implementation/autograd.py b/test/core/derivatives/implementation/autograd.py index 4f72c0ba2..4b03edbf7 100644 --- a/test/core/derivatives/implementation/autograd.py +++ b/test/core/derivatives/implementation/autograd.py @@ -1,4 +1,5 @@ """Derivatives computed with PyTorch's autograd.""" + from test.core.derivatives.implementation.base import DerivativesImplementation from typing import List @@ -235,8 +236,7 @@ def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor: w.r.t. `x[a, b, c]`. If ``tensor`` is linear in ``x``, autograd raises a ``RuntimeError``. - If ``tensor`` does not depend on ``x``, autograd raises an ``AttributeError``. - In both cases, a Hessian of zeros is created manually and returned. + In that case, a Hessian of zeros is created manually and returned. Arguments: tensor: An arbitrary tensor. @@ -248,7 +248,7 @@ def _elementwise_hessian(self, tensor: Tensor, x: Tensor) -> Tensor: for t in tensor.flatten(): try: yield self._hessian(t, x) - except (RuntimeError, AttributeError, TypeError): + except RuntimeError: yield zeros(*x.shape, *x.shape, device=x.device, dtype=x.dtype) def hessian_is_zero(self) -> bool: # noqa: D102 diff --git a/test/core/derivatives/implementation/backpack.py b/test/core/derivatives/implementation/backpack.py index f6e1a1693..3fd2a3bb9 100644 --- a/test/core/derivatives/implementation/backpack.py +++ b/test/core/derivatives/implementation/backpack.py @@ -1,4 +1,5 @@ """Contains derivative calculation with BackPACK.""" + from test.core.derivatives.implementation.base import DerivativesImplementation from test.utils import chunk_sizes from typing import List diff --git a/test/core/derivatives/implementation/base.py b/test/core/derivatives/implementation/base.py index 1bf91c387..5ef5d88df 100644 --- a/test/core/derivatives/implementation/base.py +++ b/test/core/derivatives/implementation/base.py @@ -1,4 +1,5 @@ """Contains DerivativesImplementation, the base class for autograd and backpack.""" + from abc import ABC, abstractmethod from typing import List diff --git a/test/core/derivatives/scale_module_settings.py b/test/core/derivatives/scale_module_settings.py index 3bc089ef9..b580f9550 100644 --- a/test/core/derivatives/scale_module_settings.py +++ b/test/core/derivatives/scale_module_settings.py @@ -1,4 +1,5 @@ """Test settings for ScaleModule derivatives.""" + from torch import rand from torch.nn import Identity diff --git a/test/core/derivatives/slicing_settings.py b/test/core/derivatives/slicing_settings.py index 70863c26d..ea23712bf 100644 --- a/test/core/derivatives/slicing_settings.py +++ b/test/core/derivatives/slicing_settings.py @@ -1,6 +1,5 @@ """Contains test cases of BackPACK's custom Slicing module.""" - from torch import rand from backpack.custom_module.slicing import Slicing diff --git a/test/core/derivatives/utils.py b/test/core/derivatives/utils.py index 21bfe1834..c699673d6 100644 --- a/test/core/derivatives/utils.py +++ b/test/core/derivatives/utils.py @@ -1,4 +1,5 @@ """Utility functions to test `backpack.core.derivatives`.""" + from test.core.derivatives import derivatives_for from typing import Tuple, Type diff --git a/test/extensions/automated_settings.py b/test/extensions/automated_settings.py index f2334c515..1ac1ae4f3 100644 --- a/test/extensions/automated_settings.py +++ b/test/extensions/automated_settings.py @@ -1,4 +1,5 @@ """Contains helpers to create CNN test cases.""" + from test.core.derivatives.utils import classification_targets from typing import Any, Tuple, Type diff --git a/test/extensions/firstorder/batch_grad/batch_grad_settings.py b/test/extensions/firstorder/batch_grad/batch_grad_settings.py index 7b1926d63..f85631aad 100644 --- a/test/extensions/firstorder/batch_grad/batch_grad_settings.py +++ b/test/extensions/firstorder/batch_grad/batch_grad_settings.py @@ -3,6 +3,7 @@ The tests are taken from ``test.extensions.firstorder.firstorder_settings``, but additional custom tests can be defined here by appending it to the list. """ + from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS SHARED_SETTINGS = FIRSTORDER_SETTINGS diff --git a/test/extensions/firstorder/batch_grad/test_batch_grad.py b/test/extensions/firstorder/batch_grad/test_batch_grad.py index 7c916568e..790f548bb 100644 --- a/test/extensions/firstorder/batch_grad/test_batch_grad.py +++ b/test/extensions/firstorder/batch_grad/test_batch_grad.py @@ -1,4 +1,5 @@ """Test BackPACK's ``BatchGrad`` extension.""" + from test.automated_test import check_sizes_and_values from test.extensions.firstorder.batch_grad.batch_grad_settings import ( BATCH_GRAD_SETTINGS, diff --git a/test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py b/test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py index b26059607..16e02930c 100644 --- a/test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py +++ b/test/extensions/firstorder/batch_l2_grad/batchl2grad_settings.py @@ -3,6 +3,7 @@ The tests are taken from `test.extensions.firstorder.firstorder_settings`, but additional custom tests can be defined here by appending it to the list. """ + from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS SHARED_SETTINGS = FIRSTORDER_SETTINGS diff --git a/test/extensions/firstorder/firstorder_settings.py b/test/extensions/firstorder/firstorder_settings.py index 4fc1f1a8e..2eb59cf29 100644 --- a/test/extensions/firstorder/firstorder_settings.py +++ b/test/extensions/firstorder/firstorder_settings.py @@ -18,6 +18,7 @@ "id_prefix" (str): Prefix to be included in the test name. "seed" (int): seed set before initializing a case. """ + from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.automated_settings import make_simple_cnn_setting from test.utils.evaluation_mode import initialize_training_false_recursive diff --git a/test/extensions/firstorder/sum_grad_squared/sumgradsquared_settings.py b/test/extensions/firstorder/sum_grad_squared/sumgradsquared_settings.py index 7d5aa2daa..3b3d465d6 100644 --- a/test/extensions/firstorder/sum_grad_squared/sumgradsquared_settings.py +++ b/test/extensions/firstorder/sum_grad_squared/sumgradsquared_settings.py @@ -3,6 +3,7 @@ The tests are taken from `test.extensions.firstorder.firstorder_settings`, but additional custom tests can be defined here by appending it to the list. """ + from test.core.derivatives.utils import classification_targets from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS diff --git a/test/extensions/firstorder/sum_grad_squared/test_sumgradsquared.py b/test/extensions/firstorder/sum_grad_squared/test_sumgradsquared.py index 0f244159c..be26cdab3 100644 --- a/test/extensions/firstorder/sum_grad_squared/test_sumgradsquared.py +++ b/test/extensions/firstorder/sum_grad_squared/test_sumgradsquared.py @@ -6,6 +6,7 @@ - sum of the square of batch gradients of convolutional layers """ + from test.automated_test import check_sizes_and_values from test.extensions.firstorder.sum_grad_squared.sumgradsquared_settings import ( SUMGRADSQUARED_SETTINGS, diff --git a/test/extensions/firstorder/variance/test_variance.py b/test/extensions/firstorder/variance/test_variance.py index 8680c2086..760cbe32d 100644 --- a/test/extensions/firstorder/variance/test_variance.py +++ b/test/extensions/firstorder/variance/test_variance.py @@ -1,4 +1,5 @@ """Test BackPACK's ``Variance`` extension.""" + from test.automated_test import check_sizes_and_values from test.extensions.firstorder.variance.variance_settings import VARIANCE_SETTINGS from test.extensions.implementation.autograd import AutogradExtensions diff --git a/test/extensions/firstorder/variance/variance_settings.py b/test/extensions/firstorder/variance/variance_settings.py index c8a8de2da..1dc8b0442 100644 --- a/test/extensions/firstorder/variance/variance_settings.py +++ b/test/extensions/firstorder/variance/variance_settings.py @@ -3,6 +3,7 @@ Uses shared test cases from `test.extensions.firstorder.firstorder_settings`, and the local cases defined in this file. """ + from test.extensions.firstorder.firstorder_settings import FIRSTORDER_SETTINGS SHARED_SETTINGS = FIRSTORDER_SETTINGS diff --git a/test/extensions/graph_clear_test.py b/test/extensions/graph_clear_test.py index 17b6419ee..1fce6ed2c 100644 --- a/test/extensions/graph_clear_test.py +++ b/test/extensions/graph_clear_test.py @@ -1,4 +1,5 @@ """Test whether the graph is clear after a backward pass.""" + from typing import Tuple from pytest import fixture diff --git a/test/extensions/implementation/autograd.py b/test/extensions/implementation/autograd.py index b4f1701fd..f006b345d 100644 --- a/test/extensions/implementation/autograd.py +++ b/test/extensions/implementation/autograd.py @@ -1,4 +1,5 @@ """Autograd implementation of BackPACK's extensions.""" + from math import isclose from test.extensions.implementation.base import ExtensionsImplementation from typing import Iterator, List, Union diff --git a/test/extensions/implementation/backpack.py b/test/extensions/implementation/backpack.py index 74c3e7cc7..861254c53 100644 --- a/test/extensions/implementation/backpack.py +++ b/test/extensions/implementation/backpack.py @@ -1,4 +1,5 @@ """Extension implementations with BackPACK.""" + from test.extensions.implementation.base import ExtensionsImplementation from test.extensions.implementation.hooks import ( BatchL2GradHook, diff --git a/test/extensions/implementation/base.py b/test/extensions/implementation/base.py index 8f53af7fe..579c3d3ee 100644 --- a/test/extensions/implementation/base.py +++ b/test/extensions/implementation/base.py @@ -1,4 +1,5 @@ """Base class containing the functions to compare BackPACK and autograd.""" + from abc import ABC, abstractmethod from test.extensions.problem import ExtensionsTestProblem from typing import List, Union diff --git a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py index 869e4b64b..508d008fd 100644 --- a/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py +++ b/test/extensions/secondorder/diag_ggn/diag_ggn_settings.py @@ -9,6 +9,7 @@ Shared settings are taken from `test.extensions.secondorder.secondorder_settings`. Additional local cases can be defined here through ``LOCAL_SETTINGS``. """ + from test.converter.resnet_cases import ResNet1, ResNet2 from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS diff --git a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py index b965ef1d9..0675c25eb 100644 --- a/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_batch_diag_ggn.py @@ -1,11 +1,11 @@ """Test BatchDiagGGN extension.""" + from test.automated_test import check_sizes_and_values from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import make_test_problems from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS from test.utils.skip_extension_test import skip_BCEWithLogitsLoss_non_binary_labels -from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda, skip_torch_2_0_1_lstm import pytest @@ -14,16 +14,13 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_diag_ggn_exact_batch(problem, request): +def test_diag_ggn_exact_batch(problem): """Test the individual diagonal of Generalized Gauss-Newton/Fisher. Args: problem (ExtensionsTestProblem): Problem for extension test. - request: problem request """ - skip_adaptive_avg_pool3d_cuda(request) problem.set_up() - skip_torch_2_0_1_lstm(problem.model) backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch() autograd_res = AutogradExtensions(problem).diag_ggn_exact_batch() @@ -48,7 +45,6 @@ def test_diag_ggn_mc_batch_light(problem): """ problem.set_up() skip_BCEWithLogitsLoss_non_binary_labels(problem) - skip_torch_2_0_1_lstm(problem.model) backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch() mc_samples = 6000 @@ -72,7 +68,6 @@ def test_diag_ggn_mc_batch(problem): """ problem.set_up() skip_BCEWithLogitsLoss_non_binary_labels(problem) - skip_torch_2_0_1_lstm(problem.model) backpack_res = BackpackExtensions(problem).diag_ggn_exact_batch() mc_samples = 300000 diff --git a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py index d21d63d3a..5d2b4ec95 100644 --- a/test/extensions/secondorder/diag_ggn/test_diag_ggn.py +++ b/test/extensions/secondorder/diag_ggn/test_diag_ggn.py @@ -1,11 +1,11 @@ """Test DiagGGN extension.""" + from test.automated_test import check_sizes_and_values from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions from test.extensions.problem import make_test_problems from test.extensions.secondorder.diag_ggn.diag_ggn_settings import DiagGGN_SETTINGS from test.utils.skip_extension_test import skip_BCEWithLogitsLoss_non_binary_labels -from test.utils.skip_test import skip_adaptive_avg_pool3d_cuda, skip_torch_2_0_1_lstm import pytest @@ -14,16 +14,13 @@ @pytest.mark.parametrize("problem", PROBLEMS, ids=IDS) -def test_diag_ggn(problem, request): +def test_diag_ggn(problem): """Test the diagonal of generalized Gauss-Newton. Args: problem (ExtensionsTestProblem): Problem for extension test. - request: problem request """ - skip_adaptive_avg_pool3d_cuda(request) problem.set_up() - skip_torch_2_0_1_lstm(problem.model) backpack_res = BackpackExtensions(problem).diag_ggn() autograd_res = AutogradExtensions(problem).diag_ggn() @@ -48,7 +45,6 @@ def test_diag_ggn_mc_light(problem): """ problem.set_up() skip_BCEWithLogitsLoss_non_binary_labels(problem) - skip_torch_2_0_1_lstm(problem.model) backpack_res = BackpackExtensions(problem).diag_ggn() mc_samples = 3000 @@ -72,7 +68,6 @@ def test_diag_ggn_mc(problem): """ problem.set_up() skip_BCEWithLogitsLoss_non_binary_labels(problem) - skip_torch_2_0_1_lstm(problem.model) backpack_res = BackpackExtensions(problem).diag_ggn() mc_samples = 300000 diff --git a/test/extensions/secondorder/diag_hessian/diagh_settings.py b/test/extensions/secondorder/diag_hessian/diagh_settings.py index 60fc77b96..5eda20217 100644 --- a/test/extensions/secondorder/diag_hessian/diagh_settings.py +++ b/test/extensions/secondorder/diag_hessian/diagh_settings.py @@ -4,10 +4,21 @@ but additional custom tests can be defined here by appending it to the list. """ -from test.extensions.automated_settings import make_simple_act_setting +from test.extensions.automated_settings import ( + make_simple_act_setting, + make_simple_pooling_setting, +) from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS -from torch.nn import LogSigmoid +from torch.nn import ( + AdaptiveAvgPool1d, + AdaptiveAvgPool2d, + AdaptiveAvgPool3d, + Conv1d, + Conv2d, + Conv3d, + LogSigmoid, +) SHARED_SETTINGS = SECONDORDER_SETTINGS LOCAL_SETTINGS = [ @@ -15,4 +26,14 @@ make_simple_act_setting(LogSigmoid, bias=False), ] +############################################################################### +# test setting: Adaptive Pooling Layers # +############################################################################### +LOCAL_SETTINGS += [ + make_simple_pooling_setting((3, 3, 7), Conv1d, AdaptiveAvgPool1d, (2,)), + make_simple_pooling_setting((3, 3, 11, 11), Conv2d, AdaptiveAvgPool2d, (2,)), + make_simple_pooling_setting((3, 3, 7, 7, 7), Conv3d, AdaptiveAvgPool3d, (2,)), +] + + DiagHESSIAN_SETTINGS = SHARED_SETTINGS + LOCAL_SETTINGS diff --git a/test/extensions/secondorder/hbp/test_kfac.py b/test/extensions/secondorder/hbp/test_kfac.py index fa76400d8..e0b3f24cd 100644 --- a/test/extensions/secondorder/hbp/test_kfac.py +++ b/test/extensions/secondorder/hbp/test_kfac.py @@ -1,4 +1,5 @@ """Test BackPACK's KFAC extension.""" + from test.automated_test import check_sizes_and_values from test.extensions.implementation.autograd import AutogradExtensions from test.extensions.implementation.backpack import BackpackExtensions diff --git a/test/extensions/secondorder/secondorder_settings.py b/test/extensions/secondorder/secondorder_settings.py index b011f4e37..942464a0b 100644 --- a/test/extensions/secondorder/secondorder_settings.py +++ b/test/extensions/secondorder/secondorder_settings.py @@ -20,7 +20,6 @@ "seed" (int): seed for the random number for rand """ - from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.automated_settings import ( make_simple_act_setting, diff --git a/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py b/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py index e95fabcd5..57411dc49 100644 --- a/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py +++ b/test/extensions/secondorder/sqrt_ggn/sqrt_ggn_settings.py @@ -1,4 +1,5 @@ """Contains test settings for testing SqrtGGN extension.""" + from test.converter.resnet_cases import ResNet1, ResNet2 from test.core.derivatives.utils import classification_targets, regression_targets from test.extensions.secondorder.secondorder_settings import SECONDORDER_SETTINGS diff --git a/test/extensions/test_hooks.py b/test/extensions/test_hooks.py index bc6058f87..577a57ef7 100644 --- a/test/extensions/test_hooks.py +++ b/test/extensions/test_hooks.py @@ -3,6 +3,7 @@ These tests aim at demonstrating the pitfalls one may run into when using hooks that iterate over ``module.parameters()``. """ + from test.core.derivatives.utils import classification_targets, get_available_devices from typing import Tuple diff --git a/test/interface_test.py b/test/interface_test.py index 671875705..70da01883 100644 --- a/test/interface_test.py +++ b/test/interface_test.py @@ -1,6 +1,7 @@ """ Test of the interface - calls every method that needs implementation """ + import pytest import torch from torch.nn import Conv2d, CrossEntropyLoss, Linear, ReLU, Sequential diff --git a/test/test_batch_first.py b/test/test_batch_first.py index ec962dab4..57dee4b04 100644 --- a/test/test_batch_first.py +++ b/test/test_batch_first.py @@ -1,4 +1,5 @@ """Tests whether batch axis is always first.""" + from pytest import raises from backpack.custom_module.permute import Permute diff --git a/test/test_problems_activations.py b/test/test_problems_activations.py index 75cfffb22..8d6bac573 100644 --- a/test/test_problems_activations.py +++ b/test/test_problems_activations.py @@ -17,23 +17,23 @@ for act_name, act_cls in ACTIVATIONS.items(): for lin_name, lin_cls in LINEARS.items(): - TEST_PROBLEMS[ - "{}{}-regression".format(lin_name, act_name) - ] = make_regression_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + TEST_PROBLEMS["{}{}-regression".format(lin_name, act_name)] = ( + make_regression_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) ) - TEST_PROBLEMS[ - "{}{}-classification".format(lin_name, act_name) - ] = make_classification_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + TEST_PROBLEMS["{}{}-classification".format(lin_name, act_name)] = ( + make_classification_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) ) - TEST_PROBLEMS[ - "{}{}-2layer-classification".format(lin_name, act_name) - ] = make_classification_problem( - INPUT_SHAPE, - two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + TEST_PROBLEMS["{}{}-2layer-classification".format(lin_name, act_name)] = ( + make_classification_problem( + INPUT_SHAPE, + two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) ) diff --git a/test/test_problems_bn.py b/test/test_problems_bn.py index 724b763d2..a3fd0866f 100644 --- a/test/test_problems_bn.py +++ b/test/test_problems_bn.py @@ -34,17 +34,18 @@ def bn_layer2(): + [BatchNorm1d(TEST_SETTINGS["out_features"])], ) - TEST_PROBLEMS[ - "{}-bn-classification".format(lin_name) - ] = make_classification_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) - + [bn_layer1()], + TEST_PROBLEMS["{}-bn-classification".format(lin_name)] = ( + make_classification_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) + + [bn_layer1()], + ) ) - TEST_PROBLEMS[ - "{}-bn-2layer-classification".format(lin_name) - ] = make_classification_problem( - INPUT_SHAPE, - two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None) + [bn_layer2()], + TEST_PROBLEMS["{}-bn-2layer-classification".format(lin_name)] = ( + make_classification_problem( + INPUT_SHAPE, + two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None) + + [bn_layer2()], + ) ) diff --git a/test/test_problems_convolutions.py b/test/test_problems_convolutions.py index 2e4832198..806166850 100644 --- a/test/test_problems_convolutions.py +++ b/test/test_problems_convolutions.py @@ -104,12 +104,12 @@ def make_2layer_classification_problem(conv_cls, act_cls): TEST_PROBLEMS = {} for conv_name, conv_cls in CONVS.items(): for act_name, act_cls in ACTIVATIONS.items(): - TEST_PROBLEMS[ - "{}-{}-regression".format(conv_name, act_name) - ] = make_regression_problem(conv_cls, act_cls) - TEST_PROBLEMS[ - "{}-{}-classification".format(conv_name, act_name) - ] = make_classification_problem(conv_cls, act_cls) - TEST_PROBLEMS[ - "{}-{}-2layer-classification".format(conv_name, act_name) - ] = make_2layer_classification_problem(conv_cls, act_cls) + TEST_PROBLEMS["{}-{}-regression".format(conv_name, act_name)] = ( + make_regression_problem(conv_cls, act_cls) + ) + TEST_PROBLEMS["{}-{}-classification".format(conv_name, act_name)] = ( + make_classification_problem(conv_cls, act_cls) + ) + TEST_PROBLEMS["{}-{}-2layer-classification".format(conv_name, act_name)] = ( + make_2layer_classification_problem(conv_cls, act_cls) + ) diff --git a/test/test_problems_kfacs.py b/test/test_problems_kfacs.py index 924c7d223..57e0c503e 100644 --- a/test/test_problems_kfacs.py +++ b/test/test_problems_kfacs.py @@ -17,11 +17,11 @@ REGRESSION_PROBLEMS = {} for act_name, act_cls in ACTIVATIONS.items(): for lin_name, lin_cls in LINEARS.items(): - REGRESSION_PROBLEMS[ - "{}{}-regression".format(lin_name, act_name) - ] = make_regression_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + REGRESSION_PROBLEMS["{}{}-regression".format(lin_name, act_name)] = ( + make_regression_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) ) TEST_PROBLEMS = { @@ -29,16 +29,16 @@ } for act_name, act_cls in ACTIVATIONS.items(): for lin_name, lin_cls in LINEARS.items(): - TEST_PROBLEMS[ - "{}{}-classification".format(lin_name, act_name) - ] = make_classification_problem( - INPUT_SHAPE, - single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + TEST_PROBLEMS["{}{}-classification".format(lin_name, act_name)] = ( + make_classification_problem( + INPUT_SHAPE, + single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) ) - TEST_PROBLEMS[ - "{}{}-2layer-classification".format(lin_name, act_name) - ] = make_classification_problem( - INPUT_SHAPE, - two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + TEST_PROBLEMS["{}{}-2layer-classification".format(lin_name, act_name)] = ( + make_classification_problem( + INPUT_SHAPE, + two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=act_cls), + ) ) diff --git a/test/test_problems_linear.py b/test/test_problems_linear.py index ad785f4a5..8b7a53471 100644 --- a/test/test_problems_linear.py +++ b/test/test_problems_linear.py @@ -24,8 +24,8 @@ INPUT_SHAPE, single_linear_layer(TEST_SETTINGS, lin_cls, activation_cls=None) ) - TEST_PROBLEMS[ - "{}-2layer-classification".format(lin_name) - ] = make_classification_problem( - INPUT_SHAPE, two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None) + TEST_PROBLEMS["{}-2layer-classification".format(lin_name)] = ( + make_classification_problem( + INPUT_SHAPE, two_linear_layers(TEST_SETTINGS, lin_cls, activation_cls=None) + ) ) diff --git a/test/test_problems_padding.py b/test/test_problems_padding.py index eec06d099..90f0a61ac 100644 --- a/test/test_problems_padding.py +++ b/test/test_problems_padding.py @@ -54,6 +54,6 @@ def make_2layer_classification_problem(padding_cls): TEST_PROBLEMS = {} for pad_name, pad_cls in PADDINGS.items(): - TEST_PROBLEMS[ - "conv+{}-classification-2layer".format(pad_name) - ] = make_2layer_classification_problem(pad_cls) + TEST_PROBLEMS["conv+{}-classification-2layer".format(pad_name)] = ( + make_2layer_classification_problem(pad_cls) + ) diff --git a/test/test_problems_pooling.py b/test/test_problems_pooling.py index 23c5e469a..020680789 100644 --- a/test/test_problems_pooling.py +++ b/test/test_problems_pooling.py @@ -85,9 +85,9 @@ def make_2layer_classification_problem(pooling_cls): TEST_PROBLEMS["conv+{}-regression".format(pool_name)] = make_regression_problem( pool_cls ) - TEST_PROBLEMS[ - "conv+{}-classification".format(pool_name) - ] = make_classification_problem(pool_cls) - TEST_PROBLEMS[ - "conv+{}-classification-2layer".format(pool_name) - ] = make_2layer_classification_problem(pool_cls) + TEST_PROBLEMS["conv+{}-classification".format(pool_name)] = ( + make_classification_problem(pool_cls) + ) + TEST_PROBLEMS["conv+{}-classification-2layer".format(pool_name)] = ( + make_2layer_classification_problem(pool_cls) + ) diff --git a/test/test_retain_graph.py b/test/test_retain_graph.py index 055f18d14..bf548a78b 100644 --- a/test/test_retain_graph.py +++ b/test/test_retain_graph.py @@ -1,4 +1,5 @@ """Test autograd functionality like retain_graph.""" + from test.automated_test import check_sizes_and_values from pytest import raises diff --git a/test/utils/evaluation_mode.py b/test/utils/evaluation_mode.py index f4e57be77..f41ba2f57 100644 --- a/test/utils/evaluation_mode.py +++ b/test/utils/evaluation_mode.py @@ -1,4 +1,5 @@ """Tools for initializing in evaluation mode, especially BatchNorm.""" + from typing import Union from torch import rand_like diff --git a/test/utils/skip_test.py b/test/utils/skip_test.py index 304656101..9e44b06de 100644 --- a/test/utils/skip_test.py +++ b/test/utils/skip_test.py @@ -4,48 +4,8 @@ from test.extensions.problem import ExtensionsTestProblem from typing import List, Union -from pkg_resources import packaging from pytest import skip -from torch.nn import ( - LSTM, - BatchNorm1d, - BatchNorm2d, - BatchNorm3d, - BCEWithLogitsLoss, - Module, -) - -from backpack.utils import ADAPTIVE_AVG_POOL_BUG, TORCH_VERSION - - -def skip_torch_2_0_1_lstm(module: Module): - """Skip if module contains LSTMs and we are using PyTorch 2.0.1. - - Args: - module: Neural network - """ - # double-backward not supported https://github.com/pytorch/pytorch/issues/99413 - TORCH_VERSION_2_0_1 = TORCH_VERSION == packaging.version.parse("2.0.1") - lstm = any(isinstance(m, LSTM) for m in module.modules()) - if lstm and TORCH_VERSION_2_0_1: - skip("Double-backward not supported for LSTM in PyTorch 2.0.1 (#99413)") - - -def skip_adaptive_avg_pool3d_cuda(request) -> None: - """Skips test if AdaptiveAvgPool3d and cuda. - - Args: - request: problem request - """ - if ADAPTIVE_AVG_POOL_BUG: - if all( - string in request.node.callspec.id - for string in ["AdaptiveAvgPool3d", "cuda"] - ): - skip( - "Skip test because AdaptiveAvgPool3d does not work on cuda. " - "Should be fixed in torch 2.0." - ) +from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d, BCEWithLogitsLoss def skip_batch_norm_train_mode_with_subsampling( diff --git a/test/utils/test_conv_settings.py b/test/utils/test_conv_settings.py index d78a9d7d8..bf073726d 100644 --- a/test/utils/test_conv_settings.py +++ b/test/utils/test_conv_settings.py @@ -14,6 +14,7 @@ "id_prefix" (str): Prefix to be included in the test name. "seed" (int): seed for the random number for torch.rand """ + import torch SETTINGS = [] diff --git a/test/utils/test_conv_transpose_settings.py b/test/utils/test_conv_transpose_settings.py index 1cd7cd4d3..34986e42b 100644 --- a/test/utils/test_conv_transpose_settings.py +++ b/test/utils/test_conv_transpose_settings.py @@ -11,6 +11,7 @@ "id_prefix" (str): Prefix to be included in the test name. "seed" (int): seed for the random number for torch.rand """ + import torch SETTINGS = []