Skip to content

Commit

Permalink
Implementation of a Simplex Trie for Better Performance
Browse files Browse the repository at this point in the history
  • Loading branch information
ffl096 committed Dec 17, 2024
1 parent dc8252d commit c15f9c0
Show file tree
Hide file tree
Showing 13 changed files with 1,203 additions and 443 deletions.
1 change: 1 addition & 0 deletions test/classes/test_combinatorial_complex.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ def test_remove_nodes(self):
frozenset({6}): {"weight": 1},
}
}

example.remove_nodes(HyperEdge([3]))
assert example._complex_set.hyperedge_dict == {
0: {frozenset({4}): {"weight": 1}, frozenset({6}): {"weight": 1}}
Expand Down
21 changes: 10 additions & 11 deletions test/classes/test_reportviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,12 +576,10 @@ class TestReportViews_SimplexView:

def test_getitem(self):
"""Test __getitem__ method of the SimplexView."""
assert self.simplex_view.__getitem__((1, 2)) == {
"is_maximal": True,
"membership": set(),
}
assert self.simplex_view[(1, 2)] == {}
assert self.simplex_view[1] == {}
with pytest.raises(KeyError):
self.simplex_view.__getitem__([5])
_ = self.simplex_view[(5,)]

def test_str(self):
"""Test __str__ method of the SimplexView."""
Expand All @@ -590,6 +588,10 @@ def test_str(self):
== "SimplexView([(1,), (2,), (3,), (4,), (1, 2), (2, 3), (3, 4)])"
)

def test_shape(self) -> None:
"""Test the shape method of the SimplexView."""
assert self.simplex_view.shape == (4, 3)


class TestReportViews_NodeView:
"""Test the NodeView class of the ReportViews module."""
Expand All @@ -610,15 +612,12 @@ def test_repr(self):

def test_getitem(self):
"""Test __getitem__ method of the NodeView."""
assert self.node_view.__getitem__(Simplex([1])) == {
"is_maximal": False,
"membership": {frozenset({1, 2})},
}
assert self.node_view[Simplex([1])] == {}
with pytest.raises(KeyError):
self.node_view.__getitem__([1, 2])
_ = self.node_view[(1, 2)]

# test for nodes of ColoredHyperGraph.
assert self.node_view_1.__getitem__([1]) == {"weight": 1}
assert self.node_view_1[(1,)] == {"weight": 1}


class TestReportViews_PathView:
Expand Down
13 changes: 13 additions & 0 deletions test/classes/test_simplex.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,19 @@ def test_contains(self):
assert 3 not in s
assert (1, 2) in s

def test_le(self) -> None:
"""Test the __le__ method of the simplex."""
s1 = Simplex([1, 2])
s2 = Simplex([1, 2])
s3 = Simplex([1, 2, 3])

assert s1 <= s2
assert s1 <= s3
assert not s3 <= s1

with pytest.raises(TypeError):
s1 <= 1

