Skip to content

Commit

Permalink
Merge branch 'images'
Browse files Browse the repository at this point in the history
  • Loading branch information
jeanfeydy committed Mar 29, 2022
2 parents e0c5236 + 630b6f8 commit 84e9961
Show file tree
Hide file tree
Showing 14 changed files with 1,417 additions and 315 deletions.
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,8 @@ geomloss/examples/brain_tractograms/data/*

geomloss/examples/optimal_transport/data/wasserstein_3D_models*
geomloss/examples/optimal_transport/output*

*.ipynb
launch.json
*doc.zip
*debug*.py
4 changes: 3 additions & 1 deletion geomloss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,7 @@
__version__ = "0.2.4"

from .samples_loss import SamplesLoss
from .wasserstein_barycenter_images import ImagesBarycenter
from .sinkhorn_images import sinkhorn_divergence

__all__ = sorted(["SamplesLoss"])
__all__ = sorted(["SamplesLoss, ImagesBarycenter"])
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def load_ply_file(fname):
#

import SimpleITK as sitk
from skimage.measure import marching_cubes_lewiner as marching_cubes
from skimage.measure import marching_cubes


def load_nii_file(fname, threshold=0.5):
Expand Down
13 changes: 5 additions & 8 deletions geomloss/examples/performances/plot_benchmarks_samplesloss_3D.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def bench_config(Loss, dev):
print("Backend : {}, Device : {} -------------".format(Loss.backend, dev))

times = []

def run_bench():
try:
Nloops = [100, 10, 1]
Expand All @@ -133,22 +133,20 @@ def run_bench():

except IndexError:
print("**\nToo slow !")



try:
run_bench()

except RuntimeError as err:
if str(err)[:4] == "CUDA":
print("**\nMemory overflow !")

else:
# CUDA memory overflows semi-break the internal
# torch state and may cause some strange bugs.
# In this case, best option is simply to re-launch
# the benchmark.
run_bench()


return times + (len(NS) - len(times)) * [np.nan]

Expand All @@ -161,7 +159,7 @@ def full_bench(loss, *args, **kwargs):
lines = [NS]
backends = ["tensorized", "online", "multiscale"]
for backend in backends:
Loss = SamplesLoss(*args, **kwargs, backend = backend)
Loss = SamplesLoss(*args, **kwargs, backend=backend)
lines.append(bench_config(Loss, "cuda" if use_cuda else "cpu"))

benches = np.array(lines).T
Expand Down Expand Up @@ -226,7 +224,6 @@ def full_bench(loss, *args, **kwargs):
full_bench(SamplesLoss, "sinkhorn", p=2, blur=0.05, diameter=1)



##############################################
# With a small blurring scale, at one hundredth of the
# configuration's diameter:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def load_image(fname):
return 1 - img


def draw_samples(fname, sampling, dtype=torch.FloatTensor):
def draw_samples(fname, sampling, dtype=dtype):
A = load_image(fname)
A = A[::sampling, ::sampling]
A[A <= 0] = 1e-8
Expand Down
2 changes: 2 additions & 0 deletions geomloss/examples/sinkhorn_multiscale/plot_transport_blur.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,10 @@ def rweight():
x_, _ = X_i.sort(dim=0)
y_, _ = Y_j.sort(dim=0)
true_wass = (0.5 / len(X_i)) * ((x_ - y_) ** 2).sum()
true_wass = true_wass.item()
# and when blur = +infinity:
mean_diff = 0.5 * ((X_i.mean(0) - Y_j.mean(0)) ** 2).sum()
mean_diff = mean_diff.item()

blurs = [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]
sink, bwass = [], []
Expand Down
43 changes: 29 additions & 14 deletions geomloss/kernel_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def double_grad(x):
# All backends
# ==============================================================================


def gaussian_kernel(x, y, blur=0.05, use_keops=False, ranges=None):
C2 = squared_distances(x / blur, y / blur, use_keops=use_keops)
K = (-C2 / 2).exp()
Expand Down Expand Up @@ -89,37 +90,51 @@ def energy_kernel(x, y, blur=None, use_keops=False, ranges=None):


def kernel_loss(
α, x, β, y, blur=0.05, kernel=None, name=None, potentials=False, use_keops=False,
ranges_xx=None, ranges_yy=None, ranges_xy=None, **kwargs
α,
x,
β,
y,
blur=0.05,
kernel=None,
name=None,
potentials=False,
use_keops=False,
ranges_xx=None,
ranges_yy=None,
ranges_xy=None,
**kwargs
):
if kernel is None:
kernel = kernel_routines[name]

# Center the point clouds just in case, to prevent numeric overflows:
# N.B.: This may break user-provided kernels and comes at a non-negligible
# N.B.: This may break user-provided kernels and comes at a non-negligible
# cost for small problems, so let's disable this by default.
# center = (x.mean(-2, keepdim=True) + y.mean(-2, keepdim=True)) / 2
# x, y = x - center, y - center

# (B,N,N) tensor
K_xx = kernel(double_grad(x), x.detach(), blur=blur, use_keops=use_keops, ranges=ranges_xx)
K_xx = kernel(
double_grad(x), x.detach(), blur=blur, use_keops=use_keops, ranges=ranges_xx
)
# (B,M,M) tensor
K_yy = kernel(double_grad(y), y.detach(), blur=blur, use_keops=use_keops, ranges=ranges_yy)
K_yy = kernel(
double_grad(y), y.detach(), blur=blur, use_keops=use_keops, ranges=ranges_yy
)
# (B,N,M) tensor
K_xy = kernel(x, y, blur=blur, use_keops=use_keops, ranges=ranges_xy)
K_xy = kernel(x, y, blur=blur, use_keops=use_keops, ranges=ranges_xy)

# (B,N,N) @ (B,N) = (B,N)
a_x = (K_xx @ α.detach().unsqueeze(-1)).squeeze(-1)
a_x = (K_xx @ α.detach().unsqueeze(-1)).squeeze(-1)
# (B,M,M) @ (B,M) = (B,M)
b_y = (K_yy @ β.detach().unsqueeze(-1)).squeeze(-1)
b_y = (K_yy @ β.detach().unsqueeze(-1)).squeeze(-1)
# (B,N,M) @ (B,M) = (B,N)
b_x = (K_xy @ β.unsqueeze(-1)).squeeze(-1)

b_x = (K_xy @ β.unsqueeze(-1)).squeeze(-1)

if potentials:
# (B,M,N) @ (B,N) = (B,M)
Kt = K_xy.t() if use_keops else K_xy.transpose(1, 2)
a_y = (Kt @ α.unsqueeze(-1)).squeeze(-1)
a_y = (Kt @ α.unsqueeze(-1)).squeeze(-1)
return a_x - b_x, b_y - a_y

else: # Return the Kernel norm. N.B.: we assume that 'kernel' is symmetric:
Expand All @@ -130,6 +145,7 @@ def kernel_loss(
- scal(α, b_x, batch=batch)
)


# ==============================================================================
# backend == "tensorized"
# ==============================================================================
Expand Down Expand Up @@ -188,14 +204,13 @@ def kernel_multiscale(
**kwargs
)

# Renormalize our point cloud so that blur = 1:
# Renormalize our point cloud so that blur = 1:
# Center the point clouds just in case, to prevent numeric overflows:
center = (x.mean(-2, keepdim=True) + y.mean(-2, keepdim=True)) / 2
x, y = x - center, y - center
x_ = x / blur
y_ = y / blur


# Don't forget to normalize the diameter too!
if cluster_scale is None:
D = x.shape[-1]
Expand Down
29 changes: 17 additions & 12 deletions geomloss/samples_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,21 @@ def forward(self, *args):
backend = "online" # Play it safe, without kernel truncation

# Check compatibility between the batchsize and the backend --------------------------

if backend in ["multiscale"]: # multiscale routines work on single measures
if B == 1:
α, x, β, y = α.squeeze(0), x.squeeze(0), β.squeeze(0), y.squeeze(0)
elif B > 1:
warnings.warn("The 'multiscale' backend do not support batchsize > 1. " \
+"Using 'tensorized' instead: beware of memory overflows!")
warnings.warn(
"The 'multiscale' backend do not support batchsize > 1. "
+ "Using 'tensorized' instead: beware of memory overflows!"
)
backend = "tensorized"

if B == 0 and backend in ["tensorized", "online"]: # tensorized and online routines work on batched tensors
if B == 0 and backend in [
"tensorized",
"online",
]: # tensorized and online routines work on batched tensors
α, x, β, y = α.unsqueeze(0), x.unsqueeze(0), β.unsqueeze(0), y.unsqueeze(0)

# Run --------------------------------------------------------------------------------
Expand Down Expand Up @@ -287,8 +292,12 @@ def forward(self, *args):

else: # Return a scalar cost value
if backend in ["multiscale"]: # KeOps backends return a single scalar value
if B == 0: return values # The user expects a scalar value
else: return values.view(-1) # The user expects a "batch list" of distances
if B == 0:
return values # The user expects a scalar value
else:
return values.view(
-1
) # The user expects a "batch list" of distances

else: # "tensorized" backend returns a "batch vector" of values
if B == 0:
Expand Down Expand Up @@ -402,17 +411,13 @@ def check_shapes(self, l_x, α, x, l_y, β, y):
B,
N,
D,
) = (
x.shape
)
) = x.shape
# Batchsize, number of "i" samples, dimension of the feature space
(
B2,
M,
_,
) = (
y.shape
)
) = y.shape
# Batchsize, number of "j" samples, dimension of the feature space
if B != B2:
raise ValueError("Samples 'x' and 'y' should have the same batchsize.")
Expand Down
Loading

0 comments on commit 84e9961

Please sign in to comment.