Skip to content

Commit

Permalink
Merge pull request GUDHI#807 from mglisse/fit-call
Browse files Browse the repository at this point in the history
[representations] __call__ and fit changes
  • Loading branch information
VincentRouvreau authored Apr 7, 2023
2 parents db8880a + ca716f9 commit 6dbf761
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 46 deletions.
22 changes: 10 additions & 12 deletions src/python/gudhi/representations/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,12 +368,6 @@ def __init__(self, order=1, internal_p=np.inf, mode="hera", delta=0.01, n_jobs=N
n_jobs (int): number of jobs to use for the computation. See :func:`pairwise_persistence_diagram_distances` for details.
"""
self.order, self.internal_p, self.mode = order, internal_p, mode
if mode == "pot":
self.metric = "pot_wasserstein"
elif mode == "hera":
self.metric = "hera_wasserstein"
else:
raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'")
self.delta = delta
self.n_jobs = n_jobs

Expand All @@ -385,6 +379,8 @@ def fit(self, X, y=None):
X (list of n x 2 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
if self.mode not in ("pot", "hera"):
raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'")
self.diagrams_ = X
return self

Expand All @@ -398,10 +394,10 @@ def transform(self, X):
Returns:
numpy array of shape (number of diagrams in **diagrams**) x (number of diagrams in X): matrix of pairwise Wasserstein distances.
"""
if self.metric == "hera_wasserstein":
Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, delta=self.delta, n_jobs=self.n_jobs)
if self.mode == "hera":
Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="hera_wasserstein", order=self.order, internal_p=self.internal_p, delta=self.delta, n_jobs=self.n_jobs)
else:
Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric=self.metric, order=self.order, internal_p=self.internal_p, matching=False, n_jobs=self.n_jobs)
Xfit = pairwise_persistence_diagram_distances(X, self.diagrams_, metric="pot_wasserstein", order=self.order, internal_p=self.internal_p, matching=False, n_jobs=self.n_jobs)
return Xfit

def __call__(self, diag1, diag2):
Expand All @@ -415,12 +411,14 @@ def __call__(self, diag1, diag2):
Returns:
float: Wasserstein distance.
"""
if self.metric == "hera_wasserstein":
if self.mode == "hera":
return hera_wasserstein_distance(diag1, diag2, order=self.order, internal_p=self.internal_p, delta=self.delta)
else:
elif self.mode == "pot":
try:
from gudhi.wasserstein import wasserstein_distance as pot_wasserstein_distance
return pot_wasserstein_distance(diag1, diag2, order=self.order, internal_p=self.internal_p, matching=False)
except ImportError:
print("POT (Python Optimal Transport) is not installed. Please install POT or use metric='wasserstein' or metric='hera_wasserstein'")
print("POT (Python Optimal Transport) is not installed. Please install POT or use mode='hera'")
raise
else:
raise NameError("Unknown mode. Current available values for mode are 'hera' and 'pot'")
30 changes: 22 additions & 8 deletions src/python/gudhi/representations/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@
# Utils #####################################
#############################################

def _maybe_fit_transform(obj, attr, diag):
"""
In __call__, use transform on the object itself if it has been fitted,
otherwise fit_transform on a clone of the object so it doesn't affect future calls.
"""
if hasattr(obj, attr):
result = obj.transform([diag])
else:
result = obj.__class__(**obj.get_params()).fit_transform([diag])
return result[0]

class Clamping(BaseEstimator, TransformerMixin):
"""
This is a class for clamping a list of values. It is not meant to be called directly on (a list of) persistence diagrams, but it is rather meant to be used as a parameter for the DiagramScaler class. As such it has the same methods and purpose as common scalers from sklearn.preprocessing such as MinMaxScaler, RobustScaler, StandardScaler, etc. A typical use would be for instance if you want to clamp abscissae or ordinates (or both) of persistence diagrams within a pre-defined interval.
Expand Down Expand Up @@ -108,7 +119,7 @@ def __call__(self, diag):
Returns:
n x 2 numpy array: transformed persistence diagram.
"""
return self.fit_transform([diag])[0]
return self.transform([diag])[0]

class DiagramScaler(BaseEstimator, TransformerMixin):
"""
Expand All @@ -133,6 +144,7 @@ def fit(self, X, y=None):
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
self.is_fitted_ = True
if self.use:
if len(X) == 1:
P = X[0]
Expand Down Expand Up @@ -164,14 +176,15 @@ def transform(self, X):
def __call__(self, diag):
"""
Apply DiagramScaler on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
n x 2 numpy array: transformed persistence diagram.
"""
return self.fit_transform([diag])[0]
return _maybe_fit_transform(self, 'is_fitted_', diag)

