-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a9b34b7
commit 5bae6d9
Showing
4 changed files
with
362 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
180 changes: 180 additions & 0 deletions
180
adaptive_scheduler/_server_support/multi_run_manager.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
import ipywidgets as ipw | ||
|
||
from adaptive_scheduler.widgets import _disable_widgets_output_scrollbar, info | ||
|
||
if TYPE_CHECKING: | ||
from adaptive_scheduler._server_support.run_manager import RunManager | ||
|
||
|
||
class MultiRunManager: | ||
"""A manager that can contain multiple RunManagers. | ||
Parameters | ||
---------- | ||
run_managers | ||
Initial list of RunManagers to include. | ||
Attributes | ||
---------- | ||
run_managers | ||
Dictionary of managed RunManagers, keyed by their names. | ||
""" | ||
|
||
def __init__(self, run_managers: list[RunManager] | None = None) -> None: | ||
self.run_managers: dict[str, RunManager] = {} | ||
self._widget: ipw.VBox | None = None | ||
self._tab_widget: ipw.Tab | None = None | ||
self._info_widgets: dict[str, ipw.Widget] = {} | ||
self._update_all_button: ipw.Button | None = None | ||
if run_managers: | ||
for rm in run_managers: | ||
self.add_run_manager(rm) | ||
|
||
def add_run_manager( | ||
self, | ||
run_manager: RunManager, | ||
*, | ||
start: bool = False, | ||
wait_for: str | None = None, | ||
) -> None: | ||
"""Add a new RunManager to the MultiRunManager. | ||
Parameters | ||
---------- | ||
run_manager | ||
The RunManager to add. | ||
start | ||
Whether to start the RunManager immediately after adding it. | ||
wait_for | ||
The name of another RunManager to wait for before starting this one. | ||
Only applicable if start is True. | ||
Raises | ||
------ | ||
ValueError | ||
If a RunManager with the same name already exists. | ||
KeyError | ||
If the specified wait_for RunManager does not exist. | ||
""" | ||
if run_manager.job_name in self.run_managers: | ||
msg = f"A RunManager with the name '{run_manager.job_name}' already exists." | ||
raise ValueError(msg) | ||
|
||
self.run_managers[run_manager.job_name] = run_manager | ||
self._info_widgets[run_manager.job_name] = info( | ||
run_manager, | ||
display_widget=False, | ||
disable_widgets_output_scrollbar=False, | ||
) | ||
|
||
if start: | ||
if wait_for: | ||
if wait_for not in self.run_managers: | ||
msg = f"No RunManager with the name '{wait_for}' exists." | ||
raise KeyError(msg) | ||
run_manager.start(wait_for=self.run_managers[wait_for]) | ||
else: | ||
run_manager.start() | ||
elif wait_for: | ||
msg = "`start` must be True if `wait_for` is used." | ||
raise ValueError(msg) | ||
|
||
if self._widget is not None: | ||
self._update_widget() | ||
|
||
def remove_run_manager(self, name: str) -> None: | ||
"""Remove a RunManager from the MultiRunManager. | ||
Parameters | ||
---------- | ||
name | ||
The name of the RunManager to remove. | ||
Raises | ||
------ | ||
KeyError | ||
If no RunManager with the given name exists. | ||
""" | ||
if name in self.run_managers: | ||
rm = self.run_managers.pop(name) | ||
rm.cancel() | ||
self._info_widgets.pop(name) | ||
if self._widget is not None: | ||
self._update_widget() | ||
else: | ||
msg = f"No RunManager with the name '{name}' exists." | ||
raise KeyError(msg) | ||
|
||
def start_all(self) -> None: | ||
"""Start all RunManagers.""" | ||
for run_manager in self.run_managers.values(): | ||
run_manager.start() | ||
|
||
def cancel_all(self) -> None: | ||
"""Cancel all RunManagers.""" | ||
for run_manager in self.run_managers.values(): | ||
run_manager.cancel() | ||
|
||
def _create_widget(self) -> ipw.VBox: | ||
"""Create the widget for displaying RunManager info and update button.""" | ||
self._tab_widget = ipw.Tab() | ||
children = list(self._info_widgets.values()) | ||
self._tab_widget.children = children | ||
for i, name in enumerate(self.run_managers.keys()): | ||
self._tab_widget.set_title(i, f"RunManager: {name}") | ||
|
||
self._update_all_button = ipw.Button( | ||
description="Update All", | ||
button_style="info", | ||
tooltip="Update all RunManagers", | ||
) | ||
self._update_all_button.on_click(self._update_all_callback) | ||
|
||
return ipw.VBox([self._update_all_button, self._tab_widget]) | ||
|
||
def _update_widget(self) -> None: | ||
"""Update the widget when RunManagers are added or removed.""" | ||
if self._tab_widget is not None: | ||
current_children = list(self._tab_widget.children) | ||
new_children = list(self._info_widgets.values()) | ||
|
||
# Create a new tuple of children | ||
updated_children = tuple(child for child in current_children if child in new_children) | ||
|
||
# Add new children | ||
for widget in new_children: | ||
if widget not in updated_children: | ||
updated_children += (widget,) | ||
|
||
# Update the widget's children | ||
self._tab_widget.children = updated_children | ||
|
||
# Update titles | ||
for i, name in enumerate(self.run_managers.keys()): | ||
self._tab_widget.set_title(i, f"RunManager: {name}") | ||
|
||
def _update_all_callback(self, _: ipw.Button) -> None: | ||
"""Callback function for the Update All button.""" | ||
assert self._tab_widget is not None | ||
for widget in self._tab_widget.children: | ||
update_button = widget.children[0].children[1].children[0] | ||
assert update_button.description == "update info" | ||
update_button.click() | ||
|
||
def info(self) -> ipw.VBox: | ||
"""Display info about all RunManagers in a widget with an Update All button.""" | ||
if self._widget is None: | ||
_disable_widgets_output_scrollbar() | ||
self._widget = self._create_widget() | ||
return self._widget | ||
|
||
def _repr_html_(self) -> str: | ||
"""HTML representation for Jupyter notebooks.""" | ||
return self.info() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,178 @@ | ||
"""Test the MultiRunManager class.""" | ||
|
||
from __future__ import annotations | ||
|
||
import asyncio | ||
from typing import TYPE_CHECKING | ||
from unittest.mock import patch | ||
|
||
import ipywidgets as ipw | ||
import pytest | ||
|
||
from adaptive_scheduler._server_support.multi_run_manager import MultiRunManager | ||
from adaptive_scheduler._server_support.run_manager import RunManager | ||
|
||
if TYPE_CHECKING: | ||
from pathlib import Path | ||
|
||
import adaptive | ||
|
||
from .helpers import MockScheduler | ||
|
||
|
||
@pytest.fixture() | ||
def mock_run_manager( | ||
mock_scheduler: MockScheduler, | ||
learners: list[adaptive.Learner1D] | ||
| list[adaptive.BalancingLearner] | ||
| list[adaptive.SequenceLearner], | ||
fnames: list[str] | list[Path], | ||
) -> RunManager: | ||
"""Create a mock RunManager for testing.""" | ||
return RunManager(mock_scheduler, learners[:1], fnames[:1], job_name="test-rm") | ||
|
||
|
||
def test_multi_run_manager_init() -> None: | ||
"""Test the initialization of MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
assert isinstance(mrm, MultiRunManager) | ||
assert mrm.run_managers == {} | ||
|
||
|
||
def test_multi_run_manager_add_run_manager(mock_run_manager: RunManager) -> None: | ||
"""Test adding a RunManager to MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
assert len(mrm.run_managers) == 1 | ||
assert "test-rm" in mrm.run_managers | ||
assert mrm.run_managers["test-rm"] == mock_run_manager | ||
|
||
|
||
def test_multi_run_manager_add_duplicate_run_manager(mock_run_manager: RunManager) -> None: | ||
"""Test adding a duplicate RunManager to MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
with pytest.raises(ValueError, match="A RunManager with the name 'test-rm' already exists."): | ||
mrm.add_run_manager(mock_run_manager) | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_multi_run_manager_add_run_manager_with_start(mock_run_manager: RunManager) -> None: | ||
"""Test adding a RunManager to MultiRunManager with start=True.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager, start=True) | ||
await asyncio.sleep(0.1) | ||
assert mock_run_manager.status() == "running" | ||
mock_run_manager.cancel() | ||
|
||
|
||
def test_multi_run_manager_add_run_manager_with_invalid_wait_for( | ||
mock_run_manager: RunManager, | ||
) -> None: | ||
"""Test adding a RunManager with an invalid wait_for parameter.""" | ||
mrm = MultiRunManager() | ||
with pytest.raises(KeyError, match="No RunManager with the name 'non-existent' exists."): | ||
mrm.add_run_manager(mock_run_manager, start=True, wait_for="non-existent") | ||
|
||
|
||
def test_multi_run_manager_add_run_manager_with_wait_for_without_start( | ||
mock_run_manager: RunManager, | ||
) -> None: | ||
"""Test adding a RunManager with wait_for but without start=True.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager( | ||
RunManager( | ||
mock_run_manager.scheduler, | ||
mock_run_manager.learners, | ||
mock_run_manager.fnames, | ||
job_name="rm1", | ||
), | ||
) | ||
with pytest.raises(ValueError, match="`start` must be True if `wait_for` is used."): | ||
mrm.add_run_manager(mock_run_manager, wait_for="rm1") | ||
|
||
|
||
def test_multi_run_manager_remove_run_manager(mock_run_manager: RunManager) -> None: | ||
"""Test removing a RunManager from MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
mrm.remove_run_manager("test-rm") | ||
assert len(mrm.run_managers) == 0 | ||
|
||
|
||
def test_multi_run_manager_remove_non_existent_run_manager() -> None: | ||
"""Test removing a non-existent RunManager from MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
with pytest.raises(KeyError, match="No RunManager with the name 'non-existent' exists."): | ||
mrm.remove_run_manager("non-existent") | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_multi_run_manager_start_all(mock_run_manager: RunManager) -> None: | ||
"""Test starting all RunManagers in MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
mrm.start_all() | ||
await asyncio.sleep(0.1) | ||
assert mock_run_manager.status() == "running" | ||
mock_run_manager.cancel() | ||
|
||
|
||
@pytest.mark.asyncio() | ||
async def test_multi_run_manager_cancel_all(mock_run_manager: RunManager) -> None: | ||
"""Test cancelling all RunManagers in MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
mrm.start_all() | ||
await asyncio.sleep(0.1) | ||
mrm.cancel_all() | ||
await asyncio.sleep(0.1) | ||
assert mock_run_manager.status() == "cancelled" | ||
|
||
|
||
def test_multi_run_manager_create_widget(mock_run_manager: RunManager) -> None: | ||
"""Test creating the widget for MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
vbox = mrm._create_widget() | ||
assert isinstance(vbox, ipw.VBox) | ||
tab = vbox.children[1] | ||
assert isinstance(tab, ipw.Tab) | ||
assert len(tab.children) == 1 | ||
assert tab.get_title(0) == "RunManager: test-rm" | ||
|
||
|
||
def test_multi_run_manager_update_widget(mock_run_manager: RunManager) -> None: | ||
"""Test updating the widget for MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
mrm._widget = mrm._create_widget() | ||
new_rm = RunManager( | ||
mock_run_manager.scheduler, | ||
mock_run_manager.learners, | ||
mock_run_manager.fnames, | ||
job_name="new-rm", | ||
) | ||
mrm.add_run_manager(new_rm) | ||
assert len(mrm._widget.children) == 2 | ||
assert mrm._widget.children[1].get_title(1) == "RunManager: new-rm" | ||
|
||
|
||
def test_multi_run_manager_info(mock_run_manager: RunManager) -> None: | ||
"""Test the info method of MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
vbox = mrm.info() | ||
assert isinstance(vbox, ipw.VBox) | ||
tab = vbox.children[1] | ||
assert isinstance(tab, ipw.Tab) | ||
assert len(tab.children) == 1 | ||
|
||
|
||
def test_multi_run_manager_repr_html(mock_run_manager: RunManager) -> None: | ||
"""Test the _repr_html_ method of MultiRunManager.""" | ||
mrm = MultiRunManager() | ||
mrm.add_run_manager(mock_run_manager) | ||
with patch("IPython.display.display") as mocked_display: | ||
mrm._repr_html_() | ||
assert mocked_display.called |