diff --git a/jaclang/compiler/absyntree.py b/jaclang/compiler/absyntree.py index aa83d28d9..048ef7d8d 100644 --- a/jaclang/compiler/absyntree.py +++ b/jaclang/compiler/absyntree.py @@ -110,11 +110,15 @@ def get_all_sub_nodes(self, typ: Type[T], brute_force: bool = True) -> list[T]: return Pass.get_all_sub_nodes(node=self, typ=typ, brute_force=brute_force) - def has_parent_of_type(self, typ: Type[T]) -> Optional[T]: + def parent_of_type(self, typ: Type[T]) -> T: """Get parent of type.""" from jaclang.compiler.passes import Pass - return Pass.has_parent_of_type(node=self, typ=typ) + ret = Pass.has_parent_of_type(node=self, typ=typ) + if isinstance(ret, typ): + return ret + else: + raise ValueError(f"Parent of type {typ} not found.") def format(self) -> str: """Get all sub nodes of type.""" @@ -592,15 +596,15 @@ class Import(ElementStmt, CodeBlockStmt): def __init__( self, hint: SubTag[Name], - paths: list[ModulePath], - items: Optional[SubNodeList[ModuleItem]], + from_loc: Optional[ModulePath], + items: SubNodeList[ModuleItem] | SubNodeList[ModulePath], is_absorb: bool, # For includes kid: Sequence[AstNode], doc: Optional[String] = None, ) -> None: """Initialize import node.""" self.hint = hint - self.paths = paths + self.from_loc = from_loc self.items = items self.is_absorb = is_absorb AstNode.__init__(self, kid=kid) @@ -611,9 +615,8 @@ def normalize(self, deep: bool = False) -> bool: res = True if deep: res = self.hint.normalize(deep) - for p in self.paths: - res = res and p.normalize(deep) - res = res and self.items.normalize(deep) if self.items else res + res = res and self.from_loc.normalize(deep) if self.from_loc else res + res = res and self.items.normalize(deep) res = res and self.doc.normalize(deep) if self.doc else res new_kid: list[AstNode] = [] if self.doc: @@ -623,15 +626,11 @@ def normalize(self, deep: bool = False) -> bool: else: new_kid.append(self.gen_token(Tok.KW_IMPORT)) new_kid.append(self.hint) - if self.items: + if self.from_loc: new_kid.append(self.gen_token(Tok.KW_FROM)) - for p in self.paths: - new_kid.append(p) - new_kid.append(self.gen_token(Tok.COMMA)) - new_kid.pop() - if self.items: + new_kid.append(self.from_loc) new_kid.append(self.gen_token(Tok.COMMA)) - new_kid.append(self.items) + new_kid.append(self.items) new_kid.append(self.gen_token(Tok.SEMI)) self.set_kids(nodes=new_kid) return res diff --git a/jaclang/compiler/parser.py b/jaclang/compiler/parser.py index c56f9bd07..be3bc267c 100644 --- a/jaclang/compiler/parser.py +++ b/jaclang/compiler/parser.py @@ -287,17 +287,22 @@ def import_stmt(self, kid: list[ast.AstNode]) -> ast.Import: if len(kid) == 1 and isinstance(kid[0], ast.Import): return self.nu(kid[0]) lang = kid[1] - paths = [i for i in kid if isinstance(i, ast.ModulePath)] + from_path = kid[3] if isinstance(kid[3], ast.ModulePath) else None + if from_path: + items = kid[-2] if isinstance(kid[-2], ast.SubNodeList) else None + else: + paths = [i for i in kid if isinstance(i, ast.ModulePath)] + items = ast.SubNodeList[ast.ModulePath]( + items=paths, delim=Tok.COMMA, kid=kid[2:-1] + ) + kid = kid[:2] + [items] + kid[-1:] - items = kid[-2] if isinstance(kid[-2], ast.SubNodeList) else None is_absorb = False - if isinstance(lang, ast.SubTag) and ( - isinstance(items, ast.SubNodeList) or items is None - ): + if isinstance(lang, ast.SubTag) and (isinstance(items, ast.SubNodeList)): return self.nu( ast.Import( hint=lang, - paths=paths, + from_loc=from_path, items=items, is_absorb=is_absorb, kid=kid, @@ -342,14 +347,20 @@ def include_stmt(self, kid: list[ast.AstNode]) -> ast.Import: include_stmt: KW_INCLUDE sub_name import_path SEMI """ lang = kid[1] - paths = [i for i in kid if isinstance(i, ast.ModulePath)] + from_path = kid[2] + if not isinstance(from_path, ast.ModulePath): + raise self.ice() + items = ast.SubNodeList[ast.ModulePath]( + items=[from_path], delim=Tok.COMMA, kid=[from_path] + ) + kid = kid[:2] + [items] + kid[3:] is_absorb = True if isinstance(lang, ast.SubTag): return self.nu( ast.Import( hint=lang, - paths=paths, - items=None, + from_loc=None, + items=items, is_absorb=is_absorb, kid=kid, ) diff --git a/jaclang/compiler/passes/ir_pass.py b/jaclang/compiler/passes/ir_pass.py index dd1208ba0..3ab4d10ca 100644 --- a/jaclang/compiler/passes/ir_pass.py +++ b/jaclang/compiler/passes/ir_pass.py @@ -1,16 +1,18 @@ """Abstract class for IR Passes for Jac.""" -from typing import Optional, Type +from typing import Optional, Type, TypeVar import jaclang.compiler.absyntree as ast from jaclang.compiler.passes.transform import Transform from jaclang.utils.helpers import pascal_to_snake +T = TypeVar("T", bound=ast.AstNode) -class Pass(Transform[ast.T]): + +class Pass(Transform[T]): """Abstract class for IR passes.""" - def __init__(self, input_ir: ast.T, prior: Optional[Transform]) -> None: + def __init__(self, input_ir: T, prior: Optional[Transform]) -> None: """Initialize parser.""" self.term_signal = False self.prune_signal = False @@ -45,10 +47,10 @@ def prune(self) -> None: @staticmethod def get_all_sub_nodes( - node: ast.AstNode, typ: Type[ast.T], brute_force: bool = False - ) -> list[ast.T]: + node: ast.AstNode, typ: Type[T], brute_force: bool = False + ) -> list[T]: """Get all sub nodes of type.""" - result: list[ast.T] = [] + result: list[T] = [] # Assumes pass built the sub node table if not node: return result @@ -69,7 +71,7 @@ def get_all_sub_nodes( return result @staticmethod - def has_parent_of_type(node: ast.AstNode, typ: Type[ast.T]) -> Optional[ast.T]: + def has_parent_of_type(node: ast.AstNode, typ: Type[T]) -> Optional[T]: """Check if node has parent of type.""" while node.parent: if isinstance(node.parent, typ): @@ -97,7 +99,7 @@ def recalculate_parents(self, node: ast.AstNode) -> None: # Transform Implementations # ------------------------- - def transform(self, ir: ast.T) -> ast.AstNode: + def transform(self, ir: T) -> ast.AstNode: """Run pass.""" # Only performs passes on proper ASTs if not isinstance(ir, ast.AstNode): diff --git a/jaclang/compiler/passes/main/import_pass.py b/jaclang/compiler/passes/main/import_pass.py index b51007622..7cf66da6a 100644 --- a/jaclang/compiler/passes/main/import_pass.py +++ b/jaclang/compiler/passes/main/import_pass.py @@ -36,7 +36,8 @@ def enter_module(self, node: ast.Module) -> None: self.run_again = False all_imports = self.get_all_sub_nodes(node, ast.ModulePath) for i in all_imports: - if i.parent.hint.tag.value == "jac" and not i.sub_module: + lang = i.parent_of_type(ast.Import).hint.tag.value + if lang == "jac" and not i.sub_module: self.run_again = True mod = self.import_module( node=i, @@ -48,7 +49,7 @@ def enter_module(self, node: ast.Module) -> None: self.annex_impl(mod) i.sub_module = mod i.add_kids_right([mod], pos_update=False) - elif i.parent.hint.tag.value == "py" and settings.jac_proc_debug: + elif lang == "py" and settings.jac_proc_debug: mod = self.import_py_module(node=i, mod_path=node.loc.mod_path) i.sub_module = mod i.add_kids_right([mod], pos_update=False) diff --git a/jaclang/compiler/passes/main/pyast_gen_pass.py b/jaclang/compiler/passes/main/pyast_gen_pass.py index 13b0bedfa..3280d4cdf 100644 --- a/jaclang/compiler/passes/main/pyast_gen_pass.py +++ b/jaclang/compiler/passes/main/pyast_gen_pass.py @@ -492,17 +492,20 @@ def exit_import(self, node: ast.Import) -> None: py_nodes.append( self.sync(ast3.Expr(value=node.doc.gen.py_ast[0]), jac_node=node.doc) ) - py_compat_path_str = [] - path_alias = {} - for path in node.paths: - py_compat_path_str.append(path.path_str.lstrip(".")) - path_alias[path.path_str] = path.alias.sym_name if path.alias else None + path_alias: dict[str, Optional[str]] = ( + {node.from_loc.path_str: None} if node.from_loc else {} + ) imp_from = {} if node.items: for item in node.items.items: - imp_from[item.name.sym_name] = ( - item.alias.sym_name if item.alias else False - ) + if isinstance(item, ast.ModuleItem): + imp_from[item.name.sym_name] = ( + item.alias.sym_name if item.alias else False + ) + elif isinstance(item, ast.ModulePath): + path_alias[item.path_str] = ( + item.alias.sym_name if item.alias else None + ) keys = [] values = [] @@ -594,10 +597,13 @@ def exit_import(self, node: ast.Import) -> None: ) ) if node.is_absorb: + source = node.items.items[0] + if not isinstance(source, ast.ModulePath): + raise self.ice() py_nodes.append( self.sync( py_node=ast3.ImportFrom( - module=py_compat_path_str[0] if py_compat_path_str[0] else None, + module=(source.path_str.lstrip(".") if source else None), names=[self.sync(ast3.alias(name="*"), node)], level=0, ), @@ -608,15 +614,17 @@ def exit_import(self, node: ast.Import) -> None: self.warning( "Includes import * in target module into current namespace." ) - if not node.items: - py_nodes.append( - self.sync(ast3.Import(names=[i.gen.py_ast[0] for i in node.paths])) - ) + if not node.from_loc: + py_nodes.append(self.sync(ast3.Import(names=node.items.gen.py_ast))) else: py_nodes.append( self.sync( ast3.ImportFrom( - module=py_compat_path_str[0] if py_compat_path_str[0] else None, + module=( + node.from_loc.path_str.lstrip(".") + if node.from_loc + else None + ), names=node.items.gen.py_ast, level=0, ) diff --git a/jaclang/compiler/passes/main/pyast_load_pass.py b/jaclang/compiler/passes/main/pyast_load_pass.py index c94249421..f8a785313 100644 --- a/jaclang/compiler/passes/main/pyast_load_pass.py +++ b/jaclang/compiler/passes/main/pyast_load_pass.py @@ -1372,12 +1372,13 @@ class Import(stmt): pos_end=0, ) pytag = ast.SubTag[ast.Name](tag=lang, kid=[lang]) + items = ast.SubNodeList[ast.ModulePath](items=paths, delim=Tok.COMMA, kid=paths) ret = ast.Import( hint=pytag, - paths=paths, - items=None, + from_loc=None, + items=items, is_absorb=False, - kid=[pytag, *paths], + kid=[pytag, items], ) return ret @@ -1449,7 +1450,7 @@ class ImportFrom(stmt): pytag = ast.SubTag[ast.Name](tag=lang, kid=[lang]) ret = ast.Import( hint=pytag, - paths=[path], + from_loc=path, items=items, is_absorb=False, kid=[pytag, path, items], diff --git a/jaclang/compiler/passes/main/sym_tab_build_pass.py b/jaclang/compiler/passes/main/sym_tab_build_pass.py index 6c96d91c5..3acc6f877 100644 --- a/jaclang/compiler/passes/main/sym_tab_build_pass.py +++ b/jaclang/compiler/passes/main/sym_tab_build_pass.py @@ -381,16 +381,22 @@ def exit_import(self, node: ast.Import) -> None: is_absorb: bool, sub_module: Optional[Module], """ - if node.items: + if not node.is_absorb: for i in node.items.items: self.def_insert(i, single_decl="import item") elif node.is_absorb and node.hint.tag.value == "jac": - if not node.paths[0].sub_module or not node.paths[0].sub_module.sym_tab: + source = node.items.items[0] + if ( + not isinstance(source, ast.ModulePath) + or not source.sub_module + or not source.sub_module.sym_tab + ): self.error( - f"Module {node.paths[0].path_str} not found to include *, or ICE occurred!" + f"Module {node.from_loc.path_str if node.from_loc else 'from location'}" + f" not found to include *, or ICE occurred!" ) else: - for v in node.paths[0].sub_module.sym_tab.tab.values(): + for v in source.sub_module.sym_tab.tab.values(): self.def_insert(v.decl, table_override=self.cur_scope()) def enter_module_path(self, node: ast.ModulePath) -> None: diff --git a/jaclang/compiler/workspace.py b/jaclang/compiler/workspace.py index 18ff59642..6fdf2a2e9 100644 --- a/jaclang/compiler/workspace.py +++ b/jaclang/compiler/workspace.py @@ -171,9 +171,7 @@ def get_dependencies( [ i for i in mod_ir.get_all_sub_nodes(ast.ModulePath) - if i.parent - and isinstance(i.parent, ast.Import) - and i.parent.hint.tag.value == "jac" + if i.parent_of_type(ast.Import).hint.tag.value == "jac" ] if mod_ir else [] @@ -184,9 +182,7 @@ def get_dependencies( i for i in mod_ir.get_all_sub_nodes(ast.ModulePath) if i.loc.mod_path == file_path - and i.parent - and isinstance(i.parent, ast.Import) - and i.parent.hint.tag.value == "jac" + and i.parent_of_type(ast.Import).hint.tag.value == "jac" ] if mod_ir else []