class Padding(BaseEstimator, TransformerMixin):
"""
Expand All @@ -188,13 +201,13 @@ def __init__(self, use=False):

def fit(self, X, y=None):
"""
Fit the Padding class on a list of persistence diagrams (this function actually does nothing but is useful when Padding is included in a scikit-learn Pipeline).
Fit the Padding class on a list of persistence diagrams.
Parameters:
X (list of n x 2 or n x 1 numpy arrays): input persistence diagrams.
y (n x 1 array): persistence diagram labels (unused).
"""
self.max_pts = max(len(diag) for diag in X)
self.max_pts_ = max(len(diag) for diag in X)
return self

def transform(self, X):
Expand All @@ -210,7 +223,7 @@ def transform(self, X):
if self.use:
Xfit, num_diag = [], len(X)
for diag in X:
diag_pad = np.pad(diag, ((0,max(0, self.max_pts - diag.shape[0])), (0,1)), "constant", constant_values=((0,0),(0,0)))
diag_pad = np.pad(diag, ((0,max(0, self.max_pts_ - diag.shape[0])), (0,1)), "constant", constant_values=((0,0),(0,0)))
diag_pad[:diag.shape[0],2] = np.ones(diag.shape[0])
Xfit.append(diag_pad)
else:
Expand All @@ -220,14 +233,15 @@ def transform(self, X):
def __call__(self, diag):
"""
Apply Padding on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
n x 2 numpy array: padded persistence diagram.
"""
return self.fit_transform([diag])[0]
return _maybe_fit_transform(self, 'max_pts_', diag)

class ProminentPoints(BaseEstimator, TransformerMixin):
"""
Expand Down Expand Up @@ -312,7 +326,7 @@ def __call__(self, diag):
Returns:
n x 2 numpy array: thresholded persistence diagram.
"""
return self.fit_transform([diag])[0]
return self.transform([diag])[0]

class DiagramSelector(BaseEstimator, TransformerMixin):
"""
Expand Down Expand Up @@ -369,7 +383,7 @@ def __call__(self, diag):
Returns:
n x 2 numpy array: extracted persistence diagram.
"""
return self.fit_transform([diag])[0]
return self.transform([diag])[0]


# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
Expand Down
52 changes: 29 additions & 23 deletions src/python/gudhi/representations/vector_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
# Will be removed in 1.3
from sklearn.neighbors import DistanceMetric

from .preprocessing import DiagramScaler, BirthPersistenceTransform
from .preprocessing import DiagramScaler, BirthPersistenceTransform, _maybe_fit_transform

#############################################
# Finite Vectorization methods ##############
Expand Down Expand Up @@ -53,14 +53,15 @@ def fit(self, X, y=None):
y (n x 1 array): persistence diagram labels (unused).
"""
if np.isnan(np.array(self.im_range)).any():
try:
if all(len(d) == 0 for d in X):
self.im_range_fixed_ = self.im_range
else:
new_X = BirthPersistenceTransform().fit_transform(X)
pre = DiagramScaler(use=True, scalers=[([0], MinMaxScaler()), ([1], MinMaxScaler())]).fit(new_X,y)
[mx,my],[Mx,My] = [pre.scalers[0][1].data_min_[0], pre.scalers[1][1].data_min_[0]], [pre.scalers[0][1].data_max_[0], pre.scalers[1][1].data_max_[0]]
self.im_range = np.where(np.isnan(np.array(self.im_range)), np.array([mx, Mx, my, My]), np.array(self.im_range))
except ValueError:
# Empty persistence diagram case - https://github.com/GUDHI/gudhi-devel/issues/507
pass
self.im_range_fixed_ = np.where(np.isnan(np.array(self.im_range)), np.array([mx, Mx, my, My]), np.array(self.im_range))
else:
self.im_range_fixed_ = self.im_range
return self

