Skip to content

Commit

Permalink
Merge pull request #257 from SpiNNakerManchester/overrides_check
Browse files Browse the repository at this point in the history
Overrides check
  • Loading branch information
Christian-B authored Jan 15, 2024
2 parents 1aae7f5 + fb779e0 commit e13f213
Show file tree
Hide file tree
Showing 21 changed files with 311 additions and 134 deletions.
1 change: 1 addition & 0 deletions .pylint_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,5 +38,6 @@ iter
UnusedVariable

# Misc
Github
haha
X'th
2 changes: 2 additions & 0 deletions spinn_utilities/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from spinn_utilities.config_holder import (
add_default_cfg, clear_cfg_files)
from spinn_utilities.data.utils_data_writer import UtilsDataWriter
from spinn_utilities.overrides import overrides

BASE_CONFIG_FILE = "spinn_utilities.cfg"

Expand All @@ -39,4 +40,5 @@ def add_spinn_utilities_cfg() -> None:
"""
Loads the default config values for spinn_utilities
"""
overrides.check_types()
add_default_cfg(os.path.join(os.path.dirname(__file__), BASE_CONFIG_FILE))
9 changes: 5 additions & 4 deletions spinn_utilities/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
import re
import sys
from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple
from inspect import getfullargspec
from .log_store import LogStore
from .overrides import overrides
Expand Down Expand Up @@ -214,7 +214,7 @@ def __init__(self, logger: logging.Logger, extra=None):
super().__init__(logger, extra)
self.do_log = logger._log # pylint: disable=protected-access

@overrides(logging.LoggerAdapter.log, extend_doc=False)
@overrides(logging.LoggerAdapter.log, extend_doc=False, adds_typing=True)
def log(self, level: int, msg: object, *args, **kwargs):
"""
Delegate a log call to the underlying logger, applying appropriate
Expand Down Expand Up @@ -247,8 +247,9 @@ def log(self, level: int, msg: object, *args, **kwargs):
log_kwargs["exc_info"] = kwargs["exc_info"]
self.do_log(level, message, (), **log_kwargs)

@overrides(logging.LoggerAdapter.process, extend_doc=False)
def process(self, msg: object, kwargs) -> Tuple[object, dict]:
@overrides(logging.LoggerAdapter.process, extend_doc=False,
adds_typing=True)
def process(self, msg: object, kwargs: Any) -> Tuple[object, dict]:
"""
Process the logging message and keyword arguments passed in to a
logging call to insert contextual information. You can either
Expand Down
66 changes: 63 additions & 3 deletions spinn_utilities/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
# limitations under the License.

import inspect
import os
from types import FunctionType, MethodType
from typing import Any, Callable, Iterable, Optional, TypeVar

#: :meta private:
Method = TypeVar("Method", bound=Callable[..., Any])

Expand All @@ -27,6 +29,9 @@ class overrides(object):
copies the doc-string for the method, and enforces that the method
overridden is specified, making maintenance easier.
"""
# This near constant is changed by unit tests to check our code
# Github actions sets TYPE_OVERRIDES as True
__CHECK_TYPES = os.getenv("TYPE_OVERRIDES")

__slots__ = [
# The method in the superclass that this method overrides
Expand All @@ -40,13 +45,15 @@ class overrides(object):
# True if the name check is relaxed
"_relax_name_check",
# The name of the thing being overridden for error messages
"_override_name"
"_override_name",
# True if this method adds typing info not in the original
"_adds_typing"
]

def __init__(
self, super_class_method, extend_doc: bool = True,
self, super_class_method, *, extend_doc: bool = True,
additional_arguments: Optional[Iterable[str]] = None,
extend_defaults: bool = False):
extend_defaults: bool = False, adds_typing: bool = False,):
"""
:param super_class_method: The method to override in the superclass
:param bool extend_doc:
Expand All @@ -58,6 +65,9 @@ def __init__(
superclass method, e.g., that are to be injected
:param bool extend_defaults:
Whether the subclass may specify extra defaults for the parameters
:param adds_typing:
Allows more typing (of non additional) than in the subclass.
Should only be used for built in super classes
"""
if isinstance(super_class_method, property):
super_class_method = super_class_method.fget
Expand All @@ -73,6 +83,7 @@ def __init__(
self._additional_arguments = frozenset(additional_arguments)
else:
self._additional_arguments = frozenset()
self._adds_typing = adds_typing

@staticmethod
def __match_defaults(default_args, super_defaults, extend_ok):
Expand All @@ -84,6 +95,45 @@ def __match_defaults(default_args, super_defaults, extend_ok):
return len(default_args) >= len(super_defaults)
return len(default_args) == len(super_defaults)

def _verify_types(self, method_args, super_args, all_args):
"""
Check that the arguments match.
"""
if not self.__CHECK_TYPES:
return
if "self" in all_args:
all_args.remove("self")
elif "cls" in all_args:
all_args.remove("cls")
method_types = method_args.annotations
super_types = super_args.annotations
for arg in all_args:
if arg not in super_types and not self._adds_typing:
raise AttributeError(
f"Super Method {self._superclass_method.__name__} "
f"has untyped arguments including {arg}")
if arg not in method_types:
raise AttributeError(
f"Method {self._superclass_method.__name__} "
f"has untyped arguments including {arg}")

if len(all_args) == 0:
if "return" not in super_types and not self._adds_typing and \
not method_args.varkw and not method_args.varargs:
raise AttributeError(
f"Super Method {self._superclass_method.__name__} "
f"has no arguments so should declare a return type")
if "return" in super_types:
if "return" not in method_types:
raise AttributeError(
f"Method {self._superclass_method.__name__} "
f"has no return type, while super does")
else:
if "return" in method_types and not self._adds_typing:
raise AttributeError(
f"Super Method {self._superclass_method.__name__} "
f"has no return type, while this does")

def __verify_method_arguments(self, method: Method):
"""
Check that the arguments match.
Expand All @@ -109,6 +159,7 @@ def __verify_method_arguments(self, method: Method):
default_args, super_args.defaults, self._extend_defaults):
raise AttributeError(
f"Default arguments don't match {self._override_name}")
self._verify_types(method_args, super_args, all_args)

