Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cech persistence sklearn #1126

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
49 changes: 49 additions & 0 deletions src/python/doc/delaunay_complex_sklearn_itf_ref.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
:orphan:

.. To get rid of WARNING: document isn't included in any toctree

Čech complex persistence scikit-learn like interface
####################################################

.. list-table::
:width: 100%
:header-rows: 0

* - :Since: GUDHI 3.11.0
- :License: MIT
- :Requires: `Scikit-learn <installation.html#scikit-learn>`_

Čech complex persistence scikit-learn like interface example
------------------------------------------------------------

In this example, we build a dataset `X` composed of 1000 circles of radius randomly between 1.0 and 10.0.
N points are subsampled randomly on each circle, where N randomly between 100 and 300 for each circle.
In order to complicate things, some noise (+/- 5% of the radius value) to the point coordinates.

The TDA scikit-learn pipeline is constructed and is composed of:

#. :class:`~gudhi.sklearn.cech_persistence.CechPersistence` that builds a Čech complex from the inputs and
returns its persistence diagrams in dimension 1.
#. :class:`~gudhi.representations.vector_methods.PersistenceLengths` that returns here the biggest persistence bar in
dimension 1.
#. `LinearRegression <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html>`_
an ordinary least squares Linear Regression from scikit-learn.

The model is trained with the squared radiuses of each circles and 75% of the dataset, and you can appreciate the
regression line of the model when fitting on the other 25% of the dataset.

.. literalinclude:: ../../python/example/cech_complex_sklearn_itf.py
:lines: 1-15,22-
:language: python

.. figure:: ./img/cech_persistence_sklearn_itf.png
:figclass: align-center

Regression line of the model

Čech complex persistence scikit-learn like interface reference
-----------------------------------------------------------------

.. autoclass:: gudhi.sklearn.cech_persistence.CechPersistence
:members:
:show-inheritance:
29 changes: 17 additions & 12 deletions src/python/doc/delaunay_complex_sum.inc
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
.. table::
:widths: 30 40 30

+----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------+
| .. figure:: | Delaunay complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau |
| ../../doc/Alpha_complex/alpha_complex_doc.png | cells of a Delaunay Triangulation. The Simplicial complex filtration | |
| :alt: Delaunay complex representation | values can be computed with different filtrations (Delaunay complex, | :Since: GUDHI 2.0.0 |
| :figclass: align-center | Delaunay Čech complex or Alpha complex) | |
| | | :License: MIT (`GPL v3 </licensing/>`_) |
| | When the simplicial complex filtration values are computed, it has the | |
| | same persistent homology as the Čech complex, while being significantly | :Requires: `Eigen <installation.html#eigen>`_ and `CGAL <installation.html#cgal>`_ |
| | smaller. | |
+----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------+
| * :doc:`delaunay_complex_user` | * :doc:`delaunay_complex_ref` |
+----------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------+
+-----------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------+
| .. figure:: | Delaunay complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau |
| ../../doc/Alpha_complex/alpha_complex_doc.png | cells of a Delaunay Triangulation. The Simplicial complex filtration | |
| :alt: Delaunay complex representation | values can be computed with different filtrations (Delaunay complex, | :Since: GUDHI 2.0.0 |
| :figclass: align-center | Delaunay Čech complex or Alpha complex) | |
| | | :License: MIT (`GPL v3 </licensing/>`_) |
| | When the simplicial complex filtration values are computed, it has the | |
| | same persistent homology as the Čech complex, while being significantly | :Requires: `Eigen <installation.html#eigen>`_ and `CGAL <installation.html#cgal>`_ |
| | smaller. | |
+-----------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------+
| * :doc:`delaunay_complex_user` | * :doc:`delaunay_complex_ref` |
+-----------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------+
| .. image:: | * :doc:`delaunay_complex_sklearn_itf_ref` |
| img/sklearn.png | |
| :target: installation.html#scikit-learn | |
| :height: 30 | |
+-----------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------+
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
62 changes: 62 additions & 0 deletions src/python/example/cech_complex_sklearn_itf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Standard data science imports
import numpy as np
from sklearn.pipeline import Pipeline
import matplotlib.pyplot as plt
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split

# Import TDA pipeline requirements
from gudhi.sklearn.cech_persistence import CechPersistence
from gudhi.representations.vector_methods import PersistenceLengths
# To build the dataset
from gudhi.datasets.generators import points

no_plot = False

# Mainly for tests purpose - removed from documentation
import argparse
parser = argparse.ArgumentParser(description='Plot average landscapes')
parser.add_argument("--no-plot", default=False, action="store_true")
args = parser.parse_args()
no_plot = args.no_plot

# Build the dataset
dataset_size = 1000
# Noise is expressed in percentage of radius - set it to 0. for no noise
noise = 0.1
# target is a list of 1000 random circle radiuses (between 1. and 10.)
target = 1. + 9. * np.random.rand(dataset_size,1)
# use also a random number of points (between 100 and 300)
nb_points = np.random.randint(100,high=300, size=dataset_size)
X = []

for idx in range(dataset_size):
pts = points.sphere(nb_points[idx], 2, radius=target[idx])
ns = noise * target[idx] * np.random.rand(nb_points[idx], 2) - (noise / 2.)
X.append(pts + ns)

# Cech filtration are squared radius, so transform targets with squared radius values for train/predict purposes
target = target * target

# Split the dataset for train/predict
Xtrain, Xtest, ytrain, ytest = train_test_split(X, target, test_size = 0.25)

pipe = Pipeline(
[
("cech_pers", CechPersistence(homology_dimensions=1, n_jobs=-2)),
("max_pers", PersistenceLengths(num_lengths=1)),
("regression", LinearRegression()),
]
)

