-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cech persistence sklearn #1126
Open
VincentRouvreau
wants to merge
9
commits into
GUDHI:master
Choose a base branch
from
VincentRouvreau:cech_persistence_sklearn
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Cech persistence sklearn #1126
Changes from 2 commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
06af78a
First Cech persistence version
VincentRouvreau e933a57
First version of the Cech persistence sklearn like interfaces
VincentRouvreau e9356cc
Merge branch 'master' into cech_persistence_sklearn
VincentRouvreau 322487c
doc review: review class explanation
VincentRouvreau 3861535
Remove filtration True/False mechanism (not sure it is needed as we c…
VincentRouvreau b0abb71
Merge branch 'master' into cech_persistence_sklearn
VincentRouvreau 151b898
Add Cech sklearn test and reduce to 500 the number of points
VincentRouvreau b82ffd9
code review: add some type hints
VincentRouvreau 0ced3d0
Merge branch 'master' into cech_persistence_sklearn
VincentRouvreau File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
:orphan: | ||
|
||
.. To get rid of WARNING: document isn't included in any toctree | ||
|
||
Čech complex persistence scikit-learn like interface | ||
#################################################### | ||
|
||
.. list-table:: | ||
:width: 100% | ||
:header-rows: 0 | ||
|
||
* - :Since: GUDHI 3.11.0 | ||
- :License: MIT | ||
- :Requires: `Scikit-learn <installation.html#scikit-learn>`_ | ||
|
||
Čech complex persistence scikit-learn like interface example | ||
------------------------------------------------------------ | ||
|
||
In this example, we build a dataset `X` composed of 1000 circles of radius randomly between 1.0 and 10.0. | ||
N points are subsampled randomly on each circle, where N randomly between 100 and 300 for each circle. | ||
In order to complicate things, some noise (+/- 5% of the radius value) to the point coordinates. | ||
|
||
The TDA scikit-learn pipeline is constructed and is composed of: | ||
|
||
#. :class:`~gudhi.sklearn.cech_persistence.CechPersistence` that builds a Čech complex from the inputs and | ||
returns its persistence diagrams in dimension 1. | ||
#. :class:`~gudhi.representations.vector_methods.PersistenceLengths` that returns here the biggest persistence bar in | ||
dimension 1. | ||
#. `LinearRegression <https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LinearRegression.html>`_ | ||
an ordinary least squares Linear Regression from scikit-learn. | ||
|
||
The model is trained with the squared radiuses of each circles and 75% of the dataset, and you can appreciate the | ||
regression line of the model when fitting on the other 25% of the dataset. | ||
|
||
.. literalinclude:: ../../python/example/cech_complex_sklearn_itf.py | ||
:lines: 1-15,22- | ||
:language: python | ||
|
||
.. figure:: ./img/cech_persistence_sklearn_itf.png | ||
:figclass: align-center | ||
|
||
Regression line of the model | ||
|
||
Čech complex persistence scikit-learn like interface reference | ||
----------------------------------------------------------------- | ||
|
||
.. autoclass:: gudhi.sklearn.cech_persistence.CechPersistence | ||
:members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,20 @@ | ||
.. table:: | ||
:widths: 30 40 30 | ||
|
||
+----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------+ | ||
| .. figure:: | Delaunay complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau | | ||
| ../../doc/Alpha_complex/alpha_complex_doc.png | cells of a Delaunay Triangulation. The Simplicial complex filtration | | | ||
| :alt: Delaunay complex representation | values can be computed with different filtrations (Delaunay complex, | :Since: GUDHI 2.0.0 | | ||
| :figclass: align-center | Delaunay Čech complex or Alpha complex) | | | ||
| | | :License: MIT (`GPL v3 </licensing/>`_) | | ||
| | When the simplicial complex filtration values are computed, it has the | | | ||
| | same persistent homology as the Čech complex, while being significantly | :Requires: `Eigen <installation.html#eigen>`_ and `CGAL <installation.html#cgal>`_ | | ||
| | smaller. | | | ||
+----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------+ | ||
| * :doc:`delaunay_complex_user` | * :doc:`delaunay_complex_ref` | | ||
+----------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ||
+-----------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------+ | ||
| .. figure:: | Delaunay complex is a simplicial complex constructed from the finite | :Author: Vincent Rouvreau | | ||
| ../../doc/Alpha_complex/alpha_complex_doc.png | cells of a Delaunay Triangulation. The Simplicial complex filtration | | | ||
| :alt: Delaunay complex representation | values can be computed with different filtrations (Delaunay complex, | :Since: GUDHI 2.0.0 | | ||
| :figclass: align-center | Delaunay Čech complex or Alpha complex) | | | ||
| | | :License: MIT (`GPL v3 </licensing/>`_) | | ||
| | When the simplicial complex filtration values are computed, it has the | | | ||
| | same persistent homology as the Čech complex, while being significantly | :Requires: `Eigen <installation.html#eigen>`_ and `CGAL <installation.html#cgal>`_ | | ||
| | smaller. | | | ||
+-----------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------+ | ||
| * :doc:`delaunay_complex_user` | * :doc:`delaunay_complex_ref` | | ||
+-----------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------+ | ||
| .. image:: | * :doc:`delaunay_complex_sklearn_itf_ref` | | ||
| img/sklearn.png | | | ||
| :target: installation.html#scikit-learn | | | ||
| :height: 30 | | | ||
+-----------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------+ |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Standard data science imports | ||
import numpy as np | ||
from sklearn.pipeline import Pipeline | ||
import matplotlib.pyplot as plt | ||
from sklearn.linear_model import LinearRegression | ||
from sklearn.model_selection import train_test_split | ||
|
||
# Import TDA pipeline requirements | ||
from gudhi.sklearn.cech_persistence import CechPersistence | ||
from gudhi.representations.vector_methods import PersistenceLengths | ||
# To build the dataset | ||
from gudhi.datasets.generators import points | ||
|
||
no_plot = False | ||
|
||
# Mainly for tests purpose - removed from documentation | ||
import argparse | ||
parser = argparse.ArgumentParser(description='Plot average landscapes') | ||
parser.add_argument("--no-plot", default=False, action="store_true") | ||
args = parser.parse_args() | ||
no_plot = args.no_plot | ||
|
||
# Build the dataset | ||
dataset_size = 1000 | ||
# Noise is expressed in percentage of radius - set it to 0. for no noise | ||
noise = 0.1 | ||
# target is a list of 1000 random circle radiuses (between 1. and 10.) | ||
target = 1. + 9. * np.random.rand(dataset_size,1) | ||
# use also a random number of points (between 100 and 300) | ||
nb_points = np.random.randint(100,high=300, size=dataset_size) | ||
X = [] | ||
|
||
for idx in range(dataset_size): | ||
pts = points.sphere(nb_points[idx], 2, radius=target[idx]) | ||
ns = noise * target[idx] * np.random.rand(nb_points[idx], 2) - (noise / 2.) | ||
X.append(pts + ns) | ||
|
||
# Cech filtration are squared radius, so transform targets with squared radius values for train/predict purposes | ||
target = target * target | ||
|
||
# Split the dataset for train/predict | ||
Xtrain, Xtest, ytrain, ytest = train_test_split(X, target, test_size = 0.25) | ||
|
||
pipe = Pipeline( | ||
[ | ||
("cech_pers", CechPersistence(homology_dimensions=1, n_jobs=-2)), | ||
("max_pers", PersistenceLengths(num_lengths=1)), | ||
("regression", LinearRegression()), | ||
] | ||
) | ||
|
||
model = pipe.fit(Xtrain, ytrain) | ||
|
||
# Let's see how our model is predicting squared radiuses from points on random circles | ||
predictions = model.predict(Xtest) | ||
_, ax = plt.subplots() | ||
ax.set_xlabel('target') | ||
ax.set_ylabel('prediction') | ||
ax.scatter(ytest, predictions) | ||
ax.set_aspect("equal") | ||
if no_plot == False: | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. | ||
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. | ||
# Author(s): Vincent Rouvreau | ||
# | ||
# Copyright (C) 2024 Inria | ||
# | ||
# Modification(s): | ||
# - YYYY/MM Author: Description of the modification | ||
|
||
from .. import DelaunayComplex | ||
from sklearn.base import BaseEstimator, TransformerMixin | ||
|
||
# joblib is required by scikit-learn | ||
from joblib import Parallel, delayed | ||
|
||
# Mermaid sequence diagram - https://mermaid-js.github.io/mermaid-live-editor/ | ||
# sequenceDiagram | ||
# participant USER | ||
# participant R as CechPersistence | ||
# USER->>R: fit_transform(X) | ||
# Note right of R: homology_dimensions=[i,j] | ||
# R->>thread1: _tranform(X[0]) | ||
# R->>thread2: _tranform(X[1]) | ||
# Note right of R: ... | ||
# thread1->>R: [array( Hi(X[0]) ), array( Hj(X[0]) )] | ||
# thread2->>R: [array( Hi(X[1]) ), array( Hj(X[1]) )] | ||
# Note right of R: ... | ||
# R->>USER: [[array( Hi(X[0]) ), array( Hj(X[0]) )],<br/> [array( Hi(X[1]) ), array( Hj(X[1]) )],<br/>...] | ||
|
||
|
||
class CechPersistence(BaseEstimator, TransformerMixin): | ||
""" | ||
This is a class for constructing Čech complexes and computing the persistence diagrams from them. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
homology_dimensions, | ||
input_type="point cloud", | ||
filtration=True, | ||
precision="safe", | ||
max_alpha_square=float("inf"), | ||
homology_coeff_field=11, | ||
n_jobs=None, | ||
): | ||
""" | ||
Constructor for the CechPersistence class. | ||
|
||
Parameters: | ||
homology_dimensions (int or list of int): The returned persistence diagrams dimension(s). | ||
Short circuit the use of :class:`~gudhi.representations.preprocessing.DimensionSelector` when only one | ||
dimension matters (in other words, when `homology_dimensions` is an int). | ||
input_type (str): Can be 'point cloud' when inputs are point clouds, or 'weighted point cloud', when | ||
inputs are point clouds plus the weight (as the last column value). Default is 'point cloud'. | ||
filtration (bool): Can be True, the filtration value of each simplex is computed, or False, the filtration | ||
value of each simplex is not computed (set to NaN). Default is True. | ||
precision (str): Delaunay complex precision can be 'fast', 'safe' or 'exact'. Default is 'safe'. | ||
max_alpha_square (float): The maximum alpha square threshold the simplices shall not exceed. Default is set | ||
to infinity, and there is very little point using anything else since it does not save time. | ||
homology_coeff_field (int): The homology coefficient field. Must be a prime number. Default value is 11. | ||
n_jobs (int): Number of jobs to run in parallel. `None` (default value) means `n_jobs = 1` unless in a | ||
joblib.parallel_backend context. `-1` means using all processors. cf. | ||
https://joblib.readthedocs.io/en/latest/generated/joblib.Parallel.html for more details. | ||
""" | ||
self.homology_dimensions = homology_dimensions | ||
self.input_type = input_type | ||
self.filtration = filtration | ||
self.precision = precision | ||
self.max_alpha_square = max_alpha_square | ||
self.homology_coeff_field = homology_coeff_field | ||
self.n_jobs = n_jobs | ||
if self.input_type not in ["point cloud", "weighted point cloud"]: | ||
raise ValueError("Unknown input type") | ||
if self.precision not in ["safe", "fast", "exact"]: | ||
raise ValueError("Unknown precision") | ||
|
||
def fit(self, X, Y=None): | ||
""" | ||
Nothing to be done, but useful when included in a scikit-learn Pipeline. | ||
""" | ||
return self | ||
|
||
def __transform(self, inputs): | ||
max_dimension = max(self.dim_list_) + 1 | ||
|
||
# Default filtration value | ||
fltr = None | ||
if self.input_type == "point cloud": | ||
pts = inputs | ||
cech = DelaunayComplex(points=pts, precision=self.precision) | ||
if self.filtration: | ||
fltr = "cech" | ||
|
||
elif self.input_type == "weighted point cloud": | ||
wgts = inputs[:, -1] | ||
pts = inputs[:, :-1] | ||
cech = DelaunayComplex(points=pts, weights=wgts, precision=self.precision) | ||
if self.filtration: | ||
fltr = "alpha" | ||
|
||
stree = cech.create_simplex_tree(max_alpha_square=self.max_alpha_square, filtration=fltr) | ||
|
||
persistence_dim_max = False | ||
# Specific case where, despite expansion(max_dimension), stree has a lower dimension | ||
if max_dimension > stree.dimension(): | ||
persistence_dim_max = True | ||
|
||
stree.compute_persistence( | ||
homology_coeff_field=self.homology_coeff_field, persistence_dim_max=persistence_dim_max | ||
) | ||
|
||
return [stree.persistence_intervals_in_dimension(dim) for dim in self.dim_list_] | ||
|
||
def transform(self, X, Y=None): | ||
"""Compute all the Čech complexes and their associated persistence diagrams. | ||
|
||
:param X: list of point clouds as Euclidean coordinates, plus the weight (as the last column value) if | ||
:paramref:`~gudhi.sklearn.cech_persistence.CechPersistence.input_type` was set to 'weighted point cloud'. | ||
:type X: list of list of float OR list of numpy.ndarray | ||
|
||
:return: Persistence diagrams in the format: | ||
|
||
- If `homology_dimensions` was set to `n`: `[array( Hn(X[0]) ), array( Hn(X[1]) ), ...]` | ||
- If `homology_dimensions` was set to `[i, j]`: | ||
`[[array( Hi(X[0]) ), array( Hj(X[0]) )], [array( Hi(X[1]) ), array( Hj(X[1]) )], ...]` | ||
:rtype: list of numpy ndarray of shape (,2) or list of list of numpy ndarray of shape (,2) | ||
""" | ||
# Depends on homology_dimensions is an integer or a list of integer (else case) | ||
if isinstance(self.homology_dimensions, int): | ||
unwrap = True | ||
self.dim_list_ = [self.homology_dimensions] | ||
else: | ||
unwrap = False | ||
self.dim_list_ = self.homology_dimensions | ||
|
||
# threads is preferred as Rips construction and persistence computation releases the GIL | ||
res = Parallel(n_jobs=self.n_jobs, prefer="threads")(delayed(self.__transform)(inputs) for inputs in X) | ||
|
||
if unwrap: | ||
res = [d[0] for d in res] | ||
return res | ||
|
||
def get_feature_names_out(self): | ||
"""Provide column names for implementing sklearn's set_output API.""" | ||
return [f"H{i}" for i in self.dim_list_] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this sentence is very misleading: no Cech complex is built in the process, that's the point of the alpha and delaunay-cech, they build smaller filtrations that still have the same diagram.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reworked class explanation on 322487c