Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Nov 25, 2024
1 parent e1e7297 commit d87cd1f
Show file tree
Hide file tree
Showing 17 changed files with 99 additions and 89 deletions.
12 changes: 6 additions & 6 deletions adaptive_scheduler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
)

__all__ = [
"__version__",
"client_support",
"PBS",
"SLURM",
"MultiRunManager",
"RunManager",
"SlurmExecutor",
"SlurmTask",
"__version__",
"client_support",
"scheduler",
"server_support",
"slurm_run",
"SLURM",
"start_one_by_one",
"utils",
"SlurmExecutor",
"SlurmTask",
"MultiRunManager",
]
2 changes: 1 addition & 1 deletion adaptive_scheduler/_scheduler/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def start_job(self, name: str, *, index: int | None = None) -> None:
submit_cmd = f"{self.submit_cmd} {name} {self.batch_fname(name_prefix)}"
run_submit(submit_cmd, name)

def extra_scheduler(self, *, index: int | None = None) -> str: # noqa: ARG002
def extra_scheduler(self, *, index: int | None = None) -> str:
"""Get the extra scheduler options."""
msg = "extra_scheduler is not implemented."
raise NotImplementedError(msg)
4 changes: 2 additions & 2 deletions adaptive_scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from adaptive_scheduler._scheduler.slurm import SLURM, slurm_partitions

__all__ = [
"LocalMockScheduler",
"PBS",
"SLURM",
"slurm_partitions",
"BaseScheduler",
"DefaultScheduler",
"LocalMockScheduler",
"slurm_partitions",
]


Expand Down
26 changes: 13 additions & 13 deletions adaptive_scheduler/server_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,25 @@
from ._server_support.slurm_run import slurm_run

