Skip to content

Commit

Permalink
First clean version
Browse files Browse the repository at this point in the history
  • Loading branch information
tvayer committed Oct 16, 2019
1 parent 0a4fcab commit 78e0203
Show file tree
Hide file tree
Showing 20 changed files with 689 additions and 407 deletions.
90 changes: 90 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@

# FGW

Python3 implementation of the paper [Sliced Gromov-Wasserstein
](https://arxiv.org/abs/1905.10124) (NeurIPS 2019)

Sliced Gromov-Wasserstein is an Optimal Transport discrepancy between measures whose supports do not necessarily live in the same metric space. It is based on a closed form expression for 1D measures of the Gromov-Wasserstein distance (GW) [2] that allows a sliced version of GW akin to the Sliced Wasserstein distance. SGW can be applied for large scale applications (about 1s between two measures of 1 millions points each on standard GPU) and can be easily plugged into a deep learning architecture.

Feel free to ask if any question.

If you use this toolbox in your research and find it useful, please cite SGW using the following bibtex reference:

```
@ARTICLE{vay2019sgw,
author = {{Vayer}, Titouan and {Flamary}, R{\'e}mi and {Tavenard}, Romain and
{Chapel}, Laetitia and {Courty}, Nicolas},
title = "{Sliced Gromov-Wasserstein}",
journal = {arXiv e-prints},
keywords = {Statistics - Machine Learning, Computer Science - Machine Learning},
year = "2019",
month = "May",
eid = {arXiv:1905.10124},
pages = {arXiv:1905.10124},
archivePrefix = {arXiv},
eprint = {1905.10124},
primaryClass = {stat.ML}
}
```

### Prerequisites

* Numpy (>=1.11)
* Matplotlib (>=1.5)
* Pytorch (>= 1.0.1)
* For Optimal transport [Python Optimal Transport](https://pot.readthedocs.io/en/stable/) POT (>=0.5.1)

For examples with RISGW:
* [Pymanopt](https://pymanopt.github.io)

### What is included ?


* SGW function both in CPU and GPU (with Pytorch):

<p align="center">
<img src="https://github.com/tvayer/SGW/blob/master/sgw.png" width="340" >
</p>

* Entropic Gromov-Wasserstein in Pytorch.

* Runtimes comparaison with Gromov-Wasserstein of [POT](https://github.com/rflamary/POT), Entropic Gromov-Wasserstein, e.g to calculate the runtimes:

```
python3 runtime.py -p '../res' -ln 200 500 1000 -pr 10 20
```

To plot the results e.g:

```
python plot_runtimes.py -p '../res/runtime_2019_10_16_14_26_32/'
```

* A demo notebook:
- [sgw_example.ipynb](sgw_example.ipynb): SGW between random measures and 3D meashes

* An example of optimization on the Stiefel manifold for computing RISGW. This implementation is a CPU implementation using autograd and is not efficient for large scale applications.

```
python run_rot_scale.py
```

### What will be added ?
* Some works on RISGW for larger applications.
* Integration of SGW in the POT library [1]


### Authors

* [Titouan Vayer](https://github.com/tvayer)
* [Rémi Flamary](https://github.com/rflamary)
* [Romain Tavenard](https://github.com/rtavenar)
* [Laetitia Chapel](https://github.com/lchapel)
* [Nicolas Courty](https://github.com/ncourty)


## References

[1] Flamary Rémi and Courty Nicolas [POT Python Optimal Transport library](https://github.com/rflamary/POT)

[2] Facundo Mémoli [Gromov–Wasserstein Distances and the Metric Approach to Object Matching](https://media.adelaide.edu.au/acvt/Publications/2011/2011-Gromov%E2%80%93Wasserstein%20Distances%20and%20the%20Metric%20Approach%20to%20Object%20Matching.pdf)
229 changes: 0 additions & 229 deletions SGW example.ipynb

This file was deleted.

130 changes: 66 additions & 64 deletions expe_paper/plot_runtimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,76 +8,78 @@
import pylab as pl
import torch
import matplotlib as mpl
import argparse

#%%
path='../res/runtime_2019_04_23_14_56_55/'

d=torch.load(path+'runtime.pt',map_location='cpu')
if __name__ == '__main__':

all_samples=d['all_samples']
""" Plot the runtimes
----------
path : path to the results previously calculated
#%%
legen=[]

pl.figure(figsize=(15,8))

pl.subplot(1,2,1)

s=60
fs=12
colors={True:'r',False:'black'}

pl.scatter(all_samples,d['t_all_gw'],c=[colors[x] for x in d['all_converged']],s=s)
pl.plot(all_samples, d['t_all_gw'],'r')
legen.append('Time GW POT')


norm = mpl.colors.Normalize(vmin=0, vmax=3)
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.YlOrBr)
cmap.set_array([])

norm = mpl.colors.Normalize(vmin=0, vmax=3)
cmap2 = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Purples)
cmap2.set_array([])

for i,proj in enumerate(d['projs']):
pl.scatter(all_samples,d[('t_all_sgw',proj)],s=s, c=cmap.to_rgba(i + 1))
pl.plot(all_samples, d[('t_all_sgw',proj)], c=cmap.to_rgba(i + 1))
legen.append('Time SGW n_proj={}'.format(proj))
Example
----------
python plot_runtimes.py -p '../res/runtime_2019_10_16_14_26_32/'
for i,proj in enumerate(d['projs']):
pl.scatter(all_samples,d[('t_all_sgw_numpy',proj)],s=s, c=cmap2.to_rgba(i + 1))
pl.plot(all_samples, d[('t_all_sgw_numpy',proj)], c=cmap2.to_rgba(i + 1))
legen.append('Time SGW numpy n_proj={}'.format(proj))
"""

pl.legend(legen,loc='upper left',fontsize=fs)
pl.ylabel('Time in s. Semi log axis')
pl.xlabel('Number of samples in each distri')
pl.xlim(0,max(all_samples)+10)

pl.subplot(1,2,2)
s=60
colors={True:'r',False:'black'}

pl.scatter(all_samples,d['t_all_gw'],c=[colors[x] for x in d['all_converged']],s=s)
pl.plot(all_samples, d['t_all_gw'],'r')


for i,proj in enumerate(d['projs']):
pl.scatter(all_samples,d[('t_all_sgw',proj)],s=s, c=cmap.to_rgba(i + 1))
pl.plot(all_samples, d[('t_all_sgw',proj)], c=cmap.to_rgba(i + 1))

for i,proj in enumerate(d['projs']):
pl.scatter(all_samples,d[('t_all_sgw_numpy',proj)],s=s, c=cmap2.to_rgba(i + 1))
pl.plot(all_samples, d[('t_all_sgw_numpy',proj)], c=cmap2.to_rgba(i + 1))
pl.ylabel('Time in s')
pl.xlabel('Number of samples in each distri')
pl.xlim(0,max(all_samples)+10)
pl.ylim(0,0.1)
#pl.yscale('symlog')

pl.suptitle('Running time')
parser = argparse.ArgumentParser(description='Runtime')
parser.add_argument('-p','--path',type=str,help='Path to te results',required=True)
args = parser.parse_args()

path=args.path

d=torch.load(path+'runtime.pt',map_location='cpu')

all_samples=d['all_samples']

pl.savefig('../res/running_times.pdf')
pl.show()
legen=[]

fig = pl.figure(figsize=(15,8))
ax = pl.axes()

s=60
fs=12
colors={True:'r',False:'black'}

ax.scatter(all_samples,d['t_all_gw'],c=[colors[x] for x in d['all_converged']],s=s)
ax.plot(all_samples, d['t_all_gw'],'r')
legen.append('Time GW POT')


norm = mpl.colors.Normalize(vmin=0, vmax=3)
cmap = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.YlOrBr)
cmap.set_array([])

norm = mpl.colors.Normalize(vmin=0, vmax=3)
cmap2 = mpl.cm.ScalarMappable(norm=norm, cmap=mpl.cm.Purples)
cmap2.set_array([])

for i,proj in enumerate(d['projs']):
ax.scatter(all_samples,d[('t_all_sgw',proj)],s=s, c=cmap.to_rgba(i + 1))
ax.plot(all_samples, d[('t_all_sgw',proj)], c=cmap.to_rgba(i + 1))
legen.append('Time SGW Pytorch n_proj={}'.format(proj))

for i,proj in enumerate(d['projs']):
ax.scatter(all_samples,d[('t_all_sgw_numpy',proj)],s=s, c=cmap2.to_rgba(i + 1))
ax.plot(all_samples, d[('t_all_sgw_numpy',proj)], c=cmap2.to_rgba(i + 1))
legen.append('Time SGW numpy n_proj={}'.format(proj))

ax.legend(legen,loc='upper left',fontsize=fs)
ax.set_xscale("log")
ax.set_yscale("log")

ax.set_ylabel('Seconds',fontsize=fs)
ax.set_xlabel('Number of samples n in each distribution',fontsize=fs-3)

ax.set_title('Running time',fontsize=fs)

pl.xticks(fontsize=20)
pl.yticks(fontsize=20)

pl.title('Running time')

pl.savefig('../res/running_times.pdf')
pl.show()

Loading

0 comments on commit 78e0203

Please sign in to comment.