Skip to content

Commit

Permalink
Fix some higher-order behavior of TypeGuard and TypeIs (#719)
Browse files Browse the repository at this point in the history
  • Loading branch information
JelleZijlstra authored Feb 18, 2024
1 parent 2180dc1 commit 868ba56
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 39 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Unreleased

- Fix some higher-order behavior of `TypeGuard` and `TypeIs` (#719)
- Add support for `TypeIs` from PEP 742 (#718)
- More PEP 695 support: generic classes and functions. Scoping rules
are not yet fully implemented. (#703)
Expand Down
5 changes: 5 additions & 0 deletions pyanalyze/name_check_visitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,11 @@ def visit_FunctionDef(self, node: FunctionDefNode) -> Value:
expected_return = info.return_annotation | KnownValue(NotImplemented)
else:
expected_return = info.return_annotation
if isinstance(expected_return, AnnotatedValue):
expected_return, _ = unannotate_value(expected_return, TypeIsExtension)
expected_return, _ = unannotate_value(
expected_return, TypeGuardExtension
)

with self.asynq_checker.set_func_name(
node.name,
Expand Down
51 changes: 16 additions & 35 deletions pyanalyze/test_typeis.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,8 @@ def main(a: int):

@assert_passes()
def testTypeIsHigherOrder(self):
import collections.abc
from typing import Callable, TypeVar, Iterable, List
from typing_extensions import TypeIs
from pyanalyze.value import assert_is_value, GenericValue, AnyValue, AnySource
from typing_extensions import TypeIs, assert_type

T = TypeVar("T")
R = TypeVar("R")
Expand All @@ -143,13 +141,7 @@ def is_float(a: object) -> TypeIs[float]:
def capybara() -> None:
a: List[object] = ["a", 0, 0.0]
b = filter(is_float, a)
# TODO should be Iterable[float]
assert_is_value(
b,
GenericValue(
collections.abc.Iterable, [AnyValue(AnySource.generic_argument)]
),
)
assert_type(b, Iterable[float])

@assert_passes()
def testTypeIsMethod(self):
Expand Down Expand Up @@ -277,10 +269,8 @@ def main1(a: object) -> None:

@assert_passes()
def testTypeIsOverload(self):
import collections.abc
from typing import Callable, Iterable, Iterator, List, Optional, TypeVar
from typing_extensions import TypeIs, overload
from pyanalyze.value import assert_is_value, GenericValue, AnyValue, AnySource
from typing_extensions import TypeIs, overload, assert_type

T = TypeVar("T")
R = TypeVar("R")
Expand All @@ -302,21 +292,13 @@ def is_int_typeguard(a: object) -> TypeIs[int]:
def is_int_bool(a: object) -> bool:
return False

iter_any = GenericValue(
collections.abc.Iterator, [AnyValue(AnySource.generic_argument)]
)

def main(a: List[Optional[int]]) -> None:
bb = filter(lambda x: x is not None, a)
# TODO Iterator[Optional[int]]
assert_is_value(bb, iter_any)
# Also, if you replace 'bool' with 'Any' in the second overload, bb is Iterator[Any]
assert_type(bb, Iterator[Optional[int]])
cc = filter(is_int_typeguard, a)
# TODO Iterator[int]
assert_is_value(cc, iter_any)
assert_type(cc, Iterator[int])
dd = filter(is_int_bool, a)
# TODO Iterator[Optional[int]]
assert_is_value(dd, iter_any)
assert_type(dd, Iterator[Optional[int]])

@assert_passes()
def testTypeIsDecorated(self):
Expand Down Expand Up @@ -345,7 +327,7 @@ def is_float(self, a: object) -> TypeIs[float]:
return False

class D(C):
def is_float(self, a: object) -> bool: # TODO: incompatible_override
def is_float(self, a: object) -> bool: # E: incompatible_override
return False

@assert_passes()
Expand Down Expand Up @@ -576,10 +558,10 @@ def with_bool(o: object) -> bool:
return False

accepts_typeguard(with_typeguard)
accepts_typeguard(with_bool) # TODO error
accepts_typeguard(with_bool) # E: incompatible_argument

different_typeguard(with_typeguard) # TODO error
different_typeguard(with_bool) # TODO error
different_typeguard(with_typeguard) # E: incompatible_argument
different_typeguard(with_bool) # E: incompatible_argument

@assert_passes()
def testTypeIsAsGenericFunctionArg(self):
Expand All @@ -602,7 +584,7 @@ def with_bool(o: object) -> bool:

accepts_typeguard(with_bool_typeguard)
accepts_typeguard(with_str_typeguard)
accepts_typeguard(with_bool) # TODO error
accepts_typeguard(with_bool) # E: incompatible_argument

@assert_passes()
def testTypeIsAsOverloadedFunctionArg(self):
Expand Down Expand Up @@ -662,9 +644,9 @@ def with_typeguard_b(o: object) -> TypeIs[B]:
def with_typeguard_c(o: object) -> TypeIs[C]:
return False

accepts_typeguard(with_typeguard_a) # TODO error
accepts_typeguard(with_typeguard_a) # E: incompatible_argument
accepts_typeguard(with_typeguard_b)
accepts_typeguard(with_typeguard_c) # TODO error
accepts_typeguard(with_typeguard_c) # E: incompatible_argument

@assert_passes()
def testTypeIsWithIdentityGeneric(self):
Expand Down Expand Up @@ -786,7 +768,7 @@ def typeguard(x: object, y: str) -> TypeIs[str]: ...
@overload
def typeguard(x: object, y: int) -> TypeIs[int]: ...

def typeguard(x: object, y: Union[int, str]) -> Union[TypeIs[int], TypeIs[str]]:
def typeguard(x: object, y: Union[int, str]) -> bool:
return False

def capybara(x: object) -> None:
Expand All @@ -805,7 +787,7 @@ def capybara(x: object) -> None:
@assert_passes()
def testGenericAliasWithTypeIs(self):
from typing import Callable, List, TypeVar
from typing_extensions import TypeIs
from typing_extensions import TypeIs, assert_type

T = TypeVar("T")
A = Callable[[object], TypeIs[List[T]]]
Expand All @@ -817,8 +799,7 @@ def test(f: A[T]) -> T:
raise NotImplementedError

def capybara() -> None:
pass
# TODO: assert_type(test(foo), List[str])
assert_type(test(foo), str)

@assert_passes()
def testNoCrashOnDunderCallTypeIs(self):
Expand Down
67 changes: 63 additions & 4 deletions pyanalyze/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -1853,6 +1853,12 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Extension":
def walk_values(self) -> Iterable[Value]:
return []

def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign:
return {}

def can_be_assigned(self, value: Value, ctx: CanAssignContext) -> CanAssign:
return {}


@dataclass(frozen=True)
class CustomCheckExtension(Extension):
Expand All @@ -1868,6 +1874,12 @@ def substitute_typevars(self, typevars: TypeVarMap) -> "Extension":
def walk_values(self) -> Iterable[Value]:
yield from self.custom_check.walk_values()

def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign:
return self.custom_check.can_assign(value, ctx)

def can_be_assigned(self, value: Value, ctx: CanAssignContext) -> CanAssign:
return self.custom_check.can_be_assigned(value, ctx)


@dataclass(frozen=True)
class ParameterTypeGuardExtension(Extension):
Expand Down Expand Up @@ -1928,6 +1940,26 @@ def substitute_typevars(self, typevars: TypeVarMap) -> Extension:
def walk_values(self) -> Iterable[Value]:
yield from self.guarded_type.walk_values()

def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign:
can_assign_maps = []
if isinstance(value, AnnotatedValue):
for ext in value.get_metadata_of_type(Extension):
if isinstance(ext, TypeIsExtension):
return CanAssignError("TypeGuard is not compatible with TypeIs")
elif isinstance(ext, TypeGuardExtension):
# TypeGuard is covariant
left_can_assign = self.guarded_type.can_assign(
ext.guarded_type, ctx
)
if isinstance(left_can_assign, CanAssignError):
return CanAssignError(
"Incompatible types in TypeIs", children=[left_can_assign]
)
can_assign_maps.append(left_can_assign)
if not can_assign_maps:
return CanAssignError(f"{value} is not a TypeGuard")
return unify_bounds_maps(can_assign_maps)


@dataclass(frozen=True)
class TypeIsExtension(Extension):
Expand All @@ -1947,6 +1979,33 @@ def substitute_typevars(self, typevars: TypeVarMap) -> Extension:
def walk_values(self) -> Iterable[Value]:
yield from self.guarded_type.walk_values()

def can_assign(self, value: Value, ctx: CanAssignContext) -> CanAssign:
can_assign_maps = []
if isinstance(value, AnnotatedValue):
for ext in value.get_metadata_of_type(Extension):
if isinstance(ext, TypeGuardExtension):
return CanAssignError("TypeGuard is not compatible with TypeIs")
elif isinstance(ext, TypeIsExtension):
# TypeIs is invariant
left_can_assign = self.guarded_type.can_assign(
ext.guarded_type, ctx
)
if isinstance(left_can_assign, CanAssignError):
return CanAssignError(
"Incompatible types in TypeIs", children=[left_can_assign]
)
right_can_assign = ext.guarded_type.can_assign(
self.guarded_type, ctx
)
if isinstance(right_can_assign, CanAssignError):
return CanAssignError(
"Incompatible types in TypeIs", children=[right_can_assign]
)
can_assign_maps += [left_can_assign, right_can_assign]
if not can_assign_maps:
return CanAssignError(f"{value} is not a TypeIs")
return unify_bounds_maps(can_assign_maps)


@dataclass(frozen=True)
class HasAttrGuardExtension(Extension):
Expand Down Expand Up @@ -2120,8 +2179,8 @@ def can_assign(self, other: Value, ctx: CanAssignContext) -> CanAssign:
if isinstance(can_assign, CanAssignError):
return can_assign
bounds_maps = [can_assign]
for custom_check in self.get_metadata_of_type(CustomCheckExtension):
custom_can_assign = custom_check.custom_check.can_assign(other, ctx)
for ext in self.get_metadata_of_type(Extension):
custom_can_assign = ext.can_assign(other, ctx)
if isinstance(custom_can_assign, CanAssignError):
return custom_can_assign
bounds_maps.append(custom_can_assign)
Expand All @@ -2132,8 +2191,8 @@ def can_be_assigned(self, other: Value, ctx: CanAssignContext) -> CanAssign:
if isinstance(can_assign, CanAssignError):
return can_assign
bounds_maps = [can_assign]
for custom_check in self.get_metadata_of_type(CustomCheckExtension):
custom_can_assign = custom_check.custom_check.can_be_assigned(other, ctx)
for ext in self.get_metadata_of_type(Extension):
custom_can_assign = ext.can_be_assigned(other, ctx)
if isinstance(custom_can_assign, CanAssignError):
return custom_can_assign
bounds_maps.append(custom_can_assign)
Expand Down

0 comments on commit 868ba56

Please sign in to comment.