Skip to content

Commit

Permalink
typing
Browse files Browse the repository at this point in the history
  • Loading branch information
Christian-B committed Oct 30, 2024
1 parent 66ef6ea commit 2d8d46c
Show file tree
Hide file tree
Showing 47 changed files with 162 additions and 129 deletions.
8 changes: 4 additions & 4 deletions spinnman/board_test_configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ class BoardTestConfiguration(object):
Configuration to use for a test board
"""

def __init__(self):
self.remotehost = None
self.auto_detect_bmp = None
def __init__(self) -> None:
self.remotehost: str = "UNSET"
self.auto_detect_bmp: bool = False

def set_up_remote_board(self, version: Optional[int] = None):
def set_up_remote_board(self, version: Optional[int] = None) -> None:
"""
Gets a remote board to test, returning the first that it finds.
Expand Down
7 changes: 3 additions & 4 deletions spinnman/connections/token_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TokenBucket(object):
"""
__slots__ = ('_capacity', '_tokens', '_fill_rate', '_timestamp')

def __init__(self, tokens, fill_rate):
def __init__(self, tokens: int, fill_rate: float):
"""
:param int tokens: the total tokens in the bucket
:param float fill_rate:
Expand All @@ -38,7 +38,7 @@ def __init__(self, tokens, fill_rate):
self._fill_rate = float(fill_rate)
self._timestamp = time.time()

def consume(self, tokens, block=True):
def consume(self, tokens: int, block: bool = True) -> bool:
"""
Consume tokens from the bucket. Returns True if there were
sufficient tokens.
Expand All @@ -65,11 +65,10 @@ def consume(self, tokens, block=True):
return False

@property
def tokens(self):
def tokens(self) -> float:
"""
The number of tokens currently in the bucket.
:rtype: int
"""
if self._tokens < self._capacity:
now = time.time()
Expand Down
5 changes: 4 additions & 1 deletion spinnman/connections/udp_packet_connections/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from spinnman.messages.sdp import SDPHeader


# Kept for spalloc_server to use
def update_sdp_header_for_udp_send(sdp_header, source_x, source_y):
def update_sdp_header_for_udp_send(
sdp_header: SDPHeader, source_x: int, source_y: int) -> None:
"""
Apply defaults to the SDP header for sending over UDP.
Expand Down
2 changes: 1 addition & 1 deletion spinnman/data/spinnman_data_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ def get_new_id(cls) -> int:
return cls.__data._app_id_tracker.get_new_id()

@classmethod
def free_id(cls, app_id: int):
def free_id(cls, app_id: int) -> None:
"""
Frees up an app_id.
Expand Down
26 changes: 14 additions & 12 deletions spinnman/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@
from __future__ import annotations
import traceback
from types import TracebackType
from typing import Any, List, Optional, FrozenSet, Union, TYPE_CHECKING
from typing import (
Any, Generic, List, Optional, FrozenSet, TYPE_CHECKING, TypeVar, Union)
if TYPE_CHECKING:
from spinnman.messages.scp.enums import SCPResult
from spinnman.model.enums import CPUState
from spinnman.model import CPUInfos
from spinnman.messages.scp.abstract_messages import AbstractSCPRequest
from spinnman.connections.udp_packet_connections import SCAMPConnection

T = TypeVar("T")


class SpinnmanException(Exception):
"""
Expand Down Expand Up @@ -64,13 +67,13 @@ def problem(self) -> str:
return self._problem


class SpinnmanInvalidParameterException(SpinnmanException):
class SpinnmanInvalidParameterException(SpinnmanException, Generic[T]):
"""
An exception that indicates that the value of one of the parameters
passed was invalid.
"""

def __init__(self, parameter: str, value, problem: str):
def __init__(self, parameter: str, value: T, problem: str):
"""
:param str parameter: The name of the parameter that is invalid
:param str value: The value of the parameter that is invalid
Expand All @@ -93,7 +96,7 @@ def parameter(self) -> str:
return self._parameter

@property
def value(self):
def value(self) -> T:
"""
The value that is invalid.
"""
Expand All @@ -115,7 +118,7 @@ class SpinnmanInvalidParameterTypeException(SpinnmanException):
passed was invalid.
"""

