Skip to content

Commit

Permalink
Add support for TypeIs (PEP 742)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra committed Feb 18, 2024
1 parent 9c323f6 commit af75361
Show file tree
Hide file tree
Showing 8 changed files with 1,040 additions and 23 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Add support for `TypeIs` from PEP 742 (#718)
- More PEP 695 support: generic classes and functions. Scoping rules
are not yet fully implemented. (#703)
- Fix type inference when constructing user-defined generic classes
Expand Down
15 changes: 15 additions & 0 deletions pyanalyze/annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
KVPair,
TypeAlias,
TypeAliasValue,
TypeIsExtension,
annotate_value,
AnnotatedValue,
AnySource,
Expand Down Expand Up @@ -784,6 +785,13 @@ def _type_from_subscripted_value(
return AnnotatedValue(
TypedValue(bool), [TypeGuardExtension(_type_from_value(members[0], ctx))]
)
elif is_typing_name(root, "TypeIs"):
if len(members) != 1:
ctx.show_error("TypeIs requires a single argument")
return AnyValue(AnySource.error)
return AnnotatedValue(
TypedValue(bool), [TypeIsExtension(_type_from_value(members[0], ctx))]
)
elif is_typing_name(root, "Required"):
if not is_typeddict:
ctx.show_error("Required[] used in unsupported context")
Expand Down Expand Up @@ -1179,6 +1187,13 @@ def _value_of_origin_args(
return AnnotatedValue(
TypedValue(bool), [TypeGuardExtension(_type_from_runtime(args[0], ctx))]
)
elif is_typing_name(origin, "TypeIs"):
if len(args) != 1:
ctx.show_error("TypeIs requires a single argument")
return AnyValue(AnySource.error)
return AnnotatedValue(
TypedValue(bool), [TypeIsExtension(_type_from_runtime(args[0], ctx))]
)
elif is_typing_name(origin, "Final"):
if len(args) != 1:
ctx.show_error("Final requires a single argument")
Expand Down
4 changes: 4 additions & 0 deletions pyanalyze/error_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class ErrorCode(enum.Enum):
reveal_type = 87
missing_generic_parameters = 88
disallowed_import = 89
typeis_must_be_subtype = 90
invalid_typeguard = 91


# Allow testing unannotated functions without too much fuss
Expand Down Expand Up @@ -239,6 +241,8 @@ class ErrorCode(enum.Enum):
ErrorCode.override_does_not_override: "Method does not override any base method",
ErrorCode.missing_generic_parameters: "Missing type parameters for generic type",
ErrorCode.disallowed_import: "Disallowed import",
ErrorCode.typeis_must_be_subtype: "TypeIs narrowed type must be a subtype of the input type",
ErrorCode.invalid_typeguard: "Invalid use of TypeGuard or TypeIs",
}


Expand Down
1 change: 1 addition & 0 deletions pyanalyze/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class FunctionInfo:
is_override: bool # @typing.override
is_evaluated: bool # @pyanalyze.extensions.evaluated
is_abstractmethod: bool # has @abstractmethod
is_instancemethod: bool # is an instance method
# a list of tuples of (decorator function, applied decorator function, AST node). These are
# different for decorators that take arguments, like @asynq(): the first element will be the
# asynq function and the second will be the result of calling asynq().
Expand Down
55 changes: 55 additions & 0 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
KWARGS,
MaybeSignature,
OverloadedSignature,
ParameterKind,
Signature,
SigParameter,
)
Expand Down Expand Up @@ -178,6 +179,8 @@
SkipDeprecatedExtension,
TypeAlias,
TypeAliasValue,
TypeGuardExtension,
TypeIsExtension,
annotate_value,
AnnotatedValue,
AnySource,
Expand Down Expand Up @@ -1932,6 +1935,9 @@ def compute_function_info(
is_classmethod=is_classmethod,
is_staticmethod=is_staticmethod,
is_abstractmethod=is_abstractmethod,
is_instancemethod=is_nested_in_class
and not is_classmethod
and not is_staticmethod,
is_overload=is_overload,
is_override=is_override,
is_evaluated=is_evaluated,
Expand Down Expand Up @@ -2000,6 +2006,8 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value:
):
result = self._visit_function_body(info)

self.check_typeis(info)

if (
not result.has_return
and not info.is_overload
Expand Down Expand Up @@ -2056,6 +2064,53 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value:
self._set_argspec_to_retval(val, info, result)
return val

def check_typeis(self, info: FunctionInfo) -> None:
if info.return_annotation is None:
return
assert isinstance(info.node, (ast.FunctionDef, ast.AsyncFunctionDef))
assert info.node.returns is not None
_, ti = unannotate_value(info.return_annotation, TypeIsExtension)
for type_is in ti:
param = self._get_typeis_parameter(info)
if param is None:
self._show_error_if_checking(
info.node,
"TypeIs must be used on a function taking at least one positional parameter",
error_code=ErrorCode.invalid_typeguard,
)
continue
can_assign = param.annotation.can_assign(type_is.guarded_type, self)
if isinstance(can_assign, CanAssignError):
self._show_error_if_checking(
info.node.returns,
f"TypeIs narrowed type {type_is.guarded_type} is incompatible with parameter {param.name}",
error_code=ErrorCode.typeis_must_be_subtype,
detail=can_assign.display(),
)
_, tg = unannotate_value(info.return_annotation, TypeGuardExtension)
for _ in tg:
param = self._get_typeis_parameter(info)
if param is None:
self._show_error_if_checking(
info.node,
"TypeGuard must be used on a function taking at least one positional parameter",
error_code=ErrorCode.invalid_typeguard,
)

def _get_typeis_parameter(self, info: FunctionInfo) -> Optional[SigParameter]:
index = 0
if info.is_classmethod or info.is_instancemethod:
index = 1
if len(info.params) <= index:
return None
param = info.params[index].param
if param.kind not in (
ParameterKind.POSITIONAL_ONLY,
ParameterKind.POSITIONAL_OR_KEYWORD,
):
return None
return param

def _set_argspec_to_retval(
self, val: Value, info: FunctionInfo, result: FunctionResult
) -> None:
Expand Down
76 changes: 53 additions & 23 deletions pyanalyze/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from qcore.helpers import safe_str
from typing_extensions import assert_never, Literal, Protocol, Self

from pyanalyze.predicates import IsAssignablePredicate

from .error_code import ErrorCode
from .safe import safe_getattr
from .node_visitor import Replacement
Expand All @@ -62,6 +64,7 @@
from .typevar import resolve_bounds_map
from .value import (
SelfT,
TypeIsExtension,
annotate_value,
AnnotatedValue,
AnySource,
Expand Down Expand Up @@ -683,13 +686,18 @@ def _get_positional_parameter(self, index: int) -> Optional[SigParameter]:
return None

def _apply_annotated_constraints(
self, raw_return: Union[Value, ImplReturn], composites: Dict[str, Composite]
self,
raw_return: Union[Value, ImplReturn],
composites: Dict[str, Composite],
ctx: CheckCallContext,
) -> Value:
if isinstance(raw_return, Value):
ret = ImplReturn(raw_return)
else:
ret = raw_return
constraints = [ret.constraint]
constraints = []
if ret.constraint is not NULL_CONSTRAINT:
constraints.append(ret.constraint)
return_value = ret.return_value
no_return_unless = ret.no_return_unless
if isinstance(return_value, AnnotatedValue):
Expand All @@ -707,28 +715,31 @@ def _apply_annotated_constraints(
guard.guarded_type,
)
constraints.append(constraint)

return_value, tg = unannotate_value(return_value, TypeGuardExtension)
for guard in tg:
# This might miss some cases where we should use the second argument instead. We'll
# have to come up with additional heuristics if that comes up.
if isinstance(self.callable, MethodType) or (
isinstance(self.callable, FunctionType)
and self.callable.__name__ != self.callable.__qualname__
):
index = 1
else:
index = 0
param = self._get_positional_parameter(index)
if param is not None:
composite = composites[param.name]
if composite.varname is not None:
constraint = Constraint(
composite.varname,
ConstraintType.is_value_object,
True,
guard.guarded_type,
)
constraints.append(constraint)
varname = self._get_typeguard_varname(composites)
if varname is not None:
constraint = Constraint(
varname,
ConstraintType.is_value_object,
True,
guard.guarded_type,
)
constraints.append(constraint)

return_value, ti = unannotate_value(return_value, TypeIsExtension)
for guard in ti:
varname = self._get_typeguard_varname(composites)
if varname is not None and ctx.visitor is not None:
predicate = IsAssignablePredicate(
guard.guarded_type, ctx.visitor, positive_only=False
)
constraint = Constraint(
varname, ConstraintType.predicate, True, predicate
)
constraints.append(constraint)

return_value, hag = unannotate_value(return_value, HasAttrGuardExtension)
for guard in hag:
if guard.varname in composites:
Expand Down Expand Up @@ -768,6 +779,25 @@ def _apply_annotated_constraints(
extensions.append(NoReturnConstraintExtension(no_return_unless))
return annotate_value(return_value, extensions)

def _get_typeguard_varname(
self, composites: Dict[str, Composite]
) -> Optional[VarnameWithOrigin]:
# This might miss some cases where we should use the second argument instead. We'll
# have to come up with additional heuristics if that comes up.
if isinstance(self.callable, MethodType) or (
isinstance(self.callable, FunctionType)
and self.callable.__name__ != self.callable.__qualname__
):
index = 1
else:
index = 0
param = self._get_positional_parameter(index)
if param is not None:
composite = composites[param.name]
if composite.varname is not None:
return composite.varname
return None

def bind_arguments(
self, actual_args: ActualArguments, ctx: CheckCallContext
) -> Optional[BoundArgs]:
Expand Down Expand Up @@ -1308,7 +1338,7 @@ def check_call_with_bound_args(
)
else:
return_value = runtime_return
ret = self._apply_annotated_constraints(return_value, composites)
ret = self._apply_annotated_constraints(return_value, composites, ctx)
return CallReturn(
ret,
is_error=had_error,
Expand Down
Loading

0 comments on commit af75361

Please sign in to comment.