Skip to content

Commit

Permalink
More progress on PEP 695 (#692)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Sep 30, 2023
1 parent b05bdad commit 3260bb0
Show file tree
Hide file tree
Showing 6 changed files with 227 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

## Unreleased

- Partial support for PEP 695-style type aliases (#690, #692)
- Fix tests to account for new `typeshed_client` release
(#694)
- Partial support for PEP 695-style type aliases (#690)
- Add option to disable all error codes (#659)
- Add hacky fix for bugs with hashability on type objects (#689)
- Show an error on calls to `typing.Any` (#688)
Expand Down
10 changes: 7 additions & 3 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
TypeVar,
Union,
)
import typing_extensions

import qcore

Expand Down Expand Up @@ -539,11 +540,14 @@ def _type_from_runtime(


def make_type_var_value(tv: TypeVarLike, ctx: Context) -> TypeVarValue:
if tv.__bound__ is not None:
if (
isinstance(tv, (TypeVar, typing_extensions.TypeVar))
and tv.__bound__ is not None
):
bound = _type_from_runtime(tv.__bound__, ctx)
else:
bound = None
if isinstance(tv, TypeVar) and tv.__constraints__:
if isinstance(tv, (TypeVar, typing_extensions.TypeVar)) and tv.__constraints__:
constraints = tuple(
_type_from_runtime(constraint, ctx) for constraint in tv.__constraints__
)
Expand Down Expand Up @@ -656,7 +660,7 @@ def _type_from_value(
return _type_from_runtime(
value.val, ctx, is_typeddict=is_typeddict, allow_unpack=allow_unpack
)
elif isinstance(value, TypeVarValue):
elif isinstance(value, (TypeVarValue, TypeAliasValue)):
return value
elif isinstance(value, MultiValuedValue):
return unite_values(
Expand Down
99 changes: 95 additions & 4 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import sys
import traceback
import types
import typing
from argparse import ArgumentParser
from dataclasses import dataclass
from itertools import chain
Expand Down Expand Up @@ -100,6 +101,7 @@
from .predicates import EqualsPredicate, InPredicate
from .reexport import ImplicitReexportTracker
from .safe import (
all_of_type,
is_dataclass_type,
is_hashable,
safe_getattr,
Expand Down Expand Up @@ -162,6 +164,8 @@
DefiniteValueExtension,
DeprecatedExtension,
SkipDeprecatedExtension,
TypeAlias,
TypeAliasValue,
annotate_value,
AnnotatedValue,
AnySource,
Expand Down Expand Up @@ -1754,22 +1758,22 @@ def visit_ClassDef(self, node: ast.ClassDef) -> Value:
value, _ = self._set_name_in_scope(node.name, node, value)
return value

def _get_class_object(self, node: ast.ClassDef) -> Value:
def _get_local_object(self, name: str, node: ast.AST) -> Value:
if self.scopes.scope_type() == ScopeType.module_scope:
return self.scopes.get(node.name, node, self.state)
return self.scopes.get(name, node, self.state)
elif (
self.scopes.scope_type() == ScopeType.class_scope
and self.current_class is not None
and hasattr(self.current_class, "__dict__")
):
runtime_obj = self.current_class.__dict__.get(node.name)
runtime_obj = self.current_class.__dict__.get(name)
if isinstance(runtime_obj, type):
return KnownValue(runtime_obj)
return AnyValue(AnySource.inference)

def _visit_class_and_get_value(self, node: ast.ClassDef) -> Value:
if self._is_checking():
cls_obj = self._get_class_object(node)
cls_obj = self._get_local_object(node.name, node)

module = self.module
if isinstance(cls_obj, MultiValuedValue) and module is not None:
Expand Down Expand Up @@ -4506,6 +4510,93 @@ def visit_AugAssign(self, node: ast.AugAssign) -> None:
# syntax like 'x = y = 0' results in multiple targets
self.visit(node.target)

if sys.version_info >= (3, 12):

def visit_TypeAlias(self, node: ast.TypeAlias) -> Value:
assert isinstance(node.name, ast.Name)
name = node.name.id
alias_val = self._get_local_object(name, node)
if isinstance(alias_val, KnownValue) and isinstance(
alias_val.val, typing.TypeAliasType
):
alias_obj = alias_val.val
else:
alias_obj = None
type_param_values = []
if self._is_checking():
if node.type_params:
with self.scopes.add_scope(
ScopeType.annotation_scope,
scope_node=node,
scope_object=alias_obj,
):
type_param_values = [
self.visit(param) for param in node.type_params
]
assert all_of_type(type_param_values, TypeVarValue)
with self.scopes.add_scope(
ScopeType.annotation_scope,
scope_node=node,
scope_object=alias_obj,
):
value = self.visit(node.value)

else:
with self.scopes.add_scope(
ScopeType.annotation_scope,
scope_node=node,
scope_object=alias_obj,
):
value = self.visit(node.value)
else:
value = None
if alias_obj is None:
if value is None:
alias_val = AnyValue(AnySource.inference)
else:
alias_val = TypeAliasValue(
name,
self.module.__name__ if self.module is not None else "",
TypeAlias(
lambda: type_from_value(value, self, node),
lambda: tuple(val.typevar for val in type_param_values),
),
)
set_value, _ = self._set_name_in_scope(name, node, alias_val)
return set_value

def visit_TypeVar(self, node: ast.TypeVar) -> Value:
bound = constraints = None
if node.bound is not None:
if isinstance(node.bound, ast.Tuple):
constraints = [self.visit(elt) for elt in node.bound.elts]
else:
bound = self.visit(node.bound)
tv = TypeVar(node.name)
typevar = TypeVarValue(
tv,
type_from_value(bound, self, node) if bound is not None else None,
(
tuple(type_from_value(c, self, node) for c in constraints)
if constraints is not None
else ()
),
)
self._set_name_in_scope(node.name, node, typevar)
return typevar

def visit_ParamSpec(self, node: ast.ParamSpec) -> Value:
ps = typing.ParamSpec(node.name)
typevar = TypeVarValue(ps, is_paramspec=True)
self._set_name_in_scope(node.name, node, typevar)
return typevar

def visit_TypeVarTuple(self, node: ast.TypeVarTuple) -> Value:
tv = TypeVar(node.name)
typevar = TypeVarValue(tv, is_typevartuple=True)
self._set_name_in_scope(node.name, node, typevar)
return typevar

def visit_Name(self, node: ast.Name, force_read: bool = False) -> Value:
return self.composite_from_name(node, force_read=force_read).value

Expand Down
1 change: 1 addition & 0 deletions pyanalyze/stacked_scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class ScopeType(enum.Enum):
module_scope = 2
class_scope = 3
function_scope = 4
annotation_scope = 5


# Nodes as used in scopes can be any object, as long as they are hashable.
Expand Down
96 changes: 96 additions & 0 deletions pyanalyze/test_type_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# static analysis: ignore
from .test_name_check_visitor import TestNameCheckVisitorBase
from .test_node_visitor import assert_passes, skip_before


class TestRecursion(TestNameCheckVisitorBase):
@assert_passes()
def test(self):
from typing import Dict, List, Union

JSON = Union[Dict[str, "JSON"], List["JSON"], int, str, float, bool, None]

def f(x: JSON):
pass

def capybara():
f([])
f([1, 2, 3])
f([[{1}]]) # TODO this should throw an error


class TestTypeAliasType(TestNameCheckVisitorBase):
@assert_passes()
def test_typing_extensions(self):
from typing_extensions import TypeAliasType, assert_type

MyType = TypeAliasType("MyType", int)

def f(x: MyType):
assert_type(x, MyType)
assert_type(x + 1, int)

def capybara(i: int, s: str):
f(i)
f(s) # E: incompatible_argument

@assert_passes()
def test_typing_extensions_generic(self):
from typing_extensions import TypeAliasType, assert_type
from typing import TypeVar, Union, List, Set

T = TypeVar("T")
MyType = TypeAliasType("MyType", Union[List[T], Set[T]], type_params=(T,))

def f(x: MyType[int]):
assert_type(x, MyType[int])
assert_type(list(x), List[int])

def capybara(i: int, s: str):
f([i])
f([s]) # E: incompatible_argument

@skip_before((3, 12))
def test_312(self):
self.assert_passes("""
from typing_extensions import assert_type
type MyType = int
def f(x: MyType):
assert_type(x, MyType)
assert_type(x + 1, int)
def capybara(i: int, s: str):
f(i)
f(s) # E: incompatible_argument
""")

@skip_before((3, 12))
def test_312_generic(self):
self.assert_passes("""
from typing_extensions import assert_type
type MyType[T] = list[T] | set[T]
def f(x: MyType[int]):
assert_type(x, MyType[int])
assert_type(list(x), list[int])
def capybara(i: int, s: str):
f([i])
f([s]) # E: incompatible_argument
""")

@skip_before((3, 12))
def test_312_local_alias(self):
self.assert_passes("""
from typing_extensions import assert_type
def capybara():
type MyType = int
def f(x: MyType):
assert_type(x, MyType)
assert_type(x + 1, int)
f(1)
f("x") # E: incompatible_argument
""")
30 changes: 27 additions & 3 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def function(x: int, y: list[int], z: Any):
from collections import deque
from dataclasses import dataclass, field, InitVar
from itertools import chain
import sys
from types import FunctionType
from typing import (
Any,
Expand Down Expand Up @@ -60,9 +61,31 @@ def function(x: int, y: list[int], z: Any):
KNOWN_MUTABLE_TYPES = (list, set, dict, deque)
ITERATION_LIMIT = 1000

TypeVarLike = Union[
ExternalType["typing.TypeVar"], ExternalType["typing_extensions.ParamSpec"]
]
if sys.version_info >= (3, 11):
TypeVarLike = Union[
ExternalType["typing.TypeVar"],
ExternalType["typing_extensions.TypeVar"],
ExternalType["typing.ParamSpec"],
ExternalType["typing_extensions.ParamSpec"],
ExternalType["typing.TypeVarTuple"],
ExternalType["typing_extensions.TypeVarTuple"],
]
elif sys.version_info >= (3, 10):
TypeVarLike = Union[
ExternalType["typing.TypeVar"],
ExternalType["typing_extensions.TypeVar"],
ExternalType["typing.ParamSpec"],
ExternalType["typing_extensions.ParamSpec"],
ExternalType["typing_extensions.TypeVarTuple"],
]
else:
TypeVarLike = Union[
ExternalType["typing.TypeVar"],
ExternalType["typing_extensions.TypeVar"],
ExternalType["typing_extensions.ParamSpec"],
ExternalType["typing_extensions.TypeVarTuple"],
]

TypeVarMap = Mapping[TypeVarLike, ExternalType["pyanalyze.value.Value"]]
BoundsMap = Mapping[TypeVarLike, Sequence[ExternalType["pyanalyze.value.Bound"]]]
GenericBases = Mapping[Union[type, str], TypeVarMap]
Expand Down Expand Up @@ -1737,6 +1760,7 @@ class TypeVarValue(Value):
bound: Optional[Value] = None
constraints: Sequence[Value] = ()
is_paramspec: bool = False
is_typevartuple: bool = False # unsupported

def substitute_typevars(self, typevars: TypeVarMap) -> Value:
return typevars.get(self.typevar, self)
Expand Down

0 comments on commit 3260bb0

Please sign in to comment.