def test_boundary(self):
"""Test the boundary property of the simplex."""
s = Simplex((1, 2, 3))
Expand Down
201 changes: 201 additions & 0 deletions test/classes/test_simplex_trie.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Tests for the `simplex_trie` module."""

import pytest

from toponetx.classes.simplex import Simplex
from toponetx.classes.simplex_trie import SimplexNode, SimplexTrie


class TestSimplexNode:
"""Tests for the `SimplexNode` class."""

def test_invalid_root(self) -> None:
"""Test that invalid root nodes are rejected."""
with pytest.raises(ValueError):
_ = SimplexNode(label="invalid", parent=None)

def test_node_len(self) -> None:
"""Test the length of the node."""
root_node: SimplexNode[int] = SimplexNode(label=None)
assert len(root_node) == 0

child_node: SimplexNode[int] = SimplexNode(label=1, parent=root_node)
assert len(child_node) == 1

grandchild_node: SimplexNode[int] = SimplexNode(label=2, parent=child_node)
assert len(grandchild_node) == 2

def test_node_repr(self) -> None:
"""Test the string representation of the node."""
root_node = SimplexNode(label=None)
assert repr(root_node) == "SimplexNode(None, None)"

child_node = SimplexNode(label=1, parent=root_node)
assert repr(child_node) == "SimplexNode(1, SimplexNode(None, None))"

def test_simplex_property(self) -> None:
"""Test the `simplex` property of a node."""
root_node = SimplexNode(label=None)
assert root_node.simplex is None

child_node = SimplexNode(label=1, parent=root_node)
assert child_node.simplex == Simplex((1,))

grandchild_node = SimplexNode(label=2, parent=child_node)
assert grandchild_node.simplex == Simplex((1, 2))


class TestSimplexTrie:
"""Tests for the `SimplexTree` class."""

def test_insert(self) -> None:
"""Test that the internal data structures of the simplex trie are correct after insertion."""
trie = SimplexTrie()
trie.insert((1, 2, 3))

assert trie.shape == [3, 3, 1]

assert set(trie.root.children.keys()) == {1, 2, 3}
assert set(trie.root.children[1].children.keys()) == {2, 3}
assert set(trie.root.children[1].children[2].children.keys()) == {3}
assert set(trie.root.children[2].children.keys()) == {3}

# the label list should contain the nodes of each depth according to their label
label_to_simplex = {
1: {1: [(1,)], 2: [(2,)], 3: [(3,)]},
2: {2: [(1, 2)], 3: [(1, 3), (2, 3)]},
3: {3: [(1, 2, 3)]},
}

assert len(trie.label_lists) == len(label_to_simplex)
for depth, label_list in trie.label_lists.items():
assert depth in label_to_simplex
assert len(label_list) == len(label_to_simplex[depth])
for label, nodes in label_list.items():
assert len(nodes) == len(label_to_simplex[depth][label])
for node, expected in zip(
nodes, label_to_simplex[depth][label], strict=True
):
assert node.simplex.elements == expected

def test_len(self) -> None:
"""Test the `__len__` method of a simplex trie."""
trie = SimplexTrie()
assert len(trie) == 0

trie.insert((1, 2, 3))
assert len(trie) == 7

def test_getitem(self) -> None:
"""Test the `__getitem__` method of a simplex trie."""
trie = SimplexTrie()
trie.insert((1, 2, 3))

assert trie[(1,)].simplex == Simplex((1,))
assert trie[(1, 2)].simplex == Simplex((1, 2))
assert trie[(1, 2, 3)].simplex == Simplex((1, 2, 3))

with pytest.raises(KeyError):
_ = trie[(0,)]

def test_iter(self) -> None:
"""Test the iteration of the trie."""
trie = SimplexTrie()
trie.insert((1, 2, 3))
trie.insert((2, 3, 4))
trie.insert((0, 1))

# We guarantee a specific ordering of the simplices when iterating. Hence, we explicitly compare lists here.
assert [node.simplex.elements for node in trie] == [
(0,),
(1,),
(2,),
(3,),
(4,),
(0, 1),
(1, 2),
(1, 3),
(2, 3),
(2, 4),
(3, 4),
(1, 2, 3),
(2, 3, 4),
]

def test_cofaces(self) -> None:
"""Test the cofaces method."""
trie = SimplexTrie()
trie.insert((1, 2, 3))
trie.insert((1, 2, 4))

# no ordering is guaranteed for the cofaces method
assert {node.simplex.elements for node in trie.cofaces((1,))} == {
(1,),
(1, 2),
(1, 3),
(1, 4),
(1, 2, 3),
(1, 2, 4),
}
assert {node.simplex.elements for node in trie.cofaces((2,))} == {
(2,),
(1, 2),
(2, 3),
(2, 4),
(1, 2, 3),
(1, 2, 4),
}

def test_is_maximal(self) -> None:
"""Test the `is_maximal` method."""
trie = SimplexTrie()
trie.insert((1, 2, 3))
trie.insert((1, 2, 4))

assert trie.is_maximal((1, 2, 3))
assert trie.is_maximal((1, 2, 4))
assert not trie.is_maximal((1, 2))
assert not trie.is_maximal((1, 3))
assert not trie.is_maximal((1, 4))
assert not trie.is_maximal((2, 3))

with pytest.raises(ValueError):
trie.is_maximal((5,))

def test_skeleton(self) -> None:
"""Test the skeleton method."""
trie = SimplexTrie()
trie.insert((1, 2, 3))
trie.insert((1, 2, 4))

# no ordering is guaranteed for the skeleton method
assert {node.simplex.elements for node in trie.skeleton(0)} == {
(1,),
(2,),
(3,),
(4,),
}
assert {node.simplex.elements for node in trie.skeleton(1)} == {
(1, 2),
(1, 3),
(1, 4),
(2, 3),
(2, 4),
}
assert {node.simplex.elements for node in trie.skeleton(2)} == {
(1, 2, 3),
(1, 2, 4),
}

with pytest.raises(ValueError):
_ = next(trie.skeleton(-1))
with pytest.raises(ValueError):
_ = next(trie.skeleton(3))

def test_remove_simplex(self) -> None:
"""Test the removal of simplices from the trie."""
trie = SimplexTrie()
trie.insert((1, 2, 3))

trie.remove_simplex((1, 2))
assert len(trie) == 5
Loading

0 comments on commit c15f9c0

Please sign in to comment.