__all__ = [
"BaseManager",
"DatabaseManager",
"JobManager",
"KillManager",
"_wait_for_finished",
"_start_after",
"start_one_by_one",
"logs_with_string_or_condition",
"RunManager",
"MaxRestartsReachedError",
"MultiRunManager",
"slurm_run",
"RunManager",
"_delete_old_ipython_profiles",
"_get_all_files",
"_get_infos",
"parse_log_files",
"BaseManager",
"log",
"_start_after",
"_wait_for_finished",
"cleanup_scheduler_files",
"console",
"MaxRestartsReachedError",
"get_allowed_url",
"_get_all_files",
"cleanup_scheduler_files",
"_delete_old_ipython_profiles",
"log",
"logs_with_string_or_condition",
"parse_log_files",
"periodically_clean_ipython_profiles",
"slurm_run",
"start_one_by_one",
]
52 changes: 31 additions & 21 deletions example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,25 @@
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"import adaptive_scheduler\n",
"import random\n",
"\n",
"\n",
"def h(x, width=0.01, offset=0):\n",
" for _ in range(10): # Burn some CPU time just because\n",
" np.linalg.eig(np.random.rand(1000, 1000))\n",
" return x + width ** 2 / (width ** 2 + (x - offset) ** 2)\n",
" return x + width**2 / (width**2 + (x - offset) ** 2)\n",
"\n",
"\n",
"# Define the sequence/samples we want to run\n",
"xs = np.linspace(0, 1, 10_000)\n",
"\n",
"# ⚠️ Here a `learner` is an `adaptive` concept, read it as `jobs`.\n",
"# ⚠️ `fnames` are the result locations\n",
"learners, fnames = adaptive_scheduler.utils.split_sequence_in_sequence_learners(\n",
" h, xs, n_learners=10\n",
" h,\n",
" xs,\n",
" n_learners=10,\n",
")\n",
"\n",
"run_manager = adaptive_scheduler.slurm_run(\n",
Expand All @@ -48,7 +52,7 @@
" nodes=1, # number of nodes per `learner`\n",
" cores_per_node=1, # number of cores on 1 node per `learner`\n",
" log_interval=5, # how often to produce a log message\n",
" save_interval=5, # how often to save the results\n",
" save_interval=5, # how often to save the results\n",
")\n",
"run_manager.start()"
]
Expand Down Expand Up @@ -85,18 +89,18 @@
"from functools import partial\n",
"\n",
"import adaptive\n",
"\n",
"import adaptive_scheduler\n",
"\n",
"\n",
"def h(x, width=0.01, offset=0):\n",
" import numpy as np\n",
" import random\n",
"\n",
" for _ in range(10): # Burn some CPU time just because\n",
" np.linalg.eig(np.random.rand(1000, 1000))\n",
"\n",
" a = width\n",
" return x + a ** 2 / (a ** 2 + (x - offset) ** 2)\n",
" return x + a**2 / (a**2 + (x - offset) ** 2)\n",
"\n",
"\n",
"offsets = [i / 10 - 0.5 for i in range(5)]\n",
Expand Down Expand Up @@ -266,16 +270,16 @@
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from adaptive import SequenceLearner\n",
"from adaptive_scheduler.utils import split, combo_to_fname\n",
"\n",
"from adaptive_scheduler.utils import split\n",
"\n",
"\n",
"def g(xyz):\n",
" x, y, z = xyz\n",
" for _ in range(5): # Burn some CPU time just because\n",
" np.linalg.eig(np.random.rand(1000, 1000))\n",
" return x ** 2 + y ** 2 + z ** 2\n",
" return x**2 + y**2 + z**2\n",
"\n",
"\n",
"xs = np.linspace(0, 10, 11)\n",
Expand All @@ -302,11 +306,17 @@
"\n",
"\n",
"scheduler = adaptive_scheduler.scheduler.DefaultScheduler(\n",
" cores=10, executor_type=\"ipyparallel\",\n",
" cores=10,\n",
" executor_type=\"ipyparallel\",\n",
") # PBS or SLURM\n",
"\n",
"run_manager2 = adaptive_scheduler.server_support.RunManager(\n",
" scheduler, learners, fnames, goal=goal, log_interval=30, save_interval=30,\n",
" scheduler,\n",
" learners,\n",
" fnames,\n",
" goal=goal,\n",
" log_interval=30,\n",
" save_interval=30,\n",
")\n",
"run_manager2.start()"
]
Expand Down Expand Up @@ -343,19 +353,19 @@
"outputs": [],
"source": [
"import numpy as np\n",
"\n",
"from adaptive import SequenceLearner\n",
"from adaptive_scheduler.utils import split, combo2fname\n",
"from adaptive.utils import named_product\n",
"\n",
"from adaptive_scheduler.utils import combo2fname\n",
"\n",
"\n",
"def g(combo):\n",
" x, y, z = combo[\"x\"], combo[\"y\"], combo[\"z\"]\n",
"\n",
" for _ in range(5): # Burn some CPU time just because\n",
" np.linalg.eig(np.random.rand(1000, 1000))\n",
"\n",
" return x ** 2 + y ** 2 + z ** 2\n",
" return x**2 + y**2 + z**2\n",
"\n",
"\n",
"combos = named_product(x=np.linspace(0, 10), y=np.linspace(-1, 1), z=np.linspace(-3, 3))\n",
Expand All @@ -364,15 +374,15 @@
"\n",
"# We could run this as 1 job with N nodes, but we can also split it up in multiple jobs.\n",
"# This is desireable when you don't want to run a single job with 300 nodes for example.\n",
"# Note that \n",
"# Note that\n",
"# `adaptive_scheduler.utils.split_sequence_in_sequence_learners(g, combos, 100, \"data\")`\n",
"# does the same!\n",
"\n",
"njobs = 100\n",
"split_combos = list(split(combos, njobs))\n",
"\n",
"print(\n",
" f\"Length of split_combos: {len(split_combos)} and length of split_combos[0]: {len(split_combos[0])}.\"\n",
" f\"Length of split_combos: {len(split_combos)} and length of split_combos[0]: {len(split_combos[0])}.\",\n",
")\n",
"\n",
"learners = [SequenceLearner(g, combos_part) for combos_part in split_combos]\n",
Expand All @@ -393,17 +403,16 @@
"outputs": [],
"source": [
"from functools import partial\n",
"\n",
"import adaptive_scheduler\n",
"from adaptive_scheduler.scheduler import DefaultScheduler, PBS, SLURM\n",
"from adaptive_scheduler.scheduler import SLURM, DefaultScheduler\n",
"\n",
"\n",
"def goal(learner):\n",
" return learner.done() # the standard goal for a SequenceLearner\n",
"\n",
"\n",
"extra_scheduler = (\n",
" [\"--exclusive\", \"--time=24:00:00\"] if DefaultScheduler is SLURM else []\n",
")\n",
"extra_scheduler = [\"--exclusive\", \"--time=24:00:00\"] if DefaultScheduler is SLURM else []\n",
"\n",
"scheduler = adaptive_scheduler.scheduler.DefaultScheduler(\n",
" cores=10,\n",
Expand Down Expand Up @@ -459,7 +468,8 @@
"source": [
"run_manager3.load_learners() # load the data into the learners\n",
"result = sum(\n",
" [l.result() for l in learners], []\n",
" [l.result() for l in learners],\n",
" [],\n",
") # combine all learner's result into 1 list"
]
}
Expand Down
14 changes: 7 additions & 7 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@
import zmq.asyncio


@pytest.fixture()
@pytest.fixture
def mock_scheduler(tmp_path: Path) -> MockScheduler:
"""Fixture for creating a MockScheduler instance."""
return MockScheduler(log_folder=str(tmp_path), cores=8)


@pytest.fixture()
@pytest.fixture
def db_manager(
mock_scheduler: MockScheduler,
learners: list[adaptive.Learner1D]
Expand Down Expand Up @@ -99,14 +99,14 @@ def fnames(
raise NotImplementedError(msg)


@pytest.fixture()
@pytest.fixture
def socket(db_manager: DatabaseManager) -> zmq.asyncio.Socket:
"""Fixture for creating a ZMQ socket."""
with get_socket(db_manager) as socket:
yield socket


@pytest.fixture()
@pytest.fixture
def job_manager(
db_manager: DatabaseManager,
mock_scheduler: MockScheduler,
Expand All @@ -116,7 +116,7 @@ def job_manager(
return JobManager(job_names, db_manager, mock_scheduler, interval=0.05)


@pytest.fixture()
@pytest.fixture
def _mock_slurm_partitions_output() -> Generator[None, None, None]:
"""Mock `slurm_partitions` function."""
mock_output = "hb120v2-low\nhb60-high\nnc24-low*\nnd40v2-mpi\n"
Expand All @@ -125,7 +125,7 @@ def _mock_slurm_partitions_output() -> Generator[None, None, None]:
yield


@pytest.fixture()
@pytest.fixture
def _mock_slurm_partitions() -> Generator[None, None, None]:
"""Mock `slurm_partitions` function."""
with (
Expand All @@ -141,7 +141,7 @@ def _mock_slurm_partitions() -> Generator[None, None, None]:
yield


@pytest.fixture()
@pytest.fixture
def _mock_slurm_queue() -> Generator[None, None, None]:
"""Mock `SLURM.queue` function."""
with patch(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_client_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def client(zmq_url: str) -> zmq.Socket:
return client


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_get_learner(zmq_url: str) -> None:
"""Test `get_learner` function."""
with tempfile.NamedTemporaryFile() as tmpfile:
Expand Down Expand Up @@ -94,7 +94,7 @@ async def test_get_learner(zmq_url: str) -> None:
mock_log.exception.assert_called_with("got an exception")


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_tell_done(zmq_url: str) -> None:
"""Test `tell_done` function."""
fname = "test_learner_file.pkl"
Expand Down
14 changes: 7 additions & 7 deletions tests/test_database_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_simple_database_get_all(tmp_path: Path) -> None:
assert done_entries[1][1].fname == "file3.txt"


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_start_and_cancel(db_manager: DatabaseManager) -> None:
"""Test starting and canceling the DatabaseManager."""
db_manager.start()
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_database_manager_as_dicts(
]


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_dispatch_start_stop(
db_manager: DatabaseManager,
learners: list[adaptive.Learner1D]
Expand Down Expand Up @@ -205,7 +205,7 @@ async def test_database_manager_dispatch_start_stop(
assert entry.is_done is True


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_start_and_update(
socket: zmq.asyncio.Socket,
db_manager: DatabaseManager,
Expand Down Expand Up @@ -259,7 +259,7 @@ async def test_database_manager_start_and_update(
assert entry.job_id is None


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_start_stop(
socket: zmq.asyncio.Socket,
db_manager: DatabaseManager,
Expand Down Expand Up @@ -322,7 +322,7 @@ async def test_database_manager_start_stop(
await send_message(socket, start_message)


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_database_manager_stop_request_and_requests(
socket: zmq.asyncio.Socket,
db_manager: DatabaseManager,
Expand Down Expand Up @@ -531,7 +531,7 @@ def test_ensure_str_invalid_input(invalid_input: list[str]) -> None:
_ensure_str(invalid_input) # type: ignore[arg-type]


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_dependencies(
db_manager: DatabaseManager,
fnames: list[str] | list[Path],
Expand Down Expand Up @@ -599,7 +599,7 @@ async def test_dependencies(
db_manager._choose_fname()


@pytest.mark.asyncio()
@pytest.mark.asyncio
async def test_replace_learner(db_manager: DatabaseManager) -> None:
"""Test replacing a learner in the DatabaseManager."""
db_manager.create_empty_db()
Expand Down
Loading

0 comments on commit d87cd1f

Please sign in to comment.