Skip to content

Commit

Permalink
Merge pull request #507 from tclose/serial-worker-fix
Browse files Browse the repository at this point in the history
Fixes SerialWorker Implementation
  • Loading branch information
djarecka authored Jan 28, 2022
2 parents 4db192a + 8033ff7 commit 1100d7c
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 21 deletions.
2 changes: 1 addition & 1 deletion pydra/engine/specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def _field_metadata(
if "mandatory" in fld.metadata:
if fld.metadata["mandatory"]:
raise Exception(
f"mandatory output for variable {fld.name} does not exit"
f"mandatory output for variable {fld.name} does not exist"
)
return attr.NOTHING
return val
Expand Down
2 changes: 1 addition & 1 deletion pydra/engine/tests/test_shelltask.py
Original file line number Diff line number Diff line change
Expand Up @@ -4453,7 +4453,7 @@ def test_shell_cmd_non_existing_outputs_4(tmpdir):
# An exception should be raised because the second mandatory output does not exist
with pytest.raises(Exception) as excinfo:
shelly()
assert "mandatory output for variable out_2 does not exit" == str(excinfo.value)
assert "mandatory output for variable out_2 does not exist" == str(excinfo.value)
# checking if the first output was created
assert (Path(shelly.output_dir) / Path("test_1.nii")).exists()

Expand Down
7 changes: 7 additions & 0 deletions pydra/engine/tests/test_submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,13 @@ def test_wf_with_state(plugin_dask_opt, tmpdir):
assert res[2].output.out == 5


def test_serial_wf():
# Use serial plugin to execute workflow instead of CF
wf = gen_basic_wf()
res = wf(plugin="serial")
assert res.output.out == 9


@pytest.mark.skipif(not slurm_available, reason="slurm not installed")
def test_slurm_wf(tmpdir):
wf = gen_basic_wf()
Expand Down
30 changes: 11 additions & 19 deletions pydra/engine/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,38 +116,30 @@ async def fetch_finished(self, futures):
return pending.union(unqueued)


class SerialPool:
"""A simple class to imitate a pool executor of concurrent futures."""

def submit(self, interface, **kwargs):
"""Send new task."""
self.res = interface(**kwargs)

def result(self):
"""Get the result of a task."""
return self.res

def done(self):
"""Return whether the task is finished."""
return True


class SerialWorker(Worker):
"""A worker to execute linearly."""

def __init__(self):
"""Initialize worker."""
logger.debug("Initialize SerialWorker")
self.pool = SerialPool()

def run_el(self, interface, rerun=False, **kwargs):
"""Run a task."""
self.pool.submit(interface=interface, rerun=rerun, **kwargs)
return self.pool
return self.exec_serial(interface, rerun=rerun)

def close(self):
"""Return whether the task is finished."""

async def exec_serial(self, runnable, rerun=False):
return runnable()

async def fetch_finished(self, futures):
await asyncio.gather(*futures)
return set([])

# async def fetch_finished(self, futures):
# return await asyncio.wait(futures)


class ConcurrentFuturesWorker(Worker):
"""A worker to execute in parallel using Python's concurrent futures."""
Expand Down

0 comments on commit 1100d7c

Please sign in to comment.