diff --git a/.pylint_dict.txt b/.pylint_dict.txt index 54a82ad1..6766d2a9 100644 --- a/.pylint_dict.txt +++ b/.pylint_dict.txt @@ -38,5 +38,6 @@ iter UnusedVariable # Misc +Github haha X'th diff --git a/spinn_utilities/config_setup.py b/spinn_utilities/config_setup.py index 40db30dd..67d7531e 100644 --- a/spinn_utilities/config_setup.py +++ b/spinn_utilities/config_setup.py @@ -16,6 +16,7 @@ from spinn_utilities.config_holder import ( add_default_cfg, clear_cfg_files) from spinn_utilities.data.utils_data_writer import UtilsDataWriter +from spinn_utilities.overrides import overrides BASE_CONFIG_FILE = "spinn_utilities.cfg" @@ -39,4 +40,5 @@ def add_spinn_utilities_cfg() -> None: """ Loads the default config values for spinn_utilities """ + overrides.check_types() add_default_cfg(os.path.join(os.path.dirname(__file__), BASE_CONFIG_FILE)) diff --git a/spinn_utilities/log.py b/spinn_utilities/log.py index f318b616..82d0d929 100644 --- a/spinn_utilities/log.py +++ b/spinn_utilities/log.py @@ -17,7 +17,7 @@ import logging import re import sys -from typing import List, Optional, Tuple +from typing import Any, List, Optional, Tuple from inspect import getfullargspec from .log_store import LogStore from .overrides import overrides @@ -214,7 +214,7 @@ def __init__(self, logger: logging.Logger, extra=None): super().__init__(logger, extra) self.do_log = logger._log # pylint: disable=protected-access - @overrides(logging.LoggerAdapter.log, extend_doc=False) + @overrides(logging.LoggerAdapter.log, extend_doc=False, adds_typing=True) def log(self, level: int, msg: object, *args, **kwargs): """ Delegate a log call to the underlying logger, applying appropriate @@ -247,8 +247,9 @@ def log(self, level: int, msg: object, *args, **kwargs): log_kwargs["exc_info"] = kwargs["exc_info"] self.do_log(level, message, (), **log_kwargs) - @overrides(logging.LoggerAdapter.process, extend_doc=False) - def process(self, msg: object, kwargs) -> Tuple[object, dict]: + @overrides(logging.LoggerAdapter.process, extend_doc=False, + adds_typing=True) + def process(self, msg: object, kwargs: Any) -> Tuple[object, dict]: """ Process the logging message and keyword arguments passed in to a logging call to insert contextual information. You can either diff --git a/spinn_utilities/overrides.py b/spinn_utilities/overrides.py index 836d600f..c8c31950 100644 --- a/spinn_utilities/overrides.py +++ b/spinn_utilities/overrides.py @@ -13,8 +13,10 @@ # limitations under the License. import inspect +import os from types import FunctionType, MethodType from typing import Any, Callable, Iterable, Optional, TypeVar + #: :meta private: Method = TypeVar("Method", bound=Callable[..., Any]) @@ -27,6 +29,9 @@ class overrides(object): copies the doc-string for the method, and enforces that the method overridden is specified, making maintenance easier. """ + # This near constant is changed by unit tests to check our code + # Github actions sets TYPE_OVERRIDES as True + __CHECK_TYPES = os.getenv("TYPE_OVERRIDES") __slots__ = [ # The method in the superclass that this method overrides @@ -40,13 +45,15 @@ class overrides(object): # True if the name check is relaxed "_relax_name_check", # The name of the thing being overridden for error messages - "_override_name" + "_override_name", + # True if this method adds typing info not in the original + "_adds_typing" ] def __init__( - self, super_class_method, extend_doc: bool = True, + self, super_class_method, *, extend_doc: bool = True, additional_arguments: Optional[Iterable[str]] = None, - extend_defaults: bool = False): + extend_defaults: bool = False, adds_typing: bool = False,): """ :param super_class_method: The method to override in the superclass :param bool extend_doc: @@ -58,6 +65,9 @@ def __init__( superclass method, e.g., that are to be injected :param bool extend_defaults: Whether the subclass may specify extra defaults for the parameters + :param adds_typing: + Allows more typing (of non additional) than in the subclass. + Should only be used for built in super classes """ if isinstance(super_class_method, property): super_class_method = super_class_method.fget @@ -73,6 +83,7 @@ def __init__( self._additional_arguments = frozenset(additional_arguments) else: self._additional_arguments = frozenset() + self._adds_typing = adds_typing @staticmethod def __match_defaults(default_args, super_defaults, extend_ok): @@ -84,6 +95,45 @@ def __match_defaults(default_args, super_defaults, extend_ok): return len(default_args) >= len(super_defaults) return len(default_args) == len(super_defaults) + def _verify_types(self, method_args, super_args, all_args): + """ + Check that the arguments match. + """ + if not self.__CHECK_TYPES: + return + if "self" in all_args: + all_args.remove("self") + elif "cls" in all_args: + all_args.remove("cls") + method_types = method_args.annotations + super_types = super_args.annotations + for arg in all_args: + if arg not in super_types and not self._adds_typing: + raise AttributeError( + f"Super Method {self._superclass_method.__name__} " + f"has untyped arguments including {arg}") + if arg not in method_types: + raise AttributeError( + f"Method {self._superclass_method.__name__} " + f"has untyped arguments including {arg}") + + if len(all_args) == 0: + if "return" not in super_types and not self._adds_typing and \ + not method_args.varkw and not method_args.varargs: + raise AttributeError( + f"Super Method {self._superclass_method.__name__} " + f"has no arguments so should declare a return type") + if "return" in super_types: + if "return" not in method_types: + raise AttributeError( + f"Method {self._superclass_method.__name__} " + f"has no return type, while super does") + else: + if "return" in method_types and not self._adds_typing: + raise AttributeError( + f"Super Method {self._superclass_method.__name__} " + f"has no return type, while this does") + def __verify_method_arguments(self, method: Method): """ Check that the arguments match. @@ -109,6 +159,7 @@ def __verify_method_arguments(self, method: Method): default_args, super_args.defaults, self._extend_defaults): raise AttributeError( f"Default arguments don't match {self._override_name}") + self._verify_types(method_args, super_args, all_args) def __call__(self, method: Method) -> Method: """ @@ -142,3 +193,12 @@ def __call__(self, method: Method) -> Method: method.__doc__ = ( self._superclass_method.__doc__ + (method.__doc__ or "")) return method + + @classmethod + def check_types(cls): + """ + If called will trigger check that all parameters are checked. + + Used for testing, to avoid users being affected by the strict checks + """ + cls.__CHECK_TYPES = True diff --git a/spinn_utilities/package_loader.py b/spinn_utilities/package_loader.py index 097ae8ab..9d6e6120 100644 --- a/spinn_utilities/package_loader.py +++ b/spinn_utilities/package_loader.py @@ -15,6 +15,7 @@ import os import sys import traceback +from spinn_utilities.overrides import overrides def all_modules(directory, prefix, remove_pyc_files=False): @@ -81,10 +82,10 @@ def load_modules( if module in exclusions: print("SKIPPING " + module) continue - print(module) try: __import__(module) except Exception: # pylint: disable=broad-except + print(f"Error with {module}") if gather_errors: errors.append((module, sys.exc_info())) else: @@ -113,6 +114,7 @@ def load_module( True if errors should be gathered, False to report on first error :return: None """ + overrides.check_types() if exclusions is None: exclusions = [] module = __import__(name) diff --git a/spinn_utilities/progress_bar.py b/spinn_utilities/progress_bar.py index a6538895..d69b8b48 100644 --- a/spinn_utilities/progress_bar.py +++ b/spinn_utilities/progress_bar.py @@ -83,10 +83,10 @@ def update(self, amount_to_add=1): self._currently_completed += amount_to_add self._check_differences() - def _print_overwritten_line(self, string): + def _print_overwritten_line(self, string: str): print("\r" + string, end="", file=self._destination) - def _print_distance_indicator(self, description): + def _print_distance_indicator(self, description: str): if description is not None: print(description, file=self._destination) @@ -112,7 +112,7 @@ def _print_distance_line(self, first_space, second_space): f"{' ' * second_space}100%{self._end_character}" print(line, end="", file=self._destination) - def _print_progress(self, length): + def _print_progress(self, length: int): chars_to_print = length if not self._in_bad_terminal: self._print_overwritten_line(self._end_character) @@ -127,7 +127,7 @@ def _print_progress_unit(self, chars_to_print): # pylint: disable=unused-argument print(self._step_character, end='', file=self._destination) - def _print_progress_done(self): + def _print_progress_done(self) -> None: self._print_progress(ProgressBar.MAX_LENGTH_IN_CHARS) if not self._in_bad_terminal: print(self._end_character, file=self._destination) @@ -324,19 +324,19 @@ class DummyProgressBar(ProgressBar): fails in exactly the same way. """ @overrides(ProgressBar._print_overwritten_line) - def _print_overwritten_line(self, string): + def _print_overwritten_line(self, string: str): pass @overrides(ProgressBar._print_distance_indicator) - def _print_distance_indicator(self, description): + def _print_distance_indicator(self, description: str): pass @overrides(ProgressBar._print_progress) - def _print_progress(self, length): + def _print_progress(self, length: int): pass @overrides(ProgressBar._print_progress_done) - def _print_progress_done(self): + def _print_progress_done(self) -> None: pass def __repr__(self): diff --git a/spinn_utilities/ranged/abstract_dict.py b/spinn_utilities/ranged/abstract_dict.py index e24eb3ba..c401c757 100644 --- a/spinn_utilities/ranged/abstract_dict.py +++ b/spinn_utilities/ranged/abstract_dict.py @@ -22,7 +22,7 @@ # Can't be Iterable[str] or Sequence[str] because that includes str itself _StrSeq: TypeAlias = Union[ MutableSequence[str], Tuple[str, ...], FrozenSet[str], Set[str]] -_Keys: TypeAlias = Union[None, str, _StrSeq] +_Keys: TypeAlias = Optional[Union[str, _StrSeq]] class AbstractDict(Generic[T], metaclass=AbstractBase): @@ -117,7 +117,7 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @abstractmethod - def iter_all_values(self, key, update_safe=False): + def iter_all_values(self, key: _Keys, update_safe: bool = False): """ Iterates over the value(s) for all IDs covered by this view. There will be one yield for each ID even if values are repeated. @@ -181,7 +181,9 @@ def iter_ranges(self, key: Optional[_StrSeq]) -> Iterator[Tuple[ ... @abstractmethod - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None + ) -> Union[Iterator[Tuple[int, int, T]], + Iterator[Tuple[int, int, Dict[str, T]]]]: """ Iterates over the ranges(s) for all IDs covered by this view. There will be one yield for each range which may cover one or diff --git a/spinn_utilities/ranged/abstract_list.py b/spinn_utilities/ranged/abstract_list.py index 112d32cc..91e6826f 100644 --- a/spinn_utilities/ranged/abstract_list.py +++ b/spinn_utilities/ranged/abstract_list.py @@ -12,20 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -import numbers +from numbers import Number from typing import ( Any, Callable, Generic, Iterator, Optional, Sequence, Tuple, TypeVar, Union, cast) import numpy from numpy.typing import NDArray -from typing_extensions import Self, TypeAlias +from typing_extensions import Self, TypeAlias, TypeGuard from spinn_utilities.abstract_base import AbstractBase, abstractmethod from spinn_utilities.overrides import overrides from .abstract_sized import AbstractSized, Selector from .multiple_values_exception import MultipleValuesException #: :meta private: +R = TypeVar("R") +#: :meta private: T = TypeVar("T") #: :meta private: +U = TypeVar("U") +#: :meta private: IdsType: TypeAlias = Union[Sequence[int], NDArray[numpy.integer]] @@ -38,6 +42,13 @@ def _is_zero(value: Any) -> bool: return bool(numpy.isin(0, value)) +def is_number(value: T) -> TypeGuard[float]: + """ + Is the Value a simple integer or float? + """ + return isinstance(value, Number) + + class AbstractList(AbstractSized, Generic[T], metaclass=AbstractBase): """ A ranged implementation of list. @@ -496,7 +507,8 @@ def get_default(self) -> Optional[T]: """ raise NotImplementedError - def __add__(self, other) -> AbstractList[float]: + def __add__(self, other: Union[float, AbstractList[float]] + ) -> AbstractList[float]: """ Support for ``new_list = list1 + list2``. Applies the add operator over this and other to create a new list. @@ -510,14 +522,17 @@ def __add__(self, other) -> AbstractList[float]: :raises TypeError: """ if isinstance(other, AbstractList): + d_operation: Callable[[Any, float], float] = lambda x, y: x + y return DualList( - left=self, right=other, operation=lambda x, y: x + y) - if isinstance(other, numbers.Number): - return SingleList(a_list=self, operation=lambda x: x + other) + left=self, right=other, operation=d_operation) + if is_number(other): + s_operation: Callable[[Any], float] = lambda x: x + other + return SingleList(a_list=self, operation=s_operation) raise TypeError("__add__ operation only supported for other " "RangedLists and numerical Values") - def __sub__(self, other) -> AbstractList[float]: + def __sub__(self, other: Union[float, AbstractList[float]] + ) -> AbstractList[float]: """ Support for ``new_list = list1 - list2``. Applies the subtract operator over this and other to create a new list. @@ -531,14 +546,17 @@ def __sub__(self, other) -> AbstractList[float]: :raises TypeError: """ if isinstance(other, AbstractList): + d_operation: Callable[[Any, float], float] = lambda x, y: x - y return DualList( - left=self, right=other, operation=lambda x, y: x - y) - if isinstance(other, numbers.Number): - return SingleList(a_list=self, operation=lambda x: x - other) + left=self, right=other, operation=d_operation) + if is_number(other): + s_operation: Callable[[Any], float] = lambda x: x - other + return SingleList(a_list=self, operation=s_operation) raise TypeError("__sub__ operation only supported for other " "RangedLists and numerical Values") - def __mul__(self, other) -> AbstractList[float]: + def __mul__(self, other: Union[float, AbstractList[float]] + ) -> AbstractList[float]: """ Support for ``new_list = list1 * list2``. Applies the multiply operator over this and other. @@ -552,14 +570,17 @@ def __mul__(self, other) -> AbstractList[float]: :raises TypeError: """ if isinstance(other, AbstractList): + d_operation: Callable[[Any, float], float] = lambda x, y: x * y return DualList( - left=self, right=other, operation=lambda x, y: x * y) - if isinstance(other, numbers.Number): - return SingleList(a_list=self, operation=lambda x: x * other) + left=self, right=other, operation=d_operation) + if is_number(other): + s_operation: Callable[[Any], float] = lambda x: x * other + return SingleList(a_list=self, operation=s_operation) raise TypeError("__mul__ operation only supported for other " "RangedLists and numerical Values") - def __truediv__(self, other) -> AbstractList[float]: + def __truediv__(self, other: Union[float, AbstractList[float]] + ) -> AbstractList[float]: """ Support for ``new_list = list1 / list2``. Applies the division operator over this and other to create a @@ -574,16 +595,19 @@ def __truediv__(self, other) -> AbstractList[float]: :raises TypeError: """ if isinstance(other, AbstractList): + d_operation: Callable[[Any, float], float] = lambda x, y: x / y return DualList( - left=self, right=other, operation=lambda x, y: x / y) - if isinstance(other, numbers.Number): + left=self, right=other, operation=d_operation) + if is_number(other): if _is_zero(other): raise ZeroDivisionError() - return SingleList(a_list=self, operation=lambda x: x / other) + s_operation: Callable[[Any], float] = lambda x: x / other + return SingleList(a_list=self, operation=s_operation) raise TypeError("__truediv__ operation only supported for other " "RangedLists and numerical Values") - def __floordiv__(self, other) -> AbstractList[int]: + def __floordiv__(self, other: Union[float, AbstractList[float]] + ) -> AbstractList[int]: """ Support for ``new_list = list1 // list2``. Applies the floor division operator over this and other. @@ -595,17 +619,19 @@ def __floordiv__(self, other) -> AbstractList[int]: :raises TypeError: """ if isinstance(other, AbstractList): + d_operation: Callable[[Any, float], int] = lambda x, y: int(x // y) return DualList( - left=self, right=other, operation=lambda x, y: x // y) - if isinstance(other, numbers.Number): + left=self, right=other, operation=d_operation) + if is_number(other): if _is_zero(other): raise ZeroDivisionError() - return SingleList(a_list=self, operation=lambda x: x // other) + s_operation: Callable[[Any], int] = lambda x: int(x / other) + return SingleList(a_list=self, operation=s_operation) raise TypeError("__floordiv__ operation only supported for other " "RangedLists and numerical Values") def apply_operation( - self, operation: Callable[[T], T]) -> AbstractList[T]: + self, operation: Callable[[T], U]) -> AbstractList[U]: """ Applies a function on the list to create a new one. The values of the new list are created on the fly so any changes @@ -620,14 +646,17 @@ def apply_operation( return SingleList(a_list=self, operation=operation) -class SingleList(AbstractList[T], Generic[T], metaclass=AbstractBase): +class SingleList(AbstractList[R], Generic[T, R], + metaclass=AbstractBase): """ A List that performs an operation on the elements of another list. """ __slots__ = [ "_a_list", "_operation"] - def __init__(self, a_list, operation, key=None): + def __init__(self, a_list: AbstractList[T], + operation: Callable[[T], R], + key: Optional[str] = None): """ :param AbstractList a_list: The list to perform the operation on :param callable operation: @@ -641,46 +670,55 @@ def __init__(self, a_list, operation, key=None): self._operation = operation @overrides(AbstractList.range_based) - def range_based(self): + def range_based(self) -> bool: return self._a_list.range_based() @overrides(AbstractList.get_value_by_id) - def get_value_by_id(self, the_id): + def get_value_by_id(self, the_id: int) -> R: return self._operation(self._a_list.get_value_by_id(the_id)) @overrides(AbstractList.get_single_value_by_slice) - def get_single_value_by_slice(self, slice_start, slice_stop): + def get_single_value_by_slice( + self, slice_start: int, slice_stop: int) -> R: return self._operation(self._a_list.get_single_value_by_slice( slice_start, slice_stop)) @overrides(AbstractList.get_single_value_by_ids) - def get_single_value_by_ids(self, ids): + def get_single_value_by_ids(self, ids: IdsType) -> R: return self._operation(self._a_list.get_single_value_by_ids(ids)) @overrides(AbstractList.iter_ranges) - def iter_ranges(self): + def iter_ranges(self) -> Iterator[Tuple[int, int, R]]: for (start, stop, value) in self._a_list.iter_ranges(): yield (start, stop, self._operation(value)) @overrides(AbstractList.get_default) - def get_default(self): - return self._operation(self._a_list.get_default()) + def get_default(self) -> Optional[R]: + default = self._a_list.get_default() + if default is None: + return None + return self._operation(default) @overrides(AbstractList.iter_ranges_by_slice) - def iter_ranges_by_slice(self, slice_start, slice_stop): + def iter_ranges_by_slice( + self, slice_start: int, slice_stop: int + ) -> Iterator[Tuple[int, int, R]]: for (start, stop, value) in \ self._a_list.iter_ranges_by_slice(slice_start, slice_stop): yield (start, stop, self._operation(value)) -class DualList(AbstractList[T], Generic[T], metaclass=AbstractBase): +class DualList(AbstractList[R], Generic[T, U, R], + metaclass=AbstractBase): """ A list which combines two other lists with an operation. """ __slots__ = [ "_left", "_operation", "_right"] - def __init__(self, left, right, operation, key=None): + def __init__(self, left: AbstractList[T], right: AbstractList[U], + operation: Callable[[T, U], R], + key: Optional[str] = None): """ :param AbstractList left: The first list to combine :param AbstractList right: The second list to combine @@ -700,29 +738,31 @@ def __init__(self, left, right, operation, key=None): self._operation = operation @overrides(AbstractList.range_based) - def range_based(self): + def range_based(self) -> bool: return self._left.range_based() and self._right.range_based() @overrides(AbstractList.get_value_by_id) - def get_value_by_id(self, the_id): + def get_value_by_id(self, the_id: int) -> R: return self._operation( self._left.get_value_by_id(the_id), self._right.get_value_by_id(the_id)) @overrides(AbstractList.get_single_value_by_slice) - def get_single_value_by_slice(self, slice_start, slice_stop): + def get_single_value_by_slice( + self, slice_start: int, slice_stop: int) -> R: return self._operation( self._left.get_single_value_by_slice(slice_start, slice_stop), self._right.get_single_value_by_slice(slice_start, slice_stop)) @overrides(AbstractList.get_single_value_by_ids) - def get_single_value_by_ids(self, ids): + def get_single_value_by_ids(self, ids: IdsType) -> R: return self._operation( self._left.get_single_value_by_ids(ids), self._right.get_single_value_by_ids(ids)) @overrides(AbstractList.iter_by_slice) - def iter_by_slice(self, slice_start, slice_stop): + def iter_by_slice( + self, slice_start: int, slice_stop: int) -> Iterator[R]: slice_start, slice_stop = self._check_slice_in_range( slice_start, slice_stop) if self._left.range_based(): @@ -738,46 +778,52 @@ def iter_by_slice(self, slice_start, slice_stop): # Left list is range based, right is not left_iter = self._left.iter_ranges_by_slice( slice_start, slice_stop) - right_iter = self._right.iter_by_slice(slice_start, slice_stop) + right_values = self._right.iter_by_slice( + slice_start, slice_stop) for (start, stop, left_value) in left_iter: for _ in range(start, stop): - yield self._operation(left_value, next(right_iter)) + yield self._operation(left_value, next(right_values)) else: if self._right.range_based(): # Right list is range based left is not - left_iter = self._left.iter_by_slice( + left_values = self._left.iter_by_slice( slice_start, slice_stop) right_iter = self._right.iter_ranges_by_slice( slice_start, slice_stop) for (start, stop, right_value) in right_iter: for _ in range(start, stop): - yield self._operation(next(left_iter), right_value) + yield self._operation(next(left_values), right_value) else: # Neither list is range based - left_iter = self._left.iter_by_slice(slice_start, slice_stop) - right_iter = self._right.iter_by_slice(slice_start, slice_stop) + left_values = self._left.iter_by_slice(slice_start, slice_stop) + right_values = self._right.iter_by_slice( + slice_start, slice_stop) while True: try: yield self._operation( - next(left_iter), next(right_iter)) + next(left_values), next(right_values)) except StopIteration: return @overrides(AbstractList.iter_ranges) - def iter_ranges(self): + def iter_ranges(self) -> Iterator[Tuple[int, int, R]]: left_iter = self._left.iter_ranges() right_iter = self._right.iter_ranges() return self._merge_ranges(left_iter, right_iter) @overrides(AbstractList.iter_ranges_by_slice) - def iter_ranges_by_slice(self, slice_start, slice_stop): + def iter_ranges_by_slice( + self, slice_start: int, slice_stop: int) -> Iterator[ + Tuple[int, int, R]]: left_iter = self._left.iter_ranges_by_slice(slice_start, slice_stop) right_iter = self._right.iter_ranges_by_slice(slice_start, slice_stop) return self._merge_ranges(left_iter, right_iter) - def _merge_ranges(self, left_iter, right_iter): + def _merge_ranges(self, left_iter: Iterator[Tuple[int, int, T]], + right_iter: Iterator[Tuple[int, int, U]] + ) -> Iterator[Tuple[int, int, R]]: (left_start, left_stop, left_value) = next(left_iter) (right_start, right_stop, right_value) = next(right_iter) try: @@ -796,6 +842,11 @@ def _merge_ranges(self, left_iter, right_iter): return @overrides(AbstractList.get_default) - def get_default(self): - return self._operation( - self._left.get_default(), self._right.get_default()) + def get_default(self) -> Optional[R]: + l_default = self._left.get_default() + if l_default is None: + return None + r_default = self._right.get_default() + if r_default is None: + return None + return self._operation(l_default, r_default) diff --git a/spinn_utilities/ranged/ids_view.py b/spinn_utilities/ranged/ids_view.py index fa897630..3a177603 100644 --- a/spinn_utilities/ranged/ids_view.py +++ b/spinn_utilities/ranged/ids_view.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import ( Dict, Generic, Iterable, Iterator, Optional, Sequence, Tuple, - overload, TYPE_CHECKING) + overload, TYPE_CHECKING, Union) from spinn_utilities.overrides import overrides from .abstract_dict import AbstractDict, _StrSeq, _Keys from .abstract_list import IdsType @@ -49,7 +49,7 @@ def get_value(self, key: Optional[_StrSeq]) -> Dict[str, T]: ... @overrides(AbstractDict.get_value) - def get_value(self, key: _Keys): + def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]: if isinstance(key, str): return self._range_dict.get_list(key).get_single_value_by_ids( self._ids) @@ -65,7 +65,8 @@ def get_value(self, key: _Keys): for k in key} @overrides(AbstractDict.set_value) - def set_value(self, key: str, value: T, use_list_as_value=False): + def set_value( + self, key: str, value: T, use_list_as_value: bool = False): ranged_list = self._range_dict.get_list(key) for _id in self._ids: ranged_list.set_value_by_id(the_id=_id, value=value) @@ -92,7 +93,7 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @overrides(AbstractDict.iter_all_values) - def iter_all_values(self, key: _Keys, update_safe=False): + def iter_all_values(self, key: _Keys, update_safe: bool = False): if isinstance(key, str): yield from self._range_dict.iter_values_by_ids( ids=self._ids, key=key, update_safe=update_safe) @@ -110,5 +111,7 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[Tuple[ ... @overrides(AbstractDict.iter_ranges) - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None + ) -> Union[Iterator[Tuple[int, int, T]], + Iterator[Tuple[int, int, Dict[str, T]]]]: return self._range_dict.iter_ranges_by_ids(key=key, ids=self._ids) diff --git a/spinn_utilities/ranged/range_dictionary.py b/spinn_utilities/ranged/range_dictionary.py index a3a544c3..0ee2b7e3 100644 --- a/spinn_utilities/ranged/range_dictionary.py +++ b/spinn_utilities/ranged/range_dictionary.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations from typing import ( - Dict, Iterable, Iterator, Optional, Sequence, Tuple, Union, + Dict, Generator, Iterable, Iterator, Optional, Sequence, Tuple, Union, Generic, overload, TYPE_CHECKING) from typing_extensions import TypeAlias from spinn_utilities.overrides import overrides @@ -28,6 +28,8 @@ from .abstract_view import AbstractView _KeyType: TypeAlias = Union[int, slice, Iterable[int]] +_Keys: TypeAlias = Union[None, str, _StrSeq] + _Range: TypeAlias = Tuple[int, int, T] _SimpleRangeIter: TypeAlias = Iterator[_Range] _CompoundRangeIter: TypeAlias = Iterator[Tuple[int, int, Dict[str, T]]] @@ -170,7 +172,8 @@ def get_value(self, key: str) -> T: ... def get_value(self, key: Optional[_StrSeq]) -> Dict[str, T]: ... @overrides(AbstractDict.get_value, extend_defaults=True) - def get_value(self, key: Union[str, None, _StrSeq] = None): + def get_value(self, key: Union[str, None, _StrSeq] = None + ) -> Union[T, Dict[str, T]]: if isinstance(key, str): return self._value_lists[key].get_single_value_all() if key is None: @@ -186,7 +189,7 @@ def get_values_by_id(self, key: str, the_id: int) -> T: ... def get_values_by_id( self, key: Optional[_StrSeq], the_id: int) -> Dict[str, T]: ... - def get_values_by_id(self, key, the_id): + def get_values_by_id(self, key, the_id) -> Union[T, Dict[str, T]]: """ Same as :py:meth:`get_value` but limited to a single ID. @@ -217,15 +220,16 @@ def get_list(self, key: str) -> RangedList[T]: @overload def update_safe_iter_all_values( - self, key: str, ids: IdsType) -> Iterator[T]: ... + self, key: str, ids: IdsType) -> Generator[T, None, None]: ... @overload def update_safe_iter_all_values( self, key: Optional[_StrSeq], - ids: IdsType) -> Iterator[Dict[str, T]]: ... + ids: IdsType) -> Generator[Dict[str, T], None, None]: ... def update_safe_iter_all_values( - self, key: Union[str, Optional[_StrSeq]], ids: IdsType): + self, key: Union[str, Optional[_StrSeq]], + ids: IdsType) -> Generator[Union[T, Dict[str, T]], None, None]: """ Same as :py:meth:`iter_all_values` @@ -246,7 +250,7 @@ def iter_all_values(self, key: Optional[_StrSeq], ... @overrides(AbstractDict.iter_all_values, extend_defaults=True) - def iter_all_values(self, key=None, update_safe: bool = False): + def iter_all_values(self, key: _Keys, update_safe: bool = False): if isinstance(key, str): if update_safe: return self._value_lists[key].iter() @@ -371,7 +375,7 @@ def keys(self) -> Iterable[str]: def _merge_ranges( self, range_iters: Dict[str, Iterator[Tuple[int, int, T]]] - ) -> _CompoundRangeIter: + ) -> Iterator[Tuple[int, int, Dict[str, T]]]: current: Dict[str, T] = dict() ranges: Dict[str, Tuple[int, int, T]] = dict() start = 0 @@ -399,15 +403,28 @@ def _merge_ranges( stop = next_stop yield (start, stop, current) + @overload + def iter_ranges(self, key: str) -> Iterator[Tuple[int, int, T]]: + ... + + @overload + def iter_ranges(self, key: Optional[_StrSeq]) -> Iterator[Tuple[ + int, int, Dict[str, T]]]: + ... + @overrides(AbstractDict.iter_ranges) - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None) -> \ + Union[Iterator[Tuple[int, int, T]], + Iterator[Tuple[int, int, Dict[str, T]]]]: if isinstance(key, str): return self._value_lists[key].iter_ranges() if key is None: - key = self.keys() + keys = self.keys() + else: + keys = key return self._merge_ranges({ a_key: self._value_lists[a_key].iter_ranges() - for a_key in key}) + for a_key in keys}) @overload def iter_ranges_by_id( diff --git a/spinn_utilities/ranged/ranged_list.py b/spinn_utilities/ranged/ranged_list.py index 153d3e11..8bb87b7b 100644 --- a/spinn_utilities/ranged/ranged_list.py +++ b/spinn_utilities/ranged/ranged_list.py @@ -132,7 +132,8 @@ def get_value_by_id(self, the_id: int) -> T: return self.__the_values[the_id] @overrides(AbstractList.get_single_value_by_slice) - def get_single_value_by_slice(self, slice_start: int, slice_stop: int): + def get_single_value_by_slice( + self, slice_start: int, slice_stop: int) -> T: slice_start, slice_stop = self._check_slice_in_range( slice_start, slice_stop) diff --git a/spinn_utilities/ranged/single_view.py b/spinn_utilities/ranged/single_view.py index 40f42118..7cc699ff 100644 --- a/spinn_utilities/ranged/single_view.py +++ b/spinn_utilities/ranged/single_view.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import ( Dict, Generic, Iterator, Optional, Sequence, Tuple, overload, - TYPE_CHECKING) + TYPE_CHECKING, Union) from spinn_utilities.overrides import overrides from .abstract_dict import AbstractDict, T, _StrSeq, _Keys from .abstract_view import AbstractView @@ -48,7 +48,7 @@ def get_value(self, key: Optional[_StrSeq]) -> Dict[str, T]: ... @overrides(AbstractDict.get_value) - def get_value(self, key: _Keys): + def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]: if isinstance(key, str): return self._range_dict.get_list(key).get_value_by_id( the_id=self._id) @@ -74,7 +74,7 @@ def iter_all_values( ... @overrides(AbstractDict.iter_all_values) - def iter_all_values(self, key, update_safe=False): + def iter_all_values(self, key: _Keys, update_safe: bool = False): if isinstance(key, str): yield self._range_dict.get_list(key).get_value_by_id( the_id=self._id) @@ -82,7 +82,7 @@ def iter_all_values(self, key, update_safe=False): yield self._range_dict.get_values_by_id(key=key, the_id=self._id) @overrides(AbstractDict.set_value) - def set_value(self, key: str, value: T, use_list_as_value=False): + def set_value(self, key: str, value: T, use_list_as_value: bool = False): return self._range_dict.get_list(key).set_value_by_id( value=value, the_id=self._id) @@ -96,5 +96,7 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[ ... @overrides(AbstractDict.iter_ranges) - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None + ) -> Union[Iterator[Tuple[int, int, T]], + Iterator[Tuple[int, int, Dict[str, T]]]]: return self._range_dict.iter_ranges_by_id(key=key, the_id=self._id) diff --git a/spinn_utilities/ranged/slice_view.py b/spinn_utilities/ranged/slice_view.py index 94c34958..924c322e 100644 --- a/spinn_utilities/ranged/slice_view.py +++ b/spinn_utilities/ranged/slice_view.py @@ -14,7 +14,7 @@ from __future__ import annotations from typing import ( Dict, Generic, Iterable, Iterator, Optional, Sequence, Tuple, overload, - TYPE_CHECKING) + TYPE_CHECKING, Union) from spinn_utilities.overrides import overrides from .abstract_dict import AbstractDict, T, _StrSeq, _Keys from .abstract_view import AbstractView @@ -50,7 +50,7 @@ def get_value(self, key: Optional[_StrSeq]) -> Dict[str, T]: ... @overrides(AbstractDict.get_value) - def get_value(self, key: _Keys): + def get_value(self, key: _Keys) -> Union[T, Dict[str, T]]: if isinstance(key, str): return self._range_dict.get_list(key).get_single_value_by_slice( slice_start=self._start, slice_stop=self._stop) @@ -88,7 +88,7 @@ def iter_all_values( ... @overrides(AbstractDict.iter_all_values, extend_defaults=True) - def iter_all_values(self, key=None, update_safe=False): + def iter_all_values(self, key: _Keys = None, update_safe: bool = False): if isinstance(key, str): if update_safe: return self.update_safe_iter_all_values(key) @@ -99,8 +99,8 @@ def iter_all_values(self, key=None, update_safe=False): update_safe=update_safe) @overrides(AbstractDict.set_value) - def set_value( - self, key: str, value: _ValueType, use_list_as_value=False): + def set_value(self, key: str, value: _ValueType, + use_list_as_value: bool = False): self._range_dict.get_list(key).set_value_by_slice( slice_start=self._start, slice_stop=self._stop, value=value, use_list_as_value=use_list_as_value) @@ -115,6 +115,8 @@ def iter_ranges(self, key: Optional[_StrSeq] = None) -> Iterator[ ... @overrides(AbstractDict.iter_ranges) - def iter_ranges(self, key=None): + def iter_ranges(self, key: _Keys = None + ) -> Union[Iterator[Tuple[int, int, T]], + Iterator[Tuple[int, int, Dict[str, T]]]]: return self._range_dict.iter_ranges_by_slice( key=key, slice_start=self._start, slice_stop=self._stop) diff --git a/unittests/abstract_base/abstract_has_constraints.py b/unittests/abstract_base/abstract_has_constraints.py index fb960db9..0744a058 100644 --- a/unittests/abstract_base/abstract_has_constraints.py +++ b/unittests/abstract_base/abstract_has_constraints.py @@ -12,8 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.abstract_base import ( - AbstractBase, abstractproperty, abstractmethod) + AbstractBase, abstractmethod) class AbstractHasConstraints(object, metaclass=AbstractBase): @@ -23,7 +24,7 @@ class AbstractHasConstraints(object, metaclass=AbstractBase): __slots__ = () @abstractmethod - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any): """ Add a new constraint to the collection of constraints :param constraint: constraint to add @@ -33,8 +34,9 @@ def add_constraint(self, constraint): If the constraint is not valid """ - @abstractproperty - def constraints(self): + @property + @abstractmethod + def constraints(self) -> Any: """ An iterable of constraints :return: iterable of constraints diff --git a/unittests/abstract_base/checked_bad_param.py b/unittests/abstract_base/checked_bad_param.py index eb1d574f..7ef46279 100644 --- a/unittests/abstract_base/checked_bad_param.py +++ b/unittests/abstract_base/checked_bad_param.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.overrides import overrides from .abstract_grandparent import AbstractGrandParent from .abstract_has_constraints import AbstractHasConstraints @@ -25,9 +26,9 @@ def set_label(selfself, label): pass @overrides(AbstractHasConstraints.add_constraint) - def add_constraint(self, not_constraint): + def add_constraint(self, not_constraint: Any): raise NotImplementedError("We set our own constrainst") @overrides(AbstractHasConstraints.constraints) - def constraints(self): + def constraints(self) -> Any: return ["No night feeds", "No nappy changes"] diff --git a/unittests/abstract_base/grandparent.py b/unittests/abstract_base/grandparent.py index b142ebfb..6fd67428 100644 --- a/unittests/abstract_base/grandparent.py +++ b/unittests/abstract_base/grandparent.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.overrides import overrides from .abstract_grandparent import AbstractGrandParent from .abstract_has_constraints import AbstractHasConstraints @@ -26,9 +27,9 @@ def set_label(selfself, label): pass @overrides(AbstractHasConstraints.add_constraint) - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any): raise NotImplementedError("We set our own constrainst") @overrides(AbstractHasConstraints.constraints) - def constraints(self): + def constraints(self) -> Any: return ["No night feeds", "No nappy changes"] diff --git a/unittests/abstract_base/no_label.py b/unittests/abstract_base/no_label.py index 4b5d99ee..ec90b657 100644 --- a/unittests/abstract_base/no_label.py +++ b/unittests/abstract_base/no_label.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.overrides import overrides from .abstract_grandparent import AbstractGrandParent from .abstract_has_constraints import AbstractHasConstraints @@ -23,9 +24,9 @@ def set_label(selfself, label): pass @overrides(AbstractHasConstraints.add_constraint) - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any): raise NotImplementedError("We set our own constraints") @overrides(AbstractHasConstraints.constraints) - def constraints(self): + def constraints(self) -> Any: return ["No night feeds", "No nappy changes"] diff --git a/unittests/abstract_base/unchecked_bad_param.py b/unittests/abstract_base/unchecked_bad_param.py index b7289d77..f570fdba 100644 --- a/unittests/abstract_base/unchecked_bad_param.py +++ b/unittests/abstract_base/unchecked_bad_param.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Any from spinn_utilities.overrides import overrides from .abstract_grandparent import AbstractGrandParent from .abstract_has_constraints import AbstractHasConstraints @@ -25,9 +26,9 @@ def set_label(selfself, not_label): pass @overrides(AbstractHasConstraints.add_constraint) - def add_constraint(self, constraint): + def add_constraint(self, constraint: Any): raise NotImplementedError("We set our own constrainst") @overrides(AbstractHasConstraints.constraints) - def constraints(self): + def constraints(self) -> Any: return ["No night feeds", "No nappy changes"] diff --git a/unittests/ranged/single_list_test.py b/unittests/ranged/single_list_test.py index f5cf57a6..3562c02d 100644 --- a/unittests/ranged/single_list_test.py +++ b/unittests/ranged/single_list_test.py @@ -99,6 +99,28 @@ def test_iter_by_slice(): assert [2, 4, 8] == list(single.iter_by_slice(2, 5)) +def test_plus_minus(): + a_list = RangedList(5, 12, "twelve") + b_list = a_list + 2 + c_list = b_list - 2 + assert a_list == c_list + d_list = RangedList(5, 8, "eight") + e_list = a_list + d_list + f_list = e_list - d_list + assert a_list == f_list + + +def test_multi_divide(): + a_list = RangedList(5, 12, "twelve") + b_list = a_list * 2 + c_list = b_list / 2 + assert a_list == c_list + d_list = RangedList(5, 4, "four") + e_list = a_list / d_list + f_list = e_list * d_list + assert a_list == f_list + + def test_equals(): a_list = RangedList(5, 12, "twelve") double = SingleList(a_list=a_list, operation=lambda x: x * 2) diff --git a/unittests/test_log.py b/unittests/test_log.py index 0e896399..aad534f7 100644 --- a/unittests/test_log.py +++ b/unittests/test_log.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +from datetime import datetime import logging +from typing import List, Optional, Tuple from spinn_utilities.log import ( _BraceMessage, ConfiguredFilter, ConfiguredFormatter, FormatAdapter, LogLevelTooHighException) @@ -54,13 +56,15 @@ def __init__(self): self.data = [] @overrides(LogStore.store_log) - def store_log(self, level, message, timestamp=None): + def store_log(self, level: int, message: str, + timestamp: Optional[datetime] = None): if level == logging.CRITICAL: 1/0 self.data.append((level, message)) @overrides(LogStore.retreive_log_messages) - def retreive_log_messages(self, min_level=0): + def retreive_log_messages( + self, min_level: int = 0) -> List[Tuple[int, str]]: result = [] for (level, message) in self.data: if level >= min_level: @@ -68,7 +72,7 @@ def retreive_log_messages(self, min_level=0): return result @overrides(LogStore.get_location) - def get_location(self): + def get_location(self) -> str: return "MOCK" diff --git a/unittests/test_overrides.py b/unittests/test_overrides.py index b4e291c5..64ea950e 100644 --- a/unittests/test_overrides.py +++ b/unittests/test_overrides.py @@ -13,6 +13,7 @@ # limitations under the License. import pytest +from typing import Any from spinn_utilities.abstract_base import abstractmethod from spinn_utilities.overrides import overrides @@ -21,11 +22,11 @@ class Base(object): - def foo(self, x, y, z): + def foo(self, x: Any, y: Any, z: Any): """this is the doc""" return [x, y, z] - def foodef(self, x, y, z=True): + def foodef(self, x: Any, y: Any, z: Any = True): """this is the doc""" return [x, y, z] @@ -37,7 +38,7 @@ def boo(self) -> int: def test_basic_use(): class Sub(Base): @overrides(Base.foo) - def foo(self, x, y, z): + def foo(self, x: Any, y: Any, z: Any): return super().foo(z, y, x) assert Sub().foo(1, 2, 3) == [3, 2, 1] @@ -45,7 +46,7 @@ def foo(self, x, y, z): def test_doc_no_sub_extend(): class Sub(Base): @overrides(Base.foo, extend_doc=True) - def foo(self, x, y, z): + def foo(self, x: Any, y: Any, z: Any): return [z, y, x] assert Sub.foo.__doc__ == "this is the doc" @@ -53,7 +54,7 @@ def foo(self, x, y, z): def test_doc_no_sub_no_extend(): class Sub(Base): @overrides(Base.foo, extend_doc=False) - def foo(self, x, y, z): + def foo(self, x: Any, y: Any, z: Any): return [z, y, x] assert Sub.foo.__doc__ == "this is the doc" @@ -61,7 +62,7 @@ def foo(self, x, y, z): def test_doc_sub_no_extend(): class Sub(Base): @overrides(Base.foo, extend_doc=False) - def foo(self, x, y, z): + def foo(self, x: Any, y: Any, z: Any): """(abc)""" return [z, y, x] assert Sub.foo.__doc__ == "(abc)" @@ -70,7 +71,7 @@ def foo(self, x, y, z): def test_doc_sub_extend(): class Sub(Base): @overrides(Base.foo, extend_doc=True) - def foo(self, x, y, z): + def foo(self, x: Any, y: Any, z: Any): """(abc)""" return [z, y, x] assert Sub.foo.__doc__ == "this is the doc(abc)" @@ -80,7 +81,7 @@ def test_removes_param(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foo) - def foo(self, x, y): + def foo(self, x: Any, y: Any): return [y, x] assert str(e.value) == WRONG_ARGS.format(3) @@ -89,7 +90,7 @@ def test_adds_param(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foo) - def foo(self, x, y, z, w): + def foo(self, x: Any, y: Any, z: Any, w: Any): return [w, z, y, x] assert str(e.value) == WRONG_ARGS.format(5) @@ -97,7 +98,7 @@ def foo(self, x, y, z, w): def test_adds_expected_param(): class Sub(Base): @overrides(Base.foo, additional_arguments=["w"]) - def foo(self, x, y, z, w): + def foo(self, x: Any, y: Any, z: Any, w: Any): return [w, z, y, x] assert Sub().foo(1, 2, 3, 4) == [4, 3, 2, 1] @@ -106,7 +107,7 @@ def test_renames_param(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foo) - def foo(self, x, y, w): + def foo(self, x: Any, y: Any, w: Any): return [w, y, x] assert str(e.value) == "Missing argument z" @@ -115,7 +116,7 @@ def test_renames_param_expected(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foo, additional_arguments=["w"]) - def foo(self, x, y, w): + def foo(self, x: Any, y: Any, w: Any): return [w, y, x] assert str(e.value) == WRONG_ARGS.format(4) # TODO: Fix the AWFUL error message in this case! @@ -124,7 +125,7 @@ def foo(self, x, y, w): def test_changes_params_defaults(): class Sub(Base): @overrides(Base.foodef) - def foodef(self, x, y, z=False): + def foodef(self, x: Any, y: Any, z: Any = False): return [z, y, x] assert Sub().foodef(1, 2) == [False, 2, 1] @@ -133,7 +134,7 @@ def test_undefaults_super_param(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foodef) - def foodef(self, x, y, z): + def foodef(self, x: Any, y: Any, z: Any): return [z, y, x] assert str(e.value) == BAD_DEFS @@ -142,7 +143,7 @@ def test_defaults_super_param(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foodef) - def foodef(self, x, y=1, z=2): + def foodef(self, x: Any, y: Any = 1, z: Any = 2): return [z, y, x] assert str(e.value) == BAD_DEFS # TODO: Should this case fail at all? @@ -151,7 +152,7 @@ def foodef(self, x, y=1, z=2): def test_defaults_super_param_expected(): class Sub(Base): @overrides(Base.foodef, extend_defaults=True) - def foodef(self, x, y=1, z=2): + def foodef(self, x: Any, y: Any = 1, z: Any = 2): return [z, y, x] assert Sub().foodef(7) == [2, 1, 7] @@ -160,7 +161,7 @@ def test_defaults_extra_param(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foodef, additional_arguments=['pdq']) - def foodef(self, x, y, z=1, pdq=2): + def foodef(self, x: Any, y: Any, z: Any = 1, pdq: Any = 2): return [z, y, x, pdq] assert str(e.value) == BAD_DEFS @@ -169,7 +170,7 @@ def test_defaults_super_param_no_super_defaults(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foo) - def foo(self, x, y, z=7): + def foo(self, x: Any, y: Any, z: Any = 7): return [z, y, x] assert str(e.value) == BAD_DEFS # TODO: Should this case fail at all? @@ -179,7 +180,7 @@ def test_crazy_extends(): with pytest.raises(AttributeError) as e: class Sub(Base): @overrides(Base.foo) - def bar(self, x, y, z): + def bar(self, x: Any, y: Any, z: Any): return [z, y, x] assert str(e.value) == \ "super class method name foo does not match bar. Ensure overrides is "\