diff --git a/spinn_machine/virtual_machine.py b/spinn_machine/virtual_machine.py index 12c34d9c..49d15920 100644 --- a/spinn_machine/virtual_machine.py +++ b/spinn_machine/virtual_machine.py @@ -29,20 +29,17 @@ def virtual_machine( - width: int, height: int, n_cpus_per_chip: Optional[int] = None, - validate: bool = True): + width: int, height: int, validate: bool = True): """ Create a virtual SpiNNaker machine, used for planning execution. :param int width: the width of the virtual machine in chips :param int height: the height of the virtual machine in chips - :param int n_cpus_per_chip: The number of CPUs to put on each chip :param bool validate: if True will call the machine validate function :returns: a virtual machine (that cannot execute code) :rtype: ~spinn_machine.Machine """ - - factory = _VirtualMachine(width, height, n_cpus_per_chip, validate) + factory = _VirtualMachine(width, height, validate) return factory.machine @@ -67,9 +64,10 @@ class _VirtualMachine(object): ORIGIN = "Virtual" def __init__( - self, width: int, height: int, - n_cpus_per_chip: Optional[int] = None, validate: bool = True): + self, width: int, height: int, validate: bool = True): version = MachineDataView.get_machine_version() + version.verify_size(height, width) + max_cores = version.max_cores_per_chip self._n_router_entries = version.n_router_entries self._machine = version.create_machine( width, height, origin=self.ORIGIN) @@ -104,17 +102,11 @@ def __init__( # If there are no wrap arounds, and the the size is not 2 * 2, # the possible chips depend on the 48 chip board's gaps configured_chips: Dict[XY, Tuple[XY, int]] = dict() - if n_cpus_per_chip is None: - for eth in ethernet_chips: - for (xy, n_cores) in self._machine.get_xy_cores_by_ethernet( - *eth): - if xy not in unused_chips: - configured_chips[xy] = (eth, n_cores) - else: - for eth in ethernet_chips: - for xy in self._machine.get_xys_by_ethernet(*eth): - if xy not in unused_chips: - configured_chips[xy] = (eth, n_cpus_per_chip) + for eth in ethernet_chips: + for (xy, n_cores) in self._machine.get_xy_cores_by_ethernet( + *eth): + if xy not in unused_chips: + configured_chips[xy] = (eth, min(n_cores, max_cores)) # for chip in self._unreachable_outgoing_chips: # configured_chips.remove(chip) diff --git a/unittests/test_virtual_machine.py b/unittests/test_virtual_machine.py index 353fd256..936b3324 100644 --- a/unittests/test_virtual_machine.py +++ b/unittests/test_virtual_machine.py @@ -126,6 +126,8 @@ def test_version_5_8_by_8(self): self.assertFalse((0, 4) in list(vm.chip_coordinates)) count = sum(1 for _chip in vm.chips for _link in _chip.router.links) self.assertEqual(240, count) + count = sum(_chip.n_processors for _chip in vm.chips) + self.assertEqual(count, 856) def test_version_5_12_by_12(self): set_config("Machine", "version", 5) @@ -167,16 +169,19 @@ def test_version_5_hole2(self): self.assertEqual(48, len(list(vm.local_xys))) self.assertEqual((0, 4), vm.get_unused_xy()) - def test_new_vm_with_monitor(self): + def test_new_vm_with_max_cores(self): set_config("Machine", "version", 2) n_cpus = 13 - vm = virtual_machine(2, 2, n_cpus_per_chip=n_cpus, validate=True) + set_config("Machine", "max_machine_core", n_cpus) + vm = virtual_machine(2, 2, validate=True) _chip = vm[1, 1] self.assertEqual(n_cpus, _chip.n_processors) self.assertEqual(n_cpus - 1, _chip.n_user_processors) self.assertEqual(1, _chip.n_monitor_processors) self.assertEqual(n_cpus - 1, len(list(_chip.user_processors))) self.assertEqual(1, len(list(_chip.monitor_processors))) + count = sum(_chip.n_processors for _chip in vm.chips) + self.assertEqual(count, 4 * n_cpus) def test_iter_chips(self): set_config("Machine", "version", 2)