Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ripsnet #587

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
352fce1
ripsnet first commit
hensel-f Feb 16, 2022
51eeb03
ripsnet typos fixed
hensel-f Feb 16, 2022
f879d1d
changed perm_op API and added documentation
hensel-f Feb 22, 2022
8ff14bf
added test
hensel-f Mar 1, 2022
9eb3231
added test
hensel-f Mar 1, 2022
8e5003f
updated test
hensel-f Mar 2, 2022
496a141
Update src/python/doc/ripsnet.inc
hensel-f Mar 7, 2022
fcf7edf
Fixed imports
hensel-f Mar 7, 2022
e23e929
Fixed imports
hensel-f Mar 7, 2022
25b842e
updated documentation
hensel-f Mar 7, 2022
f01e3cc
updated CMakeLists.txt
hensel-f Mar 7, 2022
9281794
updated documentation and bibliography
hensel-f Mar 7, 2022
6b9ecce
updated CMakeLists.txt
hensel-f Mar 10, 2022
ec19f2d
Update src/python/doc/ripsnet.inc
hensel-f Mar 10, 2022
aef67bd
documentation fix
hensel-f Mar 10, 2022
28d7320
increased error margin from 1e-7 to 1e-6
hensel-f Mar 10, 2022
32ba96b
fixed imports
hensel-f Mar 10, 2022
33424c8
fixed imports
hensel-f Mar 10, 2022
146b34e
fixed imports
hensel-f Mar 10, 2022
f29ca81
fixed testoutput check
hensel-f Mar 11, 2022
6edf471
removed print statement
hensel-f Mar 11, 2022
59dc256
update of gitignore
hensel-f Apr 26, 2022
db348e6
removede numpy import and added explanation
hensel-f Apr 26, 2022
dd5e94e
changed name of pop to perm_op
hensel-f Apr 26, 2022
25222ca
changed TFBlock to support regged inputs and commented DenseRaggedBlock
hensel-f Apr 26, 2022
63edfc2
changed to TFBlock in test_ripsnet.py
hensel-f Apr 26, 2022
1cca679
fixed __init__.py
hensel-f Apr 26, 2022
1122d11
fixed documentation
hensel-f Apr 26, 2022
a25d3f8
allowing user specified permop functions
hensel-f Jun 7, 2022
a05f825
updated comments in the documentation and changed imports
hensel-f Jun 7, 2022
67cb4d4
added get_config()
hensel-f Jun 20, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions biblio/bibliography.bib
Original file line number Diff line number Diff line change
@@ -1,3 +1,38 @@
@article{RipsNet_arXiv,
author = {Thibault de Surrel and
Felix Hensel and
Mathieu Carri{\`{e}}re and
Th{\'{e}}o Lacombe and
Yuichi Ike and
Hiroaki Kurihara and
Marc Glisse and
Fr{\'{e}}d{\'{e}}ric Chazal},
title = {RipsNet: a general architecture for fast and robust estimation of
the persistent homology of point clouds},
journal = {CoRR},
volume = {abs/2202.01725},
year = {2022},
url = {https://arxiv.org/abs/2202.01725},
eprinttype = {arXiv},
eprint = {2202.01725},
timestamp = {Wed, 09 Feb 2022 15:43:35 +0100},
biburl = {https://dblp.org/rec/journals/corr/abs-2202-01725.bib},
bibsource = {dblp computer science bibliography, https://dblp.org}
}

@inproceedings{DeepSets17,
author = {Zaheer, Manzil and Kottur, Satwik and Ravanbakhsh, Siamak and Poczos, Barnabas and Salakhutdinov, Russ R and Smola, Alexander J},
booktitle = {Advances in Neural Information Processing Systems},
editor = {I. Guyon and U. V. Luxburg and S. Bengio and H. Wallach and R. Fergus and S. Vishwanathan and R. Garnett},
pages = {},
publisher = {Curran Associates, Inc.},
title = {Deep Sets},
url = {https://proceedings.neurips.cc/paper/2017/file/f22e4747da1aa27e363d86d40ff442fe-Paper.pdf},
volume = {30},
year = {2017}
}


@inproceedings{gudhilibrary_ICMS14,
author = {Cl\'ement Maria and Jean-Daniel Boissonnat and
Marc Glisse and Mariette Yvinec},
Expand Down
7 changes: 7 additions & 0 deletions src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ if(PYTHONINTERP_FOUND)
set(GUDHI_PYTHON_MODULES "${GUDHI_PYTHON_MODULES}'euclidean_strong_witness_complex', ")
# Modules that should not be auto-imported in __init__.py
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'representations', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'tensorflow', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'wasserstein', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'point_cloud', ")
set(GUDHI_PYTHON_MODULES_EXTRA "${GUDHI_PYTHON_MODULES_EXTRA}'weighted_rips_complex', ")
Expand Down Expand Up @@ -285,6 +286,7 @@ if(PYTHONINTERP_FOUND)
file(COPY "gudhi/dtm_rips_complex.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")
file(COPY "gudhi/hera/__init__.py" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi/hera")
file(COPY "gudhi/datasets" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi" FILES_MATCHING PATTERN "*.py")
file(COPY "gudhi/tensorflow" DESTINATION "${CMAKE_CURRENT_BINARY_DIR}/gudhi")


# Some files for pip package
Expand Down Expand Up @@ -551,6 +553,11 @@ if(PYTHONINTERP_FOUND)
add_gudhi_py_test(test_representations)
endif()

# Differentiation
if(TENSORFLOW_FOUND)
add_gudhi_py_test(test_ripsnet)
endif()

# Betti curves
if(SKLEARN_FOUND AND SCIPY_FOUND)
add_gudhi_py_test(test_betti_curve_representations)
Expand Down
15 changes: 15 additions & 0 deletions src/python/doc/ripsnet.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
.. table::
:widths: 30 40 30

+----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
| | RipsNet is a general architecture for fast and robust estimation of the | :Author: Felix Hensel, Mathieu Carrière |
| | persistent homology of point clouds. | |
| | | :Since: GUDHI |
| | | |
| | | :License: MIT |
| | | |
hensel-f marked this conversation as resolved.
Show resolved Hide resolved
| | | :Requires: `TensorFlow <installation.html#tensorflow>`_ |
| | | |
+----------------------------------------------------------------+-------------------------------------------------------------------------+------------------------------------------------------------------------------------------------------------------------+
| * :doc:`ripsnet` | |
+----------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
VincentRouvreau marked this conversation as resolved.
Show resolved Hide resolved
99 changes: 99 additions & 0 deletions src/python/doc/ripsnet.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
:orphan:

.. To get rid of WARNING: document isn't included in any toctree

RipsNet user manual
=========================
Definition
----------

.. include:: ripsnet.inc

:class:`~gudhi.ripsnet` constructs a Tensorflow model for fast and robust estimation of persistent homology of
point clouds.
RipsNet is based on a `Deep Sets <https://papers.nips.cc/paper/2017/file/f22e4747da1aa27e363d86d40ff442fe-Paper.pdf>`_
architecture :cite:`DeepSets17`, for details see the paper `RipsNet <https://arxiv.org/abs/2202.01725>`_ :cite:`RipsNet_arXiv`.

Example
-------------------

This example instantiates a RipsNet model which can then be trained as any tensorflow model.

.. testcode::
from gudhi.tensorflow import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be better to be explicit, i.e. naming the imported functions or to do import gudhi.tensorflow as gtf just to make clear to the user which functions below are indeed coming from the gudhi.tensorflow package.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, that makes it a bit clearer.

import tensorflow as tf
from tensorflow.keras import regularizers, layers
import numpy as np

ragged_layers_size = [20, 10]
dense_layers_size = [10, 20]
output_units = 25
activation_fct = 'gelu'
output_activation = 'sigmoid'
dropout = 0
kernel_regularization = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps it may be worth to comment (not in detail) what are these hyper-parameters ; in particular how one is supposed to chose ragged_layers_size and dense_layers_size.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I added a comment saying they should be tuned according to the specific dataset in order to reach a better performance.


ragged_layers = []
dense_layers = []

for n_units in ragged_layers_size:
ragged_layers.append(DenseRagged(units=n_units, use_bias=True, activation=activation_fct))

for n_units in dense_layers_size:
dense_layers.append(layers.Dense(n_units, activation=activation_fct,
kernel_regularizer=regularizers.l2(kernel_regularization)))
dense_layers.append(layers.Dropout(dropout))

dense_layers.append(layers.Dense(output_units, activation=output_activation))

phi_1 = DenseRaggedBlock(ragged_layers)
perm_op = 'mean' # can also be 'sum'.
phi_2 = TFBlock(dense_layers)
input_dim = 2

RN = RipsNet(phi_1, phi_2, input_dim, perm_op=perm_op)
hensel-f marked this conversation as resolved.
Show resolved Hide resolved

data_test = [np.array([[-7.04493841, 9.60285858],
[-13.14389003, -13.21854157],
[-3.21137961, -1.28593644]]),
np.array([[10.40324933, -0.80540584],
[16.54752459, 0.70355361],
[6.410207, -10.63175183],
[2.96613799, -11.97463568]]),
np.array([[4.85041719, -2.93820024],
[2.15379915, -5.39669696],
[5.83968556, -5.67350982],
[5.25955172, -6.36860269]])]

tf_data_test = tf.ragged.constant([
[list(c) for c in list(data_test[i])] for i in range(len(data_test))], ragged_rank=1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh, so we build data_test using numpy (from lists), only to convert it again to lists here, and finally build a tensorflow object from those lists? Would it be possible to skip some of those conversions? We probably don't need to import numpy at all.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right, thanks. I updated it.


print(RN.predict(tf_data_test))

Once RN is properly trained (which we skip in this documentation) it can be used to make predictions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be cumbersome to provide a sort of minimal working example to train RipsNet ?

A user may not be familiar with tensorflow and have no clue on how to train the RipsNet model at this stage. Perhaps just setting the typical optimizer = ..., loss_function = ... , and do a single step of gradient descent here, would help to and not discourage the user not familiar with tf?

Another option is to write a Tutorial ( https://github.com/GUDHI/TDA-tutorial ) to reproduce, say, the synthetic experiment of the paper (multiple_circles) and to refer to it in this doc (I understand that we don't want this doc to be too long).

A final option (that requires more development) would be to provide a method train to RipsNet that does the job with some default parameters, so that one could get starting by simply going for something like

RN = ripsnet.RipsNet(...)
RN.train(train_data)
RN.predict(test_data)

Of course these are just suggestions.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the best option is to link to the notebook containing the synthetic examples. This provides a very nice example where one can see the workflow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've adapted Mathieu's original tutorial on the synthetic data so that it illustrates the use (including setup and training) of a RipsNet architecture. So, as @tlacombe suggested, I think it may be nice to include and link to this tutorial somewhere.

Copy link
Author

@hensel-f hensel-f Jun 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've opened a PR (here: GUDHI/TDA-tutorial#59) to include the tutorial notebook so that we can then link to it.

A possible output is:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It isn't obvious what the example is computing. Maybe adding a comment or 2 would help (or an image). Is data_test a list of 3 point sets in 2D, and is the output some kind of vectorized persistence diagram? It becomes clear once we read the detailed doc, but I think some minimal comments in the example would still make sense.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I added a sentence to describe it.


.. testoutput::

[[0.58554363 0.6054868 0.44672886 0.5216672 0.5814481 0.48068565
0.49626726 0.5285395 0.4805212 0.37918684 0.49745193 0.49247316
0.4706078 0.5491477 0.47016636 0.55804974 0.46501246 0.4065692
0.5386659 0.5660226 0.52014357 0.5329493 0.52178216 0.5156043
0.48742113]
[0.9446074 0.99024785 0.1316272 0.3013248 0.98174655 0.52285945
0.33727515 0.997285 0.3711884 0.00388432 0.63181967 0.5377489
0.22074646 0.7681194 0.04337704 0.80116796 0.02139336 0.04605395
0.8911999 0.9570045 0.5789719 0.8221929 0.7742506 0.4596561
0.08529088]
[0.8230771 0.9320036 0.25120026 0.48027694 0.8988322 0.5789062
0.38307947 0.9252455 0.39485127 0.06090912 0.5786307 0.51115406
0.28706372 0.70552015 0.16929033 0.7028084 0.12379596 0.1867683
0.6969584 0.84437454 0.6172329 0.66728634 0.630455 0.47643042
0.27172992]]

Detailed documentation
----------------------
.. automodule:: gudhi.tensorflow.ripsnet
:members:
:special-members:
:show-inheritance:
3 changes: 3 additions & 0 deletions src/python/gudhi/tensorflow/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ripsnet import *

__all__ = ["RipsNet", "PermopRagged", "TFBlock", "DenseRaggedBlock", "DenseRagged"]
209 changes: 209 additions & 0 deletions src/python/gudhi/tensorflow/ripsnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
# This file is part of the Gudhi Library - https://gudhi.inria.fr/ - which is released under MIT.
# See file LICENSE or go to https://gudhi.inria.fr/licensing/ for full license details.
# Author(s): Felix Hensel, Mathieu Carrière
#
# Copyright (C) 2022 Inria
#
# Modification(s):
# - YYYY/MM Author: Description of the modification

import tensorflow as tf

class DenseRagged(tf.keras.layers.Layer):
"""
This is a class for the ragged layer in the RipsNet architecture, processing the input pointclouds.
"""

def __init__(self, units, input_dim=None, use_bias=True, activation='gelu', kernel_initializer=None, bias_initializer=None, **kwargs):
"""
Constructor for the DenseRagged class.

Parameters:
units (int): number of units in the layer.
use_bias (bool): flag, indicating whether to use bias or not.
activation (string or function): identifier of a keras activation function, e.g. 'relu'.
kernel_initializer: tensorflow kernel initializer.
bias_initializer: tensorflow bias initializer.
"""
super().__init__(dynamic=True, **kwargs)
self._supports_ragged_inputs = True
self.units = units
self.use_bias = use_bias
self.activation = tf.keras.activations.get(activation)
self.kernel_initializer = kernel_initializer
self.bias_initializer = bias_initializer

def build(self, input_shape):
last_dim = input_shape[-1]
self.kernel = self.add_weight('kernel', shape=[last_dim, self.units], trainable=True, initializer=self.kernel_initializer)
if self.use_bias:
self.bias = self.add_weight('bias', shape=[self.units, ], trainable=True, initializer=self.bias_initializer)
else:
self.bias = None
super().build(input_shape)

def call(self, inputs):
"""
Apply DenseRagged layer on a ragged input tensor.

Parameters:
ragged tensor (e.g. containing a point cloud).

Returns:
ragged tensor containing the output of the layer.
"""
outputs = tf.ragged.map_flat_values(tf.matmul, inputs, self.kernel)
if self.use_bias:
outputs = tf.ragged.map_flat_values(tf.nn.bias_add, outputs, self.bias)
outputs = tf.ragged.map_flat_values(self.activation, outputs)
Comment on lines +69 to +72
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like an ad hoc reimplementation of a dense layer, applied to each element. Would it make sense first to define a small network that takes an input of size 2 (if we work with 2d points) and that may be composed of several layers, and only then apply (map) it to all the points in the tensor, so there is a single call to a map function for the whole phi1? It seems conceptually simpler, but it might be slower if tensorflow doesn't realize that it is equivalent.

Copy link
Author

@hensel-f hensel-f Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I'm not exactly sure what you mean, or if it makes it faster. But if you have a concrete change in mind please just adapt it directly or let me know.

return outputs


class DenseRaggedBlock(tf.keras.layers.Layer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks identical to TFBlock except for _supports_ragged_inputs? Would TFBlock(ragged=True) make sense? Or could it even be implicit, ragged iff the first layer is?

Copy link
Author

@hensel-f hensel-f Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, I have changed TFBlock such that it supports ragged inputs if the first layer is an instance of DenseRagged. So, indeed I think that DenseRaggedBlock is no longer needed. I have commented it for now but if the change is confirmed it can be deleted.

"""
This is a block of DenseRagged layers.
"""

def __init__(self, dense_ragged_layers, **kwargs):
"""
Constructor for the DenseRaggedBlock class.

Parameters:
dense_ragged_layers (list): a list of DenseRagged layers :class:`~gudhi.tensorflow.DenseRagged`.
input_dim (int): dimension of the pointcloud, if the input consists of pointclouds.
"""
super().__init__(dynamic=True, **kwargs)
self._supports_ragged_inputs = True
self.dr_layers = dense_ragged_layers

def build(self, input_shape):
return self

def call(self, inputs):
"""
Apply the sequence of DenseRagged layers on a ragged input tensor.

Parameters:
ragged tensor (e.g. containing a point cloud).

Returns:
ragged tensor containing the output of the sequence of layers.
"""
outputs = inputs
for dr_layer in self.dr_layers:
outputs = dr_layer(outputs)
return outputs


class TFBlock(tf.keras.layers.Layer):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems very general. Is it related to tf.keras.Sequential, or is there some other utility already providing this composition?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible that there is an existing utility for this composition, but I couldn't find anything explanation of how to make it work with ragged inputs.

"""
This class is a block of tensorflow layers.
"""

def __init__(self, layers, **kwargs):
"""
Constructor for the TFBlock class.

Parameters:
dense_layers (list): a list of dense tensorflow layers.
"""
super().__init__(dynamic=True, **kwargs)
self.layers = layers

def build(self, input_shape):
super().build(input_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't that exactly what happens if you don't define this function?
I am also trying to understand the difference between this and what you did for DenseRaggedBlock and RipsNet.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also not exactly sure what the effect of this is, I modeled it after an example I saw somewhere. But I've changed it to mach the case for RipsNet and DenseRagged. Is that fine, or what would you suggest?


def call(self, inputs):
"""
Apply the sequence of layers on an input tensor.

Parameters:
inputs: any input tensor.

Returns:
output tensor containing the output of the sequence of layers.
"""
outputs = inputs
for layer in self.layers:
outputs = layer(outputs)
return outputs


class PermopRagged(tf.keras.layers.Layer):
"""
This is a class for the permutation invariant layer in the RipsNet architecture.
"""

def __init__(self, perm_op, **kwargs):
"""
Constructor for the PermopRagged class.

Parameters:
perm_op: permutation invariant function, such as `tf.math.reduce_sum`, `tf.math.reduce_mean`.
"""
super().__init__(dynamic=True, **kwargs)
self._supports_ragged_inputs = True
self.pop = perm_op
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the name pop is confusing (I was wondering from which list you we removing an element), could we stick to perm_op or anything that isn't already an English word with an unrelated meaning?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I've updated it to avoid confusion, thanks.


def build(self, input_shape):
super().build(input_shape)

def call(self, inputs):
"""
Apply PermopRagged on an input tensor.
"""
out = self.pop(inputs, axis=1)
return out


class RipsNet(tf.keras.Model):
"""
This is a TensorFlow model for estimating vectorizations of persistence diagrams of point clouds.
This class implements the RipsNet described in the following article <https://arxiv.org/abs/2202.01725>.
"""

def __init__(self, phi_1, phi_2, input_dim, perm_op='mean', **kwargs):
"""
Constructor for the RipsNet class.

Parameters:
phi_1 (layers): any block of DenseRagged layers. Can be :class:`~gudhi.tensorflow.DenseRaggedBlock`, or a custom block built from :class:`~gudhi.tensorflow.DenseRagged` layers.
phi_2 (layers): Can be any (block of) TensorFlow layer(s), e.g. :class:`~gudhi.tensorflow.TFBlock`.
input_dim (int): dimension of the input point clouds.
perm_op (str): Permutation invariant operation. Can be 'mean' or 'sum'.
"""
super().__init__(dynamic=True, **kwargs)
self.phi_1 = phi_1
self.pop = perm_op
self.phi_2 = phi_2
self.input_dim = input_dim

if perm_op not in ['mean', 'sum']:
raise ValueError(f'Permutation invariant operation: {self.pop} is not allowed, must be "mean" or "sum".')

def build(self, input_shape):
return self

def call(self, pointclouds):
"""
Apply RipsNet on a ragged tensor containing a list of pointclouds.

Parameters:
point clouds (n x None x input_dimension): ragged tensor containing n pointclouds in dimension `input_dimension`. The second dimension is ragged since point clouds can have different numbers of points.

Returns:
output (n x output_shape): tensor containing predicted vectorizations of the persistence diagrams of pointclouds.
"""
if self.pop == 'mean':
pop_ragged = PermopRagged(tf.math.reduce_mean)
elif self.pop == 'sum':
pop_ragged = PermopRagged(tf.math.reduce_sum)
else:
raise ValueError(f'Permutation invariant operation: {self.pop} is not allowed, must be "mean" or "sum".')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If perm_op is not one of those 2 strings, should we assume that it is a function, so users can pass tf.math.reduce_max if they want? Or is that useless?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the only concern I have is that if we leave this entirely up to the user, then their input function may be such that our theoretical guarantees for RipsNet may perhaps no longer be satisfied. I'm not sure what the best option is in this case. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's good to let the user do whatever they want (as long as it is properly documented). Our theoretical results are of the form "if ..., then" ; but nothing prevents to use RipsNet with some other perm_op and a user may be interested in doing so.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, sounds good, I've removed the requirement that the permutation invariant function has to be 'mean' or 'sum', so that users can specify their own functions.


inputs = tf.keras.layers.InputLayer(input_shape=(None, self.input_dim), dtype="float32", ragged=True)(
pointclouds)
output = self.phi_1(inputs)
output = pop_ragged(output)
output = self.phi_2(output)
return output
Loading