Skip to content

Commit

Permalink
Merge pull request #118 from scipp/handle-constraints
Browse files Browse the repository at this point in the history
feat: check constraints on typevars
  • Loading branch information
jokasimr authored Feb 13, 2024
2 parents 95ef23e + 1fb4230 commit 3dc2017
Show file tree
Hide file tree
Showing 3 changed files with 236 additions and 29 deletions.
10 changes: 5 additions & 5 deletions src/sciline/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class UnsatisfiedRequirement(Exception):
class ErrorHandler(Protocol):
"""Error handling protocol for pipelines."""

def handle_unsatisfied_requirement(self, tp: Key) -> Provider:
def handle_unsatisfied_requirement(self, tp: Key, *explanation: str) -> Provider:
...


Expand All @@ -29,9 +29,9 @@ class HandleAsBuildTimeException(ErrorHandler):
ensuring that errors are caught early, before starting costly computation.
"""

def handle_unsatisfied_requirement(self, tp: Key) -> NoReturn:
def handle_unsatisfied_requirement(self, tp: Key, *explanation: str) -> NoReturn:
"""Raise an exception when a type cannot be provided."""
raise UnsatisfiedRequirement('No provider found for type', tp)
raise UnsatisfiedRequirement('No provider found for type', tp, *explanation)


class HandleAsComputeTimeException(ErrorHandler):
Expand All @@ -42,11 +42,11 @@ class HandleAsComputeTimeException(ErrorHandler):
visualization. This is helpful when visualizing a graph that is not yet complete.
"""

def handle_unsatisfied_requirement(self, tp: Key) -> Provider:
def handle_unsatisfied_requirement(self, tp: Key, *explanation: str) -> Provider:
"""Return a function that raises an exception when called."""

def unsatisfied_sentinel() -> NoReturn:
raise UnsatisfiedRequirement('No provider found for type', tp)
raise UnsatisfiedRequirement('No provider found for type', tp, *explanation)