def __init__(self, parameter: str, param_type, problem: str):
def __init__(self, parameter: str, param_type: str, problem: str):
"""
:param str parameter: The name of the parameter that is invalid
:param str param_type: The type of the parameter that is invalid
Expand All @@ -138,7 +141,7 @@ def parameter(self) -> str:
return self._parameter

@property
def type(self):
def type(self) -> str:
"""
The value that is invalid.
"""
Expand Down Expand Up @@ -186,13 +189,13 @@ def __init__(self) -> None:
super().__init__("connection is closed")


class SpinnmanTimeoutException(SpinnmanException):
class SpinnmanTimeoutException(SpinnmanException, Generic[T]):
"""
An exception that indicates that a timeout occurred before an operation
could finish.
"""

def __init__(self, operation: Any, timeout: Optional[float],
def __init__(self, operation: T, timeout: Optional[float],
msg: Optional[str] = None):
"""
:param operation: The operation being performed
Expand All @@ -206,11 +209,10 @@ def __init__(self, operation: Any, timeout: Optional[float],
self._timeout = timeout

@property
def operation(self) -> str:
def operation(self) -> T:
"""
The operation that was performed.
:rtype: str
"""
return self._operation

Expand All @@ -230,7 +232,7 @@ class SpinnmanUnexpectedResponseCodeException(SpinnmanException):
for the current operation.
"""

def __init__(self, operation: str, command,
def __init__(self, operation: str, command: str,
response: Union[str, SCPResult]):
"""
:param str operation: The operation being performed
Expand All @@ -254,7 +256,7 @@ def operation(self) -> str:
return self._operation

@property
def command(self):
def command(self) -> str:
"""
The command being executed.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ def get_min_packet_length() -> int:
"""
return 2

def __str__(self):
def __str__(self) -> str:
return f"EIEIOCommandMessage:{self._eieio_command_header}"

def __repr__(self):
def __repr__(self) -> str:
return self.__str__()
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ class EventStopRequest(EIEIOCommandMessage):
"""
__slots__ = ()

def __init__(self):
def __init__(self) -> None:
super().__init__(EIEIOCommandHeader(EIEIO_COMMAND_IDS.EVENT_STOP))
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

from typing import Optional

from spinnman.constants import EIEIO_COMMAND_IDS
from .eieio_command_message import EIEIOCommandMessage
from .eieio_command_header import EIEIOCommandHeader
Expand All @@ -29,9 +30,9 @@ class NotificationProtocolDatabaseLocation(EIEIOCommandMessage):
"""
__slots__ = "_database_path",

def __init__(self, database_path=None):
def __init__(self, database_path: Optional[str] = None):
"""
:param str database_path:
:param database_path:
The location of the database. If ``None``, this is an
acknowledgement, stating that the database has now been read.
"""
Expand All @@ -55,15 +56,19 @@ def database_path(self) -> Optional[str]:
return None

@property
def bytestring(self):
def bytestring(self) -> bytes:
data = super().bytestring
if self._database_path is not None:
data += self._database_path
return data

@staticmethod
def from_bytestring(command_header, data, offset):
def from_bytestring(
command_header: EIEIOCommandHeader, data: bytes,
offset: int) -> 'NotificationProtocolDatabaseLocation':
database_path = None
if len(data) - offset > 0:
database_path = data[offset:]
raise Exception(
"https://github.com/SpiNNakerManchester/SpiNNMan/issues/424")
# database_path = data[offset:]
return NotificationProtocolDatabaseLocation(database_path)
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from spinn_utilities.overrides import overrides
from spinnman.constants import EIEIO_COMMAND_IDS
from .eieio_command_message import EIEIOCommandMessage
from .eieio_command_header import EIEIOCommandHeader
Expand All @@ -26,10 +27,12 @@ class NotificationProtocolPauseStop(EIEIOCommandMessage):
"""
__slots__ = ()

def __init__(self):
def __init__(self) -> None:
super().__init__(EIEIOCommandHeader(
EIEIO_COMMAND_IDS.STOP_PAUSE_NOTIFICATION))

@staticmethod
def from_bytestring(command_header, data, offset):
@overrides(EIEIOCommandMessage.from_bytestring)
def from_bytestring(command_header: EIEIOCommandHeader, data: bytes,
offset: int) -> "NotificationProtocolPauseStop":
return NotificationProtocolPauseStop()
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from spinn_utilities.overrides import overrides
from spinnman.constants import EIEIO_COMMAND_IDS
from .eieio_command_message import EIEIOCommandMessage
from .eieio_command_header import EIEIOCommandHeader
Expand All @@ -26,10 +27,12 @@ class NotificationProtocolStartResume(EIEIOCommandMessage):
"""
__slots__ = ()

def __init__(self):
def __init__(self) -> None:
super().__init__(EIEIOCommandHeader(
EIEIO_COMMAND_IDS.START_RESUME_NOTIFICATION))

@staticmethod
def from_bytestring(command_header, data, offset):
@overrides(EIEIOCommandMessage.from_bytestring)
def from_bytestring(command_header: EIEIOCommandHeader, data: bytes,
offset: int) -> "NotificationProtocolStartResume":
return NotificationProtocolStartResume()
4 changes: 2 additions & 2 deletions spinnman/messages/eieio/command_messages/padding_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ class PaddingRequest(EIEIOCommandMessage):
"""
__slots__ = ()

def __init__(self):
def __init__(self) -> None:
super().__init__(EIEIOCommandHeader(EIEIO_COMMAND_IDS.EVENT_PADDING))

@staticmethod
def get_min_packet_length():
def get_min_packet_length() -> int:
return 2
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def channel(self, request_id: int) -> int:
"""
return self._requests.channel(request_id)

def region_id(self, request_id) -> int:
def region_id(self, request_id: int) -> int:
"""
The region_id for this request_id.
Expand All @@ -159,7 +159,7 @@ def region_id(self, request_id) -> int:
"""
return self._requests.region_id(request_id)

def start_address(self, request_id) -> int:
def start_address(self, request_id: int) -> int:
"""
The start_address for this request_id.
Expand All @@ -169,7 +169,7 @@ def start_address(self, request_id) -> int:
"""
return self._requests.start_address(request_id)

def space_to_be_read(self, request_id) -> int:
def space_to_be_read(self, request_id: int) -> int:
"""
The space_to_be_read for this request_id.
Expand Down Expand Up @@ -356,7 +356,7 @@ def __init__(self, channel: Union[List[int], int],
else:
self._space_to_be_read = space_to_be_read

def channel(self, request_id) -> int:
def channel(self, request_id: int) -> int:
"""
Gets the channel for this request_id
Expand All @@ -373,7 +373,7 @@ def channel(self, request_id) -> int:
f"channel request needs to be comprised between 0 and "
f"{len(self._channel) - 1:d}; current value: {request_id:d}")

