Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of some string forward references: #652

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Improve handling of some string forward references (#652)
- Add hardcoded support for `pytest.raises` to avoid false
positives (#651)
- Fix crash with nested classes in stubs. For now, `Any` is
Expand Down
41 changes: 33 additions & 8 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
show errors.

"""
from _ast import Dict
import ast
import builtins
import contextlib
import sys
from types import ModuleType
import typing
from collections.abc import Callable, Hashable
from dataclasses import dataclass, field, InitVar
Expand Down Expand Up @@ -135,6 +136,7 @@ class Context:
"""

should_suppress_undefined_names: bool = field(default=False, init=False)
module: Optional[ModuleType] = field(default=None, init=False)
"""While this is True, no errors are shown for undefined names."""
_being_evaluated: Set[int] = field(default_factory=set, init=False)

Expand All @@ -159,6 +161,16 @@ def add_evaluation(self, obj: object) -> Generator[None, None, None]:
finally:
self._being_evaluated.remove(obj_id)

@contextlib.contextmanager
def override_module(self, module: ModuleType) -> Generator[None, None, None]:
"""Temporarily override the module used for name resolution."""
old_module = self.module
self.module = module
try:
yield
finally:
self.module = old_module

def show_error(
self,
message: str,
Expand All @@ -170,6 +182,8 @@ def show_error(

def get_name(self, node: ast.Name) -> Value:
"""Return the :class:`Value <pyanalyze.value.Value>` corresponding to a name."""
if self.module is not None:
return self.get_name_from_globals(node.id, self.module.__dict__)
return AnyValue(AnySource.inference)

def handle_undefined_name(self, name: str) -> Value:
Expand Down Expand Up @@ -214,6 +228,8 @@ def evaluate_value(self, node: ast.AST) -> Value:

def get_name(self, node: ast.Name) -> Value:
"""Return the :class:`Value <pyanalyze.value.Value>` corresponding to a name."""
if self.module is not None:
return self.get_name_from_globals(node.id, self.module.__dict__)
return self.get_name_from_globals(node.id, self.globals)


Expand Down Expand Up @@ -481,12 +497,19 @@ def _type_from_runtime(
if ctx.is_being_evaluted(val):
return AnyValue(AnySource.inference)
with ctx.add_evaluation(val):
# This is necessary because the forward ref may be defined in a different file, in
# which case we don't know which names are valid in it.
with ctx.suppress_undefined_names():
return _eval_forward_ref(
val.__forward_arg__, ctx, is_typeddict=is_typeddict
)
if (
hasattr(val, "__forward_module__")
and val.__forward_module__ is not None
):
mod = sys.modules.get(val.__forward_module__)
if mod is not None:
with ctx.override_module(mod):
return _eval_forward_ref(
val.__forward_arg__, ctx, is_typeddict=is_typeddict
)
return _eval_forward_ref(
val.__forward_arg__, ctx, is_typeddict=is_typeddict
)
elif val is Ellipsis:
# valid in Callable[..., ]
return AnyValue(AnySource.explicit)
Expand Down Expand Up @@ -845,6 +868,8 @@ def show_error(
self.visitor.show_error(node, message, error_code)

def get_name(self, node: ast.Name) -> Value:
if self.module is not None:
return self.get_name_from_globals(node.id, self.module.__dict__)
if self.visitor is not None:
val, _ = self.visitor.resolve_name(
node,
Expand Down Expand Up @@ -924,7 +949,7 @@ def visit_Set(self, node: ast.Set) -> Value:
elts = [(False, self.visit(elt)) for elt in node.elts]
return SequenceValue(set, elts)

def visit_Dict(self, node: Dict) -> Any:
def visit_Dict(self, node: ast.Dict) -> Any:
keys = [self.visit(key) if key is not None else None for key in node.keys]
values = [self.visit(value) for value in node.values]
kvpairs = []
Expand Down
2 changes: 2 additions & 0 deletions pyanalyze/arg_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ def __post_init__(self) -> None:
super().__init__()

def get_name(self, node: ast.Name) -> Value:
if self.module is not None:
return self.get_name_from_globals(node.id, self.module.__dict__)
if self.globals is not None:
return self.get_name_from_globals(node.id, self.globals)
return self.handle_undefined_name(node.id)
Expand Down
2 changes: 2 additions & 0 deletions pyanalyze/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,8 @@ def __post_init__(self) -> None:
super().__init__()

def get_name(self, node: ast.Name) -> Value:
if self.module is not None:
return self.get_name_from_globals(node.id, self.module.__dict__)
try:
if isinstance(self.cls, types.ModuleType):
globals = self.cls.__dict__
Expand Down
27 changes: 27 additions & 0 deletions pyanalyze/test_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,33 @@ def capybara(x: "List[int]") -> "List[str]":
assert_is_value(capybara(x), GenericValue(list, [TypedValue(str)]))
return []

@assert_passes()
def test_nested_forward_ref(self):
from typing import List

def func(x: List["int"]) -> None:
pass

def no_such_type(x: List["doesnt_exist"]) -> None: # E: undefined_name
pass

def capybara() -> None:
func([1])
func(["x"]) # E: incompatible_argument

@skip_before((3, 9))
@assert_passes()
def test_nested_forward_ref_pep_585(self):
def func(x: list["int"]) -> None:
pass

def no_such_type(x: list["doesnt_exist"]) -> None: # E: undefined_name
pass

def capybara() -> None:
func([1])
func(["x"]) # E: incompatible_argument

@assert_passes()
def test_forward_ref_incompatible(self):
def f() -> "int":
Expand Down
16 changes: 10 additions & 6 deletions pyanalyze/typeshed.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
@dataclass
class _AnnotationContext(Context):
finder: "TypeshedFinder"
module: str
module_name: str

def show_error(
self,
Expand All @@ -106,7 +106,9 @@ def show_error(
self.finder.log(message, ())

def get_name(self, node: ast.Name) -> Value:
return self.finder.resolve_name(self.module, node.id)
if self.module is not None:
return self.get_name_from_globals(node.id, self.module.__dict__)
return self.finder.resolve_name(self.module_name, node.id)

def get_attribute(self, root_value: Value, node: ast.Attribute) -> Value:
if isinstance(root_value, KnownValue):
Expand Down Expand Up @@ -1002,7 +1004,9 @@ def _parse_param(
[
make_type_var_value(
tv,
_AnnotationContext(finder=self, module=tv.__module__),
_AnnotationContext(
finder=self, module_name=tv.__module__
),
)
for tv in typevars
],
Expand Down Expand Up @@ -1030,7 +1034,7 @@ def _parse_param(
return SigParameter(name, kind, annotation=typ, default=default_value)

def _parse_expr(self, node: ast.AST, module: str) -> Value:
ctx = _AnnotationContext(finder=self, module=module)
ctx = _AnnotationContext(finder=self, module_name=module)
return value_from_ast(node, ctx=ctx)

def _parse_type(
Expand All @@ -1042,7 +1046,7 @@ def _parse_type(
allow_unpack: bool = False,
) -> Value:
val = self._parse_expr(node, module)
ctx = _AnnotationContext(finder=self, module=module)
ctx = _AnnotationContext(finder=self, module_name=module)
typ = type_from_value(
val, ctx=ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)
Expand All @@ -1064,7 +1068,7 @@ def _parse_call_assignment(
info.ast.value, ast.Call
):
return AnyValue(AnySource.inference)
ctx = _AnnotationContext(finder=self, module=module)
ctx = _AnnotationContext(finder=self, module_name=module)
return value_from_ast(info.ast.value, ctx=ctx)

def _extract_metadata(self, module: str, node: ast.ClassDef) -> Sequence[Extension]:
Expand Down