diff --git a/docs/changelog.md b/docs/changelog.md index 3521cf04..2a8875ef 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Add support for `TypeIs` from PEP 742 (#718) - More PEP 695 support: generic classes and functions. Scoping rules are not yet fully implemented. (#703) - Fix type inference when constructing user-defined generic classes diff --git a/pyanalyze/annotations.py b/pyanalyze/annotations.py index e1261ae0..ab8714fd 100644 --- a/pyanalyze/annotations.py +++ b/pyanalyze/annotations.py @@ -86,6 +86,7 @@ KVPair, TypeAlias, TypeAliasValue, + TypeIsExtension, annotate_value, AnnotatedValue, AnySource, @@ -784,6 +785,13 @@ def _type_from_subscripted_value( return AnnotatedValue( TypedValue(bool), [TypeGuardExtension(_type_from_value(members[0], ctx))] ) + elif is_typing_name(root, "TypeIs"): + if len(members) != 1: + ctx.show_error("TypeIs requires a single argument") + return AnyValue(AnySource.error) + return AnnotatedValue( + TypedValue(bool), [TypeIsExtension(_type_from_value(members[0], ctx))] + ) elif is_typing_name(root, "Required"): if not is_typeddict: ctx.show_error("Required[] used in unsupported context") @@ -1179,6 +1187,13 @@ def _value_of_origin_args( return AnnotatedValue( TypedValue(bool), [TypeGuardExtension(_type_from_runtime(args[0], ctx))] ) + elif is_typing_name(origin, "TypeIs"): + if len(args) != 1: + ctx.show_error("TypeIs requires a single argument") + return AnyValue(AnySource.error) + return AnnotatedValue( + TypedValue(bool), [TypeIsExtension(_type_from_runtime(args[0], ctx))] + ) elif is_typing_name(origin, "Final"): if len(args) != 1: ctx.show_error("Final requires a single argument") diff --git a/pyanalyze/error_code.py b/pyanalyze/error_code.py index 9a0def3d..627e4960 100644 --- a/pyanalyze/error_code.py +++ b/pyanalyze/error_code.py @@ -107,6 +107,8 @@ class ErrorCode(enum.Enum): reveal_type = 87 missing_generic_parameters = 88 disallowed_import = 89 + typeis_must_be_subtype = 90 + invalid_typeguard = 91 # Allow testing unannotated functions without too much fuss @@ -239,6 +241,8 @@ class ErrorCode(enum.Enum): ErrorCode.override_does_not_override: "Method does not override any base method", ErrorCode.missing_generic_parameters: "Missing type parameters for generic type", ErrorCode.disallowed_import: "Disallowed import", + ErrorCode.typeis_must_be_subtype: "TypeIs narrowed type must be a subtype of the input type", + ErrorCode.invalid_typeguard: "Invalid use of TypeGuard or TypeIs", } diff --git a/pyanalyze/functions.py b/pyanalyze/functions.py index 898aa7ab..099e20ea 100644 --- a/pyanalyze/functions.py +++ b/pyanalyze/functions.py @@ -84,6 +84,7 @@ class FunctionInfo: is_override: bool # @typing.override is_evaluated: bool # @pyanalyze.extensions.evaluated is_abstractmethod: bool # has @abstractmethod + is_instancemethod: bool # is an instance method # a list of tuples of (decorator function, applied decorator function, AST node). These are # different for decorators that take arguments, like @asynq(): the first element will be the # asynq function and the second will be the result of calling asynq(). diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 668d276a..76027fce 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -129,6 +129,7 @@ KWARGS, MaybeSignature, OverloadedSignature, + ParameterKind, Signature, SigParameter, ) @@ -178,6 +179,8 @@ SkipDeprecatedExtension, TypeAlias, TypeAliasValue, + TypeGuardExtension, + TypeIsExtension, annotate_value, AnnotatedValue, AnySource, @@ -1932,6 +1935,9 @@ def compute_function_info( is_classmethod=is_classmethod, is_staticmethod=is_staticmethod, is_abstractmethod=is_abstractmethod, + is_instancemethod=is_nested_in_class + and not is_classmethod + and not is_staticmethod, is_overload=is_overload, is_override=is_override, is_evaluated=is_evaluated, @@ -2000,6 +2006,8 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value: ): result = self._visit_function_body(info) + self.check_typeis(info) + if ( not result.has_return and not info.is_overload @@ -2056,6 +2064,53 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value: self._set_argspec_to_retval(val, info, result) return val + def check_typeis(self, info: FunctionInfo) -> None: + if info.return_annotation is None: + return + assert isinstance(info.node, (ast.FunctionDef, ast.AsyncFunctionDef)) + assert info.node.returns is not None + _, ti = unannotate_value(info.return_annotation, TypeIsExtension) + for type_is in ti: + param = self._get_typeis_parameter(info) + if param is None: + self._show_error_if_checking( + info.node, + "TypeIs must be used on a function taking at least one positional parameter", + error_code=ErrorCode.invalid_typeguard, + ) + continue + can_assign = param.annotation.can_assign(type_is.guarded_type, self) + if isinstance(can_assign, CanAssignError): + self._show_error_if_checking( + info.node.returns, + f"TypeIs narrowed type {type_is.guarded_type} is incompatible with parameter {param.name}", + error_code=ErrorCode.typeis_must_be_subtype, + detail=can_assign.display(), + ) + _, tg = unannotate_value(info.return_annotation, TypeGuardExtension) + for _ in tg: + param = self._get_typeis_parameter(info) + if param is None: + self._show_error_if_checking( + info.node, + "TypeGuard must be used on a function taking at least one positional parameter", + error_code=ErrorCode.invalid_typeguard, + ) + + def _get_typeis_parameter(self, info: FunctionInfo) -> Optional[SigParameter]: + index = 0 + if info.is_classmethod or info.is_instancemethod: + index = 1 + if len(info.params) <= index: + return None + param = info.params[index].param + if param.kind not in ( + ParameterKind.POSITIONAL_ONLY, + ParameterKind.POSITIONAL_OR_KEYWORD, + ): + return None + return param + def _set_argspec_to_retval( self, val: Value, info: FunctionInfo, result: FunctionResult ) -> None: diff --git a/pyanalyze/signature.py b/pyanalyze/signature.py index 712b6bb6..e54f7d35 100644 --- a/pyanalyze/signature.py +++ b/pyanalyze/signature.py @@ -36,6 +36,8 @@ from qcore.helpers import safe_str from typing_extensions import assert_never, Literal, Protocol, Self +from pyanalyze.predicates import IsAssignablePredicate + from .error_code import ErrorCode from .safe import safe_getattr from .node_visitor import Replacement @@ -62,6 +64,7 @@ from .typevar import resolve_bounds_map from .value import ( SelfT, + TypeIsExtension, annotate_value, AnnotatedValue, AnySource, @@ -683,13 +686,18 @@ def _get_positional_parameter(self, index: int) -> Optional[SigParameter]: return None def _apply_annotated_constraints( - self, raw_return: Union[Value, ImplReturn], composites: Dict[str, Composite] + self, + raw_return: Union[Value, ImplReturn], + composites: Dict[str, Composite], + ctx: CheckCallContext, ) -> Value: if isinstance(raw_return, Value): ret = ImplReturn(raw_return) else: ret = raw_return - constraints = [ret.constraint] + constraints = [] + if ret.constraint is not NULL_CONSTRAINT: + constraints.append(ret.constraint) return_value = ret.return_value no_return_unless = ret.no_return_unless if isinstance(return_value, AnnotatedValue): @@ -707,28 +715,31 @@ def _apply_annotated_constraints( guard.guarded_type, ) constraints.append(constraint) + return_value, tg = unannotate_value(return_value, TypeGuardExtension) for guard in tg: - # This might miss some cases where we should use the second argument instead. We'll - # have to come up with additional heuristics if that comes up. - if isinstance(self.callable, MethodType) or ( - isinstance(self.callable, FunctionType) - and self.callable.__name__ != self.callable.__qualname__ - ): - index = 1 - else: - index = 0 - param = self._get_positional_parameter(index) - if param is not None: - composite = composites[param.name] - if composite.varname is not None: - constraint = Constraint( - composite.varname, - ConstraintType.is_value_object, - True, - guard.guarded_type, - ) - constraints.append(constraint) + varname = self._get_typeguard_varname(composites) + if varname is not None: + constraint = Constraint( + varname, + ConstraintType.is_value_object, + True, + guard.guarded_type, + ) + constraints.append(constraint) + + return_value, ti = unannotate_value(return_value, TypeIsExtension) + for guard in ti: + varname = self._get_typeguard_varname(composites) + if varname is not None and ctx.visitor is not None: + predicate = IsAssignablePredicate( + guard.guarded_type, ctx.visitor, positive_only=False + ) + constraint = Constraint( + varname, ConstraintType.predicate, True, predicate + ) + constraints.append(constraint) + return_value, hag = unannotate_value(return_value, HasAttrGuardExtension) for guard in hag: if guard.varname in composites: @@ -768,6 +779,25 @@ def _apply_annotated_constraints( extensions.append(NoReturnConstraintExtension(no_return_unless)) return annotate_value(return_value, extensions) + def _get_typeguard_varname( + self, composites: Dict[str, Composite] + ) -> Optional[VarnameWithOrigin]: + # This might miss some cases where we should use the second argument instead. We'll + # have to come up with additional heuristics if that comes up. + if isinstance(self.callable, MethodType) or ( + isinstance(self.callable, FunctionType) + and self.callable.__name__ != self.callable.__qualname__ + ): + index = 1 + else: + index = 0 + param = self._get_positional_parameter(index) + if param is not None: + composite = composites[param.name] + if composite.varname is not None: + return composite.varname + return None + def bind_arguments( self, actual_args: ActualArguments, ctx: CheckCallContext ) -> Optional[BoundArgs]: @@ -1308,7 +1338,7 @@ def check_call_with_bound_args( ) else: return_value = runtime_return - ret = self._apply_annotated_constraints(return_value, composites) + ret = self._apply_annotated_constraints(return_value, composites, ctx) return CallReturn( ret, is_error=had_error, diff --git a/pyanalyze/test_typeis.py b/pyanalyze/test_typeis.py new file mode 100644 index 00000000..246591b4 --- /dev/null +++ b/pyanalyze/test_typeis.py @@ -0,0 +1,892 @@ +# static analysis: ignore +from .test_name_check_visitor import TestNameCheckVisitorBase +from .test_node_visitor import assert_passes + + +class TestTypeIs(TestNameCheckVisitorBase): + @assert_passes() + def testTypeIsBasic(self): + from typing_extensions import TypeIs, assert_type + + class Point: + pass + + def is_point(a: object) -> TypeIs[Point]: + return False + + def main(a: object) -> None: + if is_point(a): + assert_type(a, Point) + else: + assert_type(a, object) + + @assert_passes() + def testTypeIsTypeArgsNone(self): + from typing_extensions import TypeIs + + def foo(a: object) -> TypeIs: # E: invalid_annotation + return False + + @assert_passes() + def testTypeIsTypeArgsTooMany(self): + from typing_extensions import TypeIs + + def foo(a: object) -> "TypeIs[int, int]": # E: invalid_annotation + return False + + @assert_passes() + def testTypeIsTypeArgType(self): + from typing_extensions import TypeIs + + def foo(a: object) -> "TypeIs[42]": # E: invalid_annotation + return False + + @assert_passes() + def testTypeIsCallArgsNone(self): + from typing_extensions import TypeIs, assert_type + + class Point: + pass + + def is_point() -> TypeIs[Point]: # E: invalid_typeguard + return False + + def main(a: object) -> None: + if is_point(): + assert_type(a, object) + + @assert_passes() + def testTypeIsCallArgsMultiple(self): + from typing_extensions import TypeIs, assert_type + + class Point: + pass + + def is_point(a: object, b: object) -> TypeIs[Point]: + return False + + def main(a: object, b: object) -> None: + if is_point(a, b): + assert_type(a, Point) + assert_type(b, object) + + @assert_passes() + def testTypeIsWithTypeVar(self): + from typing import TypeVar, Tuple, Type + from typing_extensions import TypeIs, assert_type + + T = TypeVar("T") + + def is_tuple_of_type( + a: Tuple[object, ...], typ: Type[T] + ) -> TypeIs[Tuple[T, ...]]: + return False + + def main(a: Tuple[object, ...]): + if is_tuple_of_type(a, int): + assert_type(a, Tuple[int, ...]) + + @assert_passes() + def testTypeIsUnionIn(self): + from typing import Union + from typing_extensions import TypeIs, assert_type + + def is_foo(a: Union[int, str]) -> TypeIs[str]: + return False + + def main(a: Union[str, int]) -> None: + if is_foo(a): + assert_type(a, str) + else: + assert_type(a, int) + assert_type(a, Union[str, int]) + + @assert_passes() + def testTypeIsUnionOut(self): + from typing import Union + from typing_extensions import TypeIs, assert_type + + def is_foo(a: object) -> TypeIs[Union[int, str]]: + return False + + def main(a: object) -> None: + if is_foo(a): + assert_type(a, Union[int, str]) + + @assert_passes() + def testTypeIsNonzeroFloat(self): + from typing_extensions import TypeIs, assert_type + + def is_nonzero(a: object) -> TypeIs[float]: + return False + + def main(a: int): + if is_nonzero(a): + assert_type(a, int) + + @assert_passes() + def testTypeIsHigherOrder(self): + import collections.abc + from typing import Callable, TypeVar, Iterable, List + from typing_extensions import TypeIs + from pyanalyze.value import assert_is_value, GenericValue, AnyValue, AnySource + + T = TypeVar("T") + R = TypeVar("R") + + def filter(f: Callable[[T], TypeIs[R]], it: Iterable[T]) -> Iterable[R]: + return () + + def is_float(a: object) -> TypeIs[float]: + return False + + def capybara() -> None: + a: List[object] = ["a", 0, 0.0] + b = filter(is_float, a) + # TODO should be Iterable[float] + assert_is_value( + b, + GenericValue( + collections.abc.Iterable, [AnyValue(AnySource.generic_argument)] + ), + ) + + @assert_passes() + def testTypeIsMethod(self): + from typing_extensions import TypeIs, assert_type + + class C: + def main(self, a: object) -> None: + if self.is_float(a): + assert_type(self, C) + assert_type(a, float) + + def is_float(self, a: object) -> TypeIs[float]: + return False + + @assert_passes() + def testTypeIsBodyRequiresBool(self): + from typing_extensions import TypeIs + + def is_float(a: object) -> TypeIs[float]: + return "not a bool" # E: incompatible_return_value + + @assert_passes() + def testTypeIsNarrowToTypedDict(self): + from typing import Mapping, TypedDict + from typing_extensions import TypeIs, assert_type + + class User(TypedDict): + name: str + id: int + + def is_user(a: Mapping[str, object]) -> TypeIs[User]: + return isinstance(a.get("name"), str) and isinstance(a.get("id"), int) + + def main(a: Mapping[str, object]) -> None: + if is_user(a): + assert_type(a, User) + + @assert_passes() + def testTypeIsInAssert(self): + from typing_extensions import TypeIs, assert_type + + def is_float(a: object) -> TypeIs[float]: + return False + + def main(a: object) -> None: + assert is_float(a) + assert_type(a, float) + + @assert_passes() + def testTypeIsFromAny(self): + from typing import Any + from typing_extensions import TypeIs, assert_type + + def is_objfloat(a: object) -> TypeIs[float]: + return False + + def is_anyfloat(a: Any) -> TypeIs[float]: + return False + + def objmain(a: object) -> None: + if is_objfloat(a): + assert_type(a, float) + if is_anyfloat(a): + assert_type(a, float) + + def anymain(a: Any) -> None: + if is_objfloat(a): + assert_type(a, float) + if is_anyfloat(a): + assert_type(a, float) + + @assert_passes() + def testTypeIsNegatedAndElse(self): + from typing import Union + from typing_extensions import TypeIs, assert_type + + def is_int(a: object) -> TypeIs[int]: + return False + + def is_str(a: object) -> TypeIs[str]: + return False + + def intmain(a: Union[int, str]) -> None: + if not is_int(a): + assert_type(a, str) + else: + assert_type(a, int) + + def strmain(a: Union[int, str]) -> None: + if is_str(a): + assert_type(a, str) + else: + assert_type(a, int) + + @assert_passes() + def testTypeIsClassMethod(self): + from typing_extensions import TypeIs, assert_type + + class C: + @classmethod + def is_float(cls, a: object) -> TypeIs[float]: + return False + + def method(self, a: object) -> None: + if self.is_float(a): + assert_type(a, float) + + def main(a: object) -> None: + if C.is_float(a): + assert_type(a, float) + + @assert_passes() + def testTypeIsRequiresPositionalArgs(self): + from typing_extensions import TypeIs, assert_type + + def is_float(a: object, b: object = 0) -> TypeIs[float]: + return False + + def main1(a: object) -> None: + if is_float(a=a, b=1): + assert_type(a, float) + + if is_float(b=1, a=a): + assert_type(a, float) + + @assert_passes() + def testTypeIsOverload(self): + import collections.abc + from typing import Callable, Iterable, Iterator, List, Optional, TypeVar + from typing_extensions import TypeIs, assert_type, overload + from pyanalyze.value import assert_is_value, GenericValue, AnyValue, AnySource + + T = TypeVar("T") + R = TypeVar("R") + + @overload + def filter(f: Callable[[T], TypeIs[R]], it: Iterable[T]) -> Iterator[R]: + raise NotImplementedError + + @overload + def filter(f: Callable[[T], bool], it: Iterable[T]) -> Iterator[T]: + raise NotImplementedError + + def filter(*args): + pass + + def is_int_typeguard(a: object) -> TypeIs[int]: + return False + + def is_int_bool(a: object) -> bool: + return False + + iter_any = GenericValue( + collections.abc.Iterator, [AnyValue(AnySource.generic_argument)] + ) + + def main(a: List[Optional[int]]) -> None: + bb = filter(lambda x: x is not None, a) + # TODO Iterator[Optional[int]] + assert_is_value(bb, iter_any) + # Also, if you replace 'bool' with 'Any' in the second overload, bb is Iterator[Any] + cc = filter(is_int_typeguard, a) + # TODO Iterator[int] + assert_is_value(cc, iter_any) + dd = filter(is_int_bool, a) + # TODO Iterator[Optional[int]] + assert_is_value(dd, iter_any) + + @assert_passes() + def testTypeIsDecorated(self): + from typing import TypeVar + from typing_extensions import TypeIs, assert_type + + T = TypeVar("T") + + def decorator(f: T) -> T: + return f + + @decorator + def is_float(a: object) -> TypeIs[float]: + return False + + def main(a: object) -> None: + if is_float(a): + assert_type(a, float) + + @assert_passes() + def testTypeIsMethodOverride(self): + from typing_extensions import TypeIs + + class C: + def is_float(self, a: object) -> TypeIs[float]: + return False + + class D(C): + def is_float(self, a: object) -> bool: # TODO: incompatible_override + return False + + @assert_passes() + def testTypeIsInAnd(self): + from typing import Any + from typing_extensions import TypeIs + + def isclass(a: object) -> bool: + return False + + def ismethod(a: object) -> TypeIs[float]: + return False + + def isfunction(a: object) -> TypeIs[str]: + return False + + def isclassmethod(obj: Any) -> bool: + if ( + ismethod(obj) + and obj.__self__ is not None # E: undefined_attribute + and isclass(obj.__self__) # E: undefined_attribute + ): + return True + + return False + + def coverage(obj: Any) -> bool: + if not (ismethod(obj) or isfunction(obj)): + return True + return False + + @assert_passes() + def testAssignToTypeIsedVariable1(self): + from typing_extensions import TypeIs + + class A: + pass + + class B(A): + pass + + def guard(a: A) -> TypeIs[B]: + return False + + def capybara() -> None: + a = A() + if not guard(a): + a = A() + print(a) + + @assert_passes() + def testAssignToTypeIsedVariable2(self): + from typing_extensions import TypeIs + + class A: + pass + + class B: + pass + + def guard(a: object) -> TypeIs[B]: + return False + + def capybara() -> None: + a = A() + if not guard(a): + a = A() + print(a) + + @assert_passes() + def testAssignToTypeIsedVariable3(self): + from typing_extensions import TypeIs, assert_type, Never + + class A: + pass + + class B: + pass + + def guard(a: object) -> TypeIs[B]: + return False + + def capybara() -> None: + a = A() + if guard(a): + assert_type(a, Never) # TODO A & B + a = B() + assert_type(a, B) + a = A() + assert_type(a, A) + assert_type(a, A) + + @assert_passes() + def testTypeIsNestedRestrictionAny(self): + from typing_extensions import TypeIs, assert_type + from typing import Any, Union + + class A: ... + + def f(x: object) -> TypeIs[A]: + return False + + def g(x: object) -> None: ... + + def test(x: Any) -> None: + if not (f(x) or x): + return + assert_type(x, Union[A, Any]) + + @assert_passes() + def testTypeIsNestedRestrictionUnionOther(self): + from typing_extensions import TypeIs, assert_type + from typing import Union + + class A: ... + + class B: ... + + def f(x: object) -> TypeIs[A]: + return False + + def f2(x: object) -> TypeIs[B]: + return False + + def test(x: object) -> None: + if not (f(x) or f2(x)): + return + assert_type(x, Union[A, B]) + + @assert_passes() + def testTypeIsComprehensionSubtype(self): + from typing import List + from typing_extensions import TypeIs + + class Base: ... + + class Foo(Base): ... + + class Bar(Base): ... + + def is_foo(item: object) -> TypeIs[Foo]: + return isinstance(item, Foo) + + def is_bar(item: object) -> TypeIs[Bar]: + return isinstance(item, Bar) + + def foobar(items: List[object]) -> object: + a: List[Base] = [x for x in items if is_foo(x) or is_bar(x)] + b: List[Base] = [x for x in items if is_foo(x)] + c: List[Bar] = [x for x in items if is_foo(x)] # E: incompatible_assignment + return (a, b, c) + + @assert_passes() + def testTypeIsNestedRestrictionUnionIsInstance(self): + from typing_extensions import TypeIs, assert_type + from typing import Any, List + + class A: ... + + def f(x: List[Any]) -> TypeIs[List[str]]: + return False + + def test(x: List[Any]) -> None: + if not (f(x) or isinstance(x, A)): + return + assert_type(x, List[Any]) + + @assert_passes() + def testTypeIsMultipleCondition(self): + from typing_extensions import TypeIs, assert_type, Never + + class Foo: ... + + class Bar: ... + + def is_foo(item: object) -> TypeIs[Foo]: + return isinstance(item, Foo) + + def is_bar(item: object) -> TypeIs[Bar]: + return isinstance(item, Bar) + + def foobar(x: object): + if not isinstance(x, Foo) or not isinstance(x, Bar): + return + assert_type(x, Never) + + def foobar_typeguard(x: object): + if not is_foo(x) or not is_bar(x): + return + assert_type(x, Never) + + @assert_passes() + def testTypeIsAsFunctionArgAsBoolSubtype(self): + from typing import Callable + from typing_extensions import TypeIs + + def accepts_bool(f: Callable[[object], bool]) -> None: + pass + + def with_bool_typeguard(o: object) -> TypeIs[bool]: + return False + + def with_str_typeguard(o: object) -> TypeIs[str]: + return False + + def with_bool(o: object) -> bool: + return False + + accepts_bool(with_bool_typeguard) + accepts_bool(with_str_typeguard) + accepts_bool(with_bool) + + @assert_passes() + def testTypeIsAsFunctionArg(self): + from typing import Callable + from typing_extensions import TypeIs + + def accepts_typeguard(f: Callable[[object], TypeIs[bool]]) -> None: + pass + + def different_typeguard(f: Callable[[object], TypeIs[str]]) -> None: + pass + + def with_typeguard(o: object) -> TypeIs[bool]: + return False + + def with_bool(o: object) -> bool: + return False + + accepts_typeguard(with_typeguard) + accepts_typeguard( + with_bool + ) # E: Argument 1 to "accepts_typeguard" has incompatible type "Callable[[object], bool]"; expected "Callable[[object], TypeIs[bool]]" + + different_typeguard( + with_typeguard + ) # E: Argument 1 to "different_typeguard" has incompatible type "Callable[[object], TypeIs[bool]]"; expected "Callable[[object], TypeIs[str]]" + different_typeguard( + with_bool + ) # E: Argument 1 to "different_typeguard" has incompatible type "Callable[[object], bool]"; expected "Callable[[object], TypeIs[str]]" + + @assert_passes() + def testTypeIsAsGenericFunctionArg(self): + from typing import Callable, TypeVar + from typing_extensions import TypeIs + + T = TypeVar("T") + + def accepts_typeguard(f: Callable[[object], TypeIs[T]]) -> None: + pass + + def with_bool_typeguard(o: object) -> TypeIs[bool]: + return False + + def with_str_typeguard(o: object) -> TypeIs[str]: + return False + + def with_bool(o: object) -> bool: + return False + + accepts_typeguard(with_bool_typeguard) + accepts_typeguard(with_str_typeguard) + accepts_typeguard( + with_bool + ) # E: Argument 1 to "accepts_typeguard" has incompatible type "Callable[[object], bool]"; expected "Callable[[object], TypeIs[bool]]" + + @assert_passes() + def testTypeIsAsOverloadedFunctionArg(self): + # https://github.com/python/mypy/issues/11307 + from typing import Callable, TypeVar, Generic, Any, overload + from typing_extensions import TypeIs, assert_type + + _T = TypeVar("_T") + + class filter(Generic[_T]): + @overload + def __init__(self, function: Callable[[object], TypeIs[_T]]) -> None: + pass + + @overload + def __init__(self, function: Callable[[_T], Any]) -> None: + pass + + def __init__(self, function): + pass + + def is_int_typeguard(a: object) -> TypeIs[int]: + return False + + def returns_bool(a: object) -> bool: + return False + + def capybara() -> None: + pass + # TODO: + # assert_type(filter(is_int_typeguard), filter[int]) + # assert_type(filter(returns_bool), filter[object]) + + @assert_passes() + def testTypeIsSubtypingVariance(self): + from typing import Callable + from typing_extensions import TypeIs + + class A: + pass + + class B(A): + pass + + class C(B): + pass + + def accepts_typeguard(f: Callable[[object], TypeIs[B]]) -> None: + pass + + def with_typeguard_a(o: object) -> TypeIs[A]: + return False + + def with_typeguard_b(o: object) -> TypeIs[B]: + return False + + def with_typeguard_c(o: object) -> TypeIs[C]: + return False + + accepts_typeguard( + with_typeguard_a + ) # E: Argument 1 to "accepts_typeguard" has incompatible type "Callable[[object], TypeIs[A]]"; expected "Callable[[object], TypeIs[B]]" + accepts_typeguard(with_typeguard_b) + accepts_typeguard( + with_typeguard_c + ) # E: Argument 1 to "accepts_typeguard" has incompatible type "Callable[[object], TypeIs[C]]"; expected "Callable[[object], TypeIs[B]]" + + @assert_passes() + def testTypeIsWithIdentityGeneric(self): + from typing import TypeVar + from typing_extensions import TypeIs, assert_type + + _T = TypeVar("_T") + + def identity(val: _T) -> TypeIs[_T]: + return False + + def func1(name: _T): + assert_type(name, _T) + if identity(name): + pass # TODO: assert_type(name, _T) + + def func2(name: str): + assert_type(name, str) + if identity(name): + assert_type(name, str) + + @assert_passes() + def testTypeIsWithGenericInstance(self): + from typing import TypeVar, List, Iterable + from typing_extensions import TypeIs, assert_type + + _T = TypeVar("_T") + + def is_list_of_str(val: Iterable[_T]) -> TypeIs[List[_T]]: + return False + + def func(name: Iterable[str]): + assert_type(name, Iterable[str]) + if is_list_of_str(name): + assert_type(name, List[str]) + + @assert_passes() + def testTypeIsWithTupleGeneric(self): + from typing import TypeVar, Tuple + from typing_extensions import TypeIs, assert_type + + _T = TypeVar("_T") + + def is_two_element_tuple(val: Tuple[_T, ...]) -> TypeIs[Tuple[_T, _T]]: + return False + + def func(names: Tuple[str, ...]): + assert_type(names, Tuple[str, ...]) + if is_two_element_tuple(names): + assert_type(names, Tuple[str, ...]) # TODO: bad type narrowing + + @assert_passes() + def testTypeIsErroneousDefinitionFails(self): + from typing_extensions import TypeIs + + class Z: + def typeguard1(self, *, x: object) -> TypeIs[int]: # E: invalid_typeguard + return False + + @staticmethod + def typeguard2(x: object) -> TypeIs[int]: + return False + + @staticmethod + def typeguard3(*, x: object) -> TypeIs[int]: # E: invalid_typeguard + return False + + def bad_typeguard(*, x: object) -> TypeIs[int]: # E: invalid_typeguard + return False + + @assert_passes() + def testTypeIsWithKeywordArg(self): + from typing_extensions import TypeIs, assert_type + + class Z: + def typeguard(self, x: object) -> TypeIs[int]: + return False + + def typeguard(x: object) -> TypeIs[int]: + return False + + def capybara(n: object) -> None: + if typeguard(x=n): + assert_type(n, int) + + if Z().typeguard(x=n): + assert_type(n, int) + + @assert_passes() + def testStaticMethodTypeIs(self): + from typing_extensions import TypeIs, assert_type + + def typeguard(h: object) -> TypeIs[int]: + return False + + class Y: + @staticmethod + def typeguard(h: object) -> TypeIs[int]: + return False + + def capybara(x: object): + if Y().typeguard(x): + # This doesn't work because we treat it as a method, not a staticmethod, + # and narrow parameter 1 instead. Doesn't look easy to fix, because the Signature + # class has no way to know. + assert_type(x, object) # TODO: int + assert_type(x, object) + if Y.typeguard(x): + assert_type(x, object) # TODO: int + + @assert_passes() + def testTypeIsKwargFollowingThroughOverloaded(self): + from typing import overload, Union + from typing_extensions import TypeIs, assert_type + + @overload + def typeguard(x: object, y: str) -> TypeIs[str]: ... + + @overload + def typeguard(x: object, y: int) -> TypeIs[int]: ... + + def typeguard(x: object, y: Union[int, str]) -> Union[TypeIs[int], TypeIs[str]]: + return False + + def capybara(x: object) -> None: + if typeguard(x=x, y=42): + assert_type(x, int) + + if typeguard(y=42, x=x): + assert_type(x, int) + + if typeguard(x=x, y="42"): + assert_type(x, str) + + if typeguard(y="42", x=x): + assert_type(x, str) + + @assert_passes() + def testGenericAliasWithTypeIs(self): + from typing import Callable, List, TypeVar + from typing_extensions import TypeIs + + T = TypeVar("T") + A = Callable[[object], TypeIs[List[T]]] + + def foo(x: object) -> TypeIs[List[str]]: + return False + + def test(f: A[T]) -> T: + raise NotImplementedError + + def capybara() -> None: + pass + # TODO: assert_type(test(foo), List[str]) + + @assert_passes() + def testNoCrashOnDunderCallTypeIs(self): + from typing_extensions import TypeIs, assert_type + + class A: + def __call__(self, x) -> TypeIs[int]: + return True + + def capybara(a: A, x: object) -> None: + assert a(x=1) + + assert a(x=x) + # Seems like we drop the annotations on the __call__ return somewhere + assert_type(x, object) # TODO: int + + @assert_passes() + def testTypeIsMustBeSubtypeFunctions(self): + from typing_extensions import TypeIs + from typing import List, Sequence, TypeVar + + def f(x: str) -> TypeIs[int]: # E: typeis_must_be_subtype + return False + + T = TypeVar("T") + + def g(x: List[T]) -> TypeIs[Sequence[T]]: # E: typeis_must_be_subtype + return False + + @assert_passes() + def testTypeIsMustBeSubtypeMethods(self): + from typing_extensions import TypeIs + + class NarrowHolder: + @classmethod + def cls_narrower_good(cls, x: object) -> TypeIs[int]: + return False + + @classmethod + def cls_narrower_bad( + cls, x: str + ) -> TypeIs[int]: # E: typeis_must_be_subtype + return False + + @staticmethod + def static_narrower_good(x: object) -> TypeIs[int]: + return False + + @staticmethod + def static_narrower_bad(x: str) -> TypeIs[int]: # E: typeis_must_be_subtype + return False + + def inst_narrower_good(self, x: object) -> TypeIs[int]: + return False + + def inst_narrower_bad( + self, x: str + ) -> TypeIs[int]: # E: typeis_must_be_subtype + return False diff --git a/pyanalyze/value.py b/pyanalyze/value.py index 433e1a8c..fcf6c7e3 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -1929,6 +1929,25 @@ def walk_values(self) -> Iterable[Value]: yield from self.guarded_type.walk_values() +@dataclass(frozen=True) +class TypeIsExtension(Extension): + """An :class:`Extension` used in a function return type. Used to + indicate that the first function argument may be narrowed to type `guarded_type`. + + Corresponds to ``typing_extensions.TypeIs`` (see PEP 742). + + """ + + guarded_type: Value + + def substitute_typevars(self, typevars: TypeVarMap) -> Extension: + guarded_type = self.guarded_type.substitute_typevars(typevars) + return TypeIsExtension(guarded_type) + + def walk_values(self) -> Iterable[Value]: + yield from self.guarded_type.walk_values() + + @dataclass(frozen=True) class HasAttrGuardExtension(Extension): """An :class:`Extension` used in a function return type. Used to