Authors: Maksim Zhdanov, David Ruhe, Maurice Weiler, and Ana Lucic, Johannes Brandstetter, Patrick Forré
ArXiv | Blog | Playbook | Google Colab
We present Clifford-Steerable Convolutional Neural Networks (CS-CNNs), a novel class of
To install all the necessary requirements, including JAX and PyTorch, run:
bash setup.shBelow is a simple example of initializing and applying a CS-ResNet to a random multivector input:
import jax
from algebra.cliffordalgebra import CliffordAlgebra
from models.resnets import CSResNet
algebra = CliffordAlgebra((1, 1))
config = dict(
algebra=algebra,
time_history=4,
time_future=1,
hidden_channels=16,
kernel_num_layers=4,
kernel_hidden_dim=12,
kernel_size=7,
bias_dims=(0,),
product_paths_sum=algebra.geometric_product_paths.sum().item(),
make_channels=1,
blocks=(2, 2, 2, 2),
norm=True,
padding_mode="symmetric",
)
csresnet = CSResNet(**config)
# random input for initialization
rng = jax.random.PRNGKey(42)
mv_field = jax.random.normal(rng, (16, config.time_history, 64, 64, algebra.n_blades))
params = csresnet.init(rng, mv_field)
# compute the output
out = csresnet.apply(params, mv_field)Note that the field must come in shape (Batch, Channels, ..., Blades), where ... indicates grid dimensions (depth, width, etc.).
The instructions for the data generation can be found in datasets/data/ns/README.md.
cd datasets/data/ns
bash download.sh
python preprocess.pyTo reproduce the experiment, run:
python experiment.py --experiment ns --model gcresnet --metric 1 1 --time_history 4 --time_future 1 --num_data 64 --batch_size 8 --norm 1 --hidden_channels 48python experiment.py --experiment ns --model resnet --metric 1 1 --time_history 4 --time_future 1 --num_data 64 --batch_size 8 --norm 1 --hidden_channels 96The instructions for the data generation can be found in datasets/data/maxwell3d/README.md.
cd datasets/data/maxwell3d
bash download.sh
python preprocess.pyTo reproduce the experiment, run:
python experiment.py --experiment maxwell3d --model gcresnet --metric 1 1 1 --time_history 4 --time_future 1 --num_data 64 --batch_size 2 --norm 1 --hidden_channels 12 --scheduler cosinepython experiment.py --experiment maxwell3d --model resnet --metric 1 1 1 --time_history 4 --time_future 1 --num_data 64 --batch_size 2 --norm 1 --hidden_channels 12 --scheduler cosineThe instructions for the data generation can be found in datasets/data/maxwell2d/datagen/README.md.
cd datasets/datagen/maxwell2d
bash generate.sh --num_points 512 --partition trainTo reproduce the experiment, run:
python experiment.py --experiment maxwell2d --model gcresnet --metric -1 1 1 --time_history 32 --time_future 32 --num_data 512 --batch_size 16 --norm 0 --hidden_channels 12python experiment.py --experiment maxwell2d --model resnet --metric -1 1 1 --time_history 32 --time_future 32 --num_data 512 --batch_size 16 --norm 0 --hidden_channels 13The repository is incomplete at the moment, below is the roadmap:
- implementation of Clifford-steerable kernels/convolutions (in JAX)
- implementation of Clifford-steerable ResNet and basic ResNet (in JAX)
- demonstrating example + test equivariance (escnn + PyTorch required)
- code for the data generation (Maxwell on spacetime)
- replicating experimental results
- Navier-Stokes (PDEarena)
- Maxwell 3D (PDEarena)
- Maxwell 2D+1 (PyCharge)
- implementation of Clifford ResNet and Steerable ResNet (in PyTorch)
If you find this repository useful in your research, please consider citing us:
@inproceedings{Zhdanov2024CliffordSteerableCN,
title = {Clifford-Steerable Convolutional Neural Networks},
author = {Maksim Zhdanov and David Ruhe and Maurice Weiler and Ana Lucic and Johannes Brandstetter and Patrick Forr{\'e}},
booktitle = {International {Conference} on {Machine} {Learning} ({ICML})},
year = {2024},
}