return Provider(
func=unsatisfied_sentinel, arg_spec=ArgSpec.null(), kind='unsatisfied'
Expand Down
117 changes: 93 additions & 24 deletions src/sciline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
List,
Mapping,
Optional,
Set,
Tuple,
Type,
TypeVar,
Expand All @@ -39,6 +40,7 @@
from .scheduler import Scheduler
from .series import Series
from .typing import Graph, Item, Key, Label, get_optional, get_union
from .utils import keyname

T = TypeVar('T')
KeyType = TypeVar('KeyType', bound=Key)
Expand All @@ -51,21 +53,68 @@ class AmbiguousProvider(Exception):
"""Raised when multiple providers are found for a type."""


def _is_compatible_type_tuple(
def _extract_typevars_from_generic_type(t: type) -> Tuple[TypeVar, ...]:
"""Returns the typevars that were used in the definition of a Generic type."""
if not hasattr(t, '__orig_bases__'):
return ()
return tuple(
chain(*(get_args(b) for b in t.__orig_bases__ if get_origin(b) == Generic))
)


def _find_all_typevars(t: Union[type, TypeVar]) -> Set[TypeVar]:
"""Returns the set of all TypeVars in a type expression."""
if isinstance(t, TypeVar):
return {t}
return set(chain(*map(_find_all_typevars, get_args(t))))


def _find_bounds_to_make_compatible_type(
requested: Key,
provided: Key | TypeVar,
) -> Optional[Dict[TypeVar, Key]]:
"""
Check if a type is compatible to a provided type.
If the types are compatible, return a mapping from typevars to concrete types
that makes the provided type equal to the requested type.
"""
if provided == requested:
ret: Dict[TypeVar, Key] = {}
return ret
if isinstance(provided, TypeVar):
# If the type var has no constraints, accept anything
if not provided.__constraints__:
return {provided: requested}
for c in provided.__constraints__:
if _find_bounds_to_make_compatible_type(requested, c) is not None:
return {provided: requested}
if get_origin(provided) is not None:
if get_origin(provided) == get_origin(requested):
return _find_bounds_to_make_compatible_type_tuple(
get_args(requested), get_args(provided)
)
return None


def _find_bounds_to_make_compatible_type_tuple(
requested: tuple[Key, ...],
provided: tuple[Key | TypeVar, ...],
) -> bool:
) -> Optional[Dict[TypeVar, Key]]:
"""
Check if a tuple of requested types is compatible with a tuple of provided types.
Types in the tuples must either by equal, or the provided type must be a TypeVar.
Check if a tuple of requested types is compatible with a tuple of provided types
and return a mapping from type vars to concrete types that makes all provided
types equal to their corresponding requested type.
If any of the types is not compatible, return None.
"""
for req, prov in zip(requested, provided):
if isinstance(prov, TypeVar):
continue
if req != prov:
return False
return True
union: Dict[TypeVar, Key] = {}
for bound in map(_find_bounds_to_make_compatible_type, requested, provided):
# If no mapping from the type-var to a concrete type was found,
# or if the mapping is inconsistent,
# interrupt the search and report that no compatible types were found.
if bound is None or any(k in union and union[k] != bound[k] for k in bound):
return None
union.update(bound)
return union


def _find_all_paths(
Expand Down Expand Up @@ -494,20 +543,22 @@ def _get_provider(
self, tp: Union[Type[T], Item[T]], handler: Optional[ErrorHandler] = None
) -> Tuple[Provider, Dict[TypeVar, Key]]:
handler = handler or HandleAsBuildTimeException()
explanation: List[str] = []
if (provider := self._providers.get(tp)) is not None:
return provider, {}
elif (origin := get_origin(tp)) is not None and (
subproviders := self._subproviders.get(origin)
) is not None:
requested = get_args(tp)
matches = [
(args, subprovider)
(subprovider, bound)
for args, subprovider in subproviders.items()
if _is_compatible_type_tuple(requested, args)
]
typevar_counts = [
sum(1 for t in args if isinstance(t, TypeVar)) for args, _ in matches
if (
bound := _find_bounds_to_make_compatible_type_tuple(requested, args)
)
is not None
]
typevar_counts = [len(bound) for _, bound in matches]
min_typevar_count = min(typevar_counts, default=0)
matches = [
m
Expand All @@ -516,20 +567,38 @@ def _get_provider(
]

if len(matches) == 1:
args, provider = matches[0]
bound = {
arg: req
for arg, req in zip(args, requested)
if isinstance(arg, TypeVar)
}
provider, bound = matches[0]
return provider, bound
elif len(matches) > 1:
matching_providers = [m[1].location.name for m in matches]
matching_providers = [provider.location.name for provider, _ in matches]
raise AmbiguousProvider(
f"Multiple providers found for type {tp}."
f" Matching providers are: {matching_providers}."
)
return handler.handle_unsatisfied_requirement(tp), {}
else:
typevars_in_expression = _extract_typevars_from_generic_type(origin)
if typevars_in_expression:
explanation = [
''.join(
map(
str,
(
'Note that ',
keyname(origin[typevars_in_expression]),
' has constraints ',
(
{
keyname(tv): tuple(
map(keyname, tv.__constraints__)
)
for tv in typevars_in_expression
}
),
),
)
)
]
return handler.handle_unsatisfied_requirement(tp, *explanation), {}

def _get_unique_provider(
self, tp: Union[Type[T], Item[T]], handler: ErrorHandler
Expand Down
138 changes: 138 additions & 0 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1266,3 +1266,141 @@ def f(x: int): # type: ignore[no-untyped-def]

with pytest.raises(ValueError, match='type-hint'):
sl.Pipeline([f])


def test_does_not_allow_type_argument_outside_of_constraints_flat() -> None:
T = TypeVar('T', int, float, str)
T2 = TypeVar('T2', int, float)

@dataclass
class M(Generic[T]):
value: T

def p1(value: T2) -> M[T2]:
return M(value)

pipeline = sl.Pipeline((p1,))
pipeline[str] = 'abc'
pipeline[int] = 123

pipeline.get(M[int])

with pytest.raises(sl.handler.UnsatisfiedRequirement):
pipeline.get(M[str])


def test_does_not_allow_type_argument_outside_of_constraints_nested() -> None:
T = TypeVar('T', int, float, str)

@dataclass
class M(Generic[T]):
value: T

S = TypeVar('S', M[int], M[float], M[str])
S2 = TypeVar('S2', M[int], M[float])

@dataclass
class N(Generic[S]):
value: S

def p1(value: T) -> M[T]:
return M(value)

def p2(value: S2) -> N[S2]:
return N(value)

pipeline = sl.Pipeline((p1, p2))
pipeline[str] = 'abc'
pipeline[int] = 123

pipeline.get(N[M[int]])

with pytest.raises(sl.handler.UnsatisfiedRequirement):
pipeline.get(N[M[str]])


def test_constraints_nested_multiple_typevars() -> None:
T = TypeVar('T', int, float, str)
T2 = TypeVar('T2', int, float)

@dataclass
class M(Generic[T]):
v: T

S = TypeVar('S', M[int], M[float], M[str])
S2 = TypeVar('S2', M[int], M[float])

@dataclass
class N(Generic[S, T]):
v1: S
v2: T

def p1(v: T) -> M[T]:
return M(v)

def p2(v1: S2, v2: T2) -> N[S2, T2]:
return N(v1, v2)

pipeline = sl.Pipeline((p1, p2))
pipeline[str] = 'abc'
pipeline[int] = 123
pipeline[float] = 3.14

pipeline.get(N[M[float], int])
pipeline.get(N[M[int], int])

with pytest.raises(sl.handler.UnsatisfiedRequirement):
pipeline.get(N[M[int], str])
with pytest.raises(sl.handler.UnsatisfiedRequirement):
pipeline.get(N[M[str], float])


def test_number_of_type_vars_defines_most_specialized() -> None:
Green = NewType('Green', str)
Blue = NewType('Blue', str)
Color = TypeVar('Color', Green, Blue)

@dataclass
class Likes(Generic[Color]):
color: Color

Preference = TypeVar('Preference')

@dataclass
class Person(Generic[Preference, Color]):
preference: Preference
hatcolor: Color
provided_by: str

def p(c: Color) -> Likes[Color]:
return Likes(c)

def p0(p: Preference, c: Color) -> Person[Preference, Color]:
return Person(p, c, 'p0')

def p1(c: Color) -> Person[Likes[Color], Color]:
return Person(Likes(c), c, 'p1')

def p2(p: Preference) -> Person[Preference, Green]:
return Person(p, Green('g'), 'p2')

pipeline = sl.Pipeline((p, p0, p1, p2))
pipeline[Blue] = 'b'
pipeline[Green] = 'g'

# only provided by p0
assert pipeline.compute(Person[Likes[Green], Blue]) == Person(
Likes(Green('g')), Blue('b'), 'p0'
)
# provided by p1 and p0 but p1 is preferred because it has fewer typevars
assert pipeline.compute(Person[Likes[Blue], Blue]) == Person(
Likes(Blue('b')), Blue('b'), 'p1'
)
# provided by p2 and p0 but p2 is preferred because it has fewer typevars
assert pipeline.compute(Person[Likes[Blue], Green]) == Person(
Likes(Blue('b')), Green('g'), 'p2'
)

with pytest.raises(sl.AmbiguousProvider):
# provided by p1 and p2 with the same number of typevars
pipeline.get(Person[Likes[Green], Green])

0 comments on commit 3dc2017

Please sign in to comment.