Skip to content

Commit

Permalink
DaskExecutor (#943)
Browse files Browse the repository at this point in the history
* adds DaskScheduler

* upgrades

* no ci for 3.8

* lazy import dask

* updates

* drop py3.7 support

* lockfile

* fixes?

* future

* fixes

* fixes

* fixes

* fixes for 3.8

* update cassettes

* ci

* changelog and DaskExecutor.from_kwargs

* fixes as_completed

* refactor Executor type to protocol

* fixes protocol

* fixes

* Update webknossos/Changelog.md

Co-authored-by: Philipp Otto <philippotto@users.noreply.github.com>

* Changelog

---------

Co-authored-by: Philipp Otto <philippotto@users.noreply.github.com>
  • Loading branch information
normanrz and philippotto authored Oct 11, 2023
1 parent 1cb7101 commit 14efb6f
Show file tree
Hide file tree
Showing 60 changed files with 3,532 additions and 5,710 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,15 @@ jobs:
poetry install
- name: Check typing
if: ${{ matrix.executors == 'multiprocessing' }}
if: ${{ matrix.executors == 'multiprocessing' && matrix.python-version == '3.11' }}
run: ./typecheck.sh

- name: Check formatting
if: ${{ matrix.executors == 'multiprocessing' }}
if: ${{ matrix.executors == 'multiprocessing' && matrix.python-version == '3.11' }}
run: ./format.sh check

- name: Lint code
if: ${{ matrix.executors == 'multiprocessing' }}
if: ${{ matrix.executors == 'multiprocessing' && matrix.python-version == '3.11' }}
run: ./lint.sh

- name: Run multiprocessing tests
Expand Down Expand Up @@ -160,15 +160,15 @@ jobs:
poetry install --extras all
- name: Check formatting
if: matrix.group == 1
if: ${{ matrix.group == 1 && matrix.python-version == '3.11' }}
run: ./format.sh check

- name: Lint code
if: matrix.group == 1
if: ${{ matrix.group == 1 && matrix.python-version == '3.11' }}
run: ./lint.sh

- name: Check typing
if: matrix.group == 1
if: ${{ matrix.group == 1 && matrix.python-version == '3.11' }}
run: ./typecheck.sh

- name: Python tests
Expand Down
4 changes: 4 additions & 0 deletions cluster_tools/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,14 @@ For upgrade instructions, please check the respective *Breaking Changes* section
[Commits](https://github.com/scalableminds/webknossos-libs/compare/v0.13.7...HEAD)

### Breaking Changes
- Dropped support for Python 3.7. [#943](https://github.com/scalableminds/webknossos-libs/pull/943)
- Please use `Executor.as_completed` instead of `concurrent.futures.as_completed` because the latter will not work for `DaskExecutor` futures. [#943](https://github.com/scalableminds/webknossos-libs/pull/943)

### Added
- Added `DaskScheduler` (only Python >= 3.9). [#943](https://github.com/scalableminds/webknossos-libs/pull/943)

### Changed
- The exported `Executor` type is now implemented as a protocol. [#943](https://github.com/scalableminds/webknossos-libs/pull/943)

### Fixed

Expand Down
19 changes: 13 additions & 6 deletions cluster_tools/cluster_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, Union, overload

from typing_extensions import Literal
from typing import Any, Literal, overload

from cluster_tools.executor_protocol import Executor
from cluster_tools.executors.dask import DaskExecutor
from cluster_tools.executors.debug_sequential import DebugSequentialExecutor
from cluster_tools.executors.multiprocessing_ import MultiprocessingExecutor
from cluster_tools.executors.pickle_ import PickleExecutor
Expand Down Expand Up @@ -70,6 +70,11 @@ def get_executor(
...


@overload
def get_executor(environment: Literal["dask"], **kwargs: Any) -> DaskExecutor:
...


@overload
def get_executor(
environment: Literal["multiprocessing"], **kwargs: Any
Expand Down Expand Up @@ -105,6 +110,11 @@ def get_executor(environment: str, **kwargs: Any) -> "Executor":
return PBSExecutor(**kwargs)
elif environment == "kubernetes":
return KubernetesExecutor(**kwargs)
elif environment == "dask":
if "client" in kwargs:
return DaskExecutor(kwargs["client"])
else:
return DaskExecutor.from_kwargs(**kwargs)
elif environment == "multiprocessing":
global did_start_test_multiprocessing
if not did_start_test_multiprocessing:
Expand All @@ -119,6 +129,3 @@ def get_executor(environment: str, **kwargs: Any) -> "Executor":
elif environment == "test_pickling":
return PickleExecutor(**kwargs)
raise Exception("Unknown executor: {}".format(environment))


Executor = Union[ClusterExecutor, MultiprocessingExecutor]
59 changes: 59 additions & 0 deletions cluster_tools/cluster_tools/executor_protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from concurrent.futures import Future
from os import PathLike
from typing import (
Callable,
ContextManager,
Iterable,
Iterator,
List,
Optional,
Protocol,
TypeVar,
)

from typing_extensions import ParamSpec

_T = TypeVar("_T")
_P = ParamSpec("_P")
_S = TypeVar("_S")


class Executor(Protocol, ContextManager["Executor"]):
@classmethod
def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]:
...

def submit(
self,
__fn: Callable[_P, _T],
/,
*args: _P.args,
**kwargs: _P.kwargs,
) -> "Future[_T]":
...

def map_unordered(self, fn: Callable[[_S], _T], args: Iterable[_S]) -> Iterator[_T]:
...

def map_to_futures(
self,
fn: Callable[[_S], _T],
args: Iterable[_S],
output_pickle_path_getter: Optional[Callable[[_S], PathLike]] = None,
) -> List["Future[_T]"]:
...

def map(
self,
fn: Callable[[_S], _T],
iterables: Iterable[_S],
timeout: Optional[float] = None,
chunksize: Optional[int] = None,
) -> Iterator[_T]:
...

def forward_log(self, fut: "Future[_T]") -> _T:
...

def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
...
128 changes: 128 additions & 0 deletions cluster_tools/cluster_tools/executors/dask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
from concurrent import futures
from concurrent.futures import Future
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Iterator,
List,
Optional,
TypeVar,
cast,
)

from typing_extensions import ParamSpec

from cluster_tools._utils.warning import enrich_future_with_uncaught_warning
from cluster_tools.executors.multiprocessing_ import CFutDict, MultiprocessingExecutor

if TYPE_CHECKING:
from distributed import Client

_T = TypeVar("_T")
_P = ParamSpec("_P")
_S = TypeVar("_S")


class DaskExecutor(futures.Executor):
client: "Client"

def __init__(
self,
client: "Client",
) -> None:
self.client = client

@classmethod
def from_kwargs(
cls,
**kwargs: Any,
) -> "DaskExecutor":
from distributed import Client

return cls(Client(**kwargs))

@classmethod
def as_completed(cls, futures: List["Future[_T]"]) -> Iterator["Future[_T]"]:
from distributed import as_completed

return as_completed(futures)

def submit( # type: ignore[override]
self,
__fn: Callable[_P, _T],
*args: _P.args,
**kwargs: _P.kwargs,
) -> "Future[_T]":
if "__cfut_options" in kwargs:
output_pickle_path = cast(CFutDict, kwargs["__cfut_options"])[
"output_pickle_path"
]
del kwargs["__cfut_options"]

__fn = partial(
MultiprocessingExecutor._execute_and_persist_function,
output_pickle_path,
__fn,
)
fut = self.client.submit(partial(__fn, *args, **kwargs))

enrich_future_with_uncaught_warning(fut)
return fut

def map_unordered(self, fn: Callable[[_S], _T], args: Iterable[_S]) -> Iterator[_T]:
futs: List["Future[_T]"] = self.map_to_futures(fn, args)

# Return a separate generator to avoid that map_unordered
# is executed lazily (otherwise, jobs would be submitted
# lazily, as well).
def result_generator() -> Iterator:
for fut in self.as_completed(futs):
yield fut.result()

return result_generator()

def map_to_futures(
self,
fn: Callable[[_S], _T],
args: Iterable[_S], # TODO change: allow more than one arg per call
output_pickle_path_getter: Optional[Callable[[_S], os.PathLike]] = None,
) -> List["Future[_T]"]:
if output_pickle_path_getter is not None:
futs = [
self.submit( # type: ignore[call-arg]
fn,
arg,
__cfut_options={
"output_pickle_path": output_pickle_path_getter(arg)
},
)
for arg in args
]
else:
futs = [self.submit(fn, arg) for arg in args]

return futs

def map( # type: ignore[override]
self,
fn: Callable[[_S], _T],
iterables: Iterable[_S],
timeout: Optional[float] = None,
chunksize: Optional[int] = None,
) -> Iterator[_T]:
if chunksize is None:
chunksize = 1
return super().map(fn, iterables, timeout=timeout, chunksize=chunksize)

def forward_log(self, fut: "Future[_T]") -> _T:
return fut.result()

def shutdown(self, wait: bool = True, *, cancel_futures: bool = False) -> None:
if wait:
self.client.close(timeout=60 * 60 * 24)
else:
self.client.close()
18 changes: 17 additions & 1 deletion cluster_tools/cluster_tools/executors/multiprocessing_.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
List,
Optional,
Tuple,
TypedDict,
TypeVar,
cast,
)

from typing_extensions import ParamSpec, TypedDict
from typing_extensions import ParamSpec

from cluster_tools._utils import pickling
from cluster_tools._utils.multiprocessing_logging_handler import (
Expand Down Expand Up @@ -85,6 +86,10 @@ def __init__(
else:
self._mp_logging_handler_pool = _MultiprocessingLoggingHandlerPool()

@classmethod
def as_completed(cls, futs: List["Future[_T]"]) -> Iterator["Future[_T]"]:
return futures.as_completed(futs)

def submit( # type: ignore[override]
self,
__fn: Callable[_P, _T],
Expand Down Expand Up @@ -143,6 +148,17 @@ def submit( # type: ignore[override]
enrich_future_with_uncaught_warning(fut)
return fut

def map( # type: ignore[override]
self,
fn: Callable[[_S], _T],
iterables: Iterable[_S],
timeout: Optional[float] = None,
chunksize: Optional[int] = None,
) -> Iterator[_T]:
if chunksize is None:
chunksize = 1
return super().map(fn, iterables, timeout=timeout, chunksize=chunksize)

def _submit_via_io(
self,
__fn: Callable[_P, _T],
Expand Down
26 changes: 19 additions & 7 deletions cluster_tools/cluster_tools/executors/pickle_.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,28 @@
from concurrent.futures import Future
from typing import Any, Callable, TypeVar
from functools import partial
from typing import Callable, TypeVar

from typing_extensions import ParamSpec

from cluster_tools._utils import pickling
from cluster_tools.executors.multiprocessing_ import MultiprocessingExecutor

# The module name includes a _-suffix to avoid name clashes with the standard library pickle module.

_T = TypeVar("_T")
_P = ParamSpec("_P")
_S = TypeVar("_S")


def _pickle_identity(obj: _T) -> _T:
def _pickle_identity(obj: _S) -> _S:
return pickling.loads(pickling.dumps(obj))


def _pickle_identity_executor(fn: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
def _pickle_identity_executor(
fn: Callable[_P, _T],
*args: _P.args,
**kwargs: _P.kwargs,
) -> _T:
result = fn(*args, **kwargs)
return _pickle_identity(result)

Expand All @@ -27,13 +36,16 @@ class PickleExecutor(MultiprocessingExecutor):

def submit( # type: ignore[override]
self,
fn: Callable[..., _T],
*args: Any,
**kwargs: Any,
fn: Callable[_P, _T],
/,
*args: _P.args,
**kwargs: _P.kwargs,
) -> "Future[_T]":
(fn_pickled, args_pickled, kwargs_pickled) = _pickle_identity(
(fn, args, kwargs)
)
return super().submit(
_pickle_identity_executor, fn_pickled, *args_pickled, **kwargs_pickled
partial(_pickle_identity_executor, fn_pickled),
*args_pickled,
**kwargs_pickled,
)
Loading

0 comments on commit 14efb6f

Please sign in to comment.