def transform(self, X):
Expand All @@ -84,7 +85,7 @@ def transform(self, X):
for j in range(num_pts_in_diag):
w[j] = self.weight(diagram[j,:])

x_values, y_values = np.linspace(self.im_range[0], self.im_range[1], self.resolution[0]), np.linspace(self.im_range[2], self.im_range[3], self.resolution[1])
x_values, y_values = np.linspace(self.im_range_fixed_[0], self.im_range_fixed_[1], self.resolution[0]), np.linspace(self.im_range_fixed_[2], self.im_range_fixed_[3], self.resolution[1])
Xs, Ys = np.tile((diagram[:,0][:,np.newaxis,np.newaxis]-x_values[np.newaxis,np.newaxis,:]),[1,self.resolution[1],1]), np.tile(diagram[:,1][:,np.newaxis,np.newaxis]-y_values[np.newaxis,:,np.newaxis],[1,1,self.resolution[0]])
image = np.tensordot(w, np.exp((-np.square(Xs)-np.square(Ys))/(2*np.square(self.bandwidth)))/(np.square(self.bandwidth)*2*np.pi), 1)

Expand All @@ -97,14 +98,15 @@ def transform(self, X):
def __call__(self, diag):
"""
Apply PersistenceImage on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
numpy array with shape (number of pixels = **resolution[0]** x **resolution[1]**):: output persistence image.
"""
return self.fit_transform([diag])[0,:]
return _maybe_fit_transform(self, 'im_range_fixed_', diag)

def _automatic_sample_range(sample_range, X):
"""
Expand Down Expand Up @@ -139,14 +141,14 @@ def _trim_endpoints(x, are_endpoints_nan):

def _grid_from_sample_range(self, X):
sample_range = np.array(self.sample_range)
self.nan_in_range = np.isnan(sample_range)
self.new_resolution = self.resolution
self.nan_in_range_ = np.isnan(sample_range)
self.new_resolution_ = self.resolution
if not self.keep_endpoints:
self.new_resolution += self.nan_in_range.sum()
self.sample_range_fixed = _automatic_sample_range(sample_range, X)
self.grid_ = np.linspace(self.sample_range_fixed[0], self.sample_range_fixed[1], self.new_resolution)
self.new_resolution_ += self.nan_in_range_.sum()
self.sample_range_fixed_ = _automatic_sample_range(sample_range, X)
self.grid_ = np.linspace(self.sample_range_fixed_[0], self.sample_range_fixed_[1], self.new_resolution_)
if not self.keep_endpoints:
self.grid_ = _trim_endpoints(self.grid_, self.nan_in_range)
self.grid_ = _trim_endpoints(self.grid_, self.nan_in_range_)


class Landscape(BaseEstimator, TransformerMixin):
Expand Down Expand Up @@ -214,14 +216,15 @@ def transform(self, X):
def __call__(self, diag):
"""
Apply Landscape on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
numpy array with shape (number of samples = **num_landscapes** x **resolution**): output persistence landscape.
"""
return self.fit_transform([diag])[0, :]
return _maybe_fit_transform(self, 'grid_', diag)

class Silhouette(BaseEstimator, TransformerMixin):
"""
Expand Down Expand Up @@ -281,14 +284,15 @@ def transform(self, X):
def __call__(self, diag):
"""
Apply Silhouette on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
numpy array with shape (**resolution**): output persistence silhouette.
"""
return self.fit_transform([diag])[0,:]
return _maybe_fit_transform(self, 'grid_', diag)


