Python3 implementation of the paper Sliced Gromov-Wasserstein (NeurIPS 2019)
Sliced Gromov-Wasserstein is an Optimal Transport discrepancy between measures whose supports do not necessarily live in the same metric space. It is based on a closed form expression for 1D measures of the Gromov-Wasserstein distance (GW) [2] that allows a sliced version of GW akin to the Sliced Wasserstein distance. SGW can be applied for large scale applications (about 1s between two measures of 1 millions points each on standard GPU) and can be easily plugged into a deep learning architecture.
Feel free to ask if any question.
If you use this toolbox in your research and find it useful, please cite SGW using the following bibtex reference:
@incollection{vay2019sgw,
title = {Sliced Gromov-Wasserstein},
author = {{Vayer}, Titouan and {Flamary}, R{\'e}mi and {Tavenard}, Romain and
{Chapel}, Laetitia and {Courty}, Nicolas},
booktitle = {Advances in Neural Information Processing Systems 32},
year = {2019}
}
- Numpy (>= 1.11)
- Matplotlib (>= 1.5)
- Pytorch (>= 1.1.0)
- For Optimal transport Python Optimal Transport POT (>=0.5.1)
For examples with RISGW:
- Python (>= 3.6)
- geoopt
- SGW function both in CPU and GPU (with Pytorch):
-
Entropic Gromov-Wasserstein in Pytorch.
-
Runtimes comparaison with Gromov-Wasserstein of POT, Entropic Gromov-Wasserstein, e.g to calculate all runtimes (in expe_paper folder):
python3 runtime.py -p '../res' -ln 200 500 1000 -pr 10 20
To plot the results (in expe_paper folder):
python plot_runtimes.py -p '../res/runtime_2019_10_16_14_26_32/'
-
Rotational Invariant SGW (RISGW) in Pytorch using geoopt.
-
Demo notebooks:
- sgw_example.ipynb: SGW between random measures and 3D meashes
- risgw_example.ipynb: RISGW between random measures and on spiral dataset
- Integration of SGW in the POT library [1]
[1] Flamary Rémi and Courty Nicolas POT Python Optimal Transport library
[2] Facundo Mémoli Gromov–Wasserstein Distances and the Metric Approach to Object Matching