-
Notifications
You must be signed in to change notification settings - Fork 15
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit b21dec0
Showing
89 changed files
with
7,460 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
[flake8] | ||
ignore = E203, E266, E501, W503, F403, F401 | ||
max-line-length = 100 | ||
max-complexity = 18 | ||
select = B,C,E,F,W,T4,B9 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
.idea | ||
wandb | ||
ckernel_fitting/wandb | ||
data |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
repos: | ||
- repo: https://github.com/ambv/black | ||
rev: stable | ||
hooks: | ||
- id: black | ||
# language_version: python3.6 | ||
- repo: https://gitlab.com/pycqa/flake8 | ||
rev: 3.7.9 | ||
hooks: | ||
- id: flake8 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
MIT License | ||
|
||
Copyright (c) 2021 David W. Romero | ||
|
||
Permission is hereby granted, free of charge, to any person obtaining a copy | ||
of this software and associated documentation files (the "Software"), to deal | ||
in the Software without restriction, including without limitation the rights | ||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
copies of the Software, and to permit persons to whom the Software is | ||
furnished to do so, subject to the following conditions: | ||
|
||
The above copyright notice and this permission notice shall be included in all | ||
copies or substantial portions of the Software. | ||
|
||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | ||
SOFTWARE. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
## CKConv: Continuous Kernel Convolution For Sequential Data | ||
|
||
This repository contains the source code accompanying the paper: | ||
|
||
[CKConv: Continuous Kernel Convolution For Sequential Data](https://arxiv.org/abs/2102.02611) [[Slides]](https://app.slidebean.com/p/wgp8j0zl62/CKConv-Continuous-Kernel-Convolutions-For-Sequential-Data) <br/>**[David W. Romero](https://www.davidromero.ml/), [Anna Kuzina](https://akuzina.github.io/), [Erik J. Bekkers](https://erikbekkers.bitbucket.io/), [Jakub M. Tomczak](https://jmtomczak.github.io/) & [Mark Hoogendoorn](https://www.cs.vu.nl/~mhoogen/)**. | ||
|
||
#### Abstract | ||
*Conventional neural architectures for sequential data present important limitations. Recurrent networks suffer from exploding | ||
and vanishing gradients, small effective memory horizons, and must be trained sequentially. Convolutional networks are unable to handle sequences of unknown size | ||
and their memory horizon must be defined a priori. In this work, we show that all these problems can be solved by formulating convolutional kernels | ||
in CNNs as continuous functions. The resulting Continuous Kernel Convolution (CKConv) allows us to model arbitrarily long sequences | ||
in a parallel manner, within a single operation, and without relying on any form of recurrence. We show that Continuous Kernel Convolutional Networks | ||
(CKCNNs) obtain state of the art results in multiple datasets, e.g., permuted MNIST, and, thanks to their continuous nature, are able to handle | ||
non-uniformly sampled datasets and irregularly sampled data natively. CKCNNs at least match neural ODEs designed for these purposes in a | ||
much faster and simple manner.* | ||
|
||
<img src="ckconv.png" alt="drawing" width="750"/> | ||
|
||
### Repository structure | ||
|
||
#### Folders | ||
|
||
This repository is organized as follows: | ||
|
||
* `ckconv` contains the main PyTorch library of our model. | ||
|
||
* `ckernel_fitting` contains source code to run experiments to approximate convolutional filters via MLPs. Please see `ckernel_fitting/README.md` for further details. | ||
|
||
* `demo` provides some minimalistic examples on the usage of CKConvs and the construction of CKCNNs. | ||
|
||
* `models` contains the models used throughout our experiments. | ||
|
||
* `probspec_routines` contains routines specific to some of the problems considered in this paper. | ||
|
||
* `runs` contains the `.sh` files with the corresponding arguments used to run our experiments. | ||
|
||
### Reproduce | ||
|
||
#### Install | ||
|
||
###### conda *(recommended)* | ||
In order to reproduce our results, please first install the required dependencies. This can be done by: | ||
``` | ||
conda create --name ckconv --file conda_requirements.txt | ||
``` | ||
This will create the conda environment `ckconv` with the correct dependencies. | ||
|
||
###### pip | ||
The same conda environment can be created with `pip` by running: | ||
``` | ||
conda create -n ckconv python=3.7 | ||
conda install pytorch==1.7.0 torchvision==0.8.1 torchaudio=0.7.0 cudatoolkit=10.1 -c pytorch | ||
conda activate ckconv | ||
pip install -r requirements.txt | ||
``` | ||
|
||
#### Experiments and `config` files | ||
To reproduce the experiments in the paper, please follow the configurations given in the file `runs/my_experiment.sh` | ||
|
||
Specifications on the parameters specified via the `argsparser` can be found in the corresponding `config.py` files. | ||
|
||
#### Pretrained models | ||
To use pretrained models, please add the argument `--config.pretrained==True` to the corresponding execution `.sh` file. | ||
|
||
#### Recommendations and details | ||
|
||
###### Replacing fft convolutions with spatial convolutions | ||
We leverage the convolution theorem in our experiments to accelerate the computation of the convolution operations (see | ||
`causal_fftconv` in `ckconv/nn/functional/causalconv.py`), and we strongly recommend using fft convolutions. | ||
However, for some applications it might be desirable to rely on spatial convolutions, e.g., small conv. kernels. This can be easily modified by replacing | ||
the call to `causal_fftconv` in the forward pass of the `CKConv` class (`ckconv/nn/ckconv.py:182`) by the function `causal_conv` found in `ckconv/nn/functional/causalconv.py`. | ||
|
||
### Cite | ||
If you found this work useful in your research, please consider citing: | ||
``` | ||
@article{romero2021ckconv, | ||
title={CKConv: Continuous Kernel Convolutions for Sequential Data}, | ||
author={Romero, David W and Kuzinna, Anna and Bekkers, Erik J and Tomczak, Jakub M and Hoogendoorn, Mark}, | ||
journal={arXiv preprint arXiv:2102.02611}, | ||
year={2021} | ||
} | ||
``` | ||
|
||
### Acknowledgements | ||
*We gratefully acknowledge Gabriel Dernbach for interesting analyses on the knot distribution of ReLU networks. We thank Emiel van Krieken and Ali el Hasouni as well for interesting questions and motivating comments at the beginning of this project. | ||
David W. Romero is financed as part of the Efficient Deep Learning (EDL) programme (grant number P16-25), partly | ||
funded by the Dutch Research Council (NWO) and Semiotic Labs. Anna Kuzina is funded by the Hybrid Intelligence Center, a 10-year programme funded | ||
by the Dutch Ministry of Education, Culture and Science through the Netherlands Organisation for | ||
Scientific Research. Erik J. Bekkers is financed by the | ||
research programme VENI (grant number 17290) funded by the Dutch Research Council. All authors are thankful to everyone | ||
involved in funding this work. | ||
This work was carried out on the Dutch national e-infrastructure with | ||
the support of SURF Cooperative.* |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .nn import * | ||
from .utils import * |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .activation_functions import Swish, Sine | ||
from .linear import Linear1d, Linear2d | ||
from .norm import LayerNorm | ||
from .ckconv import CKConv, KernelNet | ||
from .ck_block import CKBlock | ||
from .loss import LnLoss | ||
from .conv import CausalConv1d |
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# torch | ||
import torch | ||
from ckconv.nn.misc import Expression | ||
|
||
|
||
def Swish(): | ||
""" | ||
out = x * sigmoid(x) | ||
""" | ||
return Expression(lambda x: x * torch.sigmoid(x)) | ||
|
||
|
||
def Sine(): | ||
""" | ||
out = sin(x) | ||
""" | ||
return Expression(lambda x: torch.sin(x)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
import torch | ||
import ckconv.nn | ||
|
||
|
||
class CKBlock(torch.nn.Module): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
out_channels: int, | ||
kernelnet_hidden_channels: int, | ||
kernelnet_activation_function: str, | ||
kernelnet_norm_type: str, | ||
dim_linear: int, | ||
bias: bool, | ||
omega_0: bool, | ||
dropout: float, | ||
weight_dropout: float, | ||
): | ||
""" | ||
Creates a Residual Block with CKConvs as: | ||
( Follows the Residual Block of Bai et. al., 2017 ) | ||
input | ||
| ---------------| | ||
CKConv | | ||
LayerNorm | | ||
ReLU | | ||
DropOut | | ||
| | | ||
CKConv | | ||
LayerNorm | | ||
ReLU | | ||
DropOut | | ||
+ <--------------| | ||
| | ||
ReLU | ||
| | ||
output | ||
:param in_channels: Number of channels in the input signal | ||
:param out_channels: Number of output (and hidden) channels of the block. | ||
:param kernelnet_hidden_channels: Number of hidden units in the KernelNets of the CKConvs. | ||
:param kernelnet_activation_function: Activation function used in the KernelNets of the CKConvs. | ||
:param kernelnet_norm_type: Normalization type used in the KernelNets of the CKConvs (only for non-Sine KernelNets). | ||
:param dim_linear: Spatial dimension of the input, e.g., for audio = 1, images = 2 (only 1 suported). | ||
:param bias: If True, adds a learnable bias to the output. | ||
:param omega_0: Value of the omega_0 value of the KernelNets. (only for non-Sine KernelNets). | ||
:param dropout: Dropout rate of the block | ||
:param weight_dropout: Dropout rate applied to the sampled convolutional kernels. | ||
""" | ||
super().__init__() | ||
|
||
# CKConv layers | ||
self.cconv1 = ckconv.nn.CKConv( | ||
in_channels, | ||
out_channels, | ||
kernelnet_hidden_channels, | ||
kernelnet_activation_function, | ||
kernelnet_norm_type, | ||
dim_linear, | ||
bias, | ||
omega_0, | ||
weight_dropout, | ||
) | ||
self.cconv2 = ckconv.nn.CKConv( | ||
out_channels, | ||
out_channels, | ||
kernelnet_hidden_channels, | ||
kernelnet_activation_function, | ||
kernelnet_norm_type, | ||
dim_linear, | ||
bias, | ||
omega_0, | ||
weight_dropout, | ||
) | ||
# Norm layers | ||
self.norm1 = ckconv.nn.LayerNorm(out_channels) | ||
self.norm2 = ckconv.nn.LayerNorm(out_channels) | ||
|
||
# Dropout | ||
self.dp = torch.nn.Dropout(dropout) | ||
|
||
shortcut = [] | ||
if in_channels != out_channels: | ||
shortcut.append(ckconv.nn.Linear1d(in_channels, out_channels)) | ||
self.shortcut = torch.nn.Sequential(*shortcut) | ||
|
||
def forward(self, x): | ||
shortcut = self.shortcut(x) | ||
out = self.dp(torch.relu(self.norm1(self.cconv1(x)))) | ||
out = torch.relu(self.dp(torch.relu(self.norm2(self.cconv2(out)))) + shortcut) | ||
return out |
Oops, something went wrong.