Skip to content

Commit

Permalink
Kernel_samples is done.
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanfeydy committed Mar 4, 2019
1 parent cac4c35 commit c4853ad
Show file tree
Hide file tree
Showing 44 changed files with 347 additions and 87 deletions.
2 changes: 2 additions & 0 deletions doc/api/geomloss.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Maths and algorithms
=====================
3 changes: 3 additions & 0 deletions doc/api/install.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Get started
============

13 changes: 13 additions & 0 deletions doc/api/pytorch-api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
PyTorch API
============

:mod:`geomloss` - :doc:`Geometric Loss functions <geomloss>`,
with full support of PyTorch's ``autograd`` engine:

.. currentmodule:: geomloss
.. autosummary::

SamplesLoss

.. automodule:: geomloss
:members:
17 changes: 10 additions & 7 deletions doc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ The **GeomLoss** library provides efficient GPU implementations for:

These loss functions, defined between positive measures,
are available through the custom `PyTorch <https://pytorch.org/>`_ layers
``SamplesLoss``, ``ImagesLoss`` and ``VolumesLoss``
:class:`SamplesLoss <geomloss.SamplesLoss>`,
:class:`ImagesLoss <geomloss.ImagesLoss>` and
:class:`VolumesLoss <geomloss.VolumesLoss>`
which allow you to work with weighted **point clouds** (of any dimension),
**density maps** and **volumetric segmentation masks**.
Geometric losses come with three backends each:
Expand All @@ -46,6 +48,7 @@ A typical sample of code looks like:
# Apply it to large point clouds in 3D
x = torch.randn(100000, 3, requires_grad=True).cuda()
y = torch.randn(200000, 3).cuda()
L = loss(x, y) # By default, use constant weights = 1/number of samples
g_x, = torch.autograd.grad(L, [x]) # GeomLoss fully supports autograd
Expand All @@ -63,7 +66,7 @@ algorithms. It provides:
* Efficient computation of the **gradients**, which bypasses the naive
backpropagation algorithm.
* Support for `unbalanced <https://link.springer.com/article/10.1007/s00222-017-0759-8>`_
`Optimal Transport <https://arxiv.org/pdf/1506.06430.pdf>`_,
Optimal `Transport <https://arxiv.org/pdf/1506.06430.pdf>`_,
with a softening of the marginal constraints
through a maximum **reach** parameter.
* Support for the `ε-scaling heuristic <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.228.9750&rep=rep1&type=pdf>`_
Expand All @@ -73,7 +76,7 @@ algorithms. It provides:
the standard `SoftAssign/Sinkhorn algorithm <https://arxiv.org/abs/1306.0895>`_.


Note, however, that ``SamplesLoss`` does *not* implement the
Note, however, that :class:`SamplesLoss <geomloss.SamplesLoss>` does *not* implement the
`Fast Multipole <http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.129.7826&rep=rep1&type=pdf>`_
or `Fast Gauss <http://users.umiacs.umd.edu/~morariu/figtree/>`_ transforms.
If you are aware of a well-packaged implementation
Expand All @@ -97,8 +100,8 @@ to **machine learning** (kernel methods, GANs...)
and **image processing**.
Details and examples are provided below:

* :doc:`Documentation <api/geomloss>`.
* :doc:`API <api/api>`.
* :doc:`Maths and algorithms <api/geomloss>`
* :doc:`PyTorch API <api/pytorch-api>`
* `Source code <https://github.com/jeanfeydy/geomloss>`_
* :doc:`Simple examples <_auto_examples/index>`
* :doc:`Advanced tutorials <_auto_tutorials/index>`
Expand All @@ -119,9 +122,9 @@ Table of contents
.. toctree::
:maxdepth: 2

api/installation
api/install
api/geomloss
api/api
api/pytorch-api
_auto_examples/index
_auto_tutorials/index

Expand Down
19 changes: 0 additions & 19 deletions doc/readme.md

This file was deleted.

4 changes: 4 additions & 0 deletions geomloss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import sys, os.path

__version__ = '0.1'

from .samples_loss import SamplesLoss

