Skip to content

Commit

Permalink
Merge pull request #60 from SpiNNakerManchester/overrides_check
Browse files Browse the repository at this point in the history
Overrides check
  • Loading branch information
Christian-B authored Jan 15, 2024
2 parents c897315 + d2646b3 commit 3231de8
Show file tree
Hide file tree
Showing 13 changed files with 108 additions and 56 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/python_actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ jobs:
path: support
- name: Install pip, etc
uses: ./support/actions/python-tools
- name: Install mypy
run: pip install mypy

- name: Install Spinnaker Dependencies
uses: ./support/actions/install-spinn-deps
Expand Down Expand Up @@ -78,6 +80,9 @@ jobs:
if: matrix.python-version == 3.12
uses: ./support/actions/check-copyrights

- name: Lint with mypy
run: mypy mcmc

# # Add the following as required in the future
# - name: Validate XML
# if: matrix.python-version == 3.8
Expand Down
22 changes: 13 additions & 9 deletions mcmc/mcmc_cholesky_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,22 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from enum import Enum

from spinn_utilities.overrides import overrides

from spinnman.model.enums import ExecutableType

from pacman.model.graphs.machine import MachineVertex
from pacman.model.placements import Placement
from pacman.model.resources import ConstantSDRAM
from spinn_utilities.overrides import overrides

from spinn_front_end_common.abstract_models.abstract_has_associated_binary \
import AbstractHasAssociatedBinary
from spinn_front_end_common.abstract_models\
.abstract_generates_data_specification \
import AbstractGeneratesDataSpecification
from spinn_front_end_common.utilities.utility_objs.executable_type \
import ExecutableType

from enum import Enum
from spinn_front_end_common.interface.ds import DataSpecificationGenerator