def __call__(self, method: Method) -> Method:
"""
Expand Down Expand Up @@ -142,3 +193,12 @@ def __call__(self, method: Method) -> Method:
method.__doc__ = (
self._superclass_method.__doc__ + (method.__doc__ or ""))
return method

@classmethod
def check_types(cls):
"""
If called will trigger check that all parameters are checked.
Used for testing, to avoid users being affected by the strict checks
"""
cls.__CHECK_TYPES = True
4 changes: 3 additions & 1 deletion spinn_utilities/package_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os
import sys
import traceback
from spinn_utilities.overrides import overrides


def all_modules(directory, prefix, remove_pyc_files=False):
Expand Down Expand Up @@ -81,10 +82,10 @@ def load_modules(
if module in exclusions:
print("SKIPPING " + module)
continue
print(module)
try:
__import__(module)
except Exception: # pylint: disable=broad-except
print(f"Error with {module}")
if gather_errors:
errors.append((module, sys.exc_info()))
else:
Expand Down Expand Up @@ -113,6 +114,7 @@ def load_module(
True if errors should be gathered, False to report on first error
:return: None
"""
overrides.check_types()
if exclusions is None:
exclusions = []
module = __import__(name)
Expand Down
16 changes: 8 additions & 8 deletions spinn_utilities/progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,10 @@ def update(self, amount_to_add=1):
self._currently_completed += amount_to_add
self._check_differences()

def _print_overwritten_line(self, string):
def _print_overwritten_line(self, string: str):
print("\r" + string, end="", file=self._destination)

def _print_distance_indicator(self, description):
def _print_distance_indicator(self, description: str):
if description is not None:
print(description, file=self._destination)

Expand All @@ -112,7 +112,7 @@ def _print_distance_line(self, first_space, second_space):
f"{' ' * second_space}100%{self._end_character}"
print(line, end="", file=self._destination)

def _print_progress(self, length):
def _print_progress(self, length: int):
chars_to_print = length
if not self._in_bad_terminal:
self._print_overwritten_line(self._end_character)
Expand All @@ -127,7 +127,7 @@ def _print_progress_unit(self, chars_to_print):
# pylint: disable=unused-argument
print(self._step_character, end='', file=self._destination)

def _print_progress_done(self):
def _print_progress_done(self) -> None:
self._print_progress(ProgressBar.MAX_LENGTH_IN_CHARS)
if not self._in_bad_terminal:
print(self._end_character, file=self._destination)
Expand Down Expand Up @@ -324,19 +324,19 @@ class DummyProgressBar(ProgressBar):
fails in exactly the same way.
"""
@overrides(ProgressBar._print_overwritten_line)
def _print_overwritten_line(self, string):
def _print_overwritten_line(self, string: str):
pass

@overrides(ProgressBar._print_distance_indicator)
def _print_distance_indicator(self, description):
def _print_distance_indicator(self, description: str):
pass

@overrides(ProgressBar._print_progress)
def _print_progress(self, length):
def _print_progress(self, length: int):
pass

@overrides(ProgressBar._print_progress_done)
def _print_progress_done(self):
def _print_progress_done(self) -> None:
pass

def __repr__(self):
Expand Down
8 changes: 5 additions & 3 deletions spinn_utilities/ranged/abstract_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
# Can't be Iterable[str] or Sequence[str] because that includes str itself
_StrSeq: TypeAlias = Union[
MutableSequence[str], Tuple[str, ...], FrozenSet[str], Set[str]]
_Keys: TypeAlias = Union[None, str, _StrSeq]
_Keys: TypeAlias = Optional[Union[str, _StrSeq]]


class AbstractDict(Generic[T], metaclass=AbstractBase):
Expand Down Expand Up @@ -117,7 +117,7 @@ def iter_all_values(self, key: Optional[_StrSeq],
...

@abstractmethod
def iter_all_values(self, key, update_safe=False):
def iter_all_values(self, key: _Keys, update_safe: bool = False):
"""
Iterates over the value(s) for all IDs covered by this view.
There will be one yield for each ID even if values are repeated.
Expand Down Expand Up @@ -181,7 +181,9 @@ def iter_ranges(self, key: Optional[_StrSeq]) -> Iterator[Tuple[
...

@abstractmethod
def iter_ranges(self, key=None):
def iter_ranges(self, key: _Keys = None
) -> Union[Iterator[Tuple[int, int, T]],
Iterator[Tuple[int, int, Dict[str, T]]]]:
"""
Iterates over the ranges(s) for all IDs covered by this view.
There will be one yield for each range which may cover one or
Expand Down
Loading

0 comments on commit e13f213

Please sign in to comment.