Authors: Maksim Zhdanov, David Ruhe, Maurice Weiler, and Ana Lucic, Johannes Brandstetter, Patrick Forré
ArXiv | Blog | Playbook | Google Colab
We present Clifford-Steerable Convolutional Neural Networks (CS-CNNs), a novel class of
To install all the necessary requirements, including JAX and PyTorch, run:
bash setup.sh
Below is a simple example of initializing and applying a CS-ResNet to a random multivector input:
import jax
from algebra.cliffordalgebra import CliffordAlgebra
from models.resnets import CSResNet
algebra = CliffordAlgebra((1, 1))
config = dict(
algebra=algebra,
time_history=4,
time_future=1,
hidden_channels=16,
kernel_num_layers=4,
kernel_hidden_dim=12,
kernel_size=7,
bias_dims=(0,),
product_paths_sum=algebra.geometric_product_paths.sum().item(),
make_channels=1,
blocks=(2, 2, 2, 2),
norm=True,
padding_mode="symmetric",
)
csresnet = CSResNet(**config)
# random input for initialization
rng = jax.random.PRNGKey(42)
mv_field = jax.random.normal(rng, (16, config.time_history, 64, 64, algebra.n_blades))
params = csresnet.init(rng, mv_field)
# compute the output
out = csresnet.apply(params, mv_field)
Note that the field must come in shape (Batch, Channels, ..., Blades)
, where ...
indicates grid dimensions (depth, width, etc.).
The instructions for the data generation can be found in datasets/data/ns/README.md.
cd datasets/data/ns
bash download.sh
python preprocess.py
To reproduce the experiment, run:
python experiment.py --experiment ns --model gcresnet --metric 1 1 --time_history 4 --time_future 1 --num_data 64 --batch_size 8 --norm 1 --hidden_channels 48
python experiment.py --experiment ns --model resnet --metric 1 1 --time_history 4 --time_future 1 --num_data 64 --batch_size 8 --norm 1 --hidden_channels 96
The instructions for the data generation can be found in datasets/data/maxwell3d/README.md.
cd datasets/data/maxwell3d
bash download.sh
python preprocess.py
To reproduce the experiment, run:
python experiment.py --experiment maxwell3d --model gcresnet --metric 1 1 1 --time_history 4 --time_future 1 --num_data 64 --batch_size 2 --norm 1 --hidden_channels 12 --scheduler cosine
python experiment.py --experiment maxwell3d --model resnet --metric 1 1 1 --time_history 4 --time_future 1 --num_data 64 --batch_size 2 --norm 1 --hidden_channels 12 --scheduler cosine
The instructions for the data generation can be found in datasets/data/maxwell2d/datagen/README.md.
cd datasets/datagen/maxwell2d
bash generate.sh --num_points 512 --partition train
To reproduce the experiment, run:
python experiment.py --experiment maxwell2d --model gcresnet --metric -1 1 1 --time_history 32 --time_future 32 --num_data 512 --batch_size 16 --norm 0 --hidden_channels 12
python experiment.py --experiment maxwell2d --model resnet --metric -1 1 1 --time_history 32 --time_future 32 --num_data 512 --batch_size 16 --norm 0 --hidden_channels 13
The repository is incomplete at the moment, below is the roadmap:
- implementation of Clifford-steerable kernels/convolutions (in JAX)
- implementation of Clifford-steerable ResNet and basic ResNet (in JAX)
- demonstrating example + test equivariance (escnn + PyTorch required)
- code for the data generation (Maxwell on spacetime)
- replicating experimental results
- Navier-Stokes (PDEarena)
- Maxwell 3D (PDEarena)
- Maxwell 2D+1 (PyCharge)
- implementation of Clifford ResNet and Steerable ResNet (in PyTorch)
If you find this repository useful in your research, please consider citing us:
@inproceedings{Zhdanov2024CliffordSteerableCN,
title = {Clifford-Steerable Convolutional Neural Networks},
author = {Maksim Zhdanov and David Ruhe and Maurice Weiler and Ana Lucic and Johannes Brandstetter and Patrick Forr'e},
booktitle = {International {Conference} on {Machine} {Learning} ({ICML})},
year = {2024},
}