Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Import decl from py #1327

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions jac/jaclang/compiler/absyntree.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(self, kid: Sequence[AstNode]) -> None:
self.gen: CodeGenTarget = CodeGenTarget()
self.meta: dict[str, str] = {}
self.loc: CodeLocInfo = CodeLocInfo(*self.resolve_tok_range())
self.is_raised_from_py: bool = False

@property
def sym_tab(self) -> SymbolTable:
Expand All @@ -57,7 +58,7 @@ def sym_tab(self) -> SymbolTable:
if not self._sym_tab:
raise ValueError(
f"Symbol table not set for {type(self).__name__}. Impossible.\n"
f"Node: {self.pp()}\n"
f"Node: {self.pp()}{self.loc.mod_path}\n"
f"Parent: {self.parent.pp() if self.parent else None}\n"
)
return self._sym_tab
Expand Down Expand Up @@ -214,6 +215,10 @@ def unparse(self) -> str:
raise NotImplementedError(f"Node {type(self).__name__} is not valid.")
return res

def kid_of_type(self, typ: Type[T]) -> list[T]:
"""Get kids of a specific type."""
return self._sub_node_tab.get(typ, [])


class AstSymbolNode(AstNode):
"""Nodes that have link to a symbol in symbol table."""
Expand Down Expand Up @@ -634,7 +639,6 @@ def __init__(
self.py_raise_map: dict[str, str] = {}
self.registry = registry
self.terminals: list[Token] = terminals
self.is_raised_from_py: bool = False
AstNode.__init__(self, kid=kid)
AstDocNode.__init__(self, doc=doc)

Expand Down Expand Up @@ -1053,6 +1057,7 @@ def __init__(
sym_category=SymbolType.MOD_VAR,
)
self.abs_path: Optional[str] = None
self.is_imported: bool = False

@property
def from_parent(self) -> Import:
Expand Down
245 changes: 191 additions & 54 deletions jac/jaclang/compiler/passes/main/import_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def process_import(self, i: ast.ModulePath) -> None:
self.import_jac_module(node=i)

def attach_mod_to_node(
self, node: ast.ModulePath | ast.ModuleItem, mod: ast.Module | None
self,
node: ast.ModulePath | ast.ModuleItem,
mod: ast.Module | ast.Architype | ast.ArchDef | ast.Ability | None,
) -> None:
"""Attach a module to a node."""
if mod:
Expand Down Expand Up @@ -198,10 +200,19 @@ def import_jac_mod_from_file(self, target: str) -> ast.Module | None:
class PyImportPass(JacImportPass):
"""Jac statically imports Python modules."""

__call_count: int = 0

def before_pass(self) -> None:
"""Only run pass if settings are set to raise python."""
"""Run before the pass is called."""
PyImportPass.__call_count += 1
super().before_pass()
self.__load_builtins()

def after_pass(self) -> None:
"""Run after the pass is called."""
if PyImportPass.__call_count == 1:
self.__load_builtins()
PyImportPass.__call_count -= 1
return super().after_pass()

def __get_current_module(self, node: ast.AstNode) -> str:
parent = node.find_parent_of_type(ast.Module)
Expand All @@ -221,82 +232,208 @@ def process_import(self, i: ast.ModulePath) -> None:
# Solution to that is to get the import node and check the from loc then
# handle it based on if there a from loc or not
imp_node = i.parent_of_type(ast.Import)

if imp_node.is_py and not i.sub_module:
if imp_node.from_loc:
for j in imp_node.items.items:
assert isinstance(j, ast.ModuleItem)
mod_path = f"{imp_node.from_loc.dot_path_str}.{j.name.sym_name}"
self.import_py_module(
parent_node=j,
mod_path=mod_path,
imported_mod_name=(
j.name.sym_name if not j.alias else j.alias.sym_name
),
)
self.__process_import_from(imp_node)
else:
for j in imp_node.items.items:
assert isinstance(j, ast.ModulePath)
self.import_py_module(
parent_node=j,
mod_path=j.dot_path_str,
imported_mod_name=(
j.dot_path_str.replace(".", "")
if not j.alias
else j.alias.sym_name
),
)

def import_py_module(
self.__process_import(imp_node)

def __process_import_from(self, imp_node: ast.Import) -> None:
"""Process imports in the form of `from X import I`."""
assert isinstance(self.ir, ast.Module)
assert imp_node.from_loc is not None

# Attempt to import the Python module X and process it
imported_mod = self.__import_py_module(
parent_node_path=self.__get_current_module(imp_node),
mod_path=imp_node.from_loc.dot_path_str,
temp_import=True, # Prevents saving the module during this pass
)

if not imported_mod:
return

# Update the imported module's raise map and re-run the import pass
# This is not correct, we won't follow the imports of the imported mods
# this will affect using module B that is imported in module
# A when A is imported
imported_mod.py_raise_map = self.ir.py_raise_map
PyImportPass(input_ir=imported_mod, prior=None)

for decl_item in imp_node.items.items:
assert isinstance(decl_item, ast.ModuleItem)
if decl_item.is_imported:
continue

# Try to match the declaration with a module
if self.__process_module_declaration(decl_item, imported_mod):
continue

# Try to match the declaration with an ability
if self.__process_ability_declaration(decl_item, imported_mod):
continue

# Try to match the declaration with an architype
if self.__process_architype_declaration(decl_item, imported_mod):
continue

# Try to match the declaration with an assignment
if self.__process_var_declarations(decl_item, imported_mod):
continue

def __process_module_declaration(
self, decl_item: ast.ModuleItem, imported_mod: ast.Module
) -> bool:
"""Process the case where the declaration is a module."""
for mod in imported_mod.kid_of_type(ast.Module):
if decl_item.name.sym_name == mod.name:
if decl_item.alias:
mod.name = decl_item.alias.sym_name
self.attach_mod_to_node(decl_item, mod)
self.run_again = False
SymTabBuildPass(input_ir=mod, prior=self)
decl_item.is_imported = True
return True
return False

def __process_ability_declaration(
self, decl_item: ast.ModuleItem, imported_mod: ast.Module
) -> bool:
"""Process the case where the declaration is an ability."""
for ab in imported_mod.kid_of_type(ast.Ability):
assert isinstance(ab, ast.Ability)
if decl_item.name.sym_name == ab.name_ref.sym_name:
if decl_item.alias:
ab.name_ref._sym_name = decl_item.alias.sym_name
self.attach_mod_to_node(decl_item, ab)
self.run_again = False
SymTabBuildPass(
input_ir=decl_item.parent_of_type(ast.Module), prior=self
)
decl_item.is_imported = True
ab.is_raised_from_py = True
return True
return False

def __process_architype_declaration(
self, decl_item: ast.ModuleItem, imported_mod: ast.Module
) -> bool:
"""Process the case where the declaration is an architype."""
for arch in imported_mod.kid_of_type(ast.Architype):
if decl_item.name.sym_name == arch.sym_name:
if decl_item.alias:
arch.name._sym_name = decl_item.alias.sym_name
self.attach_mod_to_node(decl_item, arch)
self.run_again = False
SymTabBuildPass(
input_ir=decl_item.parent_of_type(ast.Module), prior=self
)
arch.is_raised_from_py = True
decl_item.is_imported = True
return True
return False

def __process_var_declarations(
self, decl_item: ast.ModuleItem, imported_mod: ast.Module
) -> bool:
# An issue migh happen here if the variaable was initialized with
# a specific value then was re-assigned again and i added the ast of
# the first assignment in the current mod.
# TODO: Need to make sure that import pass won't affect the PyOutPass
for var_assignment in imported_mod.kid_of_type(ast.Assignment):
var = var_assignment.target.items[0]
if not isinstance(var, ast.Name):
continue
if decl_item.name.sym_name == var.sym_name:
if decl_item.alias:
var.name_spec._sym_name = decl_item.alias.sym_name
self.attach_mod_to_node(decl_item, var_assignment)
self.run_again = False
SymTabBuildPass(
input_ir=decl_item.parent_of_type(ast.Module), prior=self
)
var_assignment.is_raised_from_py = True
decl_item.is_imported = True
return True
return False

def __process_import(self, imp_node: ast.Import) -> None:
"""Process the imports in form of `import X`."""
# Expected that each ImportStatement will import one item
# In case of this assertion fired then we need to revisit this item
assert len(imp_node.items.items) == 1
imported_item = imp_node.items.items[0]
assert isinstance(imported_item, ast.ModulePath)
imported_mod = self.__import_py_module(
parent_node_path=self.__get_current_module(imported_item),
mod_path=imported_item.dot_path_str,
imported_mod_name=(
# TODO: Check this replace
imported_item.dot_path_str.replace(".", "")
if not imported_item.alias
else imported_item.alias.sym_name
),
)
if imported_mod:
self.attach_mod_to_node(imported_item, imported_mod)
SymTabBuildPass(input_ir=imported_mod, prior=self)

def __import_py_module(
self,
parent_node: ast.ModulePath | ast.ModuleItem,
imported_mod_name: str,
parent_node_path: str,
mod_path: str,
imported_mod_name: Optional[str] = None,
temp_import: bool = False,
) -> Optional[ast.Module]:
"""Import a module."""
"""Import a python module."""
from jaclang.compiler.passes.main import PyastBuildPass

assert isinstance(self.ir, ast.Module)

python_raise_map = self.ir.py_raise_map
file_to_raise = None
# We need this as when you create a new project the imported mods won't be
# imported until it's used and auto completion won't be able to find the imported
# decls as they aren't used yet :(
python_raise_map = self.ir.py_mod_dep_map
file_to_raise: Optional[str] = None

if mod_path in python_raise_map:
file_to_raise = python_raise_map[mod_path]
else:
resolved_mod_path = (
f"{self.__get_current_module(parent_node)}.{imported_mod_name}"
)
assert isinstance(self.ir, ast.Module)
# TODO: Is it fine to use imported_mod_name or get it from mod_path
resolved_mod_path = f"{parent_node_path}.{imported_mod_name}"
resolved_mod_path = resolved_mod_path.replace(f"{self.ir.name}.", "")
file_to_raise = python_raise_map.get(resolved_mod_path)

if file_to_raise is None:
return None

try:
if file_to_raise not in {None, "built-in", "frozen"}:
if file_to_raise in self.import_table:
return self.import_table[file_to_raise]

with open(file_to_raise, "r", encoding="utf-8") as f:
mod = PyastBuildPass(
input_ir=ast.PythonModuleAst(
py_ast.parse(f.read()), mod_path=file_to_raise
),
).ir
SubNodeTabPass(input_ir=mod, prior=self)
if mod:
mod.name = imported_mod_name
if file_to_raise in {None, "built-in", "frozen"}:
return None

if file_to_raise in self.import_table:
return self.import_table[file_to_raise]

with open(file_to_raise, "r", encoding="utf-8") as f:
mod = PyastBuildPass(
input_ir=ast.PythonModuleAst(
py_ast.parse(f.read()), mod_path=file_to_raise
),
).ir
SubNodeTabPass(input_ir=mod, prior=self)

if mod:
mod.name = imported_mod_name if imported_mod_name else mod.name
if not temp_import:
self.import_table[file_to_raise] = mod
self.attach_mod_to_node(parent_node, mod)
SymTabBuildPass(input_ir=mod, prior=self)
return mod
else:
raise self.ice(f"Failed to import python module {mod_path}")
return mod
else:
raise self.ice(f"Failed to import python module {mod_path}")

except Exception as e:
self.error(f"Failed to import python module {mod_path}")
raise e
return None

def __load_builtins(self) -> None:
"""Pyraise builtins to help with builtins auto complete."""
Expand Down
6 changes: 3 additions & 3 deletions jac/jaclang/tests/test_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def test_needs_import_1(self) -> None:
return f"Error While Jac to Py AST conversion: {e}"

ir = jac_pass_to_pass(py_ast_build_pass, schedule=py_code_gen_typed).ir
self.assertEqual(len(ir.get_all_sub_nodes(ast.Architype)), 7)
self.assertEqual(len(ir.get_all_sub_nodes(ast.Architype)), 8)
captured_output = io.StringIO()
sys.stdout = captured_output
jac_import("needs_import_1", base_path=self.fixture_abs_path("./"))
Expand Down Expand Up @@ -545,7 +545,7 @@ def test_needs_import_2(self) -> None:

ir = jac_pass_to_pass(py_ast_build_pass, schedule=py_code_gen_typed).ir
self.assertEqual(
len(ir.get_all_sub_nodes(ast.Architype)), 8
len(ir.get_all_sub_nodes(ast.Architype)), 11
) # Because of the Architype from math
captured_output = io.StringIO()
sys.stdout = captured_output
Expand Down Expand Up @@ -784,7 +784,7 @@ def test_deep_convert(self) -> None:
ir = jac_pass_to_pass(py_ast_build_pass, schedule=py_code_gen_typed).ir
jac_ast = ir.pp()
self.assertIn(' | +-- String - "Loop compl', jac_ast)
self.assertEqual(len(ir.get_all_sub_nodes(ast.SubNodeList)), 269)
self.assertEqual(len(ir.get_all_sub_nodes(ast.SubNodeList)), 307)
captured_output = io.StringIO()
sys.stdout = captured_output
jac_import("deep_convert", base_path=self.fixture_abs_path("./"))
Expand Down
Loading
Loading