diff --git a/jac/jaclang/plugin/default.py b/jac/jaclang/plugin/default.py index 9fc1da55e..91cf8ec55 100644 --- a/jac/jaclang/plugin/default.py +++ b/jac/jaclang/plugin/default.py @@ -42,7 +42,11 @@ ) from jaclang.runtimelib.importer import ImportPathSpec, JacImporter, PythonImporter from jaclang.runtimelib.machine import JacMachine, JacProgram -from jaclang.runtimelib.utils import collect_node_connections, traverse_graph +from jaclang.runtimelib.utils import ( + all_issubclass, + collect_node_connections, + traverse_graph, +) import pluggy @@ -377,58 +381,75 @@ def spawn_call(op1: Architype, op2: Architype) -> WalkerArchitype: walker.path = [] walker.next = [node] - if walker.next: - current_node = walker.next[-1].architype - for i in warch._jac_entry_funcs_: - if not i.trigger: - if i.func: - i.func(warch, current_node) - else: - raise ValueError(f"No function {i.name} to call.") + current_node = node.architype + + # walker entry + for i in warch._jac_entry_funcs_: + if i.func and not i.trigger: + i.func(warch, current_node) + if walker.disengaged: + return warch + while len(walker.next): if current_node := walker.next.pop(0).architype: - for i in current_node._jac_entry_funcs_: - if not i.trigger or isinstance(warch, i.trigger): - if i.func: - i.func(current_node, warch) - else: - raise ValueError(f"No function {i.name} to call.") - if walker.disengaged: - return warch + # walker entry with for i in warch._jac_entry_funcs_: - if not i.trigger or isinstance(current_node, i.trigger): - if i.func and i.trigger: - i.func(warch, current_node) - elif not i.trigger: - continue - else: - raise ValueError(f"No function {i.name} to call.") - if walker.disengaged: - return warch - for i in warch._jac_exit_funcs_: - if not i.trigger or isinstance(current_node, i.trigger): - if i.func and i.trigger: - i.func(warch, current_node) - elif not i.trigger: - continue - else: - raise ValueError(f"No function {i.name} to call.") - if walker.disengaged: - return warch + if ( + i.func + and i.trigger + and all_issubclass(i.trigger, NodeArchitype) + and isinstance(current_node, i.trigger) + ): + i.func(warch, current_node) + + # node entry + for i in current_node._jac_entry_funcs_: + if i.func and not i.trigger: + i.func(current_node, warch) + + # node entry with + for i in current_node._jac_entry_funcs_: + if ( + i.func + and i.trigger + and all_issubclass(i.trigger, WalkerArchitype) + and isinstance(warch, i.trigger) + ): + i.func(current_node, warch) + + # node exit with for i in current_node._jac_exit_funcs_: - if not i.trigger or isinstance(warch, i.trigger): - if i.func: - i.func(current_node, warch) - else: - raise ValueError(f"No function {i.name} to call.") + if ( + i.func + and i.trigger + and all_issubclass(i.trigger, WalkerArchitype) + and isinstance(warch, i.trigger) + ): + i.func(current_node, warch) + + # node exit + for i in current_node._jac_exit_funcs_: + if i.func and not i.trigger: + i.func(current_node, warch) + + # walker exit with + for i in warch._jac_exit_funcs_: + if ( + i.func + and i.trigger + and all_issubclass(i.trigger, NodeArchitype) + and isinstance(current_node, i.trigger) + ): + i.func(warch, current_node) if walker.disengaged: return warch + # walker exit for i in warch._jac_exit_funcs_: - if not i.trigger: - if i.func: - i.func(warch, current_node) - else: - raise ValueError(f"No function {i.name} to call.") + if i.func and not i.trigger: + i.func(warch, current_node) + if walker.disengaged: + return warch + walker.ignores = [] return warch diff --git a/jac/jaclang/runtimelib/architype.py b/jac/jaclang/runtimelib/architype.py index 6c1ad9a55..0a674afe1 100644 --- a/jac/jaclang/runtimelib/architype.py +++ b/jac/jaclang/runtimelib/architype.py @@ -219,9 +219,9 @@ class WalkerAnchor(Anchor): """Walker Anchor.""" architype: WalkerArchitype - path: list[Anchor] = field(default_factory=list) - next: list[Anchor] = field(default_factory=list) - ignores: list[Anchor] = field(default_factory=list) + path: list[NodeAnchor] = field(default_factory=list) + next: list[NodeAnchor] = field(default_factory=list) + ignores: list[NodeAnchor] = field(default_factory=list) disengaged: bool = False diff --git a/jac/jaclang/runtimelib/utils.py b/jac/jaclang/runtimelib/utils.py index 010016640..0e19d91e7 100644 --- a/jac/jaclang/runtimelib/utils.py +++ b/jac/jaclang/runtimelib/utils.py @@ -5,6 +5,7 @@ import ast as ast3 import sys from contextlib import contextmanager +from types import UnionType from typing import Callable, Iterator, TYPE_CHECKING import jaclang.compiler.absyntree as ast @@ -156,6 +157,21 @@ def extract_type(node: ast.AstNode) -> list[str]: return extracted_type +def all_issubclass( + classes: type | UnionType | tuple[type | UnionType, ...], target: type +) -> bool: + """Check if all classes is subclass of target type.""" + match classes: + case type(): + return issubclass(classes, target) + case UnionType(): + return all((all_issubclass(cls, target) for cls in classes.__args__)) + case tuple(): + return all((all_issubclass(cls, target) for cls in classes)) + case _: + return False + + def extract_params( body: ast.FuncCall, ) -> tuple[dict[str, ast.Expr], list[tuple[str, ast3.AST]], list[tuple[str, ast3.AST]]]: diff --git a/jac/jaclang/tests/fixtures/visit_sequence.jac b/jac/jaclang/tests/fixtures/visit_sequence.jac new file mode 100644 index 000000000..a78d4c635 --- /dev/null +++ b/jac/jaclang/tests/fixtures/visit_sequence.jac @@ -0,0 +1,57 @@ +node Node { + has val: str; + + can entry1 with entry { + print(f"{self.val}-4"); + } + + ###################################################### + # NOT SUPPORTED YET IF IT'S DECLARED FIRST # + ###################################################### + # + # can entry2 with "Walker" entry { + # print(5); + # } + # + # can exit1 with "Walker" exit { + # print(6); + # } + # + ###################################################### + # -------------------------------------------------- # + ###################################################### + + + can exit2 with exit { + print(f"{self.val}-7"); + } +} +walker Walker { + can entry1 with entry { + print(1); + } + + can entry2 with `root entry { + print(2); + visit [-->]; + } + + can entry3 with Node entry { + print(f"{here.val}-3"); + } + + can exit1 with Node exit { + print(f"{here.val}-8"); + } + + can exit2 with exit { + print(9); + } +} +with entry{ + root ++> Node(val = "a"); + root ++> Node(val = "b"); + root ++> Node(val = "c"); + + Walker() spawn root; +} \ No newline at end of file diff --git a/jac/jaclang/tests/test_language.py b/jac/jaclang/tests/test_language.py index e24da0b13..0565265fb 100644 --- a/jac/jaclang/tests/test_language.py +++ b/jac/jaclang/tests/test_language.py @@ -1156,3 +1156,14 @@ def test_visit_order(self) -> None: sys.stdout = sys.__stdout__ stdout_value = captured_output.getvalue() self.assertEqual("[MyNode(Name='End'), MyNode(Name='Middle')]\n", stdout_value) + + def test_visit_sequence(self) -> None: + """Test conn assign on edges.""" + captured_output = io.StringIO() + sys.stdout = captured_output + jac_import("visit_sequence", base_path=self.fixture_abs_path("./")) + sys.stdout = sys.__stdout__ + self.assertEqual( + "1\n2\na-3\na-4\na-7\na-8\nb-3\nb-4\nb-7\nb-8\nc-3\nc-4\nc-7\nc-8\n9\n", + captured_output.getvalue(), + )