Skip to content

Commit 47c5925

Browse files
[MRG] srFGW barycenters (#659)
* init commit - integrating sr(F)GW barycenter * correct asymmetries in srgw * fix tests srFGW bary * fix pep8 * complete tests for srFGW barycenters and utils * last updates * update barycenter update functions and remove old ones * take review into account * fix pep8 * ot/gromov/__init__.py * update test
1 parent d0849a4 commit 47c5925

File tree

10 files changed

+713
-227
lines changed

10 files changed

+713
-227
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ POT provides the following generic OT solvers (links to examples):
4646
* [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/plot_compute_wasserstein_circle.html) [44, 45]
4747
* [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46]
4848
* [Graph Dictionary Learning solvers](https://pythonot.github.io/auto_examples/gromov/plot_gromov_wasserstein_dictionary_learning.html) [38].
49-
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) (exact and regularized [48]).
49+
* [Semi-relaxed (Fused) Gromov-Wasserstein divergences](https://pythonot.github.io/auto_examples/gromov/plot_semirelaxed_fgw.html) with corresponding barycenter solvers (exact and regularized [48]).
5050
* [Quantized (Fused) Gromov-Wasserstein distances](https://pythonot.github.io/auto_examples/gromov/plot_quantized_gromov_wasserstein.html) [68].
5151
* [Efficient Discrete Multi Marginal Optimal Transport Regularization](https://pythonot.github.io/auto_examples/others/plot_demd_gradient_minimize.html) [50].
5252
* [Several backends](https://pythonot.github.io/quickstart.html#solving-ot-with-multiple-backends) for easy use of POT with [Pytorch](https://pytorch.org/)/[jax](https://github.com/google/jax)/[Numpy](https://numpy.org/)/[Cupy](https://cupy.dev/)/[Tensorflow](https://www.tensorflow.org/) arrays.

RELEASES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#### New features
66
- Add feature `mass=True` for `nx.kl_div` (PR #654)
7+
- Add feature `semirelaxed_fgw_barycenters` and generic FGW-related barycenter updates `update_barycenter_structure` and `update_barycenter_feature` (PR #659)
78

89
#### Closed issues
910

ot/gromov/__init__.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@
1111

1212
# All submodules and packages
1313
from ._utils import (init_matrix, tensor_product, gwloss, gwggrad,
14-
update_square_loss, update_kl_loss, update_feature_matrix,
15-
init_matrix_semirelaxed)
14+
init_matrix_semirelaxed,
15+
update_barycenter_structure, update_barycenter_feature,
16+
)
1617

1718
from ._gw import (gromov_wasserstein, gromov_wasserstein2,
1819
fused_gromov_wasserstein, fused_gromov_wasserstein2,
@@ -40,14 +41,16 @@
4041
entropic_semirelaxed_gromov_wasserstein,
4142
entropic_semirelaxed_gromov_wasserstein2,
4243
entropic_semirelaxed_fused_gromov_wasserstein,
43-
entropic_semirelaxed_fused_gromov_wasserstein2)
44+
entropic_semirelaxed_fused_gromov_wasserstein2,
45+
semirelaxed_fgw_barycenters)
4446

4547
from ._dictionary import (gromov_wasserstein_dictionary_learning,
4648
gromov_wasserstein_linear_unmixing,
4749
fused_gromov_wasserstein_dictionary_learning,
4850
fused_gromov_wasserstein_linear_unmixing)
4951

50-
from ._lowrank import (_flat_product_operator, lowrank_gromov_wasserstein_samples)
52+
from ._lowrank import (_flat_product_operator,
53+
lowrank_gromov_wasserstein_samples)
5154

5255

5356
from ._quantized import (quantized_fused_gromov_wasserstein_partitioned,
@@ -60,8 +63,9 @@
6063
quantized_fused_gromov_wasserstein_samples
6164
)
6265

63-
__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad', 'update_square_loss',
64-
'update_kl_loss', 'update_feature_matrix', 'init_matrix_semirelaxed',
66+
__all__ = ['init_matrix', 'tensor_product', 'gwloss', 'gwggrad',
67+
'init_matrix_semirelaxed',
68+
'update_barycenter_structure', 'update_barycenter_feature',
6569
'gromov_wasserstein', 'gromov_wasserstein2', 'fused_gromov_wasserstein',
6670
'fused_gromov_wasserstein2', 'solve_gromov_linesearch', 'gromov_barycenters',
6771
'fgw_barycenters', 'entropic_gromov_wasserstein', 'entropic_gromov_wasserstein2',
@@ -80,4 +84,5 @@
8084
'quantized_fused_gromov_wasserstein_partitioned', 'get_graph_partition',
8185
'get_graph_representants', 'format_partitioned_graph',
8286
'quantized_fused_gromov_wasserstein', 'get_partition_and_representants_samples',
83-
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples']
87+
'format_partitioned_samples', 'quantized_fused_gromov_wasserstein_samples',
88+
'semirelaxed_fgw_barycenters']

ot/gromov/_bregman.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ..backend import get_backend
2020

2121
from ._utils import init_matrix, gwloss, gwggrad
22-
from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
22+
from ._utils import update_barycenter_structure, update_barycenter_feature
2323

2424

2525
def entropic_gromov_wasserstein(
@@ -807,10 +807,8 @@ def entropic_gromov_barycenters(
807807
curr_loss = np.sum([output[1]['gw_dist'] for output in res])
808808

809809
# update barycenters
810-
if loss_fun == 'square_loss':
811-
C = update_square_loss(p, lambdas, T, Cs, nx)
812-
elif loss_fun == 'kl_loss':
813-
C = update_kl_loss(p, lambdas, T, Cs, nx)
810+
C = update_barycenter_structure(
811+
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)
814812

815813
# update convergence criterion
816814
if stop_criterion == 'barycenter':
@@ -1651,13 +1649,14 @@ def entropic_fused_gromov_barycenters(
16511649
# Initialization of C : random euclidean distance matrix (if not provided by user)
16521650
if fixed_structure:
16531651
if init_C is None:
1654-
raise UndefinedParameter('If C is fixed it must be initialized')
1652+
raise UndefinedParameter(
1653+
'If C is fixed it must be provided in init_C')
16551654
else:
16561655
C = init_C
16571656
else:
16581657
if init_C is None:
1659-
generator = check_random_state(random_state)
1660-
xalea = generator.randn(N, 2)
1658+
rng = check_random_state(random_state)
1659+
xalea = rng.randn(N, 2)
16611660
C = dist(xalea, xalea)
16621661
C = nx.from_numpy(C, type_as=ps[0])
16631662
else:
@@ -1666,7 +1665,8 @@ def entropic_fused_gromov_barycenters(
16661665
# Initialization of Y
16671666
if fixed_features:
16681667
if init_Y is None:
1669-
raise UndefinedParameter('If Y is fixed it must be initialized')
1668+
raise UndefinedParameter(
1669+
'If Y is fixed it must be provided in init_Y')
16701670
else:
16711671
Y = init_Y
16721672
else:
@@ -1681,20 +1681,12 @@ def entropic_fused_gromov_barycenters(
16811681
if warmstartT:
16821682
T = [None] * S
16831683

1684-
cpt = 0
1685-
16861684
if stop_criterion == 'barycenter':
16871685
inner_log = False
1688-
err_feature = 1e15
1689-
err_structure = 1e15
1690-
err_rel_loss = 0.
16911686

16921687
else:
16931688
inner_log = True
1694-
err_feature = 0.
1695-
err_structure = 0.
16961689
curr_loss = 1e15
1697-
err_rel_loss = 1e15
16981690

16991691
if log:
17001692
log_ = {}
@@ -1706,7 +1698,8 @@ def entropic_fused_gromov_barycenters(
17061698
log_['loss'] = []
17071699
log_['err_rel_loss'] = []
17081700

1709-
while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
1701+
for cpt in range(max_iter): # break if specified errors are below tol.
1702+
17101703
if stop_criterion == 'barycenter':
17111704
Cprev = C
17121705
Yprev = Y
@@ -1732,16 +1725,14 @@ def entropic_fused_gromov_barycenters(
17321725

17331726
# update barycenters
17341727
if not fixed_features:
1735-
Ys_temp = [y.T for y in Ys]
1736-
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
1728+
X = update_barycenter_feature(
1729+
T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx)
1730+
17371731
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
17381732

17391733
if not fixed_structure:
1740-
if loss_fun == 'square_loss':
1741-
C = update_square_loss(p, lambdas, T, Cs, nx)
1742-
1743-
elif loss_fun == 'kl_loss':
1744-
C = update_kl_loss(p, lambdas, T, Cs, nx)
1734+
C = update_barycenter_structure(
1735+
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)
17451736

17461737
# update convergence criterion
17471738
if stop_criterion == 'barycenter':
@@ -1761,6 +1752,9 @@ def entropic_fused_gromov_barycenters(
17611752
'It.', 'Err') + '\n' + '-' * 19)
17621753
print('{:5d}|{:8e}|'.format(cpt, err_structure))
17631754
print('{:5d}|{:8e}|'.format(cpt, err_feature))
1755+
1756+
if (err_feature <= tol) or (err_structure <= tol):
1757+
break
17641758
else:
17651759
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
17661760
if log:
@@ -1773,7 +1767,8 @@ def entropic_fused_gromov_barycenters(
17731767
'It.', 'Err') + '\n' + '-' * 19)
17741768
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))
17751769

1776-
cpt += 1
1770+
if err_rel_loss <= tol:
1771+
break
17771772

17781773
if log:
17791774
log_['T'] = T

ot/gromov/_gw.py

Lines changed: 27 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ..backend import get_backend, NumpyBackend
2222

2323
from ._utils import init_matrix, gwloss, gwggrad
24-
from ._utils import update_square_loss, update_kl_loss, update_feature_matrix
24+
from ._utils import update_barycenter_structure, update_barycenter_feature
2525

2626

2727
def gromov_wasserstein(C1, C2, p=None, q=None, loss_fun='square_loss', symmetric=None, log=False, armijo=False, G0=None,
@@ -833,17 +833,14 @@ def gromov_barycenters(
833833

834834
# Initialization of C : random SPD matrix (if not provided by user)
835835
if init_C is None:
836-
generator = check_random_state(random_state)
837-
xalea = generator.randn(N, 2)
836+
rng = check_random_state(random_state)
837+
xalea = rng.randn(N, 2)
838838
C = dist(xalea, xalea)
839839
C /= C.max()
840840
C = nx.from_numpy(C, type_as=p)
841841
else:
842842
C = init_C
843843

844-
cpt = 0
845-
err = 1e15 # either the error on 'barycenter' or 'loss'
846-
847844
if warmstartT:
848845
T = [None] * S
849846

@@ -859,7 +856,8 @@ def gromov_barycenters(
859856
if stop_criterion == 'loss':
860857
log_['loss'] = []
861858

862-
while (err > tol and cpt < max_iter):
859+
for cpt in range(max_iter):
860+
863861
if stop_criterion == 'barycenter':
864862
Cprev = C
865863
else:
@@ -883,11 +881,8 @@ def gromov_barycenters(
883881
curr_loss = np.sum([output[1]['gw_dist'] for output in res])
884882

885883
# update barycenters
886-
if loss_fun == 'square_loss':
887-
C = update_square_loss(p, lambdas, T, Cs, nx)
888-
889-
elif loss_fun == 'kl_loss':
890-
C = update_kl_loss(p, lambdas, T, Cs, nx)
884+
C = update_barycenter_structure(
885+
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)
891886

892887
# update convergence criterion
893888
if stop_criterion == 'barycenter':
@@ -907,7 +902,8 @@ def gromov_barycenters(
907902
'It.', 'Err') + '\n' + '-' * 19)
908903
print('{:5d}|{:8e}|'.format(cpt, err))
909904

910-
cpt += 1
905+
if err <= tol:
906+
break
911907

912908
if log:
913909
log_['T'] = T
@@ -1046,21 +1042,23 @@ def fgw_barycenters(
10461042

10471043
if fixed_structure:
10481044
if init_C is None:
1049-
raise UndefinedParameter('If C is fixed it must be initialized')
1045+
raise UndefinedParameter(
1046+
'If C is fixed it must be provided in init_C')
10501047
else:
10511048
C = init_C
10521049
else:
10531050
if init_C is None:
1054-
generator = check_random_state(random_state)
1055-
xalea = generator.randn(N, 2)
1051+
rng = check_random_state(random_state)
1052+
xalea = rng.randn(N, 2)
10561053
C = dist(xalea, xalea)
10571054
C = nx.from_numpy(C, type_as=ps[0])
10581055
else:
10591056
C = init_C
10601057

10611058
if fixed_features:
10621059
if init_X is None:
1063-
raise UndefinedParameter('If X is fixed it must be initialized')
1060+
raise UndefinedParameter(
1061+
'If X is fixed it must be provided in init_X')
10641062
else:
10651063
X = init_X
10661064
else:
@@ -1075,20 +1073,12 @@ def fgw_barycenters(
10751073
if warmstartT:
10761074
T = [None] * S
10771075

1078-
cpt = 0
1079-
10801076
if stop_criterion == 'barycenter':
10811077
inner_log = False
1082-
err_feature = 1e15
1083-
err_structure = 1e15
1084-
err_rel_loss = 0.
10851078

10861079
else:
10871080
inner_log = True
1088-
err_feature = 0.
1089-
err_structure = 0.
10901081
curr_loss = 1e15
1091-
err_rel_loss = 1e15
10921082

10931083
if log:
10941084
log_ = {}
@@ -1100,7 +1090,8 @@ def fgw_barycenters(
11001090
log_['loss'] = []
11011091
log_['err_rel_loss'] = []
11021092

1103-
while ((err_feature > tol or err_structure > tol or err_rel_loss > tol) and cpt < max_iter):
1093+
for cpt in range(max_iter): # break if specified errors are below tol.
1094+
11041095
if stop_criterion == 'barycenter':
11051096
Cprev = C
11061097
Xprev = X
@@ -1126,16 +1117,14 @@ def fgw_barycenters(
11261117

11271118
# update barycenters
11281119
if not fixed_features:
1129-
Ys_temp = [y.T for y in Ys]
1130-
X = update_feature_matrix(lambdas, Ys_temp, T, p, nx).T
1120+
X = update_barycenter_feature(
1121+
T, Ys, lambdas, p, target=False, check_zeros=False, nx=nx)
1122+
11311123
Ms = [dist(X, Ys[s]) for s in range(len(Ys))]
11321124

11331125
if not fixed_structure:
1134-
if loss_fun == 'square_loss':
1135-
C = update_square_loss(p, lambdas, T, Cs, nx)
1136-
1137-
elif loss_fun == 'kl_loss':
1138-
C = update_kl_loss(p, lambdas, T, Cs, nx)
1126+
C = update_barycenter_structure(
1127+
T, Cs, lambdas, p, loss_fun, target=False, check_zeros=False, nx=nx)
11391128

11401129
# update convergence criterion
11411130
if stop_criterion == 'barycenter':
@@ -1155,6 +1144,9 @@ def fgw_barycenters(
11551144
'It.', 'Err') + '\n' + '-' * 19)
11561145
print('{:5d}|{:8e}|'.format(cpt, err_structure))
11571146
print('{:5d}|{:8e}|'.format(cpt, err_feature))
1147+
1148+
if (err_feature <= tol) or (err_structure <= tol):
1149+
break
11581150
else:
11591151
err_rel_loss = abs(curr_loss - prev_loss) / prev_loss if prev_loss != 0. else np.nan
11601152
if log:
@@ -1167,7 +1159,8 @@ def fgw_barycenters(
11671159
'It.', 'Err') + '\n' + '-' * 19)
11681160
print('{:5d}|{:8e}|'.format(cpt, err_rel_loss))
11691161

1170-
cpt += 1
1162+
if err_rel_loss <= tol:
1163+
break
11711164

11721165
if log:
11731166
log_['T'] = T

0 commit comments

Comments
 (0)