diff --git a/codebasin/preprocessor.py b/codebasin/preprocessor.py index 1b1d342..ce45311 100644 --- a/codebasin/preprocessor.py +++ b/codebasin/preprocessor.py @@ -11,7 +11,9 @@ import hashlib import logging import os +from collections.abc import Iterable from copy import copy +from typing import Self import numpy as np @@ -583,6 +585,18 @@ def evaluate_for_platform(self, **kwargs): """ return False + def walk(self) -> Iterable[Self]: + """ + Returns + ------- + Iterable[Self] + An Iterable visiting all descendants of this node via a preorder + traversal. + """ + yield self + for child in self.children: + yield from child.walk() + class FileNode(Node): """ @@ -2330,6 +2344,16 @@ def __init__(self, filename): self.root = FileNode(filename) self._latest_node = self.root + def walk(self) -> Iterable[Node]: + """ + Returns + ------- + Iterable[Node] + An Iterable visiting all nodes in the tree via a preorder + traversal. + """ + yield from self.root.walk() + def associate_file(self, filename): self.root.filename = filename diff --git a/tests/source-tree/__init__.py b/tests/source-tree/__init__.py new file mode 100644 index 0000000..94adb81 --- /dev/null +++ b/tests/source-tree/__init__.py @@ -0,0 +1,2 @@ +# Copyright (C) 2019-2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause diff --git a/tests/source-tree/test_source_tree.py b/tests/source-tree/test_source_tree.py new file mode 100644 index 0000000..4d9463a --- /dev/null +++ b/tests/source-tree/test_source_tree.py @@ -0,0 +1,79 @@ +# Copyright (C) 2019-2024 Intel Corporation +# SPDX-License-Identifier: BSD-3-Clause + +import logging +import tempfile +import unittest +import warnings + +from codebasin.file_parser import FileParser +from codebasin.preprocessor import CodeNode, DirectiveNode, FileNode + + +class TestSourceTree(unittest.TestCase): + """ + Test SourceTree class. + """ + + def setUp(self): + logging.getLogger("codebasin").disabled = False + warnings.simplefilter("ignore", ResourceWarning) + + def test_walk(self): + """Check that walk() visits nodes in the expected order""" + + # TODO: Revisit this when SourceTree can be built without a file. + with tempfile.NamedTemporaryFile( + mode="w", + delete_on_close=False, + suffix=".cpp", + ) as f: + source = """ + #if defined(FOO) + void foo(); + #elif defined(BAR) + void bar(); + #else + void baz(); + #endif + + void qux(); + """ + f.write(source) + f.close() + + # TODO: Revisit this when __str__() is more reliable. + tree = FileParser(f.name).parse_file(summarize_only=False) + expected_types = [ + FileNode, + DirectiveNode, + CodeNode, + DirectiveNode, + CodeNode, + DirectiveNode, + CodeNode, + DirectiveNode, + CodeNode, + ] + expected_contents = [ + f.name, + "FOO", + "foo", + "BAR", + "bar", + "else", + "baz", + "endif", + "qux", + ] + for i, node in enumerate(tree.walk()): + self.assertTrue(isinstance(node, expected_types[i])) + if isinstance(node, CodeNode): + contents = node.spelling()[0] + else: + contents = str(node) + self.assertTrue(expected_contents[i] in contents) + + +if __name__ == "__main__": + unittest.main()