Skip to content

Commit

Permalink
Add initial MultiRunManager (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt authored Oct 30, 2024
1 parent a9b34b7 commit 5bae6d9
Show file tree
Hide file tree
Showing 4 changed files with 362 additions and 0 deletions.
2 changes: 2 additions & 0 deletions adaptive_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from adaptive_scheduler._version import __version__
from adaptive_scheduler.scheduler import PBS, SLURM
from adaptive_scheduler.server_support import (
MultiRunManager,
RunManager,
slurm_run,
start_one_by_one,
Expand All @@ -23,4 +24,5 @@
"utils",
"SlurmExecutor",
"SlurmTask",
"MultiRunManager",
]
180 changes: 180 additions & 0 deletions adaptive_scheduler/_server_support/multi_run_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import ipywidgets as ipw

from adaptive_scheduler.widgets import _disable_widgets_output_scrollbar, info

if TYPE_CHECKING:
from adaptive_scheduler._server_support.run_manager import RunManager


class MultiRunManager:
"""A manager that can contain multiple RunManagers.
Parameters
----------
run_managers
Initial list of RunManagers to include.
Attributes
----------
run_managers
Dictionary of managed RunManagers, keyed by their names.
"""

def __init__(self, run_managers: list[RunManager] | None = None) -> None:
self.run_managers: dict[str, RunManager] = {}
self._widget: ipw.VBox | None = None
self._tab_widget: ipw.Tab | None = None
self._info_widgets: dict[str, ipw.Widget] = {}
self._update_all_button: ipw.Button | None = None
if run_managers:
for rm in run_managers:
self.add_run_manager(rm)

def add_run_manager(
self,
run_manager: RunManager,
*,
start: bool = False,
wait_for: str | None = None,
) -> None:
"""Add a new RunManager to the MultiRunManager.
Parameters
----------
run_manager
The RunManager to add.
start
Whether to start the RunManager immediately after adding it.
wait_for
The name of another RunManager to wait for before starting this one.
Only applicable if start is True.
Raises
------
ValueError
If a RunManager with the same name already exists.
KeyError
If the specified wait_for RunManager does not exist.
"""
if run_manager.job_name in self.run_managers:
msg = f"A RunManager with the name '{run_manager.job_name}' already exists."
raise ValueError(msg)

self.run_managers[run_manager.job_name] = run_manager
self._info_widgets[run_manager.job_name] = info(
run_manager,
display_widget=False,
disable_widgets_output_scrollbar=False,
)

if start:
if wait_for:
if wait_for not in self.run_managers:
msg = f"No RunManager with the name '{wait_for}' exists."
raise KeyError(msg)
run_manager.start(wait_for=self.run_managers[wait_for])
else:
run_manager.start()
elif wait_for:
msg = "`start` must be True if `wait_for` is used."
raise ValueError(msg)

if self._widget is not None:
self._update_widget()

def remove_run_manager(self, name: str) -> None:
"""Remove a RunManager from the MultiRunManager.
Parameters
----------
name
The name of the RunManager to remove.
Raises
------
KeyError
If no RunManager with the given name exists.
"""
if name in self.run_managers:
rm = self.run_managers.pop(name)
rm.cancel()
self._info_widgets.pop(name)
if self._widget is not None:
self._update_widget()
else:
msg = f"No RunManager with the name '{name}' exists."
raise KeyError(msg)

def start_all(self) -> None:
"""Start all RunManagers."""
for run_manager in self.run_managers.values():
run_manager.start()

def cancel_all(self) -> None:
"""Cancel all RunManagers."""
for run_manager in self.run_managers.values():
run_manager.cancel()

def _create_widget(self) -> ipw.VBox:
"""Create the widget for displaying RunManager info and update button."""
self._tab_widget = ipw.Tab()
children = list(self._info_widgets.values())
self._tab_widget.children = children
for i, name in enumerate(self.run_managers.keys()):
self._tab_widget.set_title(i, f"RunManager: {name}")

self._update_all_button = ipw.Button(
description="Update All",
button_style="info",
tooltip="Update all RunManagers",
)
self._update_all_button.on_click(self._update_all_callback)

return ipw.VBox([self._update_all_button, self._tab_widget])

def _update_widget(self) -> None:
"""Update the widget when RunManagers are added or removed."""
if self._tab_widget is not None:
current_children = list(self._tab_widget.children)
new_children = list(self._info_widgets.values())

# Create a new tuple of children
updated_children = tuple(child for child in current_children if child in new_children)

# Add new children
for widget in new_children:
if widget not in updated_children:
updated_children += (widget,)

# Update the widget's children
self._tab_widget.children = updated_children

# Update titles
for i, name in enumerate(self.run_managers.keys()):
self._tab_widget.set_title(i, f"RunManager: {name}")

def _update_all_callback(self, _: ipw.Button) -> None:
"""Callback function for the Update All button."""
assert self._tab_widget is not None
for widget in self._tab_widget.children:
update_button = widget.children[0].children[1].children[0]
assert update_button.description == "update info"
update_button.click()

def info(self) -> ipw.VBox:
"""Display info about all RunManagers in a widget with an Update All button."""
if self._widget is None:
_disable_widgets_output_scrollbar()
self._widget = self._create_widget()
return self._widget

def _repr_html_(self) -> str:
"""HTML representation for Jupyter notebooks."""
return self.info()
2 changes: 2 additions & 0 deletions adaptive_scheduler/server_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ._server_support.database_manager import DatabaseManager
from ._server_support.job_manager import JobManager, MaxRestartsReachedError
from ._server_support.kill_manager import KillManager, logs_with_string_or_condition
from ._server_support.multi_run_manager import MultiRunManager
from ._server_support.parse_logs import _get_infos, parse_log_files
from ._server_support.run_manager import (
RunManager,
Expand All @@ -33,6 +34,7 @@
"start_one_by_one",
"logs_with_string_or_condition",
"RunManager",
"MultiRunManager",
"slurm_run",
"_get_infos",
"parse_log_files",
Expand Down
178 changes: 178 additions & 0 deletions tests/test_multi_run_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
"""Test the MultiRunManager class."""

from __future__ import annotations

import asyncio
from typing import TYPE_CHECKING
from unittest.mock import patch

import ipywidgets as ipw
import pytest

from adaptive_scheduler._server_support.multi_run_manager import MultiRunManager
from adaptive_scheduler._server_support.run_manager import RunManager

if TYPE_CHECKING:
from pathlib import Path

import adaptive

from .helpers import MockScheduler


@pytest.fixture()
def mock_run_manager(
mock_scheduler: MockScheduler,
learners: list[adaptive.Learner1D]
| list[adaptive.BalancingLearner]
| list[adaptive.SequenceLearner],
fnames: list[str] | list[Path],
) -> RunManager:
"""Create a mock RunManager for testing."""
return RunManager(mock_scheduler, learners[:1], fnames[:1], job_name="test-rm")


def test_multi_run_manager_init() -> None:
"""Test the initialization of MultiRunManager."""
mrm = MultiRunManager()
assert isinstance(mrm, MultiRunManager)
assert mrm.run_managers == {}


def test_multi_run_manager_add_run_manager(mock_run_manager: RunManager) -> None:
"""Test adding a RunManager to MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
assert len(mrm.run_managers) == 1
assert "test-rm" in mrm.run_managers
assert mrm.run_managers["test-rm"] == mock_run_manager


def test_multi_run_manager_add_duplicate_run_manager(mock_run_manager: RunManager) -> None:
"""Test adding a duplicate RunManager to MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
with pytest.raises(ValueError, match="A RunManager with the name 'test-rm' already exists."):
mrm.add_run_manager(mock_run_manager)


@pytest.mark.asyncio()
async def test_multi_run_manager_add_run_manager_with_start(mock_run_manager: RunManager) -> None:
"""Test adding a RunManager to MultiRunManager with start=True."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager, start=True)
await asyncio.sleep(0.1)
assert mock_run_manager.status() == "running"
mock_run_manager.cancel()


def test_multi_run_manager_add_run_manager_with_invalid_wait_for(
mock_run_manager: RunManager,
) -> None:
"""Test adding a RunManager with an invalid wait_for parameter."""
mrm = MultiRunManager()
with pytest.raises(KeyError, match="No RunManager with the name 'non-existent' exists."):
mrm.add_run_manager(mock_run_manager, start=True, wait_for="non-existent")


def test_multi_run_manager_add_run_manager_with_wait_for_without_start(
mock_run_manager: RunManager,
) -> None:
"""Test adding a RunManager with wait_for but without start=True."""
mrm = MultiRunManager()
mrm.add_run_manager(
RunManager(
mock_run_manager.scheduler,
mock_run_manager.learners,
mock_run_manager.fnames,
job_name="rm1",
),
)
with pytest.raises(ValueError, match="`start` must be True if `wait_for` is used."):
mrm.add_run_manager(mock_run_manager, wait_for="rm1")


def test_multi_run_manager_remove_run_manager(mock_run_manager: RunManager) -> None:
"""Test removing a RunManager from MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
mrm.remove_run_manager("test-rm")
assert len(mrm.run_managers) == 0


def test_multi_run_manager_remove_non_existent_run_manager() -> None:
"""Test removing a non-existent RunManager from MultiRunManager."""
mrm = MultiRunManager()
with pytest.raises(KeyError, match="No RunManager with the name 'non-existent' exists."):
mrm.remove_run_manager("non-existent")


@pytest.mark.asyncio()
async def test_multi_run_manager_start_all(mock_run_manager: RunManager) -> None:
"""Test starting all RunManagers in MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
mrm.start_all()
await asyncio.sleep(0.1)
assert mock_run_manager.status() == "running"
mock_run_manager.cancel()


@pytest.mark.asyncio()
async def test_multi_run_manager_cancel_all(mock_run_manager: RunManager) -> None:
"""Test cancelling all RunManagers in MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
mrm.start_all()
await asyncio.sleep(0.1)
mrm.cancel_all()
await asyncio.sleep(0.1)
assert mock_run_manager.status() == "cancelled"


def test_multi_run_manager_create_widget(mock_run_manager: RunManager) -> None:
"""Test creating the widget for MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
vbox = mrm._create_widget()
assert isinstance(vbox, ipw.VBox)
tab = vbox.children[1]
assert isinstance(tab, ipw.Tab)
assert len(tab.children) == 1
assert tab.get_title(0) == "RunManager: test-rm"


def test_multi_run_manager_update_widget(mock_run_manager: RunManager) -> None:
"""Test updating the widget for MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
mrm._widget = mrm._create_widget()
new_rm = RunManager(
mock_run_manager.scheduler,
mock_run_manager.learners,
mock_run_manager.fnames,
job_name="new-rm",
)
mrm.add_run_manager(new_rm)
assert len(mrm._widget.children) == 2
assert mrm._widget.children[1].get_title(1) == "RunManager: new-rm"


def test_multi_run_manager_info(mock_run_manager: RunManager) -> None:
"""Test the info method of MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
vbox = mrm.info()
assert isinstance(vbox, ipw.VBox)
tab = vbox.children[1]
assert isinstance(tab, ipw.Tab)
assert len(tab.children) == 1


def test_multi_run_manager_repr_html(mock_run_manager: RunManager) -> None:
"""Test the _repr_html_ method of MultiRunManager."""
mrm = MultiRunManager()
mrm.add_run_manager(mock_run_manager)
with patch("IPython.display.display") as mocked_display:
mrm._repr_html_()
assert mocked_display.called

0 comments on commit 5bae6d9

Please sign in to comment.