diff --git a/.gitignore b/.gitignore index 6aab7337ed..0e74960540 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,12 @@ data/points/human.off_sc.txt .idea/ cmake-build-debug/ +# Poetry files +*.lock +*.toml +./venv + +# Ripsnet files +ripsnet_doc_test.py +src/python/gudhi/tensorflow/tutorial.ipynb + diff --git a/biblio/bibliography.bib b/biblio/bibliography.bib index b5afff5202..20985daf07 100644 --- a/biblio/bibliography.bib +++ b/biblio/bibliography.bib @@ -1,3 +1,38 @@ +@article{RipsNet_arXiv, + author = {Thibault de Surrel and + Felix Hensel and + Mathieu Carri{\`{e}}re and + Th{\'{e}}o Lacombe and + Yuichi Ike and + Hiroaki Kurihara and + Marc Glisse and + Fr{\'{e}}d{\'{e}}ric Chazal}, + title = {RipsNet: a general architecture for fast and robust estimation of + the persistent homology of point clouds}, + journal = {CoRR}, + volume = {abs/2202.01725}, + year = {2022}, + url = {https://arxiv.org/abs/2202.01725}, + eprinttype = {arXiv}, + eprint = {2202.01725}, + timestamp = {Wed, 09 Feb 2022 15:43:35 +0100}, + biburl = {https://dblp.org/rec/journals/corr/abs-2202-01725.bib}, + bibsource = {dblp computer science bibliography, https://dblp.org} +} + +@inproceedings{DeepSets17, + author = {Zaheer, Manzil and Kottur, Satwik and Ravanbakhsh, Siamak and Poczos, Barnabas and Salakhutdinov, Russ R and Smola, Alexander J}, + booktitle = {Advances in Neural Information Processing Systems}, + editor = {I. Guyon and U. V. Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett}, + pages = {}, + publisher = {Curran Associates, Inc.}, + title = {Deep Sets}, + url = {https://proceedings.neurips.cc/paper/2017/file/f22e4747da1aa27e363d86d40ff442fe-Paper.pdf}, + volume = {30}, + year = {2017} +} + + @inproceedings{gudhilibrary_ICMS14, author = {Cl\'ement Maria and Jean-Daniel Boissonnat and Marc Glisse and Mariette Yvinec}, diff --git a/src/python/CMakeLists.txt b/src/python/CMakeLists.txt index 8eb7478ecc..cc9aaaae9f 100644 --- a/src/python/CMakeLists.txt +++ b/src/python/CMakeLists.txt @@ -70,6 +70,7 @@ if(PYTHONINTERP_FOUND) set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'euclidean_strong_witness_complex', ") # Modules that should not be auto-imported in __init__.py set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ") + set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'tensorflow', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'point_cloud', ") set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'weighted_rips_complex', ") @@ -285,6 +286,7 @@ if(PYTHONINTERP_FOUND) file(COPY "gudhi/dtm_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") file(COPY "gudhi/hera/__init__.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/hera") file(COPY "gudhi/datasets" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py") + file(COPY "gudhi/tensorflow" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi") # Some files for pip package @@ -551,6 +553,11 @@ if(PYTHONINTERP_FOUND) add_gudhi_py_test(test_representations) endif() + # Differentiation + if(TENSORFLOW_FOUND) + add_gudhi_py_test(test_ripsnet) + endif() + # Betti curves if(SKLEARN_FOUND AND SCIPY_FOUND) add_gudhi_py_test(test_betti_curve_representations) diff --git a/src/python/doc/installation.rst b/src/python/doc/installation.rst index 35c344e3b5..fee74fd531 100644 --- a/src/python/doc/installation.rst +++ b/src/python/doc/installation.rst @@ -395,6 +395,8 @@ TensorFlow `TensorFlow `_ is currently only used in some automatic differentiation tests. +:doc:`RipsNet ` module requires `TensorFlow `_. + Bug reports and contributions ***************************** diff --git a/src/python/doc/ripsnet.inc b/src/python/doc/ripsnet.inc new file mode 100644 index 0000000000..c2342749fe --- /dev/null +++ b/src/python/doc/ripsnet.inc @@ -0,0 +1,15 @@ +.. table:: + :widths: 30 40 30 + + +----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | | RipsNet is a general architecture for fast and robust estimation of the | :Author: Felix Hensel, Mathieu Carrière | + | | persistent homology of point clouds. | | + | | | :Since: GUDHI | + | | | | + | | | :License: MIT | + | | | | + | | | :Requires: `TensorFlow `_ | + | | | | + +----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+ + | * :doc:`ripsnet` | | + +----------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ diff --git a/src/python/doc/ripsnet.rst b/src/python/doc/ripsnet.rst new file mode 100644 index 0000000000..c8f7ca307f --- /dev/null +++ b/src/python/doc/ripsnet.rst @@ -0,0 +1,102 @@ +:orphan: + +.. To get rid of WARNING: document isn't included in any toctree + +RipsNet user manual +========================= +Definition +---------- + +.. include:: ripsnet.inc + +:class:`~gudhi.ripsnet` constructs a Tensorflow model for fast and robust estimation of persistent homology of +point clouds. +RipsNet is based on a Deep Sets architecture :cite:`DeepSets17`, for details see the paper RipsNet :cite:`RipsNet_arXiv`. + +Example +------------------- + +This example instantiates a RipsNet model which can then be trained as any tensorflow model. + +.. testcode:: + + import gudhi.tensorflow as gtf + import tensorflow as tf + from tensorflow.keras import regularizers, layers + + ragged_layers_size = [20, 10] + dense_layers_size = [10, 20] + output_units = 25 + activation_fct = 'gelu' + output_activation = 'sigmoid' + dropout = 0 + kernel_regularization = 0 + + ragged_layers = [] + dense_layers = [] + + for n_units in ragged_layers_size: + ragged_layers.append(gtf.DenseRagged(units=n_units, use_bias=True, activation=activation_fct)) + + for n_units in dense_layers_size: + dense_layers.append(layers.Dense(n_units, activation=activation_fct, + kernel_regularizer=regularizers.l2(kernel_regularization))) + dense_layers.append(layers.Dropout(dropout)) + + dense_layers.append(layers.Dense(output_units, activation=output_activation)) + + phi_1 = gtf.TFBlock(ragged_layers) + perm_op = 'mean' # can also be 'sum' (or a user specified function). + phi_2 = gtf.TFBlock(dense_layers) + input_dim = 2 + + RN = gtf.RipsNet(phi_1, phi_2, input_dim, perm_op=perm_op) + + data_test = [[[-7.04493841, 9.60285858], + [-13.14389003, -13.21854157], + [-3.21137961, -1.28593644]], + [[10.40324933, -0.80540584], + [16.54752459, 0.70355361], + [6.410207, -10.63175183], + [2.96613799, -11.97463568]], + [[4.85041719, -2.93820024], + [2.15379915, -5.39669696], + [5.83968556, -5.67350982], + [5.25955172, -6.36860269]]] + + tf_data_test = tf.ragged.constant(data_test, ragged_rank=1) + + RN.predict(tf_data_test) + +Once RN is properly trained (which we skip in this documentation) it can be used to make predictions. +In this example RipsNet estimates persistence vectorizations (of output size 25) of a list of 3 point clouds +of 3 points each) in 2D. +It yields an output of shape 'nb_input_pointclouds x output_units'. +The 'ragged_layers_size' and 'dense_layers_size' define the architecture of the network. +To reach best performance, they should be tuned depending on the dataset. +A possible output is: + +.. code-block:: + + [[0.58554363 0.6054868 0.44672886 0.5216672 0.5814481 0.48068565 + 0.49626726 0.5285395 0.4805212 0.37918684 0.49745193 0.49247316 + 0.4706078 0.5491477 0.47016636 0.55804974 0.46501246 0.4065692 + 0.5386659 0.5660226 0.52014357 0.5329493 0.52178216 0.5156043 + 0.48742113] + [0.9446074 0.99024785 0.1316272 0.3013248 0.98174655 0.52285945 + 0.33727515 0.997285 0.3711884 0.00388432 0.63181967 0.5377489 + 0.22074646 0.7681194 0.04337704 0.80116796 0.02139336 0.04605395 + 0.8911999 0.9570045 0.5789719 0.8221929 0.7742506 0.4596561 + 0.08529088] + [0.8230771 0.9320036 0.25120026 0.48027694 0.8988322 0.5789062 + 0.38307947 0.9252455 0.39485127 0.06090912 0.5786307 0.51115406 + 0.28706372 0.70552015 0.16929033 0.7028084 0.12379596 0.1867683 + 0.6969584 0.84437454 0.6172329 0.66728634 0.630455 0.47643042 + 0.27172992]] + +Detailed documentation +---------------------- +.. automodule:: gudhi.tensorflow.ripsnet + :members: + :special-members: + :show-inheritance: diff --git a/src/python/gudhi/tensorflow/__init__.py b/src/python/gudhi/tensorflow/__init__.py new file mode 100644 index 0000000000..7f806120af --- /dev/null +++ b/src/python/gudhi/tensorflow/__init__.py @@ -0,0 +1,3 @@ +from .ripsnet import * + +__all__ = ["RipsNet", "PermopRagged", "TFBlock", "DenseRagged"] \ No newline at end of file diff --git a/src/python/gudhi/tensorflow/ripsnet.py b/src/python/gudhi/tensorflow/ripsnet.py new file mode 100644 index 0000000000..8cb3e2df24 --- /dev/null +++ b/src/python/gudhi/tensorflow/ripsnet.py @@ -0,0 +1,263 @@ +# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. +# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. +# Author(s): Felix Hensel, Mathieu Carrière +# +# Copyright (C) 2022 Inria +# +# Modification(s): +# - YYYY/MM Author: Description of the modification + +import tensorflow as tf + +class DenseRagged(tf.keras.layers.Layer): + """ + This is a class for the ragged layer in the RipsNet architecture, processing the input pointclouds. + """ + + def __init__(self, units, input_dim=None, use_bias=True, activation='gelu', kernel_initializer=None, bias_initializer=None, **kwargs): + """ + Constructor for the DenseRagged class. + + Parameters: + units (int): number of units in the layer. + use_bias (bool): flag, indicating whether to use bias or not. + activation (string or function): identifier of a keras activation function, e.g. 'relu'. + kernel_initializer: tensorflow kernel initializer. + bias_initializer: tensorflow bias initializer. + """ + super().__init__(dynamic=True, **kwargs) + self._supports_ragged_inputs = True + self.units = units + self.use_bias = use_bias + self.activation = tf.keras.activations.get(activation) + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + + def get_config(self): + config = super().get_config().copy() + config.update( + { + 'units': self.units, + 'use_bias': self.use_bias, + 'activation': self.activation, + '_supports_ragged_inputs': self._supports_ragged_inputs, + 'kernel_initializer': self.kernel_initializer, + 'bias_initializer': self.bias_initializer, + } + ) + return config + + def build(self, input_shape): + last_dim = input_shape[-1] + self.kernel = self.add_weight('kernel', shape=[last_dim, self.units], trainable=True, initializer=self.kernel_initializer) + if self.use_bias: + self.bias = self.add_weight('bias', shape=[self.units, ], trainable=True, initializer=self.bias_initializer) + else: + self.bias = None + super().build(input_shape) + + def call(self, inputs): + """ + Apply DenseRagged layer on a ragged input tensor. + + Parameters: + ragged tensor (e.g. containing a point cloud). + + Returns: + ragged tensor containing the output of the layer. + """ + outputs = tf.ragged.map_flat_values(tf.matmul, inputs, self.kernel) + if self.use_bias: + outputs = tf.ragged.map_flat_values(tf.nn.bias_add, outputs, self.bias) + outputs = tf.ragged.map_flat_values(self.activation, outputs) + return outputs + + +# class DenseRaggedBlock(tf.keras.layers.Layer): +# """ +# This is a block of DenseRagged layers. +# """ +# +# def __init__(self, dense_ragged_layers, **kwargs): +# """ +# Constructor for the DenseRaggedBlock class. +# +# Parameters: +# dense_ragged_layers (list): a list of DenseRagged layers :class:`~gudhi.tensorflow.DenseRagged`. +# input_dim (int): dimension of the pointcloud, if the input consists of pointclouds. +# """ +# super().__init__(dynamic=True, **kwargs) +# self._supports_ragged_inputs = True +# self.dr_layers = dense_ragged_layers +# +# def build(self, input_shape): +# return self +# +# def call(self, inputs): +# """ +# Apply the sequence of DenseRagged layers on a ragged input tensor. +# +# Parameters: +# ragged tensor (e.g. containing a point cloud). +# +# Returns: +# ragged tensor containing the output of the sequence of layers. +# """ +# outputs = inputs +# for dr_layer in self.dr_layers: +# outputs = dr_layer(outputs) +# return outputs + + +class TFBlock(tf.keras.layers.Layer): + """ + This class is a block of tensorflow layers. + If the first layer is an instance of DenseRagged, it will automatically support ragged inputs. + + Parameters: + layers (list): a list of either tensorflow layers or DenseRagged layers :class:`~gudhi.tensorflow.DenseRagged`. + input_dim (int): dimension of the point cloud, if the input consists of point clouds. + """ + + def __init__(self, layers, **kwargs): + """ + Constructor for the TFBlock class. + + Parameters: + dense_layers (list): a list of dense tensorflow layers. + """ + super().__init__(dynamic=True, **kwargs) + self.layers = layers + if isinstance(layers[0], DenseRagged): + self._supports_ragged_inputs = True + + def get_config(self): + config = super().get_config().copy() + config.update( + { + 'layers': self.layers, + '_supports_ragged_inputs': self._supports_ragged_inputs, + } + ) + return config + + def build(self, input_shape): + # super().build(input_shape) + return self + + def call(self, inputs): + """ + Apply the sequence of layers on an input tensor. + + Parameters: + inputs: any input tensor. + + Returns: + output tensor containing the output of the sequence of layers. + """ + outputs = inputs + for layer in self.layers: + outputs = layer(outputs) + return outputs + + +class PermopRagged(tf.keras.layers.Layer): + """ + This is a class for the permutation invariant layer in the RipsNet architecture. + """ + + def __init__(self, perm_op, **kwargs): + """ + Constructor for the PermopRagged class. + + Parameters: + perm_op: permutation invariant function, such as `tf.math.reduce_sum`, `tf.math.reduce_mean`. + """ + super().__init__(dynamic=True, **kwargs) + self._supports_ragged_inputs = True + self.perm_op = perm_op + + def get_config(self): + config = super().get_config().copy() + config.update( + { + 'perm_op': self.perm_op, + '_supports_ragged_inputs': self._supports_ragged_inputs, + } + ) + return config + + def build(self, input_shape): + super().build(input_shape) + + def call(self, inputs): + """ + Apply PermopRagged on an input tensor. + """ + out = self.perm_op(inputs, axis=1) + return out + + +class RipsNet(tf.keras.Model): + """ + This is a TensorFlow model for estimating vectorizations of persistence diagrams of point clouds. + This class implements the RipsNet described in the following article . + """ + + def __init__(self, phi_1, phi_2, input_dim, perm_op='mean', **kwargs): + """ + Constructor for the RipsNet class. + + Parameters: + phi_1 (layers): any block of DenseRagged layers. Can be a custom block built from :class:`~gudhi.tensorflow.DenseRagged` layers. + phi_2 (layers): Can be any (block of) TensorFlow layer(s), e.g. :class:`~gudhi.tensorflow.TFBlock`. + input_dim (int): dimension of the input point clouds. + perm_op (str or function): Permutation invariant operation. + Can be 'mean' or 'sum', or any user defined (permutation invariant) function. + """ + super().__init__(dynamic=True, **kwargs) + self.phi_1 = phi_1 + self.perm_op = perm_op + self.phi_2 = phi_2 + self.input_dim = input_dim + + # if perm_op not in ['mean', 'sum']: + # raise ValueError(f'Permutation invariant operation: {self.perm_op} is not allowed, must be "mean" or "sum".') + def get_config(self): + config = super().get_config().copy() + config.update( + { + 'phi_1': self.phi_1, + 'phi_2': self.phi_2, + 'perm_op': self.perm_op, + 'input_dim': self.input_dim, + } + ) + return config + + def build(self, input_shape): + return self + + def call(self, pointclouds): + """ + Apply RipsNet on a ragged tensor containing a list of pointclouds. + + Parameters: + point clouds (n x None x input_dimension): ragged tensor containing n pointclouds in dimension `input_dimension`. The second dimension is ragged since point clouds can have different numbers of points. + + Returns: + output (n x output_shape): tensor containing predicted vectorizations of the persistence diagrams of pointclouds. + """ + if self.perm_op == 'mean': + perm_op_ragged = PermopRagged(tf.math.reduce_mean) + elif self.perm_op == 'sum': + perm_op_ragged = PermopRagged(tf.math.reduce_sum) + # else: + # raise ValueError(f'Permutation invariant operation: {self.perm_op} is not allowed, must be "mean" or "sum".') + + inputs = tf.keras.layers.InputLayer(input_shape=(None, self.input_dim), dtype="float32", ragged=True)( + pointclouds) + output = self.phi_1(inputs) + output = perm_op_ragged(output) + output = self.phi_2(output) + return output diff --git a/src/python/test/test_ripsnet.py b/src/python/test/test_ripsnet.py new file mode 100644 index 0000000000..c57ca8eade --- /dev/null +++ b/src/python/test/test_ripsnet.py @@ -0,0 +1,112 @@ +""" This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT. + See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details. + Author(s): Felix Hensel + Copyright (C) 2022 Inria + Modification(s): + - YYYY/MM Author: Description of the modification +""" + +import numpy as np +import tensorflow as tf +from gudhi.tensorflow import * + + +def test_ripsnet(): + ragged_layers_size = [4] + dense_layers_size = [4] + output_units = 4 + activation_fct = 'gelu' + output_activation = 'sigmoid' + dropout = 0 + kernel_regularization = 0 + initializer = None# tf.keras.initializers.Constant(value=1) + + ragged_layers = [] + dense_layers = [] + for n_units in ragged_layers_size: + ragged_layers.append(DenseRagged(units=n_units, use_bias=True, activation=activation_fct, + kernel_initializer=initializer, bias_initializer=initializer)) + + for n_units in dense_layers_size: + dense_layers.append(tf.keras.layers.Dense(n_units, activation=activation_fct, + kernel_initializer=initializer, bias_initializer=initializer)) + + dense_layers.append(tf.keras.layers.Dense(output_units, activation=output_activation, + kernel_initializer=initializer, bias_initializer=initializer)) + + weights_vect = [np.array([[-0.3868327 , 0.5431584 , 0.7523476 , 0.80209386], + [ 0.22491306, 0.4626178 , 0.34193814, -0.04737851]]), + np.array([ 0.5047069 , -0.11543324, -0.03882882, -0.16129738]), + np.array([[-0.7956421 , 0.2326832 , -0.5405302 , 0.096256964], + [ 0.06973686 , 0.0251764 , -0.05733281 , 0.3528394 ], + [-0.77462643 , 0.03330394 , -0.8688136 , -0.22296508 ], + [-0.5054477 , 0.7201048 , 0.1857564 , 0.65894866 ]]), + np.array([-0.30565566 , -0.77507186 , -0.049963538, 0.5765676 ]), + np.array([[-0.25560755 , 0.71504813 , 0.0047909063, -0.1595783 ], + [-0.71575665 , 0.6139034 , -0.47060093 , 0.087501734 ], + [ 0.1588738 , -0.593038 , 0.48378325 , -0.777213 ], + [ 0.6206032 , -0.20880768 , 0.14528894 , 0.18696047 ]]), + np.array([-0.17761804, -0.6905532 , 0.64367545, -0.2173939 ])] + + phi_1 = TFBlock(ragged_layers) #DenseRaggedBlock(ragged_layers) + perm_op = 'mean' + phi_2 = TFBlock(dense_layers) + input_dim = 2 + + model = RipsNet(phi_1, phi_2, input_dim, perm_op=perm_op) + + #test_input_raw = [np.array([[1.,2.],[3.,4.]])] + # test_input = tf.ragged.constant([ + # [list(c) for c in list(test_input_raw[i])] for i in range(len(test_input_raw))], ragged_rank=1) + + test_input_raw = [[[1., 2.], [3., 4.]]] + test_input = tf.ragged.constant(test_input_raw, ragged_rank=1) + + model.predict(test_input) + + model.set_weights(weights_vect) + + + clean_data_test = [np.array([[ -7.04493841, 9.60285858], + [-13.14389003, -13.21854157], + [ -3.21137961, -1.28593644]]), + np.array([[ 10.40324933, -0.80540584], + [ 16.54752459, 0.70355361], + [ 6.410207 , -10.63175183], + [ 2.96613799, -11.97463568]]), + np.array([[ 4.85041719, -2.93820024], + [ 2.15379915, -5.39669696], + [ 5.83968556, -5.67350982], + [ 5.25955172, -6.36860269]])] + + noisy_data_test = [np.array([[ -8.93311026, 1.52317533], + [-16.80344139, -3.76871298], + [-11.58448573, -2.76311122], + [-15.06107796, 5.05253587]]), + np.array([[-3.834947 , -5.1897498 ], + [-3.51701182, -4.23539191], + [-2.68678747, -1.63902703], + [-4.65070816, -3.96363227]]), + np.array([[ 4.7841113 , 19.2922069 ], + [10.5164214 , 5.50246605], + [-9.38163622, 7.03682948]])] + + tf_clean_data_test = tf.ragged.constant([ + [list(c) for c in list(clean_data_test[i])] for i in range(len(clean_data_test))], ragged_rank=1) + tf_noisy_data_test = tf.ragged.constant([ + [list(c) for c in list(noisy_data_test[i])] for i in range(len(noisy_data_test))], ragged_rank=1) + + clean_prediction = np.array([[0.5736222, 0.3047213, 0.6746019, 0.49468565], + [0.4522748, 0.7521156, 0.3061385, 0.77519494], + [0.5349713, 0.49940312, 0.51753736, 0.6435147]]) + noisy_prediction = np.array([[0.53986603, 0.33934325, 0.64809155, 0.49939266], + [0.5637899, 0.28744557, 0.67097586, 0.50240993], + [0.5689339, 0.55574864, 0.47079712, 0.68721807]]) + + #print(np.linalg.norm(clean_prediction - model.predict(tf_clean_data_test))) + + assert(clean_prediction.shape == model.predict(tf_clean_data_test).shape) + assert(noisy_prediction.shape == model.predict(tf_noisy_data_test).shape) + assert(np.linalg.norm(clean_prediction - model.predict(tf_clean_data_test)) <= 1e-6) + assert(np.linalg.norm(noisy_prediction - model.predict(tf_noisy_data_test)) <= 1e-6) + return