Skip to content

Commit

Permalink
Horse
Browse files Browse the repository at this point in the history
  • Loading branch information
tvayer committed Oct 17, 2019
1 parent 38276e3 commit fb936ed
Show file tree
Hide file tree
Showing 54 changed files with 119 additions and 102 deletions.
6 changes: 6 additions & 0 deletions .ipynb_checkpoints/Untitled-checkpoint.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"cells": [],
"metadata": {},
"nbformat": 4,
"nbformat_minor": 1
}
5 changes: 2 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ For examples with RISGW:

* SGW function both in CPU and GPU (with Pytorch):

<p align="center">
<img src="https://github.com/tvayer/SGW/blob/master/sgw.png" width="600" >
</p>
![](horses.gif)


* Entropic Gromov-Wasserstein in Pytorch.

Expand Down
Binary file added horse.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
81 changes: 0 additions & 81 deletions lib/sgw_numpy2.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,6 @@ def dist(x1, x2):
x2p2 = np.sum(np.square(x2), 1)
return x1p2.reshape((-1, 1)) + x2p2.reshape((1, -1)) - 2 * np.dot(x1, x2.T)

def sw0(xs,xt,P): # xs.shape(1)<=xt.shape(1), P.shape[0]=xt.shape[1]

L=P.shape[1]
n=xs.shape[0]
xsp=np.dot(xs,P[:xs.shape[1],:])
xtp=np.dot(xt,P)

#xsp=np.sort(xsp,0)
#xtp=np.sort(xtp,0)

res=0

for l in range(L):

x1=np.sort(xsp[:,l])
x2=np.sort(xtp[:,l])

D1=dist(x1[:,None],x2[:,None])

l1=np.sum(np.square(D1))/n/n
#l2=np.sum(D1[::-1,::-1])/n/n

res+=l1

return res/L

def sgw0(xs,xt,P): # xs.shape(1)<=xt.shape(1), P.shape[0]=xt.shape[1]

Expand All @@ -62,9 +37,6 @@ def sgw0(xs,xt,P): # xs.shape(1)<=xt.shape(1), P.shape[0]=xt.shape[1]
xsp=np.dot(xs,P[:xs.shape[1],:])
xtp=np.dot(xt,P)

#xsp=np.sort(xsp,0)
#xtp=np.sort(xtp,0)

res=0

for l in range(L):
Expand Down Expand Up @@ -189,61 +161,8 @@ def loss(delta):

return loss(Xopt), Xopt

def risw(xs,xt,P):

def loss(delta):
return sw0(np.dot(xs,delta),xt,P)

manifold = Stiefel(xt.shape[1], xs.shape[1])


problem = Problem(manifold=manifold, cost=loss)

# (3) Instantiate a Pymanopt solver
solver = SteepestDescent(logverbosity=0)

# let Pymanopt do the rest
Xopt = solver.solve(problem,x=np.eye( xs.shape[1],xt.shape[1]))

return loss(Xopt)

def risw2(xs,xt,P,X0=None):

def loss(delta):
return sw0(np.dot(xs,delta),xt,P)

if X0 is None:
X0=np.eye( xs.shape[1],xt.shape[1])


manifold = Stiefel(xt.shape[1], xs.shape[1])


problem = Problem(manifold=manifold, cost=loss)

# (3) Instantiate a Pymanopt solver
solver = SteepestDescent(logverbosity=0)

# let Pymanopt do the rest
Xopt = solver.solve(problem,x=X0)

return loss(Xopt), Xopt


def gw0(xs,xt): # xs.shape(1)<=xt.shape(1), P.shape[0]=xt.shape[1]

n=xs.shape[0]
u=ot.unif(n)
D1=dist(xs,xs)
D2=dist(xt,xt)

return ot.gromov.gromov_wasserstein2(D1,D2,u,u,'square_loss')[0]

def w0(xs,xt): # xs.shape(1)<=xt.shape(1), P.shape[0]=xt.shape[1]

n=xs.shape[0]
u=ot.unif(n)
D1=dist(xs,xt)

return ot.emd2(u,u,D1)

Binary file added res/0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/10.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/11.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/12.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/13.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/14.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/15.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/16.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/17.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/18.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/19.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/20.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/21.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/22.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/23.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/24.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/25.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/26.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/27.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/28.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/29.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added res/30.png
Binary file added res/31.png
Binary file added res/32.png
Binary file added res/33.png
Binary file added res/34.png
Binary file added res/35.png
Binary file added res/36.png
Binary file added res/37.png
Binary file added res/38.png
Binary file added res/39.png
Binary file added res/4.png
Binary file added res/40.png
Binary file added res/41.png
Binary file added res/42.png
Binary file added res/43.png
Binary file added res/44.png
Binary file added res/45.png
Binary file added res/46.png
Binary file added res/47.png
Binary file added res/5.png
Binary file added res/6.png
Binary file added res/7.png
Binary file added res/8.png
Binary file added res/9.png
Binary file removed sgw.png
Diff not rendered.
129 changes: 111 additions & 18 deletions sgw_example.ipynb

Large diffs are not rendered by default.

0 comments on commit fb936ed

Please sign in to comment.