TS2Kit (Version 1.1) is a self-contained PyTorch library which computes auto-differentiable forward and inverse discrete Spherical Harmonic Transforms (SHTs). The routines in TS2Kit are based on the seminal S2Kit and SOFT packages, but are designed for evaluation on a GPU. Specifically, the Discrete Legendre Transform (DLT) is computed via sparse matrix multiplication in what is essentially a tensorized version of the so-called "semi-naive" algorithm. This enables parallelization while keeping memory footprint small, and the end result are auto-differentiable forward and inverse SHTs that are fast and efficient in a practical sense. For example, given a spherical signal (tensor) taking values on a 128 X 128 spherical grid with b = 4096 batch dimensions, TS2Kit computes a forward SHT followed by an inverse SHT in approximately tens of milliseconds at floating precision.
Please see TS2Kit.pdf
for a detailed review of the chosen conventions and implementation details.
To use TS2Kit, simply copy the TS2Kit
folder into your project directory.
Several tensors are pre-computed at initialization and at higher bandlimits (B >= 64) this can take some time. To avoid re-computing these quantities every initialization, the modules will check if the tensors have been saved in a cache directory and either A). load the tensors directly from the cache; or B). compute the tensors and save them to the cache directory so they can be loaded next time the modules are initialized.
The repository contains the folder cache
which serves as the default cache directory. If you'd like to save the precomputed tensors in a different directory, set the variable defaultCacheDir
at the top of the ts2kit.py
file to the absolute path of the directory, e.g.
defaultCacheDir = '/absolute/path/to/your/chosen/directory'
The cache directory can be cleared (of .pt
files) at any time by importing and running the clearTS2KitCache
function:
from TS2Kit.ts2kit import clearTS2KitCache
clearTS2KitCache()
The front-end of TS2Kit consists of the torch.nn.Module
classes FTSHT
and ITSHT
, corresponding to the forward and inverse SHT, respectively. At initialization, the modules are passed an integer argument B which determines the bandlimit of the forward and inverse SHT, e.g.
from ts2kit.ts2kit import FTSHT, ITSHT
## Bandlimit
B = 64
## Initialize the (B-1)-bandlimited forward SHT
FT = FTSHT(B)
## Initialize the (B-1)-bandlimited inverse SHT
IT = ITSHT(B)
Initialized with bandlimit B, calling the FTSHT
module applies the forward SHT to a spherical signal composed with several batch dimensions. Specifically, inputs are b X 2B X 2B real or complex torch
tensors, where b is the batch dimension and the second and third dimensions increment over the values in the 2B X 2B Driscoll-Healy spherical grid (see TS2Kit.pdf
). For example, given a tensor psi
of size 100 X 128 X 128 (b = 100, B = 64), the element psi[26, 47, 12]
is the value of the spherical signal in batch dimension 26 at coordinates (theta_46, phi_11) in the DH spherical grid. To assist in sampling to a DH grid, the user can import the gridDH
function, which takes as input a fixed bandlimit B and returns two 2B X 2B tensors theta
and phi
giving the spherical coordinates of the corresponding DH grid indices.
The forward call returns a b X (2B-1) X B complex torch
tensor giving the array of SH coefficients -- with m and l incremented along the second and third dimensions, respectively -- of spherical signals for each batch dimension of the input tensor. For example, passing the real or complex 100 X 128 X 128 tensor psi
to the module returns the complex 100 X 127 X 64 tensor of SH coefficients:
F = FTSHT(B)
psiCoeff = F(psi)
The (l, m)-th SH coefficients in batch dimension c can be accessed via psiCoeff[c, m+B, l]
e.g. for ; l = 5, m = -5, and c = 12, the corresponding SH coefficient is psiCoeff[12, 59, 5]
. For l < |m|, the values in psiCoeff
will be zero.
Initialized with bandlimit B, calling the ITSHT
module applies the inverse SHT to a signal composed of several arrays of SH coefficients. Inputs are b X (2B - 1) X B complex torch
tensors consisting of b channels of SH coefficent arrays, structured in exactly the same way as the output of the FTSHT
module. The forward call returns a b X 2B X 2B_complex_
torch` tensor corresponding to the spherical signals reconstructed from the SH coefficients in each batch dimension:
I = ITSHT(B)
psi = I(psiCoeff)
The output tensor is complex-valued, so if the input SH coefficient tensor corresponds to a real-valued signal then the imaginary part of the output tensor will be zero and it can be cast to a real tensor (e.g. by calling psi.real
) without loss of information.
The FTSHT
and ITSHT
modules are initialized at double precision. That is, the forward call of FTSHT
maps tensors of type torch.double
(real-valued) or torch.cdouble
(complex-valued) to tensors of type torch.cdouble
. Similarly, the forward call of ITSHT
maps tensors of type torch.cdouble
to tensors of the same type.
The modules can also be cast to floating precision at initialization, e.g. via FTSHT(B).float()
and ITSHT(B).float()
. In this case, the forward call of FTSHT
maps tensors of type torch.float
and torch.cfloat
to tensors of type torch.cfloat
and that of ITSHT
maps tensors of type torch.cfloat
to tensors of the same type.
Casting to floating precision results in half the memory overhead and about an order of magnitude decrease in run-time at the cost of several orders of magnitude in accuracy. For example, given a tensor of double-precision SH coefficients on the GPU
device = torch.device('cuda')
Psi = torch.view_as_complex(2*(torch.rand(b, 2*B -1, B, 2).double() - 0.5)).to(device)
for m in range(-(B-1), B):
for l in range(0, B):
if (l * l < m * m):
Psi[:, m + (B-1), l] = 0.0;
one can expect the following error to be very, very small
F = FTSHT(B).to(device)
I = ITSHT(B).to(device)
Psi2 = F(I(Psi))
## This error should be very, very small
error = torch.sum(torch.abs(Psi-Psi2)) / torch.sum(torch.abs(Psi))
Casting to floating precision will result in a significant speed up and less overhead, but a larger error:
Psi_f = Psi.cfloat();
F_f = FTSHT(B).to(device).float()
I_f = ITSHT(B).to(device).float()
## This should run about an order of magnitude faster
Psi2_f = F_f(I_f(Psi_f))
## This error will be much larger
error = torch.sum(torch.abs(Psi-Psi2)) / torch.sum(torch.abs(Psi))
This does not imply that the FSHT
and ISHT
modules are "slow" at double precision nor "inaccurate" at floating precision. Rather, it all depends on the application. The test_ts2kit.ipynb
notebook can be used to compare the transforms at different precisions and bandlimits to see what makes sense for your use case.
Author: Thomas (Tommy) Mitchel (tmitchel 'at' jhu 'dot' edu)
Citation:
@inproceedings{10.1145/3528233.3530724,
author = {Mitchel, Thomas W. and Aigerman, Noam and Kim, Vladimir G. and Kazhdan, Michael},
title = {M\"{o}bius Convolutions for Spherical CNNs},
year = {2022},
isbn = {9781450393379},
publisher = {Association for Computing Machinery},
address = {New York, NY, USA},
url = {https://doi.org/10.1145/3528233.3530724},
doi = {10.1145/3528233.3530724},
booktitle = {ACM SIGGRAPH 2022 Conference Proceedings},
articleno = {30},
numpages = {9},
keywords = {Neural networks, M\"{o}bius transformations, Group equivariance, Convolution, Conformal transformations},
location = {Vancouver, BC, Canada},
series = {SIGGRAPH '22}
}