Skip to content

Commit

Permalink
refactor load_compiled_module_from_source
Browse files Browse the repository at this point in the history
Summary:
Refactors `load_compiled_module_from_source` to only run the
strict analysis that requires the `mod` object from
`loader.check_source` in the case where either static or strict flags
are set.

I wanted to isolate out this behavior in this diff before I separate out
static and strict analysis to make that change a 1-liner to ensure that
I'm not creating any bugs down the line.

The diff to separate strict/static analysis will come in the next diff.

Reviewed By: carljm

Differential Revision: D49735184

fbshipit-source-id: 0d7c88c5ff9d1e2758dbf19bcf09edd5c2d29c61
  • Loading branch information
pilleye authored and facebook-github-bot committed Oct 6, 2023
1 parent 0a93f1d commit 153701b
Showing 1 changed file with 65 additions and 40 deletions.
105 changes: 65 additions & 40 deletions Lib/compiler/strict/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,9 @@ def import_module(self, name: str, optimize: int) -> Optional[ModuleTable]:
stubKind = mod.stub_kind
if STUB_KIND_MASK_TYPING & stubKind:
root = remove_annotations(root)
root = self._get_rewritten_ast(name, mod, root, optimize)
root = self._get_rewritten_ast(
name, root, getSymbolTable(mod), mod.file_name, optimize
)
log = self.log_time_func
ctx = (
log()(name, mod.file_name, "declaration_visit")
Expand All @@ -138,13 +140,17 @@ def import_module(self, name: str, optimize: int) -> Optional[ModuleTable]:
return self.modules.get(name)

def _get_rewritten_ast(
self, name: str, mod: StrictAnalysisResult, root: ast.Module, optimize: int
self,
name: str,
root: ast.Module,
symbols: PythonSymbolTable,
filename: str,
optimize: int,
) -> ast.Module:
symbols = getSymbolTable(mod)
return rewrite(
root,
symbols,
mod.file_name,
filename,
name,
optimize=optimize,
is_static=True,
Expand All @@ -169,38 +175,58 @@ def load_compiled_module_from_source(
if override_flags and override_flags.is_strict:
self.logger.debug(f"Forcibly treating module {name} as strict")
self.loader.set_force_strict_by_name(name)
# TODO(pilleye): Only call this when no side effect analysis is requested

pyast = ast.parse(source)
symbols = symtable.symtable(source, filename, "exec")
flags = FlagExtractor().get_flags(pyast).merge(override_flags)

if not flags.is_static and not flags.is_strict:
code = self._compile_basic(name, pyast, filename, optimize)
return (code, False)

# TODO: Remove the check when static is enabled in the next diff to isolate errors
is_valid_strict = False
if flags.is_strict or flags.is_static:
is_valid_strict = self._strict_analyze(
source, flags, symbols, filename, name, submodule_search_locations
)

if flags.is_static:
code = self._compile_static(pyast, symbols, filename, name, optimize)
return (code, is_valid_strict)
else:
code = self._compile_strict(pyast, symbols, filename, name, optimize)
return (code, is_valid_strict)

def _strict_analyze(
self,
source: str | bytes,
flags: Flags,
symbols: PythonSymbolTable,
filename: str,
name: str,
submodule_search_locations: Optional[List[str]] = None,
) -> bool:
mod = self.loader.check_source(
source, filename, name, submodule_search_locations or []
)
flags = FlagExtractor().get_flags(pyast).merge(override_flags)

errors = mod.errors
is_valid_strict = (
mod.is_valid and len(errors) == 0 and (flags.is_static or flags.is_strict)
)
if errors and self.raise_on_error:
# if raise on error, just raise the first error
error = errors[0]
is_valid_strict = mod.is_valid and len(mod.errors) == 0

if mod.errors and self.raise_on_error:
error = mod.errors[0]
raise StrictModuleError(error[0], error[1], error[2], error[3])
elif is_valid_strict:
symbols = symtable.symtable(source, filename, "exec")
try:
check_class_conflict(pyast, filename, symbols)
except StrictModuleError as e:
if self.raise_on_error:
raise
mod.errors.append((e.msg, e.filename, e.lineno, e.col))

if not is_valid_strict:
code = self._compile_basic(name, pyast, filename, optimize)
elif flags.is_static:
code = self._compile_static(mod, filename, name, optimize)
else:
code = self._compile_strict(mod, filename, name, optimize)

return code, is_valid_strict
# TODO: Figure out if we need to run this analysis. This should be done only for
# static analysis and not necessarily for strict modules. Keeping it for now since
# it is currently running with the strict compiler.
try:
check_class_conflict(mod.ast, filename, symbols)
except StrictModuleError as e:
if self.raise_on_error:
raise

return is_valid_strict

def _compile_basic(
self, name: str, root: ast.Module, filename: str, optimize: int
Expand All @@ -215,14 +241,14 @@ def _compile_basic(

def _compile_strict(
self,
mod: StrictAnalysisResult,
root: ast.Module,
symbols: PythonSymbolTable,
filename: str,
name: str,
optimize: int,
) -> CodeType:
symbols = getSymbolTable(mod)
tree = rewrite(
mod.ast,
root,
symbols,
filename,
name,
Expand All @@ -233,21 +259,20 @@ def _compile_strict(

def _compile_static(
self,
mod: StrictAnalysisResult,
root: ast.Module,
symbols: PythonSymbolTable,
filename: str,
name: str,
optimize: int,
) -> CodeType | None:
root = self.ast_cache.get(name)
if root is None:
root = self._get_rewritten_ast(name, mod, mod.ast, optimize)
code = None

root = self.ast_cache.get(name) or self._get_rewritten_ast(
name, root, symbols, filename, optimize
)
try:
log = self.log_time_func
ctx = log()(name, filename, "compile") if log else nullcontext()
with ctx:
code = self.compile(
return self.compile(
name,
filename,
root,
Expand All @@ -266,4 +291,4 @@ def _compile_static(
if self.raise_on_error:
raise err

return code
return None

0 comments on commit 153701b

Please sign in to comment.