diff --git a/adaptive_scheduler/__init__.py b/adaptive_scheduler/__init__.py index 9ef15b40..a9eb4263 100644 --- a/adaptive_scheduler/__init__.py +++ b/adaptive_scheduler/__init__.py @@ -5,6 +5,7 @@ from adaptive_scheduler._version import __version__ from adaptive_scheduler.scheduler import PBS, SLURM from adaptive_scheduler.server_support import ( + MultiRunManager, RunManager, slurm_run, start_one_by_one, @@ -23,4 +24,5 @@ "utils", "SlurmExecutor", "SlurmTask", + "MultiRunManager", ] diff --git a/adaptive_scheduler/_server_support/multi_run_manager.py b/adaptive_scheduler/_server_support/multi_run_manager.py new file mode 100644 index 00000000..8e758953 --- /dev/null +++ b/adaptive_scheduler/_server_support/multi_run_manager.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import ipywidgets as ipw + +from adaptive_scheduler.widgets import _disable_widgets_output_scrollbar, info + +if TYPE_CHECKING: + from adaptive_scheduler._server_support.run_manager import RunManager + + +class MultiRunManager: + """A manager that can contain multiple RunManagers. + + Parameters + ---------- + run_managers + Initial list of RunManagers to include. + + Attributes + ---------- + run_managers + Dictionary of managed RunManagers, keyed by their names. + + """ + + def __init__(self, run_managers: list[RunManager] | None = None) -> None: + self.run_managers: dict[str, RunManager] = {} + self._widget: ipw.VBox | None = None + self._tab_widget: ipw.Tab | None = None + self._info_widgets: dict[str, ipw.Widget] = {} + self._update_all_button: ipw.Button | None = None + if run_managers: + for rm in run_managers: + self.add_run_manager(rm) + + def add_run_manager( + self, + run_manager: RunManager, + *, + start: bool = False, + wait_for: str | None = None, + ) -> None: + """Add a new RunManager to the MultiRunManager. + + Parameters + ---------- + run_manager + The RunManager to add. + start + Whether to start the RunManager immediately after adding it. + wait_for + The name of another RunManager to wait for before starting this one. + Only applicable if start is True. + + Raises + ------ + ValueError + If a RunManager with the same name already exists. + KeyError + If the specified wait_for RunManager does not exist. + + """ + if run_manager.job_name in self.run_managers: + msg = f"A RunManager with the name '{run_manager.job_name}' already exists." + raise ValueError(msg) + + self.run_managers[run_manager.job_name] = run_manager + self._info_widgets[run_manager.job_name] = info( + run_manager, + display_widget=False, + disable_widgets_output_scrollbar=False, + ) + + if start: + if wait_for: + if wait_for not in self.run_managers: + msg = f"No RunManager with the name '{wait_for}' exists." + raise KeyError(msg) + run_manager.start(wait_for=self.run_managers[wait_for]) + else: + run_manager.start() + elif wait_for: + msg = "`start` must be True if `wait_for` is used." + raise ValueError(msg) + + if self._widget is not None: + self._update_widget() + + def remove_run_manager(self, name: str) -> None: + """Remove a RunManager from the MultiRunManager. + + Parameters + ---------- + name + The name of the RunManager to remove. + + Raises + ------ + KeyError + If no RunManager with the given name exists. + + """ + if name in self.run_managers: + rm = self.run_managers.pop(name) + rm.cancel() + self._info_widgets.pop(name) + if self._widget is not None: + self._update_widget() + else: + msg = f"No RunManager with the name '{name}' exists." + raise KeyError(msg) + + def start_all(self) -> None: + """Start all RunManagers.""" + for run_manager in self.run_managers.values(): + run_manager.start() + + def cancel_all(self) -> None: + """Cancel all RunManagers.""" + for run_manager in self.run_managers.values(): + run_manager.cancel() + + def _create_widget(self) -> ipw.VBox: + """Create the widget for displaying RunManager info and update button.""" + self._tab_widget = ipw.Tab() + children = list(self._info_widgets.values()) + self._tab_widget.children = children + for i, name in enumerate(self.run_managers.keys()): + self._tab_widget.set_title(i, f"RunManager: {name}") + + self._update_all_button = ipw.Button( + description="Update All", + button_style="info", + tooltip="Update all RunManagers", + ) + self._update_all_button.on_click(self._update_all_callback) + + return ipw.VBox([self._update_all_button, self._tab_widget]) + + def _update_widget(self) -> None: + """Update the widget when RunManagers are added or removed.""" + if self._tab_widget is not None: + current_children = list(self._tab_widget.children) + new_children = list(self._info_widgets.values()) + + # Create a new tuple of children + updated_children = tuple(child for child in current_children if child in new_children) + + # Add new children + for widget in new_children: + if widget not in updated_children: + updated_children += (widget,) + + # Update the widget's children + self._tab_widget.children = updated_children + + # Update titles + for i, name in enumerate(self.run_managers.keys()): + self._tab_widget.set_title(i, f"RunManager: {name}") + + def _update_all_callback(self, _: ipw.Button) -> None: + """Callback function for the Update All button.""" + assert self._tab_widget is not None + for widget in self._tab_widget.children: + update_button = widget.children[0].children[1].children[0] + assert update_button.description == "update info" + update_button.click() + + def info(self) -> ipw.VBox: + """Display info about all RunManagers in a widget with an Update All button.""" + if self._widget is None: + _disable_widgets_output_scrollbar() + self._widget = self._create_widget() + return self._widget + + def _repr_html_(self) -> str: + """HTML representation for Jupyter notebooks.""" + return self.info() diff --git a/adaptive_scheduler/server_support.py b/adaptive_scheduler/server_support.py index 6561cd02..d55dfd64 100644 --- a/adaptive_scheduler/server_support.py +++ b/adaptive_scheduler/server_support.py @@ -15,6 +15,7 @@ from ._server_support.database_manager import DatabaseManager from ._server_support.job_manager import JobManager, MaxRestartsReachedError from ._server_support.kill_manager import KillManager, logs_with_string_or_condition +from ._server_support.multi_run_manager import MultiRunManager from ._server_support.parse_logs import _get_infos, parse_log_files from ._server_support.run_manager import ( RunManager, @@ -33,6 +34,7 @@ "start_one_by_one", "logs_with_string_or_condition", "RunManager", + "MultiRunManager", "slurm_run", "_get_infos", "parse_log_files", diff --git a/tests/test_multi_run_manager.py b/tests/test_multi_run_manager.py new file mode 100644 index 00000000..e580421b --- /dev/null +++ b/tests/test_multi_run_manager.py @@ -0,0 +1,178 @@ +"""Test the MultiRunManager class.""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING +from unittest.mock import patch + +import ipywidgets as ipw +import pytest + +from adaptive_scheduler._server_support.multi_run_manager import MultiRunManager +from adaptive_scheduler._server_support.run_manager import RunManager + +if TYPE_CHECKING: + from pathlib import Path + + import adaptive + + from .helpers import MockScheduler + + +@pytest.fixture() +def mock_run_manager( + mock_scheduler: MockScheduler, + learners: list[adaptive.Learner1D] + | list[adaptive.BalancingLearner] + | list[adaptive.SequenceLearner], + fnames: list[str] | list[Path], +) -> RunManager: + """Create a mock RunManager for testing.""" + return RunManager(mock_scheduler, learners[:1], fnames[:1], job_name="test-rm") + + +def test_multi_run_manager_init() -> None: + """Test the initialization of MultiRunManager.""" + mrm = MultiRunManager() + assert isinstance(mrm, MultiRunManager) + assert mrm.run_managers == {} + + +def test_multi_run_manager_add_run_manager(mock_run_manager: RunManager) -> None: + """Test adding a RunManager to MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + assert len(mrm.run_managers) == 1 + assert "test-rm" in mrm.run_managers + assert mrm.run_managers["test-rm"] == mock_run_manager + + +def test_multi_run_manager_add_duplicate_run_manager(mock_run_manager: RunManager) -> None: + """Test adding a duplicate RunManager to MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + with pytest.raises(ValueError, match="A RunManager with the name 'test-rm' already exists."): + mrm.add_run_manager(mock_run_manager) + + +@pytest.mark.asyncio() +async def test_multi_run_manager_add_run_manager_with_start(mock_run_manager: RunManager) -> None: + """Test adding a RunManager to MultiRunManager with start=True.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager, start=True) + await asyncio.sleep(0.1) + assert mock_run_manager.status() == "running" + mock_run_manager.cancel() + + +def test_multi_run_manager_add_run_manager_with_invalid_wait_for( + mock_run_manager: RunManager, +) -> None: + """Test adding a RunManager with an invalid wait_for parameter.""" + mrm = MultiRunManager() + with pytest.raises(KeyError, match="No RunManager with the name 'non-existent' exists."): + mrm.add_run_manager(mock_run_manager, start=True, wait_for="non-existent") + + +def test_multi_run_manager_add_run_manager_with_wait_for_without_start( + mock_run_manager: RunManager, +) -> None: + """Test adding a RunManager with wait_for but without start=True.""" + mrm = MultiRunManager() + mrm.add_run_manager( + RunManager( + mock_run_manager.scheduler, + mock_run_manager.learners, + mock_run_manager.fnames, + job_name="rm1", + ), + ) + with pytest.raises(ValueError, match="`start` must be True if `wait_for` is used."): + mrm.add_run_manager(mock_run_manager, wait_for="rm1") + + +def test_multi_run_manager_remove_run_manager(mock_run_manager: RunManager) -> None: + """Test removing a RunManager from MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + mrm.remove_run_manager("test-rm") + assert len(mrm.run_managers) == 0 + + +def test_multi_run_manager_remove_non_existent_run_manager() -> None: + """Test removing a non-existent RunManager from MultiRunManager.""" + mrm = MultiRunManager() + with pytest.raises(KeyError, match="No RunManager with the name 'non-existent' exists."): + mrm.remove_run_manager("non-existent") + + +@pytest.mark.asyncio() +async def test_multi_run_manager_start_all(mock_run_manager: RunManager) -> None: + """Test starting all RunManagers in MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + mrm.start_all() + await asyncio.sleep(0.1) + assert mock_run_manager.status() == "running" + mock_run_manager.cancel() + + +@pytest.mark.asyncio() +async def test_multi_run_manager_cancel_all(mock_run_manager: RunManager) -> None: + """Test cancelling all RunManagers in MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + mrm.start_all() + await asyncio.sleep(0.1) + mrm.cancel_all() + await asyncio.sleep(0.1) + assert mock_run_manager.status() == "cancelled" + + +def test_multi_run_manager_create_widget(mock_run_manager: RunManager) -> None: + """Test creating the widget for MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + vbox = mrm._create_widget() + assert isinstance(vbox, ipw.VBox) + tab = vbox.children[1] + assert isinstance(tab, ipw.Tab) + assert len(tab.children) == 1 + assert tab.get_title(0) == "RunManager: test-rm" + + +def test_multi_run_manager_update_widget(mock_run_manager: RunManager) -> None: + """Test updating the widget for MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + mrm._widget = mrm._create_widget() + new_rm = RunManager( + mock_run_manager.scheduler, + mock_run_manager.learners, + mock_run_manager.fnames, + job_name="new-rm", + ) + mrm.add_run_manager(new_rm) + assert len(mrm._widget.children) == 2 + assert mrm._widget.children[1].get_title(1) == "RunManager: new-rm" + + +def test_multi_run_manager_info(mock_run_manager: RunManager) -> None: + """Test the info method of MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + vbox = mrm.info() + assert isinstance(vbox, ipw.VBox) + tab = vbox.children[1] + assert isinstance(tab, ipw.Tab) + assert len(tab.children) == 1 + + +def test_multi_run_manager_repr_html(mock_run_manager: RunManager) -> None: + """Test the _repr_html_ method of MultiRunManager.""" + mrm = MultiRunManager() + mrm.add_run_manager(mock_run_manager) + with patch("IPython.display.display") as mocked_display: + mrm._repr_html_() + assert mocked_display.called