__all__ = sorted(["SamplesLoss"])
2 changes: 1 addition & 1 deletion geomloss/examples/README.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Gallery of examples
====================

These short examples showcase the features of the ``pykeops`` module.
These short examples showcase the features of the ``geomloss`` module.
Binary file added geomloss/examples/data/OAI_a.nii.gz
Binary file not shown.
Binary file added geomloss/examples/data/OAI_b.nii.gz
Binary file not shown.
Binary file added geomloss/examples/data/Ring_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/Ring_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/Worm_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/Worm_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/bar_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/bar_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/bar_c.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/blobs_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/blobs_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/crescent_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/crescent_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/density_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/density_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/ell_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/ell_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/knee_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/knee_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/moon_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/moon_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/morse_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/morse_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/reach_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/reach_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/ring_a.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/ring_b.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added geomloss/examples/data/slope_a.png
Binary file added geomloss/examples/data/slope_b.png
Binary file added geomloss/examples/data/worm_a.png
Binary file added geomloss/examples/data/worm_b.png
162 changes: 162 additions & 0 deletions geomloss/examples/plot_gradient_flows_1D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
"""
Gradient flows in 1D
====================
This example showcases the properties of **kernel MMDs**, **Hausdorff**
and **Sinkhorn** divergences on a simple toy problem:
the registration of an interval onto another.
"""



##############################################
# Setup
# ---------------------

import numpy as np
import matplotlib.pyplot as plt
from sklearn.neighbors import KernelDensity # display as density curves

import torch
from geomloss import SamplesLoss

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

###############################################
# Display routine
# ~~~~~~~~~~~~~~~~~

t_plot = np.linspace(-0.1, 1.1, 1000)[:,np.newaxis]

def display_samples(ax, x, color):
"""Displays samples on the unit interval using a density curve."""
kde = KernelDensity(kernel='gaussian', bandwidth= .005 ).fit(x.data.cpu().numpy())
dens = np.exp( kde.score_samples(t_plot) )
dens[0] = 0 ; dens[-1] = 0
ax.fill(t_plot, dens, color=color)


###############################################
# Dataset
# ~~~~~~~~~~~~~~~~~~
#
# Our source and target samples are drawn from intervals of the real line
# and define discrete probability measures:
#
# .. math::
# \alpha ~=~ \frac{1}{N}\sum_{i=1}^N \delta_{x_i}, ~~~
# \beta ~=~ \frac{1}{M}\sum_{j=1}^M \delta_{y_j}.

N, M = (250, 250) if not use_cuda else (10000, 10000)

t_i = torch.linspace(0, 1, N).type(dtype).view(-1,1)
t_j = torch.linspace(0, 1, M).type(dtype).view(-1,1)

X_i, Y_j = 0.2 * t_i, 0.4 * t_j + 0.6

###############################################
# Wasserstein gradient flow
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# To study the influence of the :math:`\text{Loss}` function in measure-fitting
# applications, we perform gradient descent on the positions
# :math:`x_i` of the samples that make up :math:`\alpha`
# as we minimize the cost :math:`\text{Loss}(\alpha,\beta)`.
# This procedure can be understood as a discrete (Lagrangian)
# `Wasserstein gradient flow <https://arxiv.org/abs/1609.03890>`_
# and as a "model-free" machine learning program, where
# we optimize directly on the samples' locations.

def gradient_flow(loss, lr=.05) :
"""Flows along the gradient of the cost function, using a simple Euler scheme.
Parameters:
loss ((x_i,y_j) -> torch float number):
Real-valued loss function.
lr (float, default = .05):
Learning rate, i.e. time step.
"""

# Parameters for the gradient descent
Nsteps = int(5/lr)+1
display_its = [int(t/lr) for t in [0, .25, .50, 1., 2., 5.]]

# Make sure that we won't modify the reference samples
x_i, y_j = X_i.clone(), Y_j.clone()

# We're going to perform gradient descent on Loss(α, β)
# wrt. the positions x_i of the diracs masses that make up α:
x_i.requires_grad = True