class MCMCCholeskyRegions(Enum):
Expand Down Expand Up @@ -61,20 +64,21 @@ def __init__(

@property
@overrides(MachineVertex.sdram_required)
def sdram_required(self):
def sdram_required(self) -> ConstantSDRAM:
return ConstantSDRAM(self._n_parameter_bytes)

@overrides(AbstractHasAssociatedBinary.get_binary_file_name)
def get_binary_file_name(self):
def get_binary_file_name(self) -> str:
return "mcmc_cholesky.aplx"

@overrides(AbstractHasAssociatedBinary.get_binary_start_type)
def get_binary_start_type(self):
def get_binary_start_type(self) -> ExecutableType:
return ExecutableType.SYNC

@overrides(
AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(self, spec, placement):
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):

# Reserve and write the parameters region
spec.reserve_memory_region(
Expand Down
34 changes: 21 additions & 13 deletions mcmc/mcmc_coordinator_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,24 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from typing import List

from pacman.model.graphs.machine import MachineVertex
from pacman.model.resources import ConstantSDRAM
from spinn_utilities.overrides import overrides

from spinnman.model.enums import ExecutableType

from pacman.model.placements import Placement

from spinn_front_end_common.abstract_models.abstract_has_associated_binary \
import AbstractHasAssociatedBinary
from spinn_front_end_common.abstract_models\
.abstract_generates_data_specification \
import AbstractGeneratesDataSpecification
from spinn_front_end_common.data import FecDataView
from spinn_front_end_common.interface.ds import DataType
from spinn_front_end_common.utilities.utility_objs.executable_type \
import ExecutableType
from spinn_front_end_common.interface.ds import (
DataSpecificationGenerator, DataType)
from spinn_utilities.progress_bar import ProgressBar

import numpy
Expand Down Expand Up @@ -187,22 +192,23 @@ def acknowledge_partition_name(self):

@property
@overrides(MachineVertex.sdram_required)
def sdram_required(self):
def sdram_required(self) -> ConstantSDRAM:
sdram = self._N_PARAMETER_BYTES + self._data_size
sdram += len(self._mcmc_vertices) * self._KEY_ELEMENT_TYPE.size
return ConstantSDRAM(sdram)

@overrides(AbstractHasAssociatedBinary.get_binary_file_name)
def get_binary_file_name(self):
def get_binary_file_name(self) -> str:
return "mcmc_coordinator.aplx"

@overrides(AbstractHasAssociatedBinary.get_binary_start_type)
def get_binary_start_type(self):
def get_binary_start_type(self) -> ExecutableType:
return ExecutableType.SYNC

@overrides(
AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(self, spec, placement):
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):
routing_info = FecDataView.get_routing_infos()

# Reserve and write the parameters region
Expand All @@ -213,14 +219,15 @@ def generate_data_specification(self, spec, placement):

# Get the placement of the vertices and find out how many chips
# are needed
keys = list()
keys: List[int] = list()
for vertex in self._mcmc_vertices:
mcmc_placement = FecDataView.get_placement_of_vertex(vertex)
self._mcmc_placements.append(mcmc_placement)
if self._is_receiver_placement(mcmc_placement):
key = routing_info.get_first_key_from_pre_vertex(
vertex, self._acknowledge_partition_name)
keys.append(key)
if key is not None:
keys.append(key)
keys.sort()

# Write the data size in words
Expand All @@ -232,16 +239,17 @@ def generate_data_specification(self, spec, placement):
spec.write_value(len(keys), data_type=DataType.UINT32)

# Write the key
routing_info = routing_info.get_routing_info_from_pre_vertex(
vertex_routing_info = routing_info.get_routing_info_from_pre_vertex(
self, self._data_partition_name)
spec.write_value(routing_info.key, data_type=DataType.UINT32)
assert vertex_routing_info is not None
spec.write_value(vertex_routing_info.key, data_type=DataType.UINT32)

# Write the window size
spec.write_value(self._window_size, data_type=DataType.UINT32)

# Write the sequence mask
spec.write_value(
~routing_info.mask & 0xFFFFFFFF, data_type=DataType.UINT32)
~vertex_routing_info.mask & 0xFFFFFFFF, data_type=DataType.UINT32)

# Write the timer
spec.write_value(self._send_timer, data_type=DataType.UINT32)
Expand All @@ -265,7 +273,7 @@ def generate_data_specification(self, spec, placement):
spec.end_specification()

@overrides(MachineVertex.get_n_keys_for_partition)
def get_n_keys_for_partition(self, partition_id):
def get_n_keys_for_partition(self, partition_id: str) -> int:
return self._n_sequences

def read_samples(self, buffer_manager):
Expand Down
12 changes: 9 additions & 3 deletions mcmc/mcmc_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,35 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from typing import List
from spinn_utilities.abstract_base import AbstractBase
from spinn_utilities.abstract_base import abstractmethod
from mcmc.mcmc_parameter import MCMCParameter
from mcmc.mcmc_state_variable import MCMCStateVariable


class MCMCModel(object, metaclass=AbstractBase):

@abstractmethod
def get_parameters(self):
def get_parameters(self) -> List[MCMCParameter]:
""" Get the parameters of the model
:rtype: list of :py:class:`mcmc.mcmc_parameter.MCMCParameter`
"""
raise NotImplementedError

@abstractmethod
def get_state_variables(self):
def get_state_variables(self) -> List[MCMCStateVariable]:
""" Get the state variables of the model
:rtype: list of :py:class:`mcmc.mcmc_state_variable.MCMCStateVariable`
"""
raise NotImplementedError

@abstractmethod
def get_binary_name(self):
def get_binary_name(self) -> str:
""" Get the name of the binary compiled with this model
:rtype: str
"""
raise NotImplementedError
16 changes: 10 additions & 6 deletions mcmc/mcmc_root_finder_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,16 @@
from pacman.model.resources import ConstantSDRAM
from spinn_utilities.overrides import overrides

from pacman.model.placements import Placement

from spinnman.model.enums import ExecutableType

from spinn_front_end_common.abstract_models.abstract_has_associated_binary \
import AbstractHasAssociatedBinary
from spinn_front_end_common.abstract_models\
.abstract_generates_data_specification \
import AbstractGeneratesDataSpecification
from spinn_front_end_common.utilities.utility_objs.executable_type \
import ExecutableType
from spinn_front_end_common.interface.ds import DataSpecificationGenerator

from enum import Enum

Expand Down Expand Up @@ -60,20 +63,21 @@ def __init__(

@property
@overrides(MachineVertex.sdram_required)
def sdram_required(self):
def sdram_required(self) -> ConstantSDRAM:
return ConstantSDRAM(self._n_parameter_bytes)

@overrides(AbstractHasAssociatedBinary.get_binary_file_name)
def get_binary_file_name(self):
def get_binary_file_name(self) -> str:
return "mcmc_root_finder.aplx"

@overrides(AbstractHasAssociatedBinary.get_binary_start_type)
def get_binary_start_type(self):
def get_binary_start_type(self) -> ExecutableType:
return ExecutableType.SYNC

@overrides(
AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(self, spec, placement):
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):

# Reserve and write the parameters region
spec.reserve_memory_region(
Expand Down
20 changes: 13 additions & 7 deletions mcmc/mcmc_vertex.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from pacman.model.resources import ConstantSDRAM
from spinn_utilities.overrides import overrides

from spinnman.model.enums import ExecutableType

from pacman.model.placements import Placement

from spinn_front_end_common.abstract_models.abstract_has_associated_binary \
import AbstractHasAssociatedBinary
from spinn_front_end_common.abstract_models\
Expand All @@ -25,12 +29,11 @@
from spinn_front_end_common.data import FecDataView
from spinn_front_end_common.interface.buffer_management.buffer_models\
.abstract_receive_buffers_to_host import AbstractReceiveBuffersToHost
from spinn_front_end_common.interface.ds import DataType
from spinn_front_end_common.interface.ds import (
DataSpecificationGenerator, DataType)
from spinn_front_end_common.utilities import helpful_functions
from spinn_front_end_common.interface.buffer_management \
import recording_utilities
from spinn_front_end_common.utilities.utility_objs.executable_type \
import ExecutableType

from enum import Enum
import numpy
Expand Down Expand Up @@ -195,20 +198,21 @@ def cholesky_result_partition_name(self):

@property
@overrides(MachineVertex.sdram_required)
def sdram_required(self):
def sdram_required(self) -> ConstantSDRAM:
return ConstantSDRAM(self._sdram_usage)

@overrides(AbstractHasAssociatedBinary.get_binary_file_name)
def get_binary_file_name(self):
def get_binary_file_name(self) -> str:
return self._model.get_binary_name()

@overrides(AbstractHasAssociatedBinary.get_binary_start_type)
def get_binary_start_type(self):
def get_binary_start_type(self) -> ExecutableType:
return ExecutableType.SYNC

@overrides(
AbstractGeneratesDataSpecification.generate_data_specification)
def generate_data_specification(self, spec, placement):
def generate_data_specification(
self, spec: DataSpecificationGenerator, placement: Placement):

routing_info = FecDataView.get_routing_infos()
# Reserve and write the recording regions
Expand Down Expand Up @@ -258,13 +262,15 @@ def generate_data_specification(self, spec, placement):
if (self._model.root_finder):
routing_info_rf = routing_info.get_routing_info_from_pre_vertex(
self, self._parameter_partition_name)
assert routing_info_rf is not None
spec.write_value(routing_info_rf.key, data_type=DataType.UINT32)
else:
spec.write_value(0, data_type=DataType.UINT32)

if (self._model.cholesky):
routing_info_ch = routing_info.get_routing_info_from_pre_vertex(
self, self._cholesky_partition_name)
assert routing_info_ch is not None
spec.write_value(routing_info_ch.key, data_type=DataType.UINT32)
else:
spec.write_value(0, data_type=DataType.UINT32)
Expand Down
13 changes: 13 additions & 0 deletions mcmc/py.typed
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023 The University of Manchester
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
7 changes: 4 additions & 3 deletions mcmc_examples/arma/arma_fixed_point_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

from typing import List
from spinn_utilities.overrides import overrides

from spinn_front_end_common.interface.ds import DataType
Expand Down Expand Up @@ -41,19 +42,19 @@ def __init__(
self._q_jump_scale = q_jump_scale

@overrides(MCMCModel.get_binary_name)
def get_binary_name(self):
def get_binary_name(self) -> str:
return "arma.aplx"

@overrides(MCMCModel.get_parameters)
def get_parameters(self):
def get_parameters(self) -> List[MCMCParameter]:
return [
MCMCParameter(self._parameters, DataType.S1615), # array ?
MCMCParameter(self._p_jump_scale, DataType.S1615),
MCMCParameter(self._q_jump_scale, DataType.S1615)
]

@overrides(MCMCModel.get_state_variables)
def get_state_variables(self):
def get_state_variables(self) -> List[MCMCStateVariable]:
return [
MCMCStateVariable("order_p", 10, DataType.S1615), # check type
MCMCStateVariable("order_q", 10, DataType.S1615)
Expand Down
Loading

0 comments on commit 3231de8

Please sign in to comment.