def region_id(self, request_id) -> int:
def region_id(self, request_id: int) -> int:
"""
Gets the region_id for this request_id
Expand Down
2 changes: 1 addition & 1 deletion spinnman/messages/eieio/command_messages/start_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ class StartRequests(EIEIOCommandMessage):
"""
__slots__ = ()

def __init__(self):
def __init__(self) -> None:
super().__init__(EIEIOCommandHeader(
EIEIO_COMMAND_IDS.START_SENDING_REQUESTS))
2 changes: 1 addition & 1 deletion spinnman/messages/eieio/command_messages/stop_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@ class StopRequests(EIEIOCommandMessage):
"""
__slots__ = ()

def __init__(self):
def __init__(self) -> None:
super().__init__(EIEIOCommandHeader(
EIEIO_COMMAND_IDS.STOP_SENDING_REQUESTS))
3 changes: 2 additions & 1 deletion spinnman/messages/eieio/create_eieio_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
from spinnman.constants import EIEIO_COMMAND_IDS


def read_eieio_command_message(data, offset):
def read_eieio_command_message(
data: bytes, offset: int) -> EIEIOCommandMessage:
"""
Reads the content of an EIEIO command message and returns an object
identifying the command which was contained in the packet, including
Expand Down
2 changes: 1 addition & 1 deletion spinnman/messages/eieio/create_eieio_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
EIEIODataMessage, EIEIODataHeader)


def read_eieio_data_message(data, offset):
def read_eieio_data_message(data: bytes, offset: int) -> EIEIODataMessage:
"""
Reads the content of an EIEIO data message and returns an object
identifying the data which was contained in the packet.
Expand Down
Loading

0 comments on commit 2d8d46c

Please sign in to comment.