plt.figure(figsize=(12,8)) ; k = 1
for i in range(Nsteps): # Euler scheme ===============
# Compute cost and gradient
L_αβ = loss(x_i, y_j)
[g] = torch.autograd.grad(L_αβ, [x_i])

if i in display_its : # display
ax = plt.subplot(2,3,k) ; k = k+1

display_samples(ax, y_j, (.55,.55,.95))
display_samples(ax, x_i, (.95,.55,.55))

ax.set_title("t = {:1.2f}".format(lr*i))
plt.axis([-.1,1.1,-.1,5.5])
plt.xticks([], []); plt.yticks([], [])
plt.tight_layout()

# in-place modification of the tensor's values
x_i.data -= lr * len(x_i) * g



###############################################
# Kernel norms, MMDs
# ------------------------------------
#
# Gaussian MMD
# ~~~~~~~~~~~~~~~
#
# The smooth Gaussian kernel
# :math:`k(x,y) = \exp(-\|x-y\|^2/2\sigma^2)`
# is blind to details which are smaller than the blurring scale :math:`\sigma`:
# its gradient stops being informative when :math:`\alpha`
# and :math:`\beta` become equal "up to the high frequencies".

gradient_flow( SamplesLoss("gaussian", blur=.5) )


###############################################
# On the other hand, if the radius :math:`\sigma`
# of the kernel is too small, particles :math:`x_i`
# won't be attracted to the target, and may **spread out**
# to minimize the auto-correlation term
# :math:`\tfrac{1}{2}\langle \alpha, k\star\alpha\rangle`.

gradient_flow( SamplesLoss("gaussian", blur=.1) )


###############################################
# Laplacian MMD
# ~~~~~~~~~~~~~~~~
#
# The pointy exponential kernel
# :math:`k(x,y) = \exp(-\|x-y\|/\sigma)`
# tends to provide a better fit, but tends to zero at infinity
# and is still very prone to **screening artifacts**.

gradient_flow( SamplesLoss("laplacian", blur=.1) )


###############################################
# Energy Distance MMD
# ~~~~~~~~~~~~~~~~~~~~~~
#
# The scale-equivariant kernel
# :math:`k(x,y)=-\|x-y\|` provides a robust baseline:
# the Energy Distance.


# sphinx_gallery_thumbnail_number = 4
gradient_flow( SamplesLoss("energy") )
15 changes: 15 additions & 0 deletions geomloss/examples/plot_gradient_flows_2D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Gradient flows in 2D
====================
This example showcases the properties of kernel MMDs, Hausdorff
and Sinkhorn divergences on a simple toy problem:
the registration of two blobs onto another.
"""


################################
# Coucou!
#

print("Done.")
53 changes: 53 additions & 0 deletions geomloss/examples/plot_profile.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Profiling the GeomLoss routines
===================================
This example explains how to **profile** the geometric losses
to select the backend and truncation/scaling values that
are best suited to your data.
"""



##############################################
# Setup
# ---------------------

import torch
from geomloss import SamplesLoss
from time import time

use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

##############################################
# Sample points on the unit sphere:
#

N, M = (250, 250) if not use_cuda else (5000, 5000)
x, y = torch.randn(N,3).type(dtype), torch.randn(M,3).type(dtype)
x, y = x / x.norm(dim=1,keepdim=True), y / y.norm(dim=1,keepdim=True)
x.requires_grad = True

##########################################################
# Use the PyTorch profiler to output Chrome trace files:

for backend in ["tensorized", "online", "multiscale"]:
with torch.autograd.profiler.profile(use_cuda=use_cuda) as prof:
loss = SamplesLoss("gaussian", blur=.1, backend=backend, truncate=3)
t_0 = time()
L_xy = loss(x, y)
L_xy.backward()
t_1 = time()
print("{:.2f}s, cost = {:.6f}".format( t_1-t_0, L_xy.item()) )

prof.export_chrome_trace("output/profile_"+backend+".json")


######################################################################
# Now, all you have to do is to open the "Easter egg" address
# ``chrome://tracing`` in Google Chrome/Chromium,
# and load the ``profile_*`` files one after
# another. Enjoy :-)

print("Done.")
Loading

0 comments on commit c4853ad

Please sign in to comment.