model = pipe.fit(Xtrain, ytrain)

# Let's see how our model is predicting squared radiuses from points on random circles
predictions = model.predict(Xtest)
_, ax = plt.subplots()
ax.set_xlabel('target')
ax.set_ylabel('prediction')
ax.scatter(ytest, predictions)
ax.set_aspect("equal")
if no_plot == False:
plt.show()
145 changes: 145 additions & 0 deletions src/python/gudhi/sklearn/cech_persistence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s): Vincent Rouvreau
#
# Copyright (C) 2024 Inria
#
# Modification(s):
# - YYYY/MM Author: Description of the modification

from .. import DelaunayComplex
from sklearn.base import BaseEstimator, TransformerMixin

# joblib is required by scikit-learn
from joblib import Parallel, delayed

# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/
# sequenceDiagram
# participant USER
# participant R as CechPersistence
# USER->>R: fit_transform(X)
# Note right of R: homology_dimensions=[i,j]
# R->>thread1: _tranform(X[0])
# R->>thread2: _tranform(X[1])
# Note right of R: ...
# thread1->>R: [array( Hi(X[0]) ), array( Hj(X[0]) )]
# thread2->>R: [array( Hi(X[1]) ), array( Hj(X[1]) )]
# Note right of R: ...
# R->>USER: [[array( Hi(X[0]) ), array( Hj(X[0]) )],<br/> [array( Hi(X[1]) ), array( Hj(X[1]) )],<br/>...]


class CechPersistence(BaseEstimator, TransformerMixin):
"""
This is a class for constructing Čech complexes and computing the persistence diagrams from them.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this sentence is very misleading: no Cech complex is built in the process, that's the point of the alpha and delaunay-cech, they build smaller filtrations that still have the same diagram.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I reworked class explanation on 322487c

"""

def __init__(
self,
homology_dimensions,
input_type="point cloud",
filtration=True,
precision="safe",
max_alpha_square=float("inf"),
homology_coeff_field=11,
n_jobs=None,
):
"""
Constructor for the CechPersistence class.

Parameters:
homology_dimensions (int or list of int): The returned persistence diagrams dimension(s).
Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one
dimension matters (in other words, when `homology_dimensions` is an int).
input_type (str): Can be 'point cloud' when inputs are point clouds, or 'weighted point cloud', when
inputs are point clouds plus the weight (as the last column value). Default is 'point cloud'.
filtration (bool): Can be True, the filtration value of each simplex is computed, or False, the filtration
value of each simplex is not computed (set to NaN). Default is True.
precision (str): Delaunay complex precision can be 'fast', 'safe' or 'exact'. Default is 'safe'.
max_alpha_square (float): The maximum alpha square threshold the simplices shall not exceed. Default is set
to infinity, and there is very little point using anything else since it does not save time.
homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11.
n_jobs (int): Number of jobs to run in parallel. `None` (default value) means `n_jobs = 1` unless in a
joblib.parallel_backend context. `-1` means using all processors. cf.
https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html for more details.
"""
self.homology_dimensions = homology_dimensions
self.input_type = input_type
self.filtration = filtration
self.precision = precision
self.max_alpha_square = max_alpha_square
self.homology_coeff_field = homology_coeff_field
self.n_jobs = n_jobs
if self.input_type not in ["point cloud", "weighted point cloud"]:
raise ValueError("Unknown input type")
if self.precision not in ["safe", "fast", "exact"]:
raise ValueError("Unknown precision")

def fit(self, X, Y=None):
"""
Nothing to be done, but useful when included in a scikit-learn Pipeline.
"""
return self

def __transform(self, inputs):
max_dimension = max(self.dim_list_) + 1

# Default filtration value
fltr = None
if self.input_type == "point cloud":
pts = inputs
cech = DelaunayComplex(points=pts, precision=self.precision)
if self.filtration:
fltr = "cech"

elif self.input_type == "weighted point cloud":
wgts = inputs[:, -1]
pts = inputs[:, :-1]
cech = DelaunayComplex(points=pts, weights=wgts, precision=self.precision)
if self.filtration:
fltr = "alpha"

stree = cech.create_simplex_tree(max_alpha_square=self.max_alpha_square, filtration=fltr)

persistence_dim_max = False
# Specific case where, despite expansion(max_dimension), stree has a lower dimension
if max_dimension > stree.dimension():
persistence_dim_max = True

stree.compute_persistence(
homology_coeff_field=self.homology_coeff_field, persistence_dim_max=persistence_dim_max
)

return [stree.persistence_intervals_in_dimension(dim) for dim in self.dim_list_]

def transform(self, X, Y=None):
"""Compute all the Čech complexes and their associated persistence diagrams.

:param X: list of point clouds as Euclidean coordinates, plus the weight (as the last column value) if
:paramref:`~gudhi.sklearn.cech_persistence.CechPersistence.input_type` was set to 'weighted point cloud'.
:type X: list of list of float OR list of numpy.ndarray

:return: Persistence diagrams in the format:

- If `homology_dimensions` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]`
- If `homology_dimensions` was set to `[i, j]`:
`[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]`
:rtype: list of numpy ndarray of shape (,2) or list of list of numpy ndarray of shape (,2)
"""
# Depends on homology_dimensions is an integer or a list of integer (else case)
if isinstance(self.homology_dimensions, int):
unwrap = True
self.dim_list_ = [self.homology_dimensions]
else:
unwrap = False
self.dim_list_ = self.homology_dimensions

# threads is preferred as Rips construction and persistence computation releases the GIL
res = Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(inputs) for inputs in X)

if unwrap:
res = [d[0] for d in res]
return res

def get_feature_names_out(self):
"""Provide column names for implementing sklearn's set_output API."""
return [f"H{i}" for i in self.dim_list_]
Loading