class BettiCurve(BaseEstimator, TransformerMixin):
Expand Down Expand Up @@ -445,8 +449,9 @@ def fit_transform(self, X):
def __call__(self, diag):
"""
Shorthand for transform on a single persistence diagram.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
"""
return self.fit_transform([diag])[0, :]
return _maybe_fit_transform(self, 'grid_', diag)



Expand Down Expand Up @@ -509,8 +514,8 @@ def transform(self, X):
ent = np.zeros(self.resolution)
for j in range(num_pts_in_diag):
[px,py] = orig_diagram[j,:2]
min_idx = np.clip(np.ceil((px - self.sample_range_fixed[0]) / self.step_).astype(int), 0, self.resolution)
max_idx = np.clip(np.ceil((py - self.sample_range_fixed[0]) / self.step_).astype(int), 0, self.resolution)
min_idx = np.clip(np.ceil((px - self.sample_range_fixed_[0]) / self.step_).astype(int), 0, self.resolution)
max_idx = np.clip(np.ceil((py - self.sample_range_fixed_[0]) / self.step_).astype(int), 0, self.resolution)
ent[min_idx:max_idx]-=p[j]*np.log(p[j])
if self.normalized:
ent = ent / np.linalg.norm(ent, ord=1)
Expand All @@ -522,14 +527,15 @@ def transform(self, X):
def __call__(self, diag):
"""
Apply Entropy on a single persistence diagram and outputs the result.
If :func:`fit` hasn't been run, this uses `fit_transform` on a clone of the object and thus does not affect later calls.
Parameters:
diag (n x 2 numpy array): input persistence diagram.
Returns:
numpy array with shape (1 if **mode** = "scalar" else **resolution**): output entropy.
"""
return self.fit_transform([diag])[0,:]
return _maybe_fit_transform(self, 'grid_', diag)

class TopologicalVector(BaseEstimator, TransformerMixin):
"""
Expand Down Expand Up @@ -600,7 +606,7 @@ def __call__(self, diag):
Returns:
numpy array with shape (**threshold**): output topological vector.
"""
return self.fit_transform([diag])[0,:]
return self.transform([diag])[0,:]

class ComplexPolynomial(BaseEstimator, TransformerMixin):
"""
Expand Down Expand Up @@ -672,7 +678,7 @@ def __call__(self, diag):
Returns:
numpy array with shape (**threshold**): output complex vector of coefficients.
"""
return self.fit_transform([diag])[0,:]
return self.transform([diag])[0,:]

def _lapl_contrast(measure, centers, inertias):
"""contrast function for vectorising `measure` in ATOL"""
Expand Down Expand Up @@ -785,7 +791,7 @@ def fit(self, X, y=None, sample_weight=None):

def __call__(self, measure, sample_weight=None):
"""
Apply measure vectorisation on a single measure.
Apply measure vectorisation on a single measure. Only available after `fit` has been called.
Parameters:
measure (n x d numpy array): input measure in R^d.
Expand Down
6 changes: 3 additions & 3 deletions src/python/test/test_representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,9 +299,9 @@ def test_landscape_numeric():
def test_landscape_nan_range():
dgm = np.array([[2., 6.], [3., 5.]])
lds = Landscape(num_landscapes=2, resolution=9, sample_range=[np.nan, 6.])
lds_dgm = lds(dgm)
assert (lds.sample_range_fixed[0] == 2) & (lds.sample_range_fixed[1] == 6)
assert lds.new_resolution == 10
lds_dgm = lds.fit([dgm])
assert lds.sample_range_fixed_[0] == 2 and lds.sample_range_fixed_[1] == 6
assert lds.new_resolution_ == 10

def test_endpoints():
diags = [ np.array([[2., 3.]]) ]
Expand Down

0 comments on commit 6dbf761

Please sign in to comment.