From bdde841ddc06fafafe56d6e0a97ae7127d9424a5 Mon Sep 17 00:00:00 2001 From: Florian Frantzen Date: Thu, 20 Jul 2023 17:45:58 +0200 Subject: [PATCH] Implementation of a Simplex Trie for Better Performance --- test/classes/test_combinatorial_complex.py | 1 + test/classes/test_reportviews.py | 21 +- test/classes/test_simplex.py | 13 + test/classes/test_simplex_trie.py | 201 ++++++++ test/classes/test_simplicial_complex.py | 196 ++++---- test/transform/test_delaunay.py | 37 +- toponetx/classes/colored_hypergraph.py | 4 +- toponetx/classes/combinatorial_complex.py | 4 +- toponetx/classes/reportviews.py | 117 +++-- toponetx/classes/simplex.py | 62 ++- toponetx/classes/simplex_trie.py | 543 +++++++++++++++++++++ toponetx/classes/simplicial_complex.py | 419 +++++++--------- toponetx/utils/iterable.py | 51 ++ 13 files changed, 1223 insertions(+), 446 deletions(-) create mode 100644 test/classes/test_simplex_trie.py create mode 100644 toponetx/classes/simplex_trie.py create mode 100644 toponetx/utils/iterable.py diff --git a/test/classes/test_combinatorial_complex.py b/test/classes/test_combinatorial_complex.py index be9079cc..3a8efd8c 100644 --- a/test/classes/test_combinatorial_complex.py +++ b/test/classes/test_combinatorial_complex.py @@ -477,6 +477,7 @@ def test_remove_nodes(self): frozenset({6}): {"weight": 1}, } } + example.remove_nodes(HyperEdge([3])) assert example._complex_set.hyperedge_dict == { 0: {frozenset({4}): {"weight": 1}, frozenset({6}): {"weight": 1}} diff --git a/test/classes/test_reportviews.py b/test/classes/test_reportviews.py index 62bf6e2d..c453fd4d 100644 --- a/test/classes/test_reportviews.py +++ b/test/classes/test_reportviews.py @@ -576,12 +576,10 @@ class TestReportViews_SimplexView: def test_getitem(self): """Test __getitem__ method of the SimplexView.""" - assert self.simplex_view.__getitem__((1, 2)) == { - "is_maximal": True, - "membership": set(), - } + assert self.simplex_view[(1, 2)] == {} + assert self.simplex_view[1] == {} with pytest.raises(KeyError): - self.simplex_view.__getitem__([5]) + _ = self.simplex_view[(5,)] def test_str(self): """Test __str__ method of the SimplexView.""" @@ -590,6 +588,10 @@ def test_str(self): == "SimplexView([(1,), (2,), (3,), (4,), (1, 2), (2, 3), (3, 4)])" ) + def test_shape(self) -> None: + """Test the shape method of the SimplexView.""" + assert self.simplex_view.shape == (4, 3) + class TestReportViews_NodeView: """Test the NodeView class of the ReportViews module.""" @@ -610,15 +612,12 @@ def test_repr(self): def test_getitem(self): """Test __getitem__ method of the NodeView.""" - assert self.node_view.__getitem__(Simplex([1])) == { - "is_maximal": False, - "membership": {frozenset({1, 2})}, - } + assert self.node_view[Simplex([1])] == {} with pytest.raises(KeyError): - self.node_view.__getitem__([1, 2]) + _ = self.node_view[(1, 2)] # test for nodes of ColoredHyperGraph. - assert self.node_view_1.__getitem__([1]) == {"weight": 1} + assert self.node_view_1[(1,)] == {"weight": 1} class TestReportViews_PathView: diff --git a/test/classes/test_simplex.py b/test/classes/test_simplex.py index e10fa46d..71030ab6 100644 --- a/test/classes/test_simplex.py +++ b/test/classes/test_simplex.py @@ -45,6 +45,19 @@ def test_contains(self): assert 3 not in s assert (1, 2) in s + def test_le(self) -> None: + """Test the __le__ method of the simplex.""" + s1 = Simplex([1, 2]) + s2 = Simplex([1, 2]) + s3 = Simplex([1, 2, 3]) + + assert s1 <= s2 + assert s1 <= s3 + assert not s3 <= s1 + + with pytest.raises(TypeError): + s1 <= 1 + def test_boundary(self): """Test the boundary property of the simplex.""" s = Simplex((1, 2, 3)) diff --git a/test/classes/test_simplex_trie.py b/test/classes/test_simplex_trie.py new file mode 100644 index 00000000..fa90674a --- /dev/null +++ b/test/classes/test_simplex_trie.py @@ -0,0 +1,201 @@ +"""Tests for the `simplex_trie` module.""" + +import pytest + +from toponetx.classes.simplex import Simplex +from toponetx.classes.simplex_trie import SimplexNode, SimplexTrie + + +class TestSimplexNode: + """Tests for the `SimplexNode` class.""" + + def test_invalid_root(self) -> None: + """Test that invalid root nodes are rejected.""" + with pytest.raises(ValueError): + _ = SimplexNode(label="invalid", parent=None) + + def test_node_len(self) -> None: + """Test the length of the node.""" + root_node: SimplexNode[int] = SimplexNode(label=None) + assert len(root_node) == 0 + + child_node: SimplexNode[int] = SimplexNode(label=1, parent=root_node) + assert len(child_node) == 1 + + grandchild_node: SimplexNode[int] = SimplexNode(label=2, parent=child_node) + assert len(grandchild_node) == 2 + + def test_node_repr(self) -> None: + """Test the string representation of the node.""" + root_node = SimplexNode(label=None) + assert repr(root_node) == "SimplexNode(None, None)" + + child_node = SimplexNode(label=1, parent=root_node) + assert repr(child_node) == "SimplexNode(1, SimplexNode(None, None))" + + def test_simplex_property(self) -> None: + """Test the `simplex` property of a node.""" + root_node = SimplexNode(label=None) + assert root_node.simplex is None + + child_node = SimplexNode(label=1, parent=root_node) + assert child_node.simplex == Simplex((1,)) + + grandchild_node = SimplexNode(label=2, parent=child_node) + assert grandchild_node.simplex == Simplex((1, 2)) + + +class TestSimplexTrie: + """Tests for the `SimplexTree` class.""" + + def test_insert(self) -> None: + """Test that the internal data structures of the simplex trie are correct after insertion.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + + assert trie.shape == [3, 3, 1] + + assert set(trie.root.children.keys()) == {1, 2, 3} + assert set(trie.root.children[1].children.keys()) == {2, 3} + assert set(trie.root.children[1].children[2].children.keys()) == {3} + assert set(trie.root.children[2].children.keys()) == {3} + + # the label list should contain the nodes of each depth according to their label + label_to_simplex = { + 1: {1: [(1,)], 2: [(2,)], 3: [(3,)]}, + 2: {2: [(1, 2)], 3: [(1, 3), (2, 3)]}, + 3: {3: [(1, 2, 3)]}, + } + + assert len(trie.label_lists) == len(label_to_simplex) + for depth, label_list in trie.label_lists.items(): + assert depth in label_to_simplex + assert len(label_list) == len(label_to_simplex[depth]) + for label, nodes in label_list.items(): + assert len(nodes) == len(label_to_simplex[depth][label]) + for node, expected in zip( + nodes, label_to_simplex[depth][label], strict=True + ): + assert node.simplex.elements == expected + + def test_len(self) -> None: + """Test the `__len__` method of a simplex trie.""" + trie = SimplexTrie() + assert len(trie) == 0 + + trie.insert((1, 2, 3)) + assert len(trie) == 7 + + def test_getitem(self) -> None: + """Test the `__getitem__` method of a simplex trie.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + + assert trie[(1,)].simplex == Simplex((1,)) + assert trie[(1, 2)].simplex == Simplex((1, 2)) + assert trie[(1, 2, 3)].simplex == Simplex((1, 2, 3)) + + with pytest.raises(KeyError): + _ = trie[(0,)] + + def test_iter(self) -> None: + """Test the iteration of the trie.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + trie.insert((2, 3, 4)) + trie.insert((0, 1)) + + # We guarantee a specific ordering of the simplices when iterating. Hence, we explicitly compare lists here. + assert [node.simplex.elements for node in trie] == [ + (0,), + (1,), + (2,), + (3,), + (4,), + (0, 1), + (1, 2), + (1, 3), + (2, 3), + (2, 4), + (3, 4), + (1, 2, 3), + (2, 3, 4), + ] + + def test_cofaces(self) -> None: + """Test the cofaces method.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + trie.insert((1, 2, 4)) + + # no ordering is guaranteed for the cofaces method + assert {node.simplex.elements for node in trie.cofaces((1,))} == { + (1,), + (1, 2), + (1, 3), + (1, 4), + (1, 2, 3), + (1, 2, 4), + } + assert {node.simplex.elements for node in trie.cofaces((2,))} == { + (2,), + (1, 2), + (2, 3), + (2, 4), + (1, 2, 3), + (1, 2, 4), + } + + def test_is_maximal(self) -> None: + """Test the `is_maximal` method.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + trie.insert((1, 2, 4)) + + assert trie.is_maximal((1, 2, 3)) + assert trie.is_maximal((1, 2, 4)) + assert not trie.is_maximal((1, 2)) + assert not trie.is_maximal((1, 3)) + assert not trie.is_maximal((1, 4)) + assert not trie.is_maximal((2, 3)) + + with pytest.raises(ValueError): + trie.is_maximal((5,)) + + def test_skeleton(self) -> None: + """Test the skeleton method.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + trie.insert((1, 2, 4)) + + # no ordering is guaranteed for the skeleton method + assert {node.simplex.elements for node in trie.skeleton(0)} == { + (1,), + (2,), + (3,), + (4,), + } + assert {node.simplex.elements for node in trie.skeleton(1)} == { + (1, 2), + (1, 3), + (1, 4), + (2, 3), + (2, 4), + } + assert {node.simplex.elements for node in trie.skeleton(2)} == { + (1, 2, 3), + (1, 2, 4), + } + + with pytest.raises(ValueError): + _ = next(trie.skeleton(-1)) + with pytest.raises(ValueError): + _ = next(trie.skeleton(3)) + + def test_remove_simplex(self) -> None: + """Test the removal of simplices from the trie.""" + trie = SimplexTrie() + trie.insert((1, 2, 3)) + + trie.remove_simplex((1, 2)) + assert len(trie) == 5 diff --git a/test/classes/test_simplicial_complex.py b/test/classes/test_simplicial_complex.py index f41367aa..e4b68b4e 100644 --- a/test/classes/test_simplicial_complex.py +++ b/test/classes/test_simplicial_complex.py @@ -39,11 +39,29 @@ def test_shape_property(self): sc = SimplicialComplex() assert sc.shape == () + # make sure that empty dimensions are not part of the shape after removal + sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) + sc.remove_nodes([2, 3]) + assert sc.shape == (3, 1) + def test_dim_property(self): """Test dim property.""" sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) assert sc.dim == 2 + sc.remove_nodes([2, 3]) + assert sc.dim == 1 + + def test_maxdim(self) -> None: + """Test deprecated maxdim property for deprecation warning.""" + SC = SimplicialComplex() + with pytest.deprecated_call(): + assert SC.maxdim == -1 + + SC.add_simplex([1, 2, 3]) + with pytest.deprecated_call(): + assert SC.maxdim == 2 + def test_nodes_property(self): """Test nodes property.""" sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) @@ -55,22 +73,35 @@ def test_nodes_property(self): def test_simplices_property(self): """Test simplices property.""" sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) - simplices = sc.simplices - assert len(simplices) == 13 - assert [1, 2, 3] in simplices - assert [2, 3, 4] in simplices - assert [0, 1] in simplices - # ... add more assertions based on the expected simplices + assert len(sc.simplices) == 13 + assert [1, 2, 3] in sc.simplices + assert [2, 3, 4] in sc.simplices + assert [0, 1] in sc.simplices def test_is_maximal(self): """Test is_maximal method.""" sc = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) assert sc.is_maximal([1, 2, 3]) + assert not sc.is_maximal([1, 2]) + assert not sc.is_maximal([3]) with pytest.raises(ValueError): sc.is_maximal([1, 2, 3, 4]) - def test_contructor_using_graph(self): + def test_get_maximal_simplices_of_simplex(self) -> None: + """Test the get_maximal_simplices_of_simplex method.""" + SC = SimplicialComplex([(1, 2, 3), (2, 3, 4), (0, 1), (5,)]) + + assert SC.get_maximal_simplices_of_simplex((1, 2, 3)) == {Simplex((1, 2, 3))} + assert SC.get_maximal_simplices_of_simplex((1, 2)) == {Simplex((1, 2, 3))} + assert SC.get_maximal_simplices_of_simplex((2, 3)) == { + Simplex((1, 2, 3)), + Simplex((2, 3, 4)), + } + assert SC.get_maximal_simplices_of_simplex((0,)) == {Simplex((0, 1))} + assert SC.get_maximal_simplices_of_simplex((5,)) == {Simplex((5,))} + + def test_constructor_using_graph(self): """Test input a networkx graph in the constructor.""" G = nx.Graph() G.add_edge(0, 1) @@ -85,20 +116,16 @@ def test_contructor_using_graph(self): def test_skeleton_raise_errors(self): """Test skeleton raises.""" + G = nx.Graph() + G.add_edge(0, 1) + G.add_edge(2, 5) + G.add_edge(5, 4, weight=5) + SC = SimplicialComplex(G) + with pytest.raises(ValueError): - G = nx.Graph() - G.add_edge(0, 1) - G.add_edge(2, 5) - G.add_edge(5, 4, weight=5) - SC = SimplicialComplex(G) SC.skeleton(-2) with pytest.raises(ValueError): - G = nx.Graph() - G.add_edge(0, 1) - G.add_edge(2, 5) - G.add_edge(5, 4, weight=5) - SC = SimplicialComplex(G) SC.skeleton(2) def test_str(self): @@ -126,16 +153,16 @@ def test_iter(self): SC = SimplicialComplex([[1, 2, 3], [2, 4], [5]]) simplices = set(SC) assert simplices == { - frozenset({1}), - frozenset({2}), - frozenset({3}), - frozenset({4}), - frozenset({5}), - frozenset({1, 2}), - frozenset({1, 3}), - frozenset({2, 3}), - frozenset({2, 4}), - frozenset({1, 2, 3}), + Simplex((1,)), + Simplex((2,)), + Simplex((3,)), + Simplex((4,)), + Simplex((5,)), + Simplex((1, 2)), + Simplex((1, 3)), + Simplex((2, 3)), + Simplex((2, 4)), + Simplex((1, 2, 3)), } def test_getittem__(self): @@ -147,7 +174,7 @@ def test_getittem__(self): assert SC[(0, 1)]["weight"] == 5 assert SC[(1, 2, 3)]["heat"] == 5 with pytest.raises(KeyError): - SC[(1, 2, 3, 4, 5)]["heat"] + _ = SC[(1, 2, 3, 4, 5)]["heat"] SC[(0, 1)]["new"] = 10 assert SC[(0, 1)]["new"] == 10 @@ -160,38 +187,21 @@ def test_setting_simplex_attributes(self): G.add_edge(5, 4, weight=5) SC = SimplicialComplex(G, name="graph complex") SC.add_simplex((1, 2, 3), heat=5) - # with pytest.raises(ValueError): SC[(1, 2, 3)]["heat"] = 6 - assert SC[(1, 2, 3)]["heat"] == 6 SC[(2, 5)]["heat"] = 1 - assert SC[(2, 5)]["heat"] == 1 s = Simplex((1, 2, 3, 4), heat=1) SC.add_simplex(s) assert SC[(1, 2, 3, 4)]["heat"] == 1 - s = Simplex(("A"), heat=1) + s = Simplex(("A",), heat=1) SC.add_simplex(s) assert SC["A"]["heat"] == 1 - def test_maxdim(self): - """Test deprecated maxdim property for deprecation warning.""" - with pytest.deprecated_call(): - # Cause a warning by accessing the deprecated maxdim property - SC = SimplicialComplex() - max_dim = SC.maxdim - assert max_dim == -1 - - with pytest.deprecated_call(): - # Cause a warning by accessing the deprecated maxdim property - SC = SimplicialComplex([[1, 2, 3]]) - max_dim = SC.maxdim - assert max_dim == 2 - def test_add_simplices_from(self): """Test add simplices from.""" with pytest.raises(TypeError): @@ -218,13 +228,11 @@ def test_add_node(self): assert SC.dim == -1 SC.add_node(9) assert SC.dim == 0 - assert SC[9]["is_maximal"] is True SC = SimplicialComplex() assert SC.dim == -1 SC.add_node(9) assert SC.dim == 0 - assert SC[9]["is_maximal"] is True def test_add_simplex(self): """Test add_simplex method.""" @@ -301,22 +309,16 @@ def test_contains(self): def test_remove_maximal_simplex(self): """Test remove_maximal_simplex method.""" - # create a SimplicialComplex object with a few simplices - SC = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) - - # remove a maximal simplex using the remove_maximal_simplex() method + SC = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1], [10]]) SC.remove_maximal_simplex([1, 2, 3]) + SC.remove_maximal_simplex(10) - # check that the simplex was removed correctly (tuple) assert (1, 2, 3) not in SC.simplices + assert (10,) not in SC.simplices - # create a SimplicialComplex object with a few simplices SC = SimplicialComplex([[1, 2, 3], [2, 3, 4]]) - - # remove a maximal simplex from the complex SC.remove_maximal_simplex([2, 3, 4]) - # check that the simplex was removed correctly (list) assert [2, 3, 4] not in SC.simplices # check after the add_simplex method @@ -336,7 +338,7 @@ def test_remove_maximal_simplex(self): assert (1, 2, 3, 4, 5) not in SC.simplices # check error when simplex not in complex - with pytest.raises(KeyError): + with pytest.raises(ValueError): SC = SimplicialComplex() SC.add_simplex((1, 2, 3, 4), weight=1) SC.remove_maximal_simplex([5, 6, 7]) @@ -362,7 +364,6 @@ def test_remove_nodes(self) -> None: assert SC.is_maximal([0, 1]) assert SC.is_maximal([1, 3]) assert SC.is_maximal([3, 4]) - assert SC.is_maximal([4]) assert not SC.is_maximal([1]) def test_skeleton_and_cliques(self): @@ -385,15 +386,10 @@ def test_skeleton_and_cliques(self): def test_incidence_matrix_1(self): """Test incidence_matrix shape and values.""" - # create a SimplicialComplex object with a few simplices SC = SimplicialComplex([[1, 2, 3], [2, 3, 4], [0, 1]]) - # compute the incidence matrix using the boundary_matrix() method B2 = SC.incidence_matrix(rank=2) - assert B2.shape == (6, 2) - - # assert that the incidence matrix is correct np.testing.assert_array_equal( B2.toarray(), np.array([[0, 1, -1, 1, 0, 0], [0, 0, 0, 1, -1, 1]]).T, @@ -402,8 +398,6 @@ def test_incidence_matrix_1(self): # repeat the same test, but with signed=False B2 = SC.incidence_matrix(rank=2, signed=False) assert B2.shape == (6, 2) - - # assert that the incidence matrix is correct np.testing.assert_array_equal( B2.toarray(), np.array([[0, 1, 1, 1, 0, 0], [0, 0, 0, 1, 1, 1]]).T, @@ -562,9 +556,11 @@ def test_get_cofaces(self): SC.add_simplex([1, 2, 3, 4]) SC.add_simplex([1, 2, 4]) SC.add_simplex([3, 4, 8]) + cofaces = SC.get_cofaces([1, 2, 4], codimension=1) - assert frozenset({1, 2, 3, 4}) in cofaces - assert frozenset({3, 4, 8}) not in cofaces + cofaces = [simplex.elements for simplex in cofaces] + assert (1, 2, 3, 4) in cofaces + assert (3, 4, 8) not in cofaces # ... add more assertions based on the expected cofaces def test_get_star(self): @@ -573,9 +569,12 @@ def test_get_star(self): SC.add_simplex([1, 2, 3, 4]) SC.add_simplex([1, 2, 4]) SC.add_simplex([3, 4, 8]) + star = SC.get_star([1, 2, 4]) - assert frozenset({1, 2, 4}) in star - assert frozenset({1, 2, 3, 4}) in star + star = [simplex.elements for simplex in star] + + assert (1, 2, 4) in star + assert (1, 2, 3, 4) in star # ... add more assertions based on the expected star def test_set_simplex_attributes(self): @@ -679,6 +678,8 @@ def test_coincidence_matrix_2(self): def test_is_triangular_mesh(self): """Test is_triangular_mesh.""" SC = stanford_bunny("simplicial") + print("dim", SC.dim) + print(list(SC.get_all_maximal_simplices())) assert SC.is_triangular_mesh() # test for non triangular mesh @@ -788,13 +789,10 @@ def test_to_hypergraph(self): def test_to_cell_complex(self): """Test to convert SimplicialComplex to Cell Complex.""" - c1 = Simplex((1, 2, 3)) - c2 = Simplex((1, 2, 4)) - c3 = Simplex((2, 5)) - SC = SimplicialComplex([c1, c2, c3]) + SC = SimplicialComplex([(1, 2, 3), (1, 2, 4), (2, 5)]) CC = SC.to_cell_complex() - assert set(CC.edges) == {(2, 5), (2, 3), (2, 1), (2, 4), (3, 1), (1, 4)} - assert set(CC.nodes) == {2, 5, 3, 1, 4} + assert set(CC.nodes) == {1, 2, 3, 4, 5} + assert set(CC.edges) == {(2, 1), (3, 1), (2, 3), (2, 4), (1, 4), (2, 5)} def test_to_combinatorial_complex(self): """Convert a SimplicialComplex to a CombinatorialComplex and compare the number of cells and nodes.""" @@ -884,11 +882,13 @@ def test_restrict_to_nodes(self): def test_get_all_maximal_simplices(self): """Retrieve all maximal simplices from a SimplicialComplex and compare the number of simplices.""" - c1 = Simplex((1, 2, 3)) - c2 = Simplex((1, 2, 4)) - c3 = Simplex((1, 2, 5)) - SC = SimplicialComplex([c1, c2, c3]) - result = SC.get_all_maximal_simplices() + SC = SimplicialComplex([(1, 2)]) + assert {simplex.elements for simplex in SC.get_all_maximal_simplices()} == { + (1, 2) + } + + SC = SimplicialComplex([(1, 2, 3), (1, 2, 4), (1, 2, 5)]) + result = list(SC.get_all_maximal_simplices()) assert len(result) == 3 def test_coincidence_matrix(self): @@ -911,31 +911,21 @@ def test_coincidence_matrix(self): def test_down_laplacian_matrix(self): """Test the down_laplacian_matrix method of SimplicialComplex.""" - # Test case 1: Rank is within valid range SC = SimplicialComplex() SC.add_simplex([1, 2, 3]) SC.add_simplex([4, 5, 6]) - rank = 1 - signed = True - weight = None - index = False - result = SC.down_laplacian_matrix(rank, signed, weight, index) - - # Assert the result is of type scipy.sparse.csr.csr_matrix + # Test case 1: Rank is within valid range + result = SC.down_laplacian_matrix(rank=1, signed=True, weight=None, index=False) assert result.shape == (6, 6) # Test case 2: Rank is below the valid range - rank = 0 - with pytest.raises(ValueError): - SC.down_laplacian_matrix(rank, signed, weight, index) + SC.down_laplacian_matrix(rank=0, signed=True, weight=None, index=False) # Test case 3: Rank is above the valid range - rank = 5 - with pytest.raises(ValueError): - SC.down_laplacian_matrix(rank, signed, weight, index) + SC.down_laplacian_matrix(rank=5, signed=True, weight=None, index=False) def test_adjacency_matrix2(self): """Test the adjacency_matrix method of SimplicialComplex.""" @@ -944,27 +934,17 @@ def test_adjacency_matrix2(self): SC.add_simplex([4, 5, 6]) # Test case 1: Rank is within valid range - rank = 1 - signed = False - weight = None - index = False - - result = SC.adjacency_matrix(rank, signed, weight, index) + result = SC.adjacency_matrix(rank=1, signed=True, weight=None, index=False) - # Assert the result is of type scipy.sparse.csr.csr_matrix assert result.shape == (6, 6) # Test case 2: Rank is below the valid range - rank = -1 - with pytest.raises(ValueError): - SC.adjacency_matrix(rank, signed, weight, index) + SC.adjacency_matrix(rank=-1, signed=False, weight=None, index=False) # Test case 3: Rank is above the valid range - rank = 5 - with pytest.raises(ValueError): - SC.adjacency_matrix(rank, signed, weight, index) + SC.adjacency_matrix(rank=5, signed=False, weight=None, index=False) def test_dirac_operator_matrix(self): """Test dirac operator.""" diff --git a/test/transform/test_delaunay.py b/test/transform/test_delaunay.py index 93527872..6fa7dc7a 100644 --- a/test/transform/test_delaunay.py +++ b/test/transform/test_delaunay.py @@ -2,6 +2,7 @@ import numpy as np +from toponetx.classes.simplex import Simplex from toponetx.transform import delaunay_triangulation @@ -11,13 +12,13 @@ def test_delaunay_triangulation_simple(): SC = delaunay_triangulation(points) assert set(SC.simplices) == { - frozenset([0]), - frozenset([1]), - frozenset([2]), - frozenset([0, 1]), - frozenset([0, 2]), - frozenset([1, 2]), - frozenset([0, 1, 2]), + Simplex((0,)), + Simplex((1,)), + Simplex((2,)), + Simplex((0, 1)), + Simplex((0, 2)), + Simplex((1, 2)), + Simplex((0, 1, 2)), } @@ -27,15 +28,15 @@ def test_delaunay_triangulation(): SC = delaunay_triangulation(points) assert set(SC.simplices) == { - frozenset([0]), - frozenset([1]), - frozenset([2]), - frozenset([3]), - frozenset([0, 1]), - frozenset([0, 2]), - frozenset([1, 2]), - frozenset([1, 3]), - frozenset([2, 3]), - frozenset([0, 1, 2]), - frozenset([1, 2, 3]), + Simplex((0,)), + Simplex((1,)), + Simplex((2,)), + Simplex((3,)), + Simplex((0, 1)), + Simplex((0, 2)), + Simplex((1, 2)), + Simplex((1, 3)), + Simplex((2, 3)), + Simplex((0, 1, 2)), + Simplex((1, 2, 3)), } diff --git a/toponetx/classes/colored_hypergraph.py b/toponetx/classes/colored_hypergraph.py index b23f7ac1..e72dc9e4 100644 --- a/toponetx/classes/colored_hypergraph.py +++ b/toponetx/classes/colored_hypergraph.py @@ -139,7 +139,9 @@ def nodes(self): NodeView of all nodes. """ return NodeView( - self._complex_set.hyperedge_dict, cell_type=HyperEdge, colored_nodes=True + self._complex_set.hyperedge_dict.get(0, {}), + cell_type=HyperEdge, + colored_nodes=True, ) @property diff --git a/toponetx/classes/combinatorial_complex.py b/toponetx/classes/combinatorial_complex.py index d05c39fb..62089282 100644 --- a/toponetx/classes/combinatorial_complex.py +++ b/toponetx/classes/combinatorial_complex.py @@ -235,7 +235,9 @@ def nodes(self): Returns all the nodes of the combinatorial complex. """ return NodeView( - self._complex_set.hyperedge_dict, cell_type=HyperEdge, colored_nodes=False + self._complex_set.hyperedge_dict.get(0, {}), + cell_type=HyperEdge, + colored_nodes=False, ) @property diff --git a/toponetx/classes/reportviews.py b/toponetx/classes/reportviews.py index f63c1b70..739e3dd2 100644 --- a/toponetx/classes/reportviews.py +++ b/toponetx/classes/reportviews.py @@ -14,6 +14,7 @@ from toponetx.classes.hyperedge import HyperEdge from toponetx.classes.path import Path from toponetx.classes.simplex import Simplex +from toponetx.classes.simplex_trie import SimplexTrie __all__ = [ "AtomView", @@ -816,19 +817,16 @@ class SimplexView(AtomView[Simplex]): These classes are used in conjunction with the SimplicialComplex class for view/read only purposes for simplices in simplicial complexes. - Attributes + Parameters ---------- - max_dim : int - Maximum dimension of the simplices in the SimplexView instance. - faces_dict : list of dict - A list containing dictionaries of faces for each dimension. + simplex_trie : SimplexTrie + A SimplexTrie instance containing the simplices in the simplex view. """ - def __init__(self) -> None: - self.max_dim = -1 - self.faces_dict = [] + def __init__(self, simplex_trie: SimplexTrie) -> None: + self._simplex_trie = simplex_trie - def __getitem__(self, simplex: Any) -> dict: + def __getitem__(self, simplex: Any) -> dict[Hashable, Any]: """Get the dictionary of attributes associated with the given simplex. Parameters @@ -846,19 +844,15 @@ def __getitem__(self, simplex: Any) -> dict: KeyError If the simplex is not in the simplex view. """ - if isinstance(simplex, Simplex): - simplex = simplex.elements if isinstance(simplex, Hashable) and not isinstance(simplex, Iterable): - simplex = frozenset({simplex}) - - simplex = frozenset(simplex) - if ( - len(self.faces_dict) >= len(simplex) - and simplex in self.faces_dict[len(simplex) - 1] - ): - return self.faces_dict[len(simplex) - 1][simplex] + simplex = (simplex,) + else: + simplex = tuple(sorted(simplex)) - raise KeyError(f"input {simplex} is not in the simplex dictionary") + node = self._simplex_trie.find(simplex) + if node is None: + raise KeyError(f"Simplex {simplex} is not in the simplex view.") + return node.attributes @property def shape(self) -> tuple[int, ...]: @@ -869,7 +863,7 @@ def shape(self) -> tuple[int, ...]: tuple of ints A tuple of integers representing the number of simplices in each dimension. """ - return tuple(len(self.faces_dict[i]) for i in range(len(self.faces_dict))) + return tuple(self._simplex_trie.shape) def __len__(self) -> int: """Return the number of simplices in the SimplexView instance. @@ -879,7 +873,7 @@ def __len__(self) -> int: int Returns the number of simplices in the SimplexView instance. """ - return sum(self.shape) + return len(self._simplex_trie) def __iter__(self) -> Iterator: """Return an iterator over all simplices in the simplex view. @@ -889,7 +883,7 @@ def __iter__(self) -> Iterator: Iterator Returns an iterator over all simplices in the simplex view. """ - return chain.from_iterable(self.faces_dict) + return iter(node.simplex for node in self._simplex_trie) def __contains__(self, atom: Any) -> bool: """Check if a simplex is in the simplex view. @@ -927,14 +921,11 @@ def __contains__(self, atom: Any) -> bool: >>> {1, 2, 3} in view False """ - if isinstance(atom, Iterable): - atom = frozenset(atom) - if not 0 < len(atom) <= self.max_dim + 1: - return False - return atom in self.faces_dict[len(atom) - 1] - if isinstance(atom, Hashable): - return frozenset({atom}) in self.faces_dict[0] - return False + if isinstance(atom, Hashable) and not isinstance(atom, Iterable): + atom = (atom,) + else: + atom = tuple(sorted(atom)) + return atom in self._simplex_trie def __repr__(self) -> str: """Return string representation that can be used to recreate it. @@ -944,11 +935,7 @@ def __repr__(self) -> str: str Returns the __repr__ representation of the object. """ - all_simplices: list[tuple[int, ...]] = [] - for i in range(len(self.faces_dict)): - all_simplices += [tuple(j) for j in self.faces_dict[i]] - - return f"SimplexView({all_simplices})" + return f"SimplexView({[tuple(simplex.elements) for simplex in self._simplex_trie]})" def __str__(self) -> str: """Return detailed string representation of the simplex view. @@ -958,11 +945,7 @@ def __str__(self) -> str: str Returns the __str__ representation of the object. """ - all_simplices: list[tuple[int, ...]] = [] - for i in range(len(self.faces_dict)): - all_simplices += [tuple(j) for j in self.faces_dict[i]] - - return f"SimplexView({all_simplices})" + return f"SimplexView({[tuple(simplex.elements) for simplex in self._simplex_trie]})" class NodeView: @@ -970,7 +953,7 @@ class NodeView: Parameters ---------- - objectdict : dict + nodes : dict[Hashable, Any] A dictionary of nodes with their attributes. cell_type : type The type of the cell. @@ -978,11 +961,10 @@ class NodeView: Whether or not the nodes are colored. """ - def __init__(self, objectdict, cell_type, colored_nodes: bool = False) -> None: - if len(objectdict) != 0: - self.nodes = objectdict[0] - else: - self.nodes = {} + def __init__( + self, nodes: dict[Hashable, Any], cell_type, colored_nodes: bool = False + ) -> None: + self.nodes = nodes if cell_type is None: raise ValueError("cell_type cannot be None") @@ -998,9 +980,7 @@ def __repr__(self) -> str: str Returns the __repr__ representation of the object. """ - all_nodes = [tuple(j) for j in self.nodes] - - return f"NodeView({all_nodes})" + return f"NodeView({list(map(tuple, self.nodes.keys()))})" def __iter__(self) -> Iterator: """Return an iterator over all nodes in the node view. @@ -1073,9 +1053,13 @@ def __contains__(self, e) -> bool: return False -class PathView(SimplexView): +class PathView(AtomView[Path]): """Path view class.""" + def __init__(self) -> None: + self.max_dim = -1 + self.faces_dict = [] + def __getitem__(self, path: Any) -> dict: """Get the dictionary of attributes associated with the given path. @@ -1107,6 +1091,16 @@ def __getitem__(self, path: Any) -> dict: raise KeyError(f"input {path} is not in the path dictionary") + def __iter__(self) -> Iterator: + """Return an iterator over all paths in the path view. + + Returns + ------- + Iterator + Iterator over all paths in the paths view. + """ + return chain.from_iterable(self.faces_dict) + def __contains__(self, atom: Any) -> bool: """Check if a path is in the path view. @@ -1134,6 +1128,16 @@ def __contains__(self, atom: Any) -> bool: return (atom,) in self.faces_dict[0] return False + def __len__(self) -> int: + """Return the number of simplices in the SimplexView instance. + + Returns + ------- + int + Returns the number of simplices in the SimplexView instance. + """ + return sum(self.shape) + def __repr__(self) -> str: """Return string representation that can be used to recreate it. @@ -1161,3 +1165,14 @@ def __str__(self) -> str: all_paths += [tuple(j) for j in self.faces_dict[i]] return f"PathView({all_paths})" + + @property + def shape(self) -> tuple[int, ...]: + """Return the number of paths in each dimension. + + Returns + ------- + tuple of ints + A tuple of integers representing the number of paths in each dimension. + """ + return tuple(len(self.faces_dict[i]) for i in range(len(self.faces_dict))) diff --git a/toponetx/classes/simplex.py b/toponetx/classes/simplex.py index d28056d5..949ec81c 100644 --- a/toponetx/classes/simplex.py +++ b/toponetx/classes/simplex.py @@ -2,17 +2,22 @@ import warnings from collections.abc import Collection, Hashable, Iterable +from functools import total_ordering from itertools import combinations -from typing import Any +from typing import Generic, TypeVar from typing_extensions import Self, deprecated from toponetx.classes.complex import Atom +from toponetx.utils.iterable import is_ordered_subset __all__ = ["Simplex"] +ElementType = TypeVar("ElementType", bound=Hashable) -class Simplex(Atom[frozenset[Hashable]]): + +@total_ordering +class Simplex(Atom, Generic[ElementType]): """A class representing a simplex in a simplicial complex. This class represents a simplex in a simplicial complex, which is a set of nodes with a specific dimension. The @@ -20,7 +25,7 @@ class Simplex(Atom[frozenset[Hashable]]): Parameters ---------- - elements : Collection + elements : Collection[Hashable] The nodes in the simplex. construct_tree : bool, default=True If True, construct the entire simplicial tree for the simplex. @@ -40,9 +45,11 @@ class Simplex(Atom[frozenset[Hashable]]): >>> simplex3 = tnx.Simplex((1, 2, 4, 5), weight=1) """ + elements: frozenset[Hashable] + def __init__( self, - elements: Collection, + elements: Collection[ElementType], construct_tree: bool = False, **kwargs, ) -> None: @@ -55,15 +62,12 @@ def __init__( stacklevel=2, ) - for i in elements: - if not isinstance(i, Hashable): - raise ValueError(f"All nodes of a simplex must be hashable, got {i}") - - super().__init__(frozenset(sorted(elements)), **kwargs) - if len(elements) != len(self.elements): + if len(elements) != len(set(elements)): raise ValueError("A simplex cannot contain duplicate nodes.") - def __contains__(self, item: Any) -> bool: + super().__init__(tuple(sorted(elements)), **kwargs) + + def __contains__(self, item: ElementType | Iterable[ElementType]) -> bool: """Return True if the given element is a subset of the nodes. Parameters @@ -89,9 +93,27 @@ def __contains__(self, item: Any) -> bool: False """ if isinstance(item, Iterable): - return frozenset(item) <= self.elements + item = tuple(sorted(item)) + return is_ordered_subset(item, self.elements) return super().__contains__(item) + def __le__(self, other) -> bool: + """Return True if this simplex comes before the other simplex in the lexicographic order. + + Parameters + ---------- + other : Any + The other simplex to compare with. + + Returns + ------- + bool + True if this simplex comes before the other simplex in the lexicographic order. + """ + if not isinstance(other, Simplex): + return NotImplemented + return self.elements <= other.elements + @staticmethod def validate_attributes(attributes: dict) -> None: """Validate the attributes of the simplex. @@ -113,8 +135,10 @@ def validate_attributes(attributes: dict) -> None: @staticmethod @deprecated("`Simplex.construct_simplex_tree` is deprecated.") - def construct_simplex_tree(elements: Collection) -> frozenset["Simplex"]: - """Return set of Simplex objects representing the faces. + def construct_simplex_tree( + elements: Collection[ElementType], + ) -> frozenset["Simplex[ElementType]"]: + """Return the set of Simplex objects representing the faces. Parameters ---------- @@ -129,16 +153,14 @@ def construct_simplex_tree(elements: Collection) -> frozenset["Simplex"]: faceset = set() for r in range(len(elements), 0, -1): for face in combinations(elements, r): - faceset.add( - Simplex(elements=sorted(face), construct_tree=False) - ) # any face is always ordered + faceset.add(Simplex(elements=face, construct_tree=False)) return frozenset(faceset) @property @deprecated( "`Simplex.boundary` is deprecated, use `SimplicialComplex.get_boundaries()` on the simplicial complex that contains this simplex instead." ) - def boundary(self) -> frozenset["Simplex"]: + def boundary(self) -> frozenset["Simplex[ElementType]"]: """Return the set of the set of all n-1 faces in of the input n-simplex. Returns @@ -156,7 +178,7 @@ def boundary(self) -> frozenset["Simplex"]: for elements in combinations(self.elements, len(self) - 1) ) - def sign(self, face) -> int: + def sign(self, face: "Simplex[ElementType]") -> int: """Calculate the sign of the simplex with respect to a given face. Parameters @@ -170,7 +192,7 @@ def sign(self, face) -> int: @deprecated( "`Simplex.faces` is deprecated, use `SimplicialComplex.get_boundaries()` on the simplicial complex that contains this simplex instead." ) - def faces(self): + def faces(self) -> frozenset["Simplex[ElementType]"]: """Get the set of faces of the simplex. If `construct_tree` is True, return the precomputed set of faces `_faces`. diff --git a/toponetx/classes/simplex_trie.py b/toponetx/classes/simplex_trie.py new file mode 100644 index 00000000..47040e7c --- /dev/null +++ b/toponetx/classes/simplex_trie.py @@ -0,0 +1,543 @@ +""" +Implementation of a simplex trie datastructure for simplicial complexes as presented in [1]_. + +This module is intended for internal use by the `SimplicialComplex` class only. Any direct interactions with this +module or its classes may break at any time. In particular, this also means that the `SimplicialComplex` class must not +leak any object references to the trie or its nodes. + +Some implementation details: +- Inside this module, simplices are represented as ordered sequences with unique elements. It is expected that all + inputs from outside are already pre-processed and ordered accordingly. This is not checked and the behavior is + undefined if this is not the case. + +References +---------- +.. [1] Jean-Daniel Boissonnat and Clément Maria. The Simplex Tree: An Efficient Data Structure for General Simplicial + Complexes. Algorithmica, pages 1-22, 2014 +""" + +from collections.abc import Generator, Hashable, Iterable, Iterator, Sequence +from typing import Any, Generic, TypeVar + +from toponetx.classes.simplex import Simplex +from toponetx.utils.iterable import is_ordered_subset + +__all__ = ["SimplexNode", "SimplexTrie"] + +ElementType = TypeVar("ElementType", bound=Hashable) + + +class SimplexNode(Generic[ElementType]): + """Node in a simplex tree. + + Parameters + ---------- + label : ElementType or None + The label of the node. May only be `None` for the root node. + parent : SimplexNode, optional + The parent node of this node. If `None`, this node is the root node. + """ + + label: ElementType | None + elements: tuple[ElementType, ...] + attributes: dict[Hashable, Any] + + depth: int + parent: "SimplexNode | None" + children: dict[ElementType, "SimplexNode[ElementType]"] + + def __init__( + self, + label: ElementType | None, + parent: "SimplexNode[ElementType] | None" = None, + ) -> None: + """Node in a simplex tree. + + Parameters + ---------- + label : ElementType or None + The label of the node. May only be `None` for the root node. + parent : SimplexNode, optional + The parent node of this node. If `None`, this node is the root node. + """ + self.label = label + self.attributes = {} + + self.children = {} + + self.parent = parent + if parent is not None: + parent.children[label] = self + self.elements = (*parent.elements, label) + self.depth = parent.depth + 1 + else: + if label is not None: + raise ValueError("Root node must have label `None`.") + self.elements = () + self.depth = 0 + + def __len__(self) -> int: + """Return the number of elements in this trie node. + + Returns + ------- + int + Number of elements in this trie node. + """ + return len(self.elements) + + def __repr__(self) -> str: + """Return a string representation of this trie node. + + Returns + ------- + str + A string representation of this trie node. + """ + return f"SimplexNode({self.label}, {self.parent!r})" + + @property + def simplex(self) -> Simplex[ElementType] | None: + """Return a `Simplex` object representing this node. + + Returns + ------- + Simplex or None + A `Simplex` object representing this node, or `None` if this node is the root node. + """ + if self.label is None: + return None + simplex = Simplex(self.elements, construct_tree=False) + simplex._attributes = self.attributes + return simplex + + def iter_all(self) -> Generator["SimplexNode[ElementType]", None, None]: + """Iterate over all nodes in the subtree rooted at this node. + + Ordering is according to breadth-first search, i.e., simplices are yielded in increasing order of dimension and + increasing order of labels within each dimension. + + Yields + ------ + SimplexNode + """ + queue = [self] + while queue: + node = queue.pop(0) + if node.label is not None: + # skip root node + yield node + queue += [node.children[label] for label in sorted(node.children.keys())] + + +class SimplexTrie(Generic[ElementType]): + """ + Simplex tree data structure as presented in [1]_. + + This class is intended for internal use by the `SimplicialComplex` class only. Any + direct interactions with this class may break at any time. + + References + ---------- + .. [1] Jean-Daniel Boissonnat and Clément Maria. The Simplex Tree: An Efficient + Data Structure for General Simplicial Complexes. Algorithmica, pages 1-22, 2014 + """ + + root: SimplexNode[ElementType] + label_lists: dict[int, dict[ElementType, list[SimplexNode[ElementType]]]] + shape: list[int] + + def __init__(self) -> None: + """Simplex tree data structure as presented in [1]_. + + This class is intended for internal use by the `SimplicialComplex` class only. + Any direct interactions with this class may break at any time. + + References + ---------- + .. [1] Jean-Daniel Boissonnat and Clément Maria. The Simplex Tree: An Efficient + Data Structure for General Simplicial Complexes. Algorithmica, pages 1-22, 2014 + """ + self.root = SimplexNode(None) + self.label_lists = {} + self.shape = [] + + def __len__(self) -> int: + """Return the number of simplices in the trie. + + Returns + ------- + int + Number of simplices in the trie. + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> len(trie) + 7 + """ + return sum(self.shape) + + def __contains__(self, item: Iterable[ElementType]) -> bool: + """Check if a simplex is contained in this trie. + + Parameters + ---------- + item : Iterable of ElementType + The simplex to check for. Must be ordered and contain unique elements. + + Returns + ------- + bool + Whether the given simplex is contained in this trie. + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> (1, 2, 3) in trie + True + >>> (1, 2, 4) in trie + False + """ + return self.find(item) is not None + + def __getitem__(self, item: Iterable[ElementType]) -> SimplexNode[ElementType]: + """Return the simplex node for a given simplex. + + Parameters + ---------- + item : Iterable of ElementType + The simplex to return the node for. Must be ordered and contain only unique + elements. + + Returns + ------- + SimplexNode + The trie node that represents the given simplex. + + Raises + ------ + KeyError + If the given simplex does not exist in this trie. + """ + node = self.find(item) + if node is None: + raise KeyError(f"Simplex {item} not found in trie.") + return node + + def __iter__(self) -> Iterator[SimplexNode[ElementType]]: + """Iterate over all simplices in the trie. + + Simplices are ordered by increasing dimension and increasing order of labels within each dimension. + + Yields + ------ + tuple of ElementType + """ + yield from self.root.iter_all() + + def insert(self, item: Sequence[ElementType], **kwargs) -> None: + """Insert a simplex into the trie. + + Any lower-dimensional simplices that do not exist in the trie are also inserted + to fulfill the simplex property. If the simplex already exists, its properties + are updated. + + Parameters + ---------- + item : Sequence of ElementType + The simplex to insert. Must be ordered and contain only unique elements. + **kwargs : keyword arguments, optional + Properties associated with the given simplex. + """ + self._insert_helper(self.root, item) + self.find(item).attributes.update(kwargs) + + # def insert_raw(self, simplex: Sequence[ElementType], **kwargs) -> None: + # """Insert a simplex into the trie without guaranteeing the simplex property. + + # Sub-simplices are not guaranteed to be inserted, which may break the simplex + # property. This allows for faster insertion of a sequence of simplices without + # considering their order. In case not all sub-simplices are inserted manually, + # the simplex property must be restored by calling `restore_simplex_property`. + + # Parameters + # ---------- + # simplex : Sequence of ElementType + # The simplex to insert. Must be ordered and contain only unique elements. + # **kwargs : keyword arguments, optional + # Properties associated with the simplex. + + # See Also + # -------- + # restore_simplex_property + # Function to restore the simplex property after using `insert_raw`. + # """ + # current_node = self.root + # for label in simplex: + # if label in current_node.children: + # current_node = current_node.children[label] + # else: + # current_node = self._insert_child(current_node, label) + # current_node.attributes.update(kwargs) + + def _insert_helper( + self, subtree: SimplexNode[ElementType], items: Sequence[ElementType] + ) -> None: + """Insert a simplex node under the given subtree. + + Parameters + ---------- + subtree : SimplexNode + The subtree to insert the simplex node under. + items : Sequence + The (partial) simplex to insert under the subtree. + """ + for i, label in enumerate(items): + if label not in subtree.children: + self._insert_child(subtree, label) + self._insert_helper(subtree.children[label], items[i + 1 :]) + + def _insert_child( + self, parent: SimplexNode[ElementType], label: ElementType + ) -> SimplexNode[ElementType]: + """Insert a child node with a given label. + + Parameters + ---------- + parent : SimplexNode + The parent node. + label : ElementType + The label of the child node. + + Returns + ------- + SimplexNode + The new child node. + """ + node = SimplexNode(label, parent=parent) + + # update label lists + if node.depth not in self.label_lists: + self.label_lists[node.depth] = {} + if label in self.label_lists[node.depth]: + self.label_lists[node.depth][label].append(node) + else: + self.label_lists[node.depth][label] = [node] + + # update shape property + if node.depth > len(self.shape): + self.shape += [0] + self.shape[node.depth - 1] += 1 + + return node + + # def restore_simplex_property(self) -> None: + # """Restore the simplex property after using `insert_raw`.""" + # all_simplices = set() + # for node in self.root.iter_all(): + # if len(node.children) == 0: + # for r in range(1, len(node.elements)): + # all_simplices.update(combinations(node.elements, r)) + + # for simplex in all_simplices: + # self.insert_raw(simplex) + + def find(self, search: Iterable[ElementType]) -> SimplexNode[ElementType] | None: + """Find the node in the trie that matches the search. + + Parameters + ---------- + search : Iterable of ElementType + The simplex to search for. Must be ordered and contain only unique elements. + + Returns + ------- + SimplexNode or None + The node that matches the search, or `None` if no such node exists. + """ + node = self.root + for item in search: + if item not in node.children: + return None + node = node.children[item] + return node + + def cofaces( + self, simplex: Sequence[ElementType] + ) -> Generator[SimplexNode[ElementType], None, None]: + """Return the cofaces of the given simplex. + + No ordering is guaranteed by this method. + + Parameters + ---------- + simplex : Sequence of ElementType + The simplex to find the cofaces of. Must be ordered and contain only unique elements. + + Yields + ------ + SimplexNode + The cofaces of the given simplex, including the simplex itself. + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> trie.insert((1, 2, 4)) + >>> sorted(map(lambda node: node.elements, trie.cofaces((1,)))) + [(1,), (1, 2), (1, 2, 3), (1, 2, 4), (1, 3), (1, 4)] + """ + # Find simplices of the form [*, l_1, *, l_2, ..., *, l_n], i.e. simplices that contain all elements of the + # given simplex plus some additional elements, but sharing the same largest element. This can be done by the + # label lists. + simplex_nodes = self._coface_roots(simplex) + + # Found all simplices of the form [*, l_1, *, l_2, ..., *, l_n] in the simplex trie. All nodes in the subtrees + # rooted at these nodes are cofaces of the given simplex. + for simplex_node in simplex_nodes: + yield from simplex_node.iter_all() + + def _coface_roots( + self, simplex: Sequence[ElementType] + ) -> Generator[SimplexNode[ElementType], None, None]: + """Return the roots of the coface subtrees. + + A coface subtree is a subtree of the trie whose simplices are all cofaces of a + given simplex. + + Parameters + ---------- + simplex : Sequence of ElementType + The simplex to find the cofaces of. Must be ordered and contain only unique + elements. + + Yields + ------ + SimplexNode + The trie nodes that are roots of the coface subtrees. + """ + for depth in range(len(simplex), len(self.shape) + 1): + if simplex[-1] not in self.label_lists[depth]: + continue + for simplex_node in self.label_lists[depth][simplex[-1]]: + if is_ordered_subset(simplex, simplex_node.elements): + yield simplex_node + + def is_maximal(self, simplex: tuple[ElementType, ...]) -> bool: + """Return True if the given simplex is maximal. + + A simplex is maximal if it has no cofaces. + + Parameters + ---------- + simplex : tuple of ElementType or SimplexNode + The simplex to check. Must be ordered and contain only unique elements. + + Returns + ------- + bool + True if the given simplex is maximal, False otherwise. + + Raises + ------ + ValueError + If the simplex does not exist in the trie. + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> trie.is_maximal((1, 2, 3)) + True + >>> trie.is_maximal((1, 2)) + False + """ + if simplex not in self: + raise ValueError(f"Simplex {simplex} does not exist.") + + gen = self._coface_roots(simplex) + try: + first_next = next(gen) + # sometimes the first root is the simplex itself + if first_next.elements == simplex: + if len(first_next.children) > 0: + return False + next(gen) + return False + except StopIteration: + return True + + def skeleton(self, rank: int) -> Generator[SimplexNode, None, None]: + """Return the simplices of the given rank. + + No particular ordering is guaranteed and is dependent on insertion order. + + Parameters + ---------- + rank : int + The rank of the simplices to return. + + Yields + ------ + SimplexNode + The simplices of the given rank. + + Raises + ------ + ValueError + If the given rank is negative or exceeds the maximum rank of the trie. + + Examples + -------- + >>> trie = SimplexTrie() + >>> trie.insert((1, 2, 3)) + >>> sorted(map(lambda node: node.elements, trie.skeleton(0))) + [(1,), (2,), (3,)] + >>> sorted(map(lambda node: node.elements, trie.skeleton(1))) + [(1, 2), (1, 3), (2, 3)] + >>> sorted(map(lambda node: node.elements, trie.skeleton(2))) + [(1, 2, 3)] + """ + if rank < 0: + raise ValueError(f"`rank` must be a positive integer, got {rank}.") + if rank >= len(self.shape): + raise ValueError(f"`rank` {rank} exceeds maximum rank {len(self.shape)}.") + + for nodes in self.label_lists[rank + 1].values(): + yield from nodes + + def remove_simplex(self, simplex: Sequence[ElementType]) -> None: + """Remove the given simplex and all its cofaces from the trie. + + This method ensures that the simplicial property is maintained by removing any simplex that is no longer valid + after removing the given simplex. + + Parameters + ---------- + simplex : Sequence of ElementType + The simplex to remove. Must be ordered and contain only unique elements. + + Raises + ------ + ValueError + If the simplex does not exist in the trie. + """ + # Locate all roots of subtrees containing cofaces of this simplex. They are invalid after removal of the given + # simplex and thus need to be removed as well. + coface_roots = self._coface_roots(simplex) + for subtree_root in coface_roots: + for node in subtree_root.iter_all(): + self.shape[node.depth - 1] -= 1 + self.label_lists[node.depth][node.label].remove(node) + + # Detach the subtree from the trie. This effectively destroys the subtree as no reference exists anymore, + # and garbage collection will take care of the rest. + subtree_root.parent.children.pop(subtree_root.label) + + # After removal of some simplices, the maximal dimension may have decreased. Make sure that there are no empty + # shape entries at the end of the shape list. + while len(self.shape) > 0 and self.shape[-1] == 0: + self.shape.pop() diff --git a/toponetx/classes/simplicial_complex.py b/toponetx/classes/simplicial_complex.py index 66a3a4eb..378dce50 100644 --- a/toponetx/classes/simplicial_complex.py +++ b/toponetx/classes/simplicial_complex.py @@ -3,9 +3,9 @@ The class also supports attaching arbitrary attributes and data to cells. """ -from collections.abc import Collection, Hashable, Iterable, Iterator +from collections.abc import Collection, Generator, Hashable, Iterable, Iterator from itertools import combinations -from typing import Any +from typing import Any, Generic, TypeVar import networkx as nx import numpy as np @@ -15,6 +15,7 @@ from toponetx.classes.complex import Complex from toponetx.classes.reportviews import NodeView, SimplexView from toponetx.classes.simplex import Simplex +from toponetx.classes.simplex_trie import SimplexTrie try: from gudhi import SimplexTree @@ -28,8 +29,12 @@ __all__ = ["SimplicialComplex"] +ElementType = TypeVar( + "ElementType", bound=Hashable +) # TODO: Also bound to SupportsLessThanT but that is not accessible? -class SimplicialComplex(Complex): + +class SimplicialComplex(Complex, Generic[ElementType]): """Class representing a simplicial complex. Class for construction boundary operators, Hodge Laplacians, @@ -107,16 +112,16 @@ class SimplicialComplex(Complex): 4 """ + _simplex_trie: SimplexTrie[ElementType] + def __init__(self, simplices=None, **kwargs) -> None: super().__init__(**kwargs) - self._simplex_set = SimplexView() + self._simplex_trie = SimplexTrie() if isinstance(simplices, nx.Graph): - _simplices: dict[tuple, Any] = {} - for simplex, data in simplices.nodes( - data=True - ): # `simplices` is a networkx graph + _simplices: dict[tuple[Hashable, ...], Any] = {} + for simplex, data in simplices.nodes(data=True): _simplices[(simplex,)] = data for u, v, data in simplices.edges(data=True): _simplices[(u, v)] = data @@ -134,38 +139,44 @@ def __init__(self, simplices=None, **kwargs) -> None: @property def shape(self) -> tuple[int, ...]: - """Shape of simplicial complex. - - (number of simplices[i], for i in range(0,dim(Sc)) ) + """Shape of simplicial complex, i.e., the number of simplices for each dimension. Returns ------- tuple of ints This gives the number of cells in each rank. + + Examples + -------- + >>> SC = SimplicialComplex([[0, 1], [1, 2, 3], [2, 3, 4]]) + >>> SC.shape + (5, 5, 2) """ - return self._simplex_set.shape + return tuple(self._simplex_trie.shape) @property def dim(self) -> int: - """ - Dimension of the simplicial complex. - - This is the highest dimension of any simplex in the complex. + """Maximal dimension of the simplices in this complex. Returns ------- int The dimension of the simplicial complex. + + Examples + -------- + >>> SC = SimplicialComplex([[0, 1], [1, 2, 3], [2, 3, 4]]) + >>> SC.dim + 2 """ - return self._simplex_set.max_dim + return len(self._simplex_trie.shape) - 1 @property @deprecated( - "`SimplicialComplex.maxdim` is deprecated and will be removed in the future, use `SimplicialComplex.max_dim` instead." + "`SimplicialComplex.maxdim` is deprecated and will be removed in the future, use `SimplicialComplex.dim` instead." ) def maxdim(self) -> int: - """ - Maximum dimension of the simplicial complex. + """Maximum dimension of the simplicial complex. This is the highest dimension of any simplex in the complex. @@ -174,10 +185,10 @@ def maxdim(self) -> int: int The maximum dimension of the simplicial complex. """ - return self._simplex_set.max_dim + return self.dim @property - def nodes(self): + def nodes(self) -> NodeView: """Return the list of nodes in the simplicial complex. Returns @@ -185,7 +196,13 @@ def nodes(self): NodeView A NodeView object representing the nodes of the simplicial complex. """ - return NodeView(self._simplex_set.faces_dict, cell_type=Simplex) + return NodeView( + { + frozenset(node.elements): node.attributes + for node in self._simplex_trie.skeleton(0) + }, + Simplex, + ) @property def simplices(self) -> SimplexView: @@ -196,10 +213,10 @@ def simplices(self) -> SimplexView: SimplexView A SimplexView object representing the set of all simplices in the simplicial complex. """ - return self._simplex_set + return SimplexView(self._simplex_trie) - def is_maximal(self, simplex: Iterable) -> bool: - """Check if simplex is maximal. + def is_maximal(self, simplex: Iterable[ElementType]) -> bool: + """Check if simplex is maximal, i.e., not a face of any other simplex in the complex. Parameters ---------- @@ -225,13 +242,11 @@ def is_maximal(self, simplex: Iterable) -> bool: >>> SC.is_maximal([1, 2]) False """ - if simplex not in self.simplices: - raise ValueError(f"Simplex {simplex} is not in the simplicial complex.") - return self[simplex]["is_maximal"] + return self._simplex_trie.is_maximal(tuple(sorted(simplex))) def get_maximal_simplices_of_simplex( self, simplex: Iterable[Hashable] - ) -> set[frozenset]: + ) -> set[Simplex[ElementType]]: """Get maximal simplices of simplex. Parameters @@ -241,13 +256,17 @@ def get_maximal_simplices_of_simplex( Returns ------- - set of frozensets + set of Simplex Set of maximal simplices of the given simplex. """ - return self[simplex]["membership"] + return { + node.simplex + for node in self._simplex_trie.cofaces(tuple(sorted(simplex))) + if self._simplex_trie.is_maximal(node.elements) + } - def skeleton(self, rank: int) -> list[tuple[Hashable, ...]]: - """Compute skeleton. + def skeleton(self, rank: int) -> list[Simplex[ElementType]]: + """Compute the rank-skeleton of the simplicial complex containing simplices of given rank. Parameters ---------- @@ -256,14 +275,10 @@ def skeleton(self, rank: int) -> list[tuple[Hashable, ...]]: Returns ------- - list of tuples + list of Simplex Simplices of rank `rank` in the simplicial complex. """ - if len(self._simplex_set.faces_dict) > rank >= 0: - return sorted(tuple(sorted(i)) for i in self._simplex_set.faces_dict[rank]) - if rank < 0: - raise ValueError(f"input must be a postive integer, got {rank}") - raise ValueError(f"input {rank} exceeds max dim") + return sorted(node.simplex for node in self._simplex_trie.skeleton(rank)) def __str__(self) -> str: """Return a detailed string representation of the simplicial complex. @@ -294,14 +309,14 @@ def __len__(self) -> int: int Number of vertices in the complex. """ - return len(self.skeleton(0)) + return self.shape[0] - def __getitem__(self, atom: Any) -> dict[Hashable, Any]: + def __getitem__(self, simplex: Iterable[ElementType]) -> dict[Hashable, Any]: """Get the data associated with the given simplex. Parameters ---------- - atom : Any + simplex : Any The simplex to retrieve. Returns @@ -314,15 +329,22 @@ def __getitem__(self, atom: Any) -> dict[Hashable, Any]: KeyError If the simplex is not present in the simplicial complex. """ - return self._simplex_set[atom] + if isinstance(simplex, Hashable) and not isinstance(simplex, Iterable): + simplex = (simplex,) + return self._simplex_trie[tuple(sorted(simplex))].attributes - def __iter__(self) -> Iterator[frozenset[Hashable]]: + def __iter__(self) -> Iterator[Simplex[ElementType]]: """Iterate over all simplices (faces) of the simplicial complex. Returns ------- Iterator[frozenset[Hashable]] An iterator over all simplices in the simplicial complex. + + Raises + ------ + KeyError + If the simplex is not present in the simplicial complex. """ return iter(self.simplices) @@ -339,62 +361,9 @@ def __contains__(self, atom: Any) -> bool: bool Returns `True` if this simplicial complex contains the atom, else `False`. """ - return atom in self._simplex_set - - def _update_faces_dict_length(self, simplex) -> None: - """Update the faces dictionary length based on the input simplex. - - Parameters - ---------- - simplex : tuple[int, ...] - The simplex to update the faces dictionary length. - """ - if len(simplex) > len(self._simplex_set.faces_dict): - diff = len(simplex) - len(self._simplex_set.faces_dict) - for _ in range(diff): - self._simplex_set.faces_dict.append({}) - - def _update_faces_dict_entry(self, face, simplex, maximal_faces) -> None: - """Update faces dictionary entry. - - Parameters - ---------- - face : iterable - Typically a list, tuple, set, or a Simplex representing a face. - simplex : iterable - Typically a list, tuple, set, or a Simplex representing the input simplex. - maximal_faces : iterable - The maximal faces are the the faces that cannot be extended by adding another node. - - Notes - ----- - The input `face` is a face of the input `simplex`. - """ - face = frozenset(face) - k = len(face) - - if face not in self._simplex_set.faces_dict[k - 1]: - if k == len(simplex): - self._simplex_set.faces_dict[k - 1][face] = { - "is_maximal": True, - "membership": set(), - } - else: - self._simplex_set.faces_dict[k - 1][face] = { - "is_maximal": False, - "membership": {simplex}, - } - else: - if k != len(simplex): - self._simplex_set.faces_dict[k - 1][face]["membership"].add(simplex) - if self._simplex_set.faces_dict[k - 1][face]["is_maximal"]: - maximal_faces.add(face) - self._simplex_set.faces_dict[k - 1][face]["is_maximal"] = False - else: - # make sure all children of previous maximal simplices do not have that membership anymore - self._simplex_set.faces_dict[k - 1][face]["membership"] -= ( - maximal_faces - ) + if not isinstance(atom, Iterable): + return (atom,) in self._simplex_trie + return tuple(sorted(atom)) in self._simplex_trie @staticmethod def get_boundaries( @@ -457,42 +426,22 @@ def remove_maximal_simplex(self, simplex: Collection) -> None: >>> SC.add_simplex((1, 2, 3, 4, 5)) >>> SC.remove_maximal_simplex((1, 2, 3, 4, 5)) """ - if isinstance(simplex, Iterable): - if not isinstance(simplex, Simplex): - simplex_ = frozenset(simplex) - else: - simplex_ = simplex.elements - if simplex_ in self._simplex_set.faces_dict[len(simplex_) - 1]: - if self.is_maximal(simplex): - del self._simplex_set.faces_dict[len(simplex_) - 1][simplex_] - faces = self.get_boundaries([simplex_]) - for s in faces: - if len(s) == len(simplex_): - continue - - self._simplex_set.faces_dict[len(s) - 1][s]["membership"] -= { - simplex_ - } - if ( - len(self._simplex_set.faces_dict[len(s) - 1][s]["membership"]) - == 0 - and len(s) == len(simplex) - 1 - ): - self._simplex_set.faces_dict[len(s) - 1][s]["is_maximal"] = True - - if ( - len(self._simplex_set.faces_dict[len(simplex_) - 1]) == 0 - and len(simplex_) - 1 == self._simplex_set.max_dim - ): - del self._simplex_set.faces_dict[len(simplex_) - 1] - self._simplex_set.max_dim = len(self._simplex_set.faces_dict) - 1 - - else: - raise ValueError( - "Only maximal simplices can be deleted, input simplex is not maximal" - ) + if isinstance(simplex, Hashable) and not isinstance(simplex, Iterable): + simplex_ = (simplex,) + elif isinstance(simplex, Iterable): + simplex_ = sorted(simplex) else: - raise KeyError("simplex is not a part of the simplicial complex") + raise TypeError("`simplex` must be Hashable or Iterable.") + + node = self._simplex_trie.find(simplex_) + if node is None: + raise ValueError( + f"Simplex {simplex_} is not a part of the simplicial complex." + ) + if len(node.children) != 0: + raise ValueError(f"Simplex {simplex_} is not maximal.") + + self._simplex_trie.remove_simplex(simplex_) def remove_nodes(self, node_set: Iterable[Hashable]) -> None: """Remove the given nodes from the simplicial complex. @@ -511,14 +460,8 @@ def remove_nodes(self, node_set: Iterable[Hashable]) -> None: >>> SC.simplices SimplexView([(2,), (3,), (4,), (2, 3), (3, 4)]) """ - removed_simplices = set() - for simplex in self.simplices: - if any(node in simplex for node in node_set): - removed_simplices.add(simplex) - - # Delete the simplices from largest to smallest. This way they are maximal when they are deleted. - for simplex in sorted(removed_simplices, key=len, reverse=True): - self.remove_maximal_simplex(simplex) + for node in node_set: + self._simplex_trie.remove_simplex((node,)) def add_node(self, node: Hashable, **kwargs) -> None: """Add node to simplicial complex. @@ -542,7 +485,9 @@ def add_node(self, node: Hashable, **kwargs) -> None: else: self.add_simplex([node], **kwargs) - def add_simplex(self, simplex: Collection, **kwargs) -> None: + def add_simplex( + self, simplex: Iterable[ElementType] | ElementType, **kwargs + ) -> None: """Add simplex to simplicial complex. In case sub-simplices are missing, they are added without attributes to the @@ -571,34 +516,20 @@ def add_simplex(self, simplex: Collection, **kwargs) -> None: if isinstance(simplex, str): simplex = [simplex] - if isinstance(simplex, Simplex): - elements = simplex.elements - kwargs = simplex._attributes | kwargs - elif isinstance(simplex, Collection): - elements = frozenset(simplex) - if len(elements) != len(simplex): - raise ValueError("a simplex cannot contain duplicate nodes") - else: + if not isinstance(simplex, Iterable) and not isinstance(simplex, Simplex): raise TypeError( - f"Input simplex must be a collection or a `Simplex` object, got {type(simplex)}." + f"Input simplex must be Iterable or Simplex, got {type(simplex)} instead." ) - # if the simplex is already part of this complex, update its attributes - if elements in self.simplices: - self._simplex_set.faces_dict[len(elements) - 1][elements].update(kwargs) - return - - self._update_faces_dict_length(elements) - - if self._simplex_set.max_dim < len(simplex) - 1: - self._simplex_set.max_dim = len(simplex) - 1 - - maximal_faces = set() - for r in range(len(elements), 0, -1): - for face in combinations(elements, r): - self._update_faces_dict_entry(face, elements, maximal_faces) + if isinstance(simplex, Simplex): + simplex_ = simplex.elements + kwargs = simplex._attributes | kwargs + else: + simplex_ = frozenset(simplex) + if len(simplex_) != len(simplex): + raise ValueError("a simplex cannot contain duplicate nodes") - self._simplex_set.faces_dict[len(elements) - 1][elements].update(kwargs) + self._simplex_trie.insert(sorted(simplex_), **kwargs) def add_simplices_from(self, simplices) -> None: """Add simplices from iterable to simplicial complex. @@ -608,51 +539,73 @@ def add_simplices_from(self, simplices) -> None: simplices : iterable Iterable of simplices to be added to the simplicial complex. """ + # for simplex in simplices: + # if isinstance(simplex, Hashable) and not isinstance(simplex, Iterable): + # simplex = [simplex] + # if isinstance(simplex, str): + # simplex = [simplex] + # + # if not isinstance(simplex, Iterable) and not isinstance(simplex, Simplex): + # raise TypeError( + # f"Input simplex must be Iterable or Simplex, got {type(simplex)} instead." + # ) + # + # if isinstance(simplex, Simplex): + # simplex_ = simplex.elements + # attr = simplex._properties + # else: + # simplex_ = frozenset(simplex) + # if len(simplex_) != len(simplex): + # raise ValueError("a simplex cannot contain duplicate nodes") + # attr = {} + # + # self._simplex_trie.insert_raw(sorted(simplex_), **attr) + # self._simplex_trie.restore_simplex_property() for s in simplices: self.add_simplex(s) def get_cofaces( - self, simplex: Iterable[Hashable], codimension: int - ) -> list[frozenset]: + self, simplex: Iterable[ElementType], codimension: int + ) -> list[Simplex[ElementType]]: """Get cofaces of simplex. Parameters ---------- - simplex : list, tuple or simplex - The simplex to get the cofaces of. + simplex : Iterable of AtomType + The simplex for which to compute the cofaces. codimension : int - The codimension. If codimension = 0, all cofaces are returned. + The codimension of the returned cofaces. If zero, all cofaces are returned. Returns ------- list of tuples The cofaces of the given simplex. """ - entire_tree = self.get_boundaries( - self.get_maximal_simplices_of_simplex(simplex) - ) + simplex = tuple(sorted(simplex)) + # TODO: This can be optimized inside the simplex trie. + cofaces = self._simplex_trie.cofaces(simplex) return [ - i - for i in entire_tree - if frozenset(simplex).issubset(i) and len(i) - len(simplex) >= codimension + coface.simplex + for coface in cofaces + if len(coface) - len(simplex) >= codimension ] - def get_star(self, simplex) -> list[frozenset]: + def get_star(self, simplex: Iterable[ElementType]) -> list[Simplex[ElementType]]: """Get star. + Notes + ----- + This function is equivalent to calling `get_cofaces(simplex, 0)`. + Parameters ---------- - simplex : list, tuple or simplex - The simplex represented by a list of its nodes. + simplex : Iterable of AtomType + The simplex for which to compute the star. Returns ------- - list[frozenset] + list of tuples The star of the given simplex. - - Notes - ----- - This function is equivalent to ``get_cofaces(simplex, 0)``. """ return self.get_cofaces(simplex, 0) @@ -714,7 +667,7 @@ def set_simplex_attributes(self, values, name: str | None = None) -> None: self[simplex].update(d) def get_node_attributes(self, name: str) -> dict[Hashable, Any]: - """Get node attributes from combinatorial complex. + """Get node attributes from simplicial complex. Parameters ---------- @@ -736,11 +689,15 @@ def get_node_attributes(self, name: str) -> dict[Hashable, Any]: >>> SC.get_node_attributes("color") {1: 'red', 2: 'blue', 3: 'black'} """ - return {n[0]: self[n][name] for n in self.skeleton(0) if name in self[n]} + return { + node.label: node.attributes[name] + for node in self._simplex_trie.skeleton(0) + if name in node.attributes + } def get_simplex_attributes( self, name: str, rank: int | None = None - ) -> dict[tuple[Hashable, ...], Any]: + ) -> dict[tuple[ElementType, ...], Any]: """Get node attributes from simplical complex. Parameters @@ -764,16 +721,18 @@ def get_simplex_attributes( >>> d = {(1, 2): "red", (2, 3): "blue", (3, 4): "black"} >>> SC.set_simplex_attributes(d, name="color") >>> SC.get_simplex_attributes("color") - {frozenset({1, 2}): 'red', frozenset({2, 3}): 'blue', frozenset({3, 4}): 'black'} + {(1, 2): 'red', (2, 3): 'blue', (3, 4): 'black'} """ if rank is None: return { - n: self.simplices[n][name] + n.elements: self.simplices[n][name] for n in self.simplices if name in self.simplices[n] } return { - n: self.simplices[n][name] for n in self.skeleton(rank) if name in self[n] + n.elements: self.simplices[n][name] + for n in self.skeleton(rank) + if name in self[n] } @staticmethod @@ -852,24 +811,25 @@ def incidence_matrix( f"Rank cannot be larger than the dimension of the complex, got {rank}." ) + # TODO: Get rid of this special case... if rank == 0: - boundary = dok_matrix( - (1, len(self._simplex_set.faces_dict[rank].items())), dtype=np.float32 - ) - boundary[0, 0 : len(self._simplex_set.faces_dict[rank].items())] = 1 + boundary = dok_matrix((1, self._simplex_trie.shape[rank]), dtype=np.float32) + boundary[0, 0 : self._simplex_trie.shape[rank]] = 1 if index: simplex_dict_d = { - simplex: i for i, simplex in enumerate(self.skeleton(0)) + simplex.elements: i for i, simplex in enumerate(self.skeleton(0)) } return {}, simplex_dict_d, boundary.tocsr() return boundary.tocsr() - idx_simplices, idx_faces, values = [], [], [] - simplex_dict_d = {simplex: i for i, simplex in enumerate(self.skeleton(rank))} + simplex_dict_d = { + tuple(sorted(simplex)): i for i, simplex in enumerate(self.skeleton(rank)) + } simplex_dict_d_minus_1 = { - simplex: i for i, simplex in enumerate(self.skeleton(rank - 1)) + tuple(sorted(simplex)): i + for i, simplex in enumerate(self.skeleton(rank - 1)) } for simplex, idx_simplex in simplex_dict_d.items(): for i, left_out in enumerate(np.sort(list(simplex))): @@ -890,22 +850,11 @@ def incidence_matrix( ), ) + if not signed: + boundary = abs(boundary) if index: - if signed: - return ( - simplex_dict_d_minus_1, - simplex_dict_d, - boundary, - ) - return ( - simplex_dict_d_minus_1, - simplex_dict_d, - abs(boundary), - ) - - if signed: - return boundary - return abs(boundary) + return simplex_dict_d_minus_1, simplex_dict_d, boundary + return boundary def coincidence_matrix( self, rank, signed: bool = True, weight=None, index: bool = False @@ -1346,7 +1295,7 @@ def restrict_to_simplices(self, cell_set) -> Self: rns = [cell for cell in cell_set if cell in self.simplices] return self.__class__(simplices=rns) - def restrict_to_nodes(self, node_set): + def restrict_to_nodes(self, node_set) -> "SimplicialComplex[ElementType]": """Construct a new simplicial complex by restricting the simplices. The simplices are restricted to the nodes referenced by node_set. @@ -1363,10 +1312,7 @@ def restrict_to_nodes(self, node_set): Examples -------- - >>> c1 = tnx.Simplex((1, 2, 3)) - >>> c2 = tnx.Simplex((1, 2, 4)) - >>> c3 = tnx.Simplex((1, 2, 5)) - >>> SC = tnx.SimplicialComplex([c1, c2, c3]) + >>> SC = tnx.SimplicialComplex([(1, 2, 3), (1, 2, 4), (1, 2, 5)]) >>> new_complex = SC.restrict_to_nodes([1, 2, 3, 4]) >>> new_complex.simplices SimplexView([(1,), (2,), (3,), (4,), (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (1, 2, 3), (1, 2, 4)]) @@ -1382,7 +1328,7 @@ def restrict_to_nodes(self, node_set): return SimplicialComplex(all_sim) - def get_all_maximal_simplices(self): + def get_all_maximal_simplices(self) -> Generator[Simplex[ElementType], None, None]: """Get all maximal simplices of this simplicial complex. A simplex is maximal if it is not a face of any other simplex in the complex. @@ -1398,7 +1344,10 @@ def get_all_maximal_simplices(self): >>> SC.get_all_maximal_simplices() [(2, 5), (1, 2, 3), (1, 2, 4)] """ - return [tuple(s) for s in self.simplices if self.is_maximal(s)] + for node in self._simplex_trie: + simplex = node.simplex + if self.is_maximal(simplex): + yield simplex @classmethod def from_spharpy(cls, mesh) -> Self: @@ -1606,10 +1555,9 @@ def to_trimesh(self, vertex_position_name: str = "position"): ).values() ) ) + faces = [simplex.elements for simplex in self.get_all_maximal_simplices()] - return trimesh.Trimesh( - vertices=vertices, faces=self.get_all_maximal_simplices(), process=False - ) + return trimesh.Trimesh(faces=faces, vertices=vertices, process=False) def to_spharapy(self, vertex_position_name: str = "position"): """Convert to sharapy. @@ -1657,8 +1605,9 @@ def to_spharapy(self, vertex_position_name: str = "position"): sorted(self.get_node_attributes(vertex_position_name).items()) ).values() ) + triangles = [simplex.elements for simplex in self.get_all_maximal_simplices()] - return tm.TriMesh(self.get_all_maximal_simplices(), vertices) + return tm.TriMesh(triangles, vertices) def laplace_beltrami_operator(self, mode: str = "inv_euclidean"): """Compute a laplacian matrix for a triangular mesh. @@ -1725,12 +1674,7 @@ def is_connected(self) -> bool: ----- A simplicial complex is connected iff its 1-skeleton is connected. """ - G = nx.Graph() - for edge in self.skeleton(1): - G.add_edge(edge[0], edge[1]) - for node in self.skeleton(0): - G.add_node(next(iter(node))) - return nx.is_connected(G) + return nx.is_connected(self.graph_skeleton()) @classmethod def simplicial_closure_of_hypergraph(cls, H) -> Self: @@ -1776,7 +1720,9 @@ def to_cell_complex(self): """ from toponetx.classes.cell_complex import CellComplex - return CellComplex(self.get_all_maximal_simplices()) + return CellComplex( + simplex.elements for simplex in self.get_all_maximal_simplices() + ) def to_hypergraph(self) -> Hypergraph: """Convert a simplicial complex to a hypergraph. @@ -1839,7 +1785,8 @@ def graph_skeleton(self) -> nx.Graph: """ G = nx.Graph() for node in self.skeleton(0): - G.add_node(node[0], **self[node]) + print("node", node, type(node)) + G.add_node(next(iter(node)), **self[node]) for edge in self.skeleton(1): G.add_edge(*edge, **self[edge]) return G diff --git a/toponetx/utils/iterable.py b/toponetx/utils/iterable.py new file mode 100644 index 00000000..98b34936 --- /dev/null +++ b/toponetx/utils/iterable.py @@ -0,0 +1,51 @@ +"""Module with iterable-related utility functions.""" + +from collections.abc import Sequence +from typing import TypeVar + +__all__ = ["is_ordered_subset"] + +T = TypeVar("T") + + +def is_ordered_subset(one: Sequence[T], other: Sequence[T]) -> bool: + """Return True if the first iterable is a subset of the second iterable. + + This method is specifically optimized for ordered iterables to use return as early as possible for non-subsets. + + Parameters + ---------- + one : Sequence + The first iterable. + other : Sequence + The second iterable. + + Returns + ------- + bool + True if the first iterable is a subset of the second iterable, False otherwise. + + Examples + -------- + >>> is_ordered_subset((2,), (1, 2)) + True + >>> is_ordered_subset((1, 2), (1, 2, 3)) + True + >>> is_ordered_subset((1, 2, 3), (1, 2, 3)) + True + >>> is_ordered_subset((1, 2, 3), (1, 2)) + False + >>> is_ordered_subset((1, 2, 3), (1, 2, 4)) + False + """ + index = 0 + for item in one: + while ( + index < len(other) + and isinstance(item, type(other[index])) + and other[index] < item + ): + index += 1 + if index >= len(other) or other[index] != item: + return False + return True