diff --git a/docs/changelog.md b/docs/changelog.md index 2a8875ef..f7c9a40d 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,7 @@ ## Unreleased +- Fix some higher-order behavior of `TypeGuard` and `TypeIs` (#719) - Add support for `TypeIs` from PEP 742 (#718) - More PEP 695 support: generic classes and functions. Scoping rules are not yet fully implemented. (#703) diff --git a/pyanalyze/name_check_visitor.py b/pyanalyze/name_check_visitor.py index 144b49cc..9bab457b 100644 --- a/pyanalyze/name_check_visitor.py +++ b/pyanalyze/name_check_visitor.py @@ -1986,6 +1986,11 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value: expected_return = info.return_annotation | KnownValue(NotImplemented) else: expected_return = info.return_annotation + if isinstance(expected_return, AnnotatedValue): + expected_return, _ = unannotate_value(expected_return, TypeIsExtension) + expected_return, _ = unannotate_value( + expected_return, TypeGuardExtension + ) with self.asynq_checker.set_func_name( node.name, diff --git a/pyanalyze/test_typeis.py b/pyanalyze/test_typeis.py index 220c2392..2b3adb47 100644 --- a/pyanalyze/test_typeis.py +++ b/pyanalyze/test_typeis.py @@ -126,10 +126,8 @@ def main(a: int): @assert_passes() def testTypeIsHigherOrder(self): - import collections.abc from typing import Callable, TypeVar, Iterable, List - from typing_extensions import TypeIs - from pyanalyze.value import assert_is_value, GenericValue, AnyValue, AnySource + from typing_extensions import TypeIs, assert_type T = TypeVar("T") R = TypeVar("R") @@ -143,13 +141,7 @@ def is_float(a: object) -> TypeIs[float]: def capybara() -> None: a: List[object] = ["a", 0, 0.0] b = filter(is_float, a) - # TODO should be Iterable[float] - assert_is_value( - b, - GenericValue( - collections.abc.Iterable, [AnyValue(AnySource.generic_argument)] - ), - ) + assert_type(b, Iterable[float]) @assert_passes() def testTypeIsMethod(self): @@ -277,10 +269,8 @@ def main1(a: object) -> None: @assert_passes() def testTypeIsOverload(self): - import collections.abc from typing import Callable, Iterable, Iterator, List, Optional, TypeVar - from typing_extensions import TypeIs, overload - from pyanalyze.value import assert_is_value, GenericValue, AnyValue, AnySource + from typing_extensions import TypeIs, overload, assert_type T = TypeVar("T") R = TypeVar("R") @@ -302,21 +292,13 @@ def is_int_typeguard(a: object) -> TypeIs[int]: def is_int_bool(a: object) -> bool: return False - iter_any = GenericValue( - collections.abc.Iterator, [AnyValue(AnySource.generic_argument)] - ) - def main(a: List[Optional[int]]) -> None: bb = filter(lambda x: x is not None, a) - # TODO Iterator[Optional[int]] - assert_is_value(bb, iter_any) - # Also, if you replace 'bool' with 'Any' in the second overload, bb is Iterator[Any] + assert_type(bb, Iterator[Optional[int]]) cc = filter(is_int_typeguard, a) - # TODO Iterator[int] - assert_is_value(cc, iter_any) + assert_type(cc, Iterator[int]) dd = filter(is_int_bool, a) - # TODO Iterator[Optional[int]] - assert_is_value(dd, iter_any) + assert_type(dd, Iterator[Optional[int]]) @assert_passes() def testTypeIsDecorated(self): @@ -345,7 +327,7 @@ def is_float(self, a: object) -> TypeIs[float]: return False class D(C): - def is_float(self, a: object) -> bool: # TODO: incompatible_override + def is_float(self, a: object) -> bool: # E: incompatible_override return False @assert_passes() @@ -576,10 +558,10 @@ def with_bool(o: object) -> bool: return False accepts_typeguard(with_typeguard) - accepts_typeguard(with_bool) # TODO error + accepts_typeguard(with_bool) # E: incompatible_argument - different_typeguard(with_typeguard) # TODO error - different_typeguard(with_bool) # TODO error + different_typeguard(with_typeguard) # E: incompatible_argument + different_typeguard(with_bool) # E: incompatible_argument @assert_passes() def testTypeIsAsGenericFunctionArg(self): @@ -602,7 +584,7 @@ def with_bool(o: object) -> bool: accepts_typeguard(with_bool_typeguard) accepts_typeguard(with_str_typeguard) - accepts_typeguard(with_bool) # TODO error + accepts_typeguard(with_bool) # E: incompatible_argument @assert_passes() def testTypeIsAsOverloadedFunctionArg(self): @@ -662,9 +644,9 @@ def with_typeguard_b(o: object) -> TypeIs[B]: def with_typeguard_c(o: object) -> TypeIs[C]: return False - accepts_typeguard(with_typeguard_a) # TODO error + accepts_typeguard(with_typeguard_a) # E: incompatible_argument accepts_typeguard(with_typeguard_b) - accepts_typeguard(with_typeguard_c) # TODO error + accepts_typeguard(with_typeguard_c) # E: incompatible_argument @assert_passes() def testTypeIsWithIdentityGeneric(self): @@ -786,7 +768,7 @@ def typeguard(x: object, y: str) -> TypeIs[str]: ... @overload def typeguard(x: object, y: int) -> TypeIs[int]: ... - def typeguard(x: object, y: Union[int, str]) -> Union[TypeIs[int], TypeIs[str]]: + def typeguard(x: object, y: Union[int, str]) -> bool: return False def capybara(x: object) -> None: @@ -805,7 +787,7 @@ def capybara(x: object) -> None: @assert_passes() def testGenericAliasWithTypeIs(self): from typing import Callable, List, TypeVar - from typing_extensions import TypeIs + from typing_extensions import TypeIs, assert_type T = TypeVar("T") A = Callable[[object], TypeIs[List[T]]] @@ -817,8 +799,7 @@ def test(f: A[T]) -> T: raise NotImplementedError def capybara() -> None: - pass - # TODO: assert_type(test(foo), List[str]) + assert_type(test(foo), str) @assert_passes() def testNoCrashOnDunderCallTypeIs(self): diff --git a/pyanalyze/value.py b/pyanalyze/value.py index fcf6c7e3..636588ef 100644 --- a/pyanalyze/value.py +++ b/pyanalyze/value.py @@ -1853,6 +1853,12 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Extension": def walk_values(self) -> Iterable[Value]: return [] + def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign: + return {} + + def can_be_assigned(self, value: Value, ctx: CanAssignContext) -> CanAssign: + return {} + @dataclass(frozen=True) class CustomCheckExtension(Extension): @@ -1868,6 +1874,12 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Extension": def walk_values(self) -> Iterable[Value]: yield from self.custom_check.walk_values() + def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign: + return self.custom_check.can_assign(value, ctx) + + def can_be_assigned(self, value: Value, ctx: CanAssignContext) -> CanAssign: + return self.custom_check.can_be_assigned(value, ctx) + @dataclass(frozen=True) class ParameterTypeGuardExtension(Extension): @@ -1928,6 +1940,26 @@ def substitute_typevars(self, typevars: TypeVarMap) -> Extension: def walk_values(self) -> Iterable[Value]: yield from self.guarded_type.walk_values() + def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign: + can_assign_maps = [] + if isinstance(value, AnnotatedValue): + for ext in value.get_metadata_of_type(Extension): + if isinstance(ext, TypeIsExtension): + return CanAssignError("TypeGuard is not compatible with TypeIs") + elif isinstance(ext, TypeGuardExtension): + # TypeGuard is covariant + left_can_assign = self.guarded_type.can_assign( + ext.guarded_type, ctx + ) + if isinstance(left_can_assign, CanAssignError): + return CanAssignError( + "Incompatible types in TypeIs", children=[left_can_assign] + ) + can_assign_maps.append(left_can_assign) + if not can_assign_maps: + return CanAssignError(f"{value} is not a TypeGuard") + return unify_bounds_maps(can_assign_maps) + @dataclass(frozen=True) class TypeIsExtension(Extension): @@ -1947,6 +1979,33 @@ def substitute_typevars(self, typevars: TypeVarMap) -> Extension: def walk_values(self) -> Iterable[Value]: yield from self.guarded_type.walk_values() + def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign: + can_assign_maps = [] + if isinstance(value, AnnotatedValue): + for ext in value.get_metadata_of_type(Extension): + if isinstance(ext, TypeGuardExtension): + return CanAssignError("TypeGuard is not compatible with TypeIs") + elif isinstance(ext, TypeIsExtension): + # TypeIs is invariant + left_can_assign = self.guarded_type.can_assign( + ext.guarded_type, ctx + ) + if isinstance(left_can_assign, CanAssignError): + return CanAssignError( + "Incompatible types in TypeIs", children=[left_can_assign] + ) + right_can_assign = ext.guarded_type.can_assign( + self.guarded_type, ctx + ) + if isinstance(right_can_assign, CanAssignError): + return CanAssignError( + "Incompatible types in TypeIs", children=[right_can_assign] + ) + can_assign_maps += [left_can_assign, right_can_assign] + if not can_assign_maps: + return CanAssignError(f"{value} is not a TypeIs") + return unify_bounds_maps(can_assign_maps) + @dataclass(frozen=True) class HasAttrGuardExtension(Extension): @@ -2120,8 +2179,8 @@ def can_assign(self, other: Value, ctx: CanAssignContext) -> CanAssign: if isinstance(can_assign, CanAssignError): return can_assign bounds_maps = [can_assign] - for custom_check in self.get_metadata_of_type(CustomCheckExtension): - custom_can_assign = custom_check.custom_check.can_assign(other, ctx) + for ext in self.get_metadata_of_type(Extension): + custom_can_assign = ext.can_assign(other, ctx) if isinstance(custom_can_assign, CanAssignError): return custom_can_assign bounds_maps.append(custom_can_assign) @@ -2132,8 +2191,8 @@ def can_be_assigned(self, other: Value, ctx: CanAssignContext) -> CanAssign: if isinstance(can_assign, CanAssignError): return can_assign bounds_maps = [can_assign] - for custom_check in self.get_metadata_of_type(CustomCheckExtension): - custom_can_assign = custom_check.custom_check.can_be_assigned(other, ctx) + for ext in self.get_metadata_of_type(Extension): + custom_can_assign = ext.can_be_assigned(other, ctx) if isinstance(custom_can_assign, CanAssignError): return custom_can_assign bounds_maps.append(custom_can_assign)