diff --git a/test/classes/test_combinatorial_complex.py b/test/classes/test_combinatorial_complex.py index de00d773..50b36565 100644 --- a/test/classes/test_combinatorial_complex.py +++ b/test/classes/test_combinatorial_complex.py @@ -398,6 +398,7 @@ def test_remove_nodes(self): frozenset({6}): {"weight": 1}, } } + example.remove_nodes(HyperEdge([3])) assert example._complex_set.hyperedge_dict == { 0: {frozenset({4}): {"weight": 1}, frozenset({6}): {"weight": 1}} diff --git a/test/classes/test_simplex_trie.py b/test/classes/test_simplex_trie.py new file mode 100644 index 00000000..215f8e4b --- /dev/null +++ b/test/classes/test_simplex_trie.py @@ -0,0 +1,131 @@ +"""Tests for the `simplex_trie` module.""" + +import pytest + +from toponetx.classes.simplex_trie import SimplexTrie + + +class TestSimplexTrie: + """Tests for the `SimplexTree` class.""" + + def test_insert(self) -> None: + """Test that the internal data structures of the simplex trie are correct after insertion.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + + assert trie.shape == [3, 3, 1] + + assert set(trie.root.children.keys()) == {1, 2, 3} + assert set(trie.root.children[1].children.keys()) == {2, 3} + assert set(trie.root.children[1].children[2].children.keys()) == {3} + assert set(trie.root.children[2].children.keys()) == {3} + + # the label list should contain the nodes of each depth according to their label + label_to_simplex = { + 1: {1: [(1,)], 2: [(2,)], 3: [(3,)]}, + 2: {2: [(1, 2)], 3: [(1, 3), (2, 3)]}, + 3: {3: [(1, 2, 3)]}, + } + + assert len(trie.label_lists) == len(label_to_simplex) + for depth, label_list in trie.label_lists.items(): + assert depth in label_to_simplex + assert len(label_list) == len(label_to_simplex[depth]) + for label, nodes in label_list.items(): + assert len(nodes) == len(label_to_simplex[depth][label]) + for node, expected in zip(nodes, label_to_simplex[depth][label]): + assert node.simplex.elements == expected + + def test_iter(self) -> None: + """Test the iteration of the trie.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + trie.insert((2, 3, 4)) + trie.insert((0, 1)) + + # We guarantee a specific ordering of the simplices when iterating. Hence, we explicitly compare lists here. + assert list(map(lambda node: node.simplex.elements, trie)) == [ + (0,), + (1,), + (2,), + (3,), + (4,), + (0, 1), + (1, 2), + (1, 3), + (2, 3), + (2, 4), + (3, 4), + (1, 2, 3), + (2, 3, 4), + ] + + def test_cofaces(self) -> None: + """Test the cofaces method.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + trie.insert((1, 2, 4)) + + # no ordering is guaranteed for the cofaces method + assert set(map(lambda node: node.simplex.elements, trie.cofaces((1,)))) == { + (1,), + (1, 2), + (1, 3), + (1, 4), + (1, 2, 3), + (1, 2, 4), + } + assert set(map(lambda node: node.simplex.elements, trie.cofaces((2,)))) == { + (2,), + (1, 2), + (2, 3), + (2, 4), + (1, 2, 3), + (1, 2, 4), + } + + def test_is_maximal(self) -> None: + """Test the `is_maximal` method.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + trie.insert((1, 2, 4)) + + assert trie.is_maximal((1, 2, 3)) + assert trie.is_maximal((1, 2, 4)) + assert not trie.is_maximal((1, 2)) + assert not trie.is_maximal((1, 3)) + assert not trie.is_maximal((1, 4)) + assert not trie.is_maximal((2, 3)) + + with pytest.raises(ValueError): + trie.is_maximal((5,)) + + def test_skeleton(self) -> None: + """Test the skeleton method.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + trie.insert((1, 2, 4)) + + # no ordering is guaranteed for the skeleton method + assert set(map(lambda node: node.simplex.elements, trie.skeleton(0))) == { + (1,), + (2,), + (3,), + (4,), + } + assert set(map(lambda node: node.simplex.elements, trie.skeleton(1))) == { + (1, 2), + (1, 3), + (1, 4), + (2, 3), + (2, 4), + } + assert set(map(lambda node: node.simplex.elements, trie.skeleton(2))) == { + (1, 2, 3), + (1, 2, 4), + } + + with pytest.raises(ValueError): + _ = next(trie.skeleton(-1)) + with pytest.raises(ValueError): + _ = next(trie.skeleton(3)) diff --git a/test/classes/test_simplicial_complex.py b/test/classes/test_simplicial_complex.py index ddb499d2..b876f917 100644 --- a/test/classes/test_simplicial_complex.py +++ b/test/classes/test_simplicial_complex.py @@ -4,8 +4,6 @@ import networkx as nx import numpy as np import pytest -import spharapy.datasets as sd -import spharapy.spharabasis as sb import spharapy.trimesh as tm from gudhi import SimplexTree @@ -33,11 +31,19 @@ def test_shape_property(self): sc = SimplicialComplex() assert sc.shape == tuple() + # make sure that empty dimensions are not part of the shape after removal + sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) + sc.remove_nodes([2, 3]) + assert sc.shape == (3, 1) + def test_dim_property(self): """Test dim property.""" sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) assert sc.dim == 2 + sc.remove_nodes([2, 3]) + assert sc.dim == 1 + def test_nodes_property(self): """Test nodes property.""" sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) @@ -60,6 +66,8 @@ def test_is_maximal(self): """Test is_maximal method.""" sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) assert sc.is_maximal([1, 2, 3]) + assert not sc.is_maximal([1, 2]) + assert not sc.is_maximal([3]) with pytest.raises(ValueError): sc.is_maximal([1, 2, 3, 4]) @@ -79,20 +87,16 @@ def test_contructor_using_graph(self): def test_skeleton_raise_errors(self): """Test skeleton raises.""" + G = nx.Graph() + G.add_edge(0, 1) + G.add_edge(2, 5) + G.add_edge(5, 4, weight=5) + SC = SimplicialComplex(G) + with pytest.raises(ValueError): - G = nx.Graph() - G.add_edge(0, 1) - G.add_edge(2, 5) - G.add_edge(5, 4, weight=5) - SC = SimplicialComplex(G) SC.skeleton(-2) with pytest.raises(ValueError): - G = nx.Graph() - G.add_edge(0, 1) - G.add_edge(2, 5) - G.add_edge(5, 4, weight=5) - SC = SimplicialComplex(G) SC.skeleton(2) def test_str(self): @@ -127,7 +131,7 @@ def test_getittem__(self): # with pytest.raises(ValueError): assert SC[(1, 2, 3)]["heat"] == 5 with pytest.raises(KeyError): - SC[(1, 2, 3, 4, 5)]["heat"] + _ = SC[(1, 2, 3, 4, 5)]["heat"] def test_setting_simplex_attributes(self): """Test setting simplex attributes through a `SimplicialComplex` object.""" @@ -137,38 +141,21 @@ def test_setting_simplex_attributes(self): G.add_edge(5, 4, weight=5) SC = SimplicialComplex(G, name="graph complex") SC.add_simplex((1, 2, 3), heat=5) - # with pytest.raises(ValueError): SC[(1, 2, 3)]["heat"] = 6 - assert SC[(1, 2, 3)]["heat"] == 6 SC[(2, 5)]["heat"] = 1 - assert SC[(2, 5)]["heat"] == 1 s = Simplex((1, 2, 3, 4), heat=1) SC.add_simplex(s) assert SC[(1, 2, 3, 4)]["heat"] == 1 - s = Simplex(("A"), heat=1) + s = Simplex(("A",), heat=1) SC.add_simplex(s) assert SC["A"]["heat"] == 1 - def test_maxdim(self): - """Test deprecated maxdim property for deprecation warning.""" - with pytest.deprecated_call(): - # Cause a warning by accessing the deprecated maxdim property - SC = SimplicialComplex() - max_dim = SC.maxdim - assert max_dim == -1 - - with pytest.deprecated_call(): - # Cause a warning by accessing the deprecated maxdim property - SC = SimplicialComplex([[1, 2, 3]]) - max_dim = SC.maxdim - assert max_dim == 2 - def test_add_simplices_from(self): """Test add simplices from.""" with pytest.raises(TypeError): @@ -195,13 +182,11 @@ def test_add_node(self): assert SC.dim == -1 SC.add_node(9) assert SC.dim == 0 - assert SC[9]["is_maximal"] is True SC = SimplicialComplex() assert SC.dim == -1 SC.add_node(9) assert SC.dim == 0 - assert SC[9]["is_maximal"] is True def test_add_simplex(self): """Test add_simplex method.""" @@ -249,22 +234,14 @@ def test_add_simplex(self): def test_remove_maximal_simplex(self): """Test remove_maximal_simplex method.""" - # create a SimplicialComplex object with a few simplices SC = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) - - # remove a maximal simplex using the remove_maximal_simplex() method SC.remove_maximal_simplex([1, 2, 3]) - # check that the simplex was removed correctly (tuple) assert (1, 2, 3) not in SC.simplices - # create a SimplicialComplex object with a few simplices SC = SimplicialComplex([[1, 2, 3], [2, 3, 4]]) - - # remove a maximal simplex from the complex SC.remove_maximal_simplex([2, 3, 4]) - # check that the simplex was removed correctly (list) assert [2, 3, 4] not in SC.simplices # check after the add_simplex method @@ -284,7 +261,7 @@ def test_remove_maximal_simplex(self): assert (1, 2, 3, 4, 5) not in SC # check error when simplex not in complex - with pytest.raises(KeyError): + with pytest.raises(ValueError): SC = SimplicialComplex() SC.add_simplex((1, 2, 3, 4), weight=1) SC.remove_maximal_simplex([5, 6, 7]) @@ -310,7 +287,6 @@ def test_remove_nodes(self) -> None: assert SC.is_maximal([0, 1]) assert SC.is_maximal([1, 3]) assert SC.is_maximal([3, 4]) - assert SC.is_maximal([4]) assert not SC.is_maximal([1]) def test_skeleton_and_cliques(self): @@ -333,15 +309,10 @@ def test_skeleton_and_cliques(self): def test_incidence_matrix_1(self): """Test incidence_matrix shape and values.""" - # create a SimplicialComplex object with a few simplices SC = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) - # compute the incidence matrix using the boundary_matrix() method B2 = SC.incidence_matrix(rank=2) - assert B2.shape == (6, 2) - - # assert that the incidence matrix is correct np.testing.assert_array_equal( B2.toarray(), np.array([[0, 1, -1, 1, 0, 0], [0, 0, 0, 1, -1, 1]]).T, @@ -350,8 +321,6 @@ def test_incidence_matrix_1(self): # repeat the same test, but with signed=False B2 = SC.incidence_matrix(rank=2, signed=False) assert B2.shape == (6, 2) - - # assert that the incidence matrix is correct np.testing.assert_array_equal( B2.toarray(), np.array([[0, 1, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1]]).T, @@ -507,9 +476,11 @@ def test_get_cofaces(self): SC.add_simplex([1, 2, 3, 4]) SC.add_simplex([1, 2, 4]) SC.add_simplex([3, 4, 8]) + cofaces = SC.get_cofaces([1, 2, 4], codimension=1) - assert frozenset({1, 2, 3, 4}) in cofaces - assert frozenset({3, 4, 8}) not in cofaces + cofaces = list(map(lambda simplex: simplex.elements, cofaces)) + assert (1, 2, 3, 4) in cofaces + assert (3, 4, 8) not in cofaces # ... add more assertions based on the expected cofaces def test_get_star(self): @@ -518,9 +489,12 @@ def test_get_star(self): SC.add_simplex([1, 2, 3, 4]) SC.add_simplex([1, 2, 4]) SC.add_simplex([3, 4, 8]) + star = SC.get_star([1, 2, 4]) - assert frozenset({1, 2, 4}) in star - assert frozenset({1, 2, 3, 4}) in star + star = list(map(lambda simplex: simplex.elements, star)) + + assert (1, 2, 4) in star + assert (1, 2, 3, 4) in star # ... add more assertions based on the expected star def test_set_simplex_attributes(self): @@ -616,6 +590,8 @@ def test_coincidence_matrix_2(self): def test_is_triangular_mesh(self): """Test is_triangular_mesh.""" SC = stanford_bunny("simplicial") + print("dim", SC.dim) + print(list(SC.get_all_maximal_simplices())) assert SC.is_triangular_mesh() # test for non triangular mesh @@ -707,13 +683,10 @@ def test_to_hypergraph(self): def test_to_cell_complex(self): """Test to convert SimplicialComplex to Cell Complex.""" - c1 = Simplex((1, 2, 3)) - c2 = Simplex((1, 2, 4)) - c3 = Simplex((2, 5)) - SC = SimplicialComplex([c1, c2, c3]) + SC = SimplicialComplex([(1, 2, 3), (1, 2, 4), (2, 5)]) CC = SC.to_cell_complex() - assert set(CC.edges) == {(2, 5), (2, 3), (2, 1), (2, 4), (3, 1), (1, 4)} - assert set(CC.nodes) == {2, 5, 3, 1, 4} + assert set(CC.nodes) == {1, 2, 3, 4, 5} + assert set(CC.edges) == {(2, 1), (3, 1), (2, 3), (2, 4), (1, 4), (2, 5)} def test_to_combinatorial_complex(self): """Convert a SimplicialComplex to a CombinatorialComplex and compare the number of cells and nodes.""" @@ -782,11 +755,13 @@ def test_restrict_to_nodes(self): def test_get_all_maximal_simplices(self): """Retrieve all maximal simplices from a SimplicialComplex and compare the number of simplices.""" - c1 = Simplex((1, 2, 3)) - c2 = Simplex((1, 2, 4)) - c3 = Simplex((1, 2, 5)) - SC = SimplicialComplex([c1, c2, c3]) - result = SC.get_all_maximal_simplices() + SC = SimplicialComplex([(1, 2)]) + assert set( + map(lambda simplex: simplex.elements, SC.get_all_maximal_simplices()) + ) == {(1, 2)} + + SC = SimplicialComplex([(1, 2, 3), (1, 2, 4), (1, 2, 5)]) + result = list(SC.get_all_maximal_simplices()) assert len(result) == 3 def test_coincidence_matrix(self): @@ -809,31 +784,21 @@ def test_coincidence_matrix(self): def test_down_laplacian_matrix(self): """Test the down_laplacian_matrix method of SimplicialComplex.""" - # Test case 1: Rank is within valid range SC = SimplicialComplex() SC.add_simplex([1, 2, 3]) SC.add_simplex([4, 5, 6]) - rank = 1 - signed = True - weight = None - index = False - result = SC.down_laplacian_matrix(rank, signed, weight, index) - - # Assert the result is of type scipy.sparse.csr.csr_matrix + # Test case 1: Rank is within valid range + result = SC.down_laplacian_matrix(rank=1, signed=True, weight=None, index=False) assert result.shape == (6, 6) # Test case 2: Rank is below the valid range - rank = 0 - with pytest.raises(ValueError): - SC.down_laplacian_matrix(rank, signed, weight, index) + SC.down_laplacian_matrix(rank=0, signed=True, weight=None, index=False) # Test case 3: Rank is above the valid range - rank = 5 - with pytest.raises(ValueError): - SC.down_laplacian_matrix(rank, signed, weight, index) + SC.down_laplacian_matrix(rank=5, signed=True, weight=None, index=False) def test_adjacency_matrix2(self): """Test the adjacency_matrix method of SimplicialComplex.""" @@ -842,27 +807,17 @@ def test_adjacency_matrix2(self): SC.add_simplex([4, 5, 6]) # Test case 1: Rank is within valid range - rank = 1 - signed = False - weight = None - index = False - - result = SC.adjacency_matrix(rank, signed, weight, index) + result = SC.adjacency_matrix(rank=1, signed=True, weight=None, index=False) - # Assert the result is of type scipy.sparse.csr.csr_matrix assert result.shape == (6, 6) # Test case 2: Rank is below the valid range - rank = -1 - with pytest.raises(ValueError): - SC.adjacency_matrix(rank, signed, weight, index) + SC.adjacency_matrix(rank=-1, signed=False, weight=None, index=False) # Test case 3: Rank is above the valid range - rank = 5 - with pytest.raises(ValueError): - SC.adjacency_matrix(rank, signed, weight, index) + SC.adjacency_matrix(rank=5, signed=False, weight=None, index=False) def test_coadjacency_matrix(self): """Test the coadjacency_matrix method of SimplicialComplex.""" diff --git a/toponetx/classes/colored_hypergraph.py b/toponetx/classes/colored_hypergraph.py index e2bbca51..c1730bbb 100644 --- a/toponetx/classes/colored_hypergraph.py +++ b/toponetx/classes/colored_hypergraph.py @@ -146,7 +146,12 @@ def nodes(self): NodeView """ - return NodeView(self._complex_set.hyperedge_dict, cell_type=HyperEdge) + return NodeView( + self._complex_set.hyperedge_dict[0] + if 0 in self._complex_set.hyperedge_dict + else {}, + cell_type=HyperEdge, + ) @property def incidence_dict(self): diff --git a/toponetx/classes/combinatorial_complex.py b/toponetx/classes/combinatorial_complex.py index 83ff2142..27a8c00f 100644 --- a/toponetx/classes/combinatorial_complex.py +++ b/toponetx/classes/combinatorial_complex.py @@ -295,7 +295,12 @@ def nodes(self): NodeView """ - return NodeView(self._complex_set.hyperedge_dict, cell_type=HyperEdge) + return NodeView( + self._complex_set.hyperedge_dict[0] + if 0 in self._complex_set.hyperedge_dict + else {}, + cell_type=HyperEdge, + ) @property def incidence_dict(self): diff --git a/toponetx/classes/reportviews.py b/toponetx/classes/reportviews.py index 6e95c603..25c47e9c 100644 --- a/toponetx/classes/reportviews.py +++ b/toponetx/classes/reportviews.py @@ -16,6 +16,8 @@ __all__ = ["HyperEdgeView", "CellView", "SimplexView", "NodeView"] +from toponetx.classes.simplex_trie import SimplexTrie + class CellView: """A CellView class for cells of a CellComplex. @@ -391,24 +393,14 @@ class SimplexView: Parameters ---------- - name : str, optional - Name of the SimplexView instance. - - Attributes - ---------- - max_dim : int - Maximum dimension of the simplices in the SimplexView instance. - faces_dict : list of dict - A list containing dictionaries of faces for each dimension. + simplex_trie : SimplexTrie + A SimplexTrie instance containing the simplices in the simplex view. """ - def __init__(self, name: str = "") -> None: - self.name = name - - self.max_dim = -1 - self.faces_dict = [] + def __init__(self, simplex_trie: SimplexTrie) -> None: + self._simplex_trie = simplex_trie - def __getitem__(self, simplex): + def __getitem__(self, simplex) -> Simplex: """Get the dictionary of properties associated with the given simplex. Parameters @@ -421,19 +413,15 @@ def __getitem__(self, simplex): dict or list or dict A dictionary of properties associated with the given simplex. """ - if isinstance(simplex, Simplex): - if simplex.elements in self.faces_dict[len(simplex) - 1]: - return self.faces_dict[len(simplex) - 1][simplex.elements] - elif isinstance(simplex, Iterable): - simplex = frozenset(simplex) - if simplex in self.faces_dict[len(simplex) - 1]: - return self.faces_dict[len(simplex) - 1][simplex] - else: - raise KeyError(f"input {simplex} is not in the simplex dictionary") + if isinstance(simplex, Hashable) and not isinstance(simplex, Iterable): + simplex = (simplex,) + else: + simplex = tuple(sorted(simplex)) - elif isinstance(simplex, Hashable): - if frozenset({simplex}) in self: - return self.faces_dict[0][frozenset({simplex})] + node = self._simplex_trie.find(simplex) + if node is None: + raise KeyError(f"Simplex {simplex} is not in the simplex view.") + return node.simplex @property def shape(self) -> tuple[int, ...]: @@ -444,15 +432,15 @@ def shape(self) -> tuple[int, ...]: tuple of ints A tuple of integers representing the number of simplices in each dimension. """ - return tuple(len(self.faces_dict[i]) for i in range(len(self.faces_dict))) + return tuple(self._simplex_trie.shape) def __len__(self) -> int: """Return the number of simplices in the SimplexView instance.""" - return sum(self.shape) + return len(self._simplex_trie) def __iter__(self) -> Iterator: """Return an iterator over all simplices in the simplex view.""" - return chain.from_iterable(self.faces_dict) + return iter(map(lambda node: node.simplex, self._simplex_trie)) def __contains__(self, item) -> bool: """Check if a simplex is in the simplex view. @@ -490,14 +478,11 @@ def __contains__(self, item) -> bool: >>> {1, 2, 3} in view False """ - if isinstance(item, Iterable): - item = frozenset(item) - if not 0 < len(item) <= self.max_dim + 1: - return False - return item in self.faces_dict[len(item) - 1] - elif isinstance(item, Hashable): - return frozenset({item}) in self.faces_dict[0] - return False + if isinstance(item, Hashable) and not isinstance(item, Iterable): + item = (item,) + else: + item = tuple(sorted(item)) + return item in self._simplex_trie def __repr__(self) -> str: """Return string representation that can be used to recreate it.""" @@ -519,12 +504,9 @@ def __str__(self) -> str: class NodeView: """Node view class.""" - def __init__(self, objectdict, cell_type, name: str = "") -> None: + def __init__(self, nodes, cell_type, name: str = "") -> None: self.name = name - if len(objectdict) != 0: - self.nodes = objectdict[0] - else: - self.nodes = {} + self.nodes = nodes if cell_type is None: raise ValueError("cell_type cannot be None") diff --git a/toponetx/classes/simplex.py b/toponetx/classes/simplex.py index a9670411..33ab29f9 100644 --- a/toponetx/classes/simplex.py +++ b/toponetx/classes/simplex.py @@ -1,15 +1,20 @@ """Simplex Class.""" -from collections.abc import Collection, Hashable, Iterable, Iterator +from collections.abc import Collection, Hashable, Iterable +from functools import total_ordering from itertools import combinations -from typing import Any +from typing import Generic, TypeVar from toponetx.classes.complex import Atom +from toponetx.utils.iterable import is_ordered_subset __all__ = ["Simplex"] +ElementType = TypeVar("ElementType", bound=Hashable) -class Simplex(Atom): + +@total_ordering +class Simplex(Atom, Generic[ElementType]): """ A class representing a simplex in a simplicial complex. @@ -41,23 +46,28 @@ class Simplex(Atom): """ def __init__( - self, elements: Collection, name: str = "", construct_tree: bool = True, **attr + self, + elements: Collection[ElementType], + name: str = "", + construct_tree: bool = True, + **attr, ) -> None: for i in elements: if not isinstance(i, Hashable): raise ValueError(f"All nodes of a simplex must be hashable, got {i}") - super().__init__(frozenset(sorted(elements)), name, **attr) - if len(elements) != len(self.elements): + if len(elements) != len(set(elements)): raise ValueError("A simplex cannot contain duplicate nodes.") + super().__init__(tuple(sorted(elements)), name, **attr) + self.construct_tree = construct_tree if construct_tree: self._faces = self.construct_simplex_tree(elements) else: self._faces = frozenset() - def __contains__(self, item: Any) -> bool: + def __contains__(self, item: ElementType | Iterable[ElementType]) -> bool: """Return True if the given element is a subset of the nodes. Parameters @@ -83,12 +93,21 @@ def __contains__(self, item: Any) -> bool: False """ if isinstance(item, Iterable): - return frozenset(item) <= self.elements + item = tuple(sorted(item)) + return is_ordered_subset(item, self.elements) return super().__contains__(item) + def __le__(self, other) -> bool: + """Return True if the simplex comes before the other simplex in the lexicographic order.""" + if not isinstance(other, Simplex): + return NotImplemented + return self.elements <= other.elements + @staticmethod - def construct_simplex_tree(elements: Collection) -> frozenset["Simplex"]: - """Return set of Simplex objects representing the faces.""" + def construct_simplex_tree( + elements: Collection[ElementType], + ) -> frozenset["Simplex[ElementType]"]: + """Return the set of Simplex objects representing the faces.""" faceset = set() for r in range(len(elements), 0, -1): for face in combinations(elements, r): @@ -98,7 +117,7 @@ def construct_simplex_tree(elements: Collection) -> frozenset["Simplex"]: return frozenset(faceset) @property - def boundary(self) -> frozenset["Simplex"]: + def boundary(self) -> frozenset["Simplex[ElementType]"]: """Return a set of Simplex objects representing the boundary faces.""" if self.construct_tree: return frozenset(i for i in self._faces if len(i) == len(self) - 1) @@ -106,7 +125,7 @@ def boundary(self) -> frozenset["Simplex"]: faces = Simplex.construct_simplex_tree(self.elements) return frozenset(i for i in faces if len(i) == len(self) - 1) - def sign(self, face) -> int: + def sign(self, face: "Simplex[ElementType]") -> int: """Calculate the sign of the simplex with respect to a given face. Parameters @@ -117,7 +136,7 @@ def sign(self, face) -> int: raise NotImplementedError() @property - def faces(self): + def faces(self) -> frozenset["Simplex[ElementType]"]: """Get the set of faces of the simplex. If `construct_tree` is True, return the precomputed set of faces `_faces`. @@ -149,7 +168,7 @@ def __str__(self) -> str: """ return f"Nodes set: {tuple(self.elements)}, attrs: {self._properties}" - def clone(self) -> "Simplex": + def clone(self) -> "Simplex[ElementType]": """Return a copy of the simplex. The clone method by default returns an independent shallow copy of the simplex and attributes. That is, if an diff --git a/toponetx/classes/simplex_trie.py b/toponetx/classes/simplex_trie.py new file mode 100644 index 00000000..30260993 --- /dev/null +++ b/toponetx/classes/simplex_trie.py @@ -0,0 +1,465 @@ +""" +Implementation of a simplex trie datastructure for simplicial complexes as presented in [1]_. + +This module is intended for internal use by the `SimplicialComplex` class only. Any direct interactions with this +module or its classes may break at any time. In particular, this also means that the `SimplicialComplex` class must not +leak any object references to the trie or its nodes. + +Some implementation details: +- Inside this module, simplices are represented as ordered sequences with unique elements. It is expected that all + inputs from outside are already pre-processed and ordered accordingly. This is not checked and the behavior is + undefined if this is not the case. + +References +---------- +.. [1] Jean-Daniel Boissonnat and Clément Maria. The Simplex Tree: An Efficient Data Structure for General Simplicial + Complexes. Algorithmica, pages 1–22, 2014 +""" +from collections.abc import Hashable, Iterable, Sequence +from itertools import combinations +from typing import Any, Generator, Generic, Iterator, TypeVar + +from toponetx.classes.simplex import Simplex +from toponetx.utils.iterable import is_ordered_subset + +__all__ = ["SimplexNode", "SimplexTrie"] + +ElementType = TypeVar("ElementType", bound=Hashable) + + +class SimplexNode(Generic[ElementType]): + """A node in a simplex tree. + + Parameters + ---------- + label : ElementType or None + The label of the node. May only be `None` for the root node. + parent : SimplexNode, optional + The parent node of this node. If `None`, this node is the root node. + """ + + label: ElementType | None + elements: tuple[ElementType, ...] + attributes: dict[Hashable, Any] + + depth: int + parent: "SimplexNode | None" + children: dict[ElementType, "SimplexNode[ElementType]"] + + def __init__( + self, + label: ElementType | None, + parent: "SimplexNode[ElementType] | None" = None, + ) -> None: + self.label = label + self.attributes = {} + + self.children = {} + + self.parent = parent + if parent is not None: + parent.children[label] = self + self.elements = parent.elements + (label,) + self.depth = parent.depth + 1 + else: + self.elements = tuple() + self.depth = 0 + + def __len__(self) -> int: + """Return the number of elements in this node.""" + return len(self.elements) + + def __repr__(self) -> str: + """Return a string representation of this node.""" + return f"SimplexNode({self.label}, {self.elements}, {self.children})" + + @property + def simplex(self) -> Simplex[ElementType] | None: + """Return a `Simplex` object representing this node.""" + if self.label is None: + return None + simplex = Simplex(self.elements) + simplex._properties = self.attributes + return simplex + + def iter_all(self) -> Generator["SimplexNode[ElementType]", None, None]: + """Iterate over all nodes in the subtree rooted at this node. + + Ordering is according to breadth-first search, i.e., simplices are yielded in increasing order of dimension and + increasing order of labels within each dimension. + + Yields + ------ + SimplexNode + """ + queue = [self] + while queue: + node = queue.pop(0) + if node.label is not None: + # skip root node + yield node + queue += [node.children[label] for label in sorted(node.children.keys())] + + +class SimplexTrie(Generic[ElementType]): + """ + Implementation of the simplex tree data structure as presented in [1]_. + + This class is intended for internal use by the `SimplicialComplex` class only. Any direct interactions with this + class may break at any time. + + References + ---------- + .. [1] Jean-Daniel Boissonnat and Clément Maria. The Simplex Tree: An Efficient Data Structure for General + Simplicial Complexes. Algorithmica, pages 1–22, 2014 + """ + + root: SimplexNode[ElementType] + label_lists: dict[int, dict[ElementType, list[SimplexNode[ElementType]]]] + shape: list[int] + + def __init__(self) -> None: + self.root = SimplexNode(None) + self.label_lists = dict() + self.shape = [] + + def __len__(self) -> int: + """Return the number of simplices in the trie. + + Returns + ------- + int + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> len(trie) + 7 + """ + return sum(self.shape) + + def __contains__(self, item: Iterable[ElementType, ...]) -> bool: + """Check if a simplex is contained in the trie. + + Parameters + ---------- + item : Iterable of ElementType + The simplex to check for. Must be ordered and contain unique elements. + + Returns + ------- + bool + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> (1, 2, 3) in trie + True + >>> (1, 2, 4) in trie + False + """ + return self.find(item) is not None + + def __getitem__(self, item: Iterable[ElementType, ...]) -> SimplexNode[ElementType]: + """Return the simplex node for a given simplex. + + Parameters + ---------- + item : Iterable of ElementType + The simplex to return the node for. Must be ordered and contain only unique elements. + + Returns + ------- + SimplexNode + """ + node = self.find(item) + if node is None: + raise KeyError(f"Simplex {item} not found in trie.") + return node + + def __iter__(self) -> Iterator[SimplexNode[ElementType]]: + """Iterate over all simplices in the trie. + + Simplices are ordered by increasing dimension and increasing order of labels within each dimension. + + Yields + ------ + tuple of ElementType + """ + yield from self.root.iter_all() + + def insert(self, item: Sequence[ElementType], **kwargs) -> None: + """Insert a simplex into the trie. + + Any lower-dimensional simplices that do not exist in the trie are also inserted to fulfill + the simplex property. If the simplex already exists, its properties are updated. + + Parameters + ---------- + item : Sequence of ElementType + The simplex to insert. Must be ordered and contain only unique elements. + kwargs + Optional properties of the simplex. + """ + self._insert_helper(self.root, item) + self.find(item).attributes.update(kwargs) + + def insert_raw(self, simplex: Sequence[ElementType], **kwargs) -> None: + """Insert a simplex into the trie without guaranteeing the simplex property. + + Sub-simplices are not guaranteed to be inserted, which may break the simplex property. This allows for faster + insertion of a sequence of simplices without considering their order. In case not all sub-simplices are + inserted manually, the simplex property must be restored by calling `restore_simplex_property`. + + Parameters + ---------- + simplex : Sequence of ElementType + The simplex to insert. Must be ordered and contain only unique elements. + kwargs + Optional properties of the simplex. + """ + current_node = self.root + for label in simplex: + if label in current_node.children: + current_node = current_node.children[label] + else: + current_node = self._insert_child(current_node, label) + current_node.attributes.update(kwargs) + + def _insert_helper( + self, subtree: SimplexNode[ElementType], items: Sequence[ElementType] + ) -> None: + for i, label in enumerate(items): + if label not in subtree.children: + self._insert_child(subtree, label) + self._insert_helper(subtree.children[label], items[i + 1 :]) + + def _insert_child( + self, parent: SimplexNode[ElementType], label: ElementType + ) -> SimplexNode[ElementType]: + """Insert a child node with a given label. + + Parameters + ---------- + parent : SimplexNode + The parent node. + label : ElementType + The label of the child node. + + Returns + ------- + SimplexNode + The new child node. + """ + node = SimplexNode(label, parent=parent) + + # update label lists + if node.depth not in self.label_lists: + self.label_lists[node.depth] = {} + if label in self.label_lists[node.depth]: + self.label_lists[node.depth][label].append(node) + else: + self.label_lists[node.depth][label] = [node] + + # update shape property + if node.depth > len(self.shape): + self.shape += [0] + self.shape[node.depth - 1] += 1 + + return node + + def restore_simplex_property(self) -> None: + """Restore the simplex property after using `insert_raw`.""" + all_simplices = set() + for node in self.root.iter_all(): + if len(node.children) == 0: + for r in range(1, len(node.elements)): + all_simplices.update(combinations(node.elements, r)) + + for simplex in all_simplices: + self.insert_raw(simplex) + + def find(self, search: Iterable[ElementType]) -> SimplexNode[ElementType] | None: + """Find the node in the trie that matches the search. + + Parameters + ---------- + search : Iterable of ElementType + The simplex to search for. Must be ordered and contain only unique elements. + + Returns + ------- + SimplexNode or None + The node that matches the search, or `None` if no such node exists. + """ + node = self.root + for item in search: + if item not in node.children: + return None + node = node.children[item] + return node + + def cofaces( + self, simplex: Sequence[ElementType] + ) -> Generator[SimplexNode[ElementType], None, None]: + """Return the cofaces of the given simplex. + + No ordering is guaranteed by this method. + + Parameters + ---------- + simplex : Sequence of ElementType + The simplex to find the cofaces of. Must be ordered and contain only unique elements. + + Yields + ------ + SimplexNode + The cofaces of the given simplex, including the simplex itself. + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> trie.insert((1, 2, 4)) + >>> sorted(map(lambda node: node.elements, trie.cofaces((1,)))) + [(1,), (1, 2), (1, 2, 3), (1, 2, 4), (1, 3), (1, 4)] + """ + # Find simplices of the form [*, l_1, *, l_2, ..., *, l_n], i.e. simplices that contain all elements of the + # given simplex plus some additional elements, but sharing the same largest element. This can be done by the + # label lists. + simplex_nodes = self._coface_roots(simplex) + + # Found all simplices of the form [*, l_1, *, l_2, ..., *, l_n] in the simplex trie. All nodes in the subtrees + # rooted at these nodes are cofaces of the given simplex. + for simplex_node in simplex_nodes: + yield from simplex_node.iter_all() + + def _coface_roots( + self, simplex: Sequence[ElementType] + ) -> Generator[SimplexNode[ElementType], None, None]: + """Return the roots of the coface subtrees.""" + for depth in range(len(simplex), len(self.shape) + 1): + if simplex[-1] not in self.label_lists[depth]: + continue + for simplex_node in self.label_lists[depth][simplex[-1]]: + if is_ordered_subset(simplex, simplex_node.elements): + yield simplex_node + + def is_maximal(self, simplex: tuple[ElementType, ...]) -> bool: + """Return True if the given simplex is maximal. + + A simplex is maximal if it has no cofaces. + + Parameters + ---------- + simplex : tuple of ElementType or SimplexNode + The simplex to check. Must be ordered and contain only unique elements. + + Returns + ------- + bool + True if the given simplex is maximal, False otherwise. + + Raises + ------ + ValueError + If the simplex does not exist in the trie. + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> trie.is_maximal((1, 2, 3)) + True + >>> trie.is_maximal((1, 2)) + False + """ + if simplex not in self: + raise ValueError(f"Simplex {simplex} does not exist.") + + gen = self._coface_roots(simplex) + try: + first_next = next(gen) + # sometimes the first root is the simplex itself + if first_next.elements == simplex: + if len(first_next.children) > 0: + return False + next(gen) + return False + except StopIteration: + return True + + def skeleton(self, rank: int) -> Generator[SimplexNode, None, None]: + """Return the simplices of the given rank. + + No particular ordering is guaranteed and is dependent on insertion order. + + Parameters + ---------- + rank : int + The rank of the simplices to return. + + Yields + ------ + SimplexNode + The simplices of the given rank. + + Raises + ------ + ValueError + If the given rank is negative or exceeds the maximum rank of the trie. + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> sorted(map(lambda node: node.elements, trie.skeleton(0))) + [(1,), (2,), (3,)] + >>> sorted(map(lambda node: node.elements, trie.skeleton(1))) + [(1, 2), (1, 3), (2, 3)] + >>> sorted(map(lambda node: node.elements, trie.skeleton(2))) + [(1, 2, 3)] + """ + print(rank) + if rank < 0: + raise ValueError(f"`rank` must be a positive integer, got {rank}.") + if rank >= len(self.shape): + raise ValueError(f"`rank` {rank} exceeds maximum rank {len(self.shape)}.") + + for nodes in self.label_lists[rank + 1].values(): + yield from nodes + + def remove_simplex(self, simplex: Sequence[ElementType]) -> None: + """Remove the given simplex and all its cofaces from the trie. + + This method ensures that the simplicial property is maintained by removing any simplex that is no longer valid + after removing the given simplex. + + Parameters + ---------- + simplex : Sequence of ElementType + The simplex to remove. Must be ordered and contain only unique elements. + + Raises + ------ + ValueError + If the simplex does not exist in the trie. + """ + # Locate all roots of subtrees containing cofaces of this simplex. They are invalid after removal of the given + # simplex and thus need to be removed as well. + coface_roots = self._coface_roots(simplex) + for subtree_root in coface_roots: + for node in subtree_root.iter_all(): + self.shape[node.depth - 1] -= 1 + self.label_lists[node.depth][node.label].remove(node) + + # Detach the subtree from the trie. This effectively destroys the subtree as no reference exists anymore, + # and garbage collection will take care of the rest. + subtree_root.parent.children.pop(subtree_root.label) + + # After removal of some simplices, the maximal dimension may have decreased. Make sure that there are no empty + # shape entries at the end of the shape list. + while len(self.shape) > 0 and self.shape[-1] == 0: + self.shape.pop() diff --git a/toponetx/classes/simplicial_complex.py b/toponetx/classes/simplicial_complex.py index 12cdde99..05f38801 100644 --- a/toponetx/classes/simplicial_complex.py +++ b/toponetx/classes/simplicial_complex.py @@ -4,8 +4,8 @@ """ from collections.abc import Hashable, Iterable, Iterator -from itertools import chain, combinations -from warnings import warn +from itertools import combinations +from typing import Any, Generator, Generic, TypeVar import networkx as nx import numpy as np @@ -16,12 +16,17 @@ from toponetx.classes.complex import Complex from toponetx.classes.reportviews import NodeView, SimplexView from toponetx.classes.simplex import Simplex +from toponetx.classes.simplex_trie import SimplexTrie from toponetx.exception import TopoNetXError __all__ = ["SimplicialComplex"] +ElementType = TypeVar( + "ElementType", bound=Hashable +) # TODO: Also bound to SupportsLessThanT but that is not accessible? -class SimplicialComplex(Complex): + +class SimplicialComplex(Complex, Generic[ElementType]): """Class representing a simplicial complex. Class for construction boundary operators, Hodge Laplacians, @@ -101,16 +106,16 @@ class SimplicialComplex(Complex): 4 """ + _simplex_trie: SimplexTrie[ElementType] + def __init__(self, simplices=None, name: str = "", **kwargs) -> None: super().__init__(name, **kwargs) - self._simplex_set = SimplexView() + self._simplex_trie = SimplexTrie() if isinstance(simplices, nx.Graph): - _simplices = {} - for simplex, data in simplices.nodes( - data=True - ): # `simplices` is a networkx graph + _simplices: dict[tuple[Hashable, ...], Any] = {} + for simplex, data in simplices.nodes(data=True): _simplices[(simplex,)] = data for u, v, data in simplices.edges(data=True): _simplices[(u, v)] = data @@ -129,49 +134,54 @@ def __init__(self, simplices=None, name: str = "", **kwargs) -> None: @property def shape(self) -> tuple[int, ...]: - """Shape of simplicial complex. - - (number of simplices[i], for i in range(0,dim(Sc)) ) + """Shape of simplicial complex, i.e., the number of simplices for each dimension. Returns ------- tuple of ints + + Examples + -------- + >>> SC = SimplicialComplex([[0, 1], [1, 2, 3], [2, 3, 4]]) + >>> SC.shape + (5, 5, 2) """ - return self._simplex_set.shape + return tuple(self._simplex_trie.shape) @property def dim(self) -> int: - """Dimension. - - This is the highest dimension of any simplex in the complex. - """ - return self._simplex_set.max_dim + """Maximal dimension of the simplices in this complex. - @property - def maxdim(self) -> int: - """Maximum dimension. + Returns + ------- + int - This is the highest dimension of any simplex in the complex + Examples + -------- + >>> SC = SimplicialComplex([[0, 1], [1, 2, 3], [2, 3, 4]]) + >>> SC.dim + 2 """ - warn( - "`SimplicialComplex.maxdim` is deprecated and will be removed in the future, use `SimplicialComplex.max_dim` instead.", - DeprecationWarning, - stacklevel=2, - ) - return self._simplex_set.max_dim + return len(self._simplex_trie.shape) - 1 @property def nodes(self): """Nodes.""" - return NodeView(self._simplex_set.faces_dict, cell_type=Simplex) + return NodeView( + { + frozenset({node.label}): node.simplex + for node in self._simplex_trie.skeleton(0) + }, + Simplex, + ) @property def simplices(self) -> SimplexView: """Set of all simplices.""" - return self._simplex_set + return SimplexView(self._simplex_trie) - def is_maximal(self, simplex: Iterable) -> bool: - """Check if simplex is maximal. + def is_maximal(self, simplex: Iterable[ElementType]) -> bool: + """Check if simplex is maximal, i.e., not a face of any other simplex in the complex. Parameters ---------- @@ -195,28 +205,26 @@ def is_maximal(self, simplex: Iterable) -> bool: >>> SC.is_maximal([1, 2]) False """ - if simplex not in self: - raise ValueError(f"Simplex {simplex} is not in the simplicial complex.") - return self[simplex]["is_maximal"] + return self._simplex_trie.is_maximal(tuple(sorted(simplex))) def get_maximal_simplices_of_simplex(self, simplex): """Get maximal simplices of simplex.""" return self[simplex]["membership"] - def skeleton(self, rank): - """Compute skeleton. + def skeleton(self, rank: int) -> list[Simplex[ElementType]]: + """Compute the rank-skeleton of the simplicial complex containing simplices of given rank. + + Parameters + ---------- + rank : int + The rank of the skeleton to compute. Returns ------- - Set of simplices of dimension n. + list of Simplex + The rank-skeleton of the simplicial complex. """ - if rank < len(self._simplex_set.faces_dict) and rank >= 0: - return sorted( - tuple(sorted(i)) for i in self._simplex_set.faces_dict[rank].keys() - ) - if rank < 0: - raise ValueError(f"input must be a postive integer, got {rank}") - raise ValueError(f"input {rank} exceeds max dim") + return sorted(map(lambda node: node.simplex, self._simplex_trie.skeleton(rank))) def __str__(self) -> str: """Return detailed string representation.""" @@ -234,76 +242,40 @@ def __len__(self) -> int: int Number of vertices in the complex. """ - return len(self.skeleton(0)) + return self.shape[0] - def __getitem__(self, simplex): + def __getitem__(self, simplex: Iterable[ElementType]) -> dict[Hashable, Any]: """Get simplex.""" - if simplex in self: - return self._simplex_set[simplex] - else: - raise KeyError("simplex is not in the simplicial complex") + if isinstance(simplex, Hashable) and not isinstance(simplex, Iterable): + simplex = (simplex,) + return self._simplex_trie[tuple(sorted(simplex))].attributes - def __iter__(self) -> Iterator: + def __iter__(self) -> Iterator[Simplex[ElementType]]: """Iterate over all faces of the simplicial complex. - Returns - ------- - dict_keyiterator - """ - return chain.from_iterable(self._simplex_set.faces_dict) - - def __contains__(self, item) -> bool: - """Return boolean indicating if item is in self.face_set. + The order of simplices is not guaranteed. - Parameters - ---------- - item : tuple, list + Yields + ------ + tuple of AtomType """ - return item in self._simplex_set - - def _update_faces_dict_length(self, simplex) -> None: - if len(simplex) > len(self._simplex_set.faces_dict): - diff = len(simplex) - len(self._simplex_set.faces_dict) - for _ in range(diff): - self._simplex_set.faces_dict.append(dict()) + return iter(map(lambda node: node.simplex, self._simplex_trie)) - def _update_faces_dict_entry(self, face, simplex, maximal_faces) -> None: - """Update faces dictionary entry. + def __contains__(self, item: Iterable[ElementType]) -> bool: + """Return whether the given simplex is in the complex. Parameters ---------- - face : an iterable, typically a list, tuple, set or a Simplex - simplex : an iterable, typically a list, tuple, set or a Simplex + item : Iterable of AtomType + The simplex to check. - Notes - ----- - the input `face` is a face of the input `simplex`. + Returns + ------- + bool """ - face = frozenset(face) - k = len(face) - - if face not in self._simplex_set.faces_dict[k - 1]: - if k == len(simplex): - self._simplex_set.faces_dict[k - 1][face] = { - "is_maximal": True, - "membership": set(), - } - else: - self._simplex_set.faces_dict[k - 1][face] = { - "is_maximal": False, - "membership": {simplex}, - } - else: - if k != len(simplex): - self._simplex_set.faces_dict[k - 1][face]["membership"].add(simplex) - if self._simplex_set.faces_dict[k - 1][face]["is_maximal"]: - maximal_faces.add(face) - self._simplex_set.faces_dict[k - 1][face]["is_maximal"] = False - else: - # make sure all children of previous maximal simplices do not have that membership anymore - self._simplex_set.faces_dict[k - 1][face][ - "membership" - ] -= maximal_faces + if not isinstance(item, Iterable): + return (item,) in self._simplex_trie + return tuple(sorted(item)) in self._simplex_trie @staticmethod def get_boundaries( @@ -365,49 +337,22 @@ def remove_maximal_simplex(self, simplex) -> None: >>> SC.add_simplex((1, 2, 3, 4, 5)) >>> SC.remove_maximal_simplex((1, 2, 3, 4, 5)) """ - if isinstance(simplex, Iterable): - if not isinstance(simplex, Simplex): - simplex_ = frozenset(simplex) - else: - simplex_ = simplex.elements - if simplex_ in self._simplex_set.faces_dict[len(simplex_) - 1]: - if self.is_maximal(simplex): - del self._simplex_set.faces_dict[len(simplex_) - 1][simplex_] - faces = Simplex(simplex_).faces - for s in faces: - if len(s) == len(simplex_): - continue - else: - s = s.elements - self._simplex_set.faces_dict[len(s) - 1][s]["membership"] -= { - simplex_ - } - if ( - len( - self._simplex_set.faces_dict[len(s) - 1][s][ - "membership" - ] - ) - == 0 - and len(s) == len(simplex) - 1 - ): - self._simplex_set.faces_dict[len(s) - 1][s][ - "is_maximal" - ] = True - - if ( - len(self._simplex_set.faces_dict[len(simplex_) - 1]) == 0 - and len(simplex_) - 1 == self._simplex_set.max_dim - ): - del self._simplex_set.faces_dict[len(simplex_) - 1] - self._simplex_set.max_dim = len(self._simplex_set.faces_dict) - 1 - - else: - raise ValueError( - "Only maximal simplices can be deleted, input simplex is not maximal" - ) + if isinstance(simplex, Hashable) and not isinstance(simplex, Iterable): + simplex_ = (simplex,) + elif isinstance(simplex, Iterable): + simplex_ = sorted(simplex) else: - raise KeyError("simplex is not a part of the simplicial complex") + raise TypeError("`simplex` must be Hashable or Iterable.") + + node = self._simplex_trie.find(simplex_) + if node is None: + raise ValueError( + f"Simplex {simplex_} is not a part of the simplicial complex." + ) + if len(node.children) != 0: + raise ValueError(f"Simplex {simplex_} is not maximal.") + + self._simplex_trie.remove_simplex(simplex_) def remove_nodes(self, node_set: Iterable[Hashable]) -> None: """Remove the given nodes from the simplicial complex. @@ -426,14 +371,8 @@ def remove_nodes(self, node_set: Iterable[Hashable]) -> None: >>> SC.simplices SimplexView([(2,), (3,), (4,), (2, 3), (3, 4)]) """ - removed_simplices = set() - for simplex in self: - if any(node in simplex for node in node_set): - removed_simplices.add(simplex) - - # Delete the simplices from largest to smallest. This way they are maximal when they are deleted. - for simplex in sorted(removed_simplices, key=len, reverse=True): - self.remove_maximal_simplex(simplex) + for node in node_set: + self._simplex_trie.remove_simplex((node,)) def add_node(self, node: Hashable, **attr) -> None: """Add node to simplicial complex. @@ -455,7 +394,7 @@ def add_node(self, node: Hashable, **attr) -> None: else: self.add_simplex([node], **attr) - def add_simplex(self, simplex, **attr) -> None: + def add_simplex(self, simplex: Iterable[ElementType] | ElementType, **attr) -> None: """Add simplex to simplicial complex. Parameters @@ -467,82 +406,89 @@ def add_simplex(self, simplex, **attr) -> None: simplex = [simplex] if isinstance(simplex, str): simplex = [simplex] - if isinstance(simplex, Iterable) or isinstance(simplex, Simplex): - if not isinstance(simplex, Simplex): - simplex_ = frozenset(simplex) - if len(simplex_) != len(simplex): - raise ValueError("a simplex cannot contain duplicate nodes") - else: - simplex_ = simplex.elements - self._update_faces_dict_length(simplex_) - - if ( - simplex_ in self._simplex_set.faces_dict[len(simplex_) - 1] - ): # simplex is already in the complex, just update the properties if needed - self._simplex_set.faces_dict[len(simplex_) - 1][simplex_].update(attr) - return - - if self._simplex_set.max_dim < len(simplex) - 1: - self._simplex_set.max_dim = len(simplex) - 1 - - numnodes = len(simplex_) - maximal_faces = set() - - for r in range(numnodes, 0, -1): - for face in combinations(simplex_, r): - self._update_faces_dict_entry(face, simplex_, maximal_faces) - self._simplex_set.faces_dict[len(simplex_) - 1][simplex_].update(attr) - if isinstance(simplex, Simplex): - self._simplex_set.faces_dict[len(simplex_) - 1][simplex_].update( - simplex._properties - ) - else: - self._simplex_set.faces_dict[len(simplex_) - 1][simplex_].update(attr) + + if not isinstance(simplex, Iterable) and not isinstance(simplex, Simplex): + raise TypeError( + f"Input simplex must be Iterable or Simplex, got {type(simplex)} instead." + ) + + if isinstance(simplex, Simplex): + simplex_ = simplex.elements + attr.update(simplex._properties) + else: + simplex_ = frozenset(simplex) + if len(simplex_) != len(simplex): + raise ValueError("a simplex cannot contain duplicate nodes") + + self._simplex_trie.insert(sorted(simplex_), **attr) def add_simplices_from(self, simplices) -> None: """Add simplices from iterable to simplicial complex.""" - for s in simplices: - self.add_simplex(s) + # for simplex in simplices: + # if isinstance(simplex, Hashable) and not isinstance(simplex, Iterable): + # simplex = [simplex] + # if isinstance(simplex, str): + # simplex = [simplex] + # + # if not isinstance(simplex, Iterable) and not isinstance(simplex, Simplex): + # raise TypeError( + # f"Input simplex must be Iterable or Simplex, got {type(simplex)} instead." + # ) + # + # if isinstance(simplex, Simplex): + # simplex_ = simplex.elements + # attr = simplex._properties + # else: + # simplex_ = frozenset(simplex) + # if len(simplex_) != len(simplex): + # raise ValueError("a simplex cannot contain duplicate nodes") + # attr = {} + # + # self._simplex_trie.insert_raw(sorted(simplex_), **attr) + # self._simplex_trie.restore_simplex_property() + for simplex in simplices: + self.add_simplex(simplex) - def get_cofaces(self, simplex, codimension): + def get_cofaces( + self, simplex: Iterable[ElementType], codimension: int + ) -> list[Simplex[ElementType]]: """Get cofaces of simplex. Parameters ---------- - simplex : list, tuple or simplex - DESCRIPTION. the n simplex represented by a list of its nodes + simplex : Iterable of AtomType + The simplex for which to compute the cofaces. codimension : int - DESCRIPTION. The codimension. If codimension = 0, all cofaces are returned + The codimension of the returned cofaces. If zero, all cofaces are returned Returns ------- list of tuples(simplex). """ - entire_tree = self.get_boundaries( - self.get_maximal_simplices_of_simplex(simplex) - ) + simplex = tuple(sorted(simplex)) + # TODO: This can be optimized inside the simplex trie. + cofaces = self._simplex_trie.cofaces(simplex) return [ - i - for i in entire_tree - if frozenset(simplex).issubset(i) and len(i) - len(simplex) >= codimension + coface.simplex + for coface in cofaces + if len(coface) - len(simplex) >= codimension ] - def get_star(self, simplex): + def get_star(self, simplex: Iterable[ElementType]) -> list[Simplex[ElementType]]: """Get star. + Notes + ----- + This function is equivalent to calling `get_cofaces(simplex, 0)`. + Parameters ---------- - simplex : list, tuple or simplex - DESCRIPTION. the n simplex represented by a list of its nodes + simplex : Iterable of AtomType + The simplex for which to compute the star. Returns ------- - TYPE - list of tuples(simplex), - - Note : return of this function is - same as get_cofaces(simplex,0) . - + list of tuples """ return self.get_cofaces(simplex, 0) @@ -609,7 +555,7 @@ def set_simplex_attributes(self, values, name: str | None = None) -> None: return def get_node_attributes(self, name: str): - """Get node attributes from combinatorial complex. + """Get node attributes from simplicial complex. Parameters ---------- @@ -632,7 +578,9 @@ def get_node_attributes(self, name: str): """ return {tuple(n): self[n][name] for n in self.skeleton(0) if name in self[n]} - def get_simplex_attributes(self, name: str, rank=None): + def get_simplex_attributes( + self, name: str, rank=None + ) -> dict[tuple[ElementType, ...], Any]: """Get node attributes from simplical complex. Parameters @@ -655,11 +603,13 @@ def get_simplex_attributes(self, name: str, rank=None): >>> d={(1, 2): "red", (2, 3): "blue", (3, 4): "black"} >>> SC.set_simplex_attributes(d, name="color") >>> SC.get_simplex_attributes("color") - {frozenset({1, 2}): 'red', frozenset({2, 3}): 'blue', frozenset({3, 4}): 'black'} + {(1, 2): 'red', (2, 3): 'blue', (3, 4): 'black'} """ if rank is None: - return {n: self[n][name] for n in self if name in self[n]} - return {n: self[n][name] for n in self.skeleton(rank) if name in self[n]} + return {n.elements: self[n][name] for n in self if name in self[n]} + return { + n.elements: self[n][name] for n in self.skeleton(rank) if name in self[n] + } @staticmethod def get_edges_from_matrix(matrix): @@ -705,26 +655,28 @@ def incidence_matrix( >>> B2 = SC.incidence_matrix(2) """ if rank < 0: - raise ValueError(f"input dimension d must be positive integer, got {rank}") + raise ValueError(f"Input dimension must be positive, got {rank}.") if rank > self.dim: raise ValueError( - f"input dimenion cannat be larger than the dimension of the complex, got {rank}" + f"Input dimension cannot be larger than the dimension of the complex, got {rank}." ) + # TODO: Get rid of this special case... if rank == 0: - boundary = dok_matrix( - (1, len(self._simplex_set.faces_dict[rank].items())), dtype=np.float32 - ) - boundary[0, 0 : len(self._simplex_set.faces_dict[rank].items())] = 1 + boundary = dok_matrix((1, self._simplex_trie.shape[rank]), dtype=np.float32) + boundary[0, 0 : self._simplex_trie.shape[rank]] = 1 return boundary.tocsr() idx_simplices, idx_faces, values = [], [], [] - simplex_dict_d = {simplex: i for i, simplex in enumerate(self.skeleton(rank))} + simplex_dict_d = { + tuple(sorted(simplex)): i for i, simplex in enumerate(self.skeleton(rank)) + } + print(simplex_dict_d) simplex_dict_d_minus_1 = { - simplex: i for i, simplex in enumerate(self.skeleton(rank - 1)) + tuple(sorted(simplex)): i + for i, simplex in enumerate(self.skeleton(rank - 1)) } for simplex, idx_simplex in simplex_dict_d.items(): - # for simplex, idx_simplex in self._simplex_set.faces_dict[d].items(): for i, left_out in enumerate(np.sort(list(simplex))): idx_simplices.append(idx_simplex) values.append((-1) ** i) @@ -739,24 +691,13 @@ def incidence_matrix( len(simplex_dict_d), ), ) + + if not signed: + boundary = abs(boundary) if index: - if signed: - return ( - simplex_dict_d_minus_1, - simplex_dict_d, - boundary, - ) - else: - return ( - simplex_dict_d_minus_1, - simplex_dict_d, - abs(boundary), - ) + return simplex_dict_d_minus_1, simplex_dict_d, boundary else: - if signed: - return boundary - else: - return abs(boundary) + return boundary def coincidence_matrix( self, rank, signed: bool = True, weight=None, index: bool = False @@ -1084,7 +1025,9 @@ def restrict_to_simplices(self, cell_set, name: str = "") -> "SimplicialComplex" return SimplicialComplex(simplices=rns, name=name) - def restrict_to_nodes(self, node_set, name: str = ""): + def restrict_to_nodes( + self, node_set, name: str = "" + ) -> "SimplicialComplex[ElementType]": """Construct a new simplicial complex by restricting the simplices. The simplices are restricted to the nodes referenced by node_set. @@ -1098,14 +1041,12 @@ def restrict_to_nodes(self, node_set, name: str = ""): Returns ------- - new Simplicial Complex : SimplicialComplex + SimplicialComplex + New simplicial complex instance restricted to simplices containing only the given nodes. Examples -------- - >>> c1 = Simplex((1, 2, 3)) - >>> c2 = Simplex((1, 2, 4)) - >>> c3 = Simplex((1, 2, 5)) - >>> SC = SimplicialComplex([c1, c2, c3]) + >>> SC = SimplicialComplex([(1, 2, 3), (1, 2, 4), (1, 2, 5)]) >>> new_complex = SC.restrict_to_nodes([1, 2, 3, 4]) >>> new_complex.simplices SimplexView([(1,), (2,), (3,), (4,), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (1, 2, 3), (1, 2, 4)]) @@ -1122,24 +1063,19 @@ def restrict_to_nodes(self, node_set, name: str = ""): return SimplicialComplex(all_sim, name=name) - def get_all_maximal_simplices(self): + def get_all_maximal_simplices(self) -> Generator[Simplex[ElementType], None, None]: """Get all maximal simplices. Examples -------- - >>> c0 = Simplex((1, 2)) - >>> c1 = Simplex((1, 2, 3)) - >>> c2 = Simplex((1, 2, 4)) - >>> c3 = Simplex((2, 5)) - >>> SC = SimplicialComplex([c1, c2, c3]) + >>> SC = SimplicialComplex([(1, 2), (1, 2, 3), (1, 2, 4), (2, 5)]) >>> SC.get_all_maximal_simplices() [(2, 5), (1, 2, 3), (1, 2, 4)] """ - maximals = [] - for s in self: - if self.is_maximal(s): - maximals.append(tuple(s)) - return maximals + for node in self._simplex_trie: + simplex = node.simplex + if self.is_maximal(simplex): + yield simplex @classmethod def from_spharpy(cls, mesh) -> "SimplicialComplex": @@ -1262,7 +1198,6 @@ def load_mesh( def is_triangular_mesh(self) -> bool: """Check if the simplicial complex is a triangular mesh.""" if self.dim <= 2: - lst = self.get_all_maximal_simplices() for i in lst: if len(i) == 2: # gas edges that are not part of a face @@ -1279,17 +1214,17 @@ def to_trimesh(self, vertex_position_name: str = "position"): raise TopoNetXError( "input simplicial complex has dimension higher than 2 and hence it cannot be converted to a trimesh object" ) - else: - vertices = list( - dict( - sorted(self.get_node_attributes(vertex_position_name).items()) - ).values() - ) + vertices = list( + dict( + sorted(self.get_node_attributes(vertex_position_name).items()) + ).values() + ) + faces = list( + map(lambda simplex: simplex.elements, self.get_all_maximal_simplices()) + ) - return trimesh.Trimesh( - faces=self.get_all_maximal_simplices(), vertices=vertices, process=False - ) + return trimesh.Trimesh(faces=faces, vertices=vertices, process=False) def to_spharapy(self, vertex_position_name: str = "position"): """Convert to sharapy. @@ -1312,15 +1247,16 @@ def to_spharapy(self, vertex_position_name: str = "position"): "input simplicial complex has dimension higher than 2 and hence it cannot be converted to a trimesh object" ) - else: - - vertices = list( - dict( - sorted(self.get_node_attributes(vertex_position_name).items()) - ).values() - ) + vertices = list( + dict( + sorted(self.get_node_attributes(vertex_position_name).items()) + ).values() + ) + triangles = list( + map(lambda simplex: simplex.elements, self.get_all_maximal_simplices()) + ) - return tm.TriMesh(self.get_all_maximal_simplices(), vertices) + return tm.TriMesh(triangles, vertices) def laplace_beltrami_operator(self, mode: str = "inv_euclidean"): """Compute a laplacian matrix for a triangular mesh. @@ -1365,9 +1301,14 @@ def from_nx_graph(cls, G) -> "SimplicialComplex": """ return cls(G, name=G.name) - def is_connected(self): + def is_connected(self) -> bool: """Check if the simplicial complex is connected. + Returns + ------- + bool + True if the simplicial complex is connected, False otherwise. + Notes ----- A simplicial complex is connected iff its 1-skeleton G is connected. @@ -1421,7 +1362,9 @@ def to_cell_complex(self): """ from toponetx.classes.cell_complex import CellComplex - return CellComplex(self.get_all_maximal_simplices()) + return CellComplex( + map(lambda simplex: simplex.elements, self.get_all_maximal_simplices()) + ) def to_hypergraph(self): """Convert a simplicial complex to a hyperG. diff --git a/toponetx/utils/iterable.py b/toponetx/utils/iterable.py new file mode 100644 index 00000000..b617275d --- /dev/null +++ b/toponetx/utils/iterable.py @@ -0,0 +1,38 @@ +"""Module with iterable-related utility functions.""" +from collections.abc import Sequence +from typing import TypeVar + +__all__ = ["is_ordered_subset"] + +T = TypeVar("T") + + +def is_ordered_subset(one: Sequence[T], other: Sequence[T]) -> bool: + """Return True if the first iterable is a subset of the second iterable. + + This method is specifically optimized for ordered iterables to use return as early as possible for non-subsets. + + Examples + -------- + >>> is_ordered_subset((2,), (1, 2)) + True + >>> is_ordered_subset((1, 2), (1, 2, 3)) + True + >>> is_ordered_subset((1, 2, 3), (1, 2, 3)) + True + >>> is_ordered_subset((1, 2, 3), (1, 2)) + False + >>> is_ordered_subset((1, 2, 3), (1, 2, 4)) + False + """ + index = 0 + for item in one: + while ( + index < len(other) + and type(item) is type(other[index]) + and other[index] < item + ): + index += 1 + if index >= len(other) or other[index] != item: + return False + return True