-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
fe3ede7
commit cf2ec0d
Showing
4 changed files
with
1,153 additions
and
1,072 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
from __future__ import annotations | ||
|
||
import functools | ||
from aiida.orm import ProcessNode | ||
from aiida.engine.processes.workchains.awaitable import ( | ||
Awaitable, | ||
AwaitableAction, | ||
AwaitableTarget, | ||
construct_awaitable, | ||
) | ||
from aiida.orm import load_node | ||
from aiida.common import exceptions | ||
from typing import Any, List | ||
import logging | ||
|
||
|
||
class AwaitableHandler: | ||
"""Handles awaitable objects and their resolutions.""" | ||
|
||
def __init__( | ||
self, _awaitables, runner, logger: logging.Logger, process, ctx_manager | ||
): | ||
self.runner = runner | ||
self.logger = logger | ||
self.process = process | ||
self.ctx_manager = ctx_manager | ||
self.ctx = ctx_manager.ctx | ||
self._awaitables: List[Awaitable] = _awaitables | ||
self._temp = {"awaitables": {}} | ||
self.ctx._awaitable_actions = [] | ||
|
||
def insert_awaitable(self, awaitable: Awaitable) -> None: | ||
"""Insert an awaitable that should be terminated before before continuing to the next step. | ||
:param awaitable: the thing to await | ||
""" | ||
ctx, key = self.ctx_manager.resolve_nested_context(awaitable.key) | ||
|
||
# Already assign the awaitable itself to the location in the context container where it is supposed to end up | ||
# once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the | ||
# order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the | ||
# awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value. | ||
if awaitable.action == AwaitableAction.ASSIGN: | ||
ctx[key] = awaitable | ||
elif awaitable.action == AwaitableAction.APPEND: | ||
ctx.setdefault(key, []).append(awaitable) | ||
else: | ||
raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") | ||
|
||
self._awaitables.append( | ||
awaitable | ||
) # add only if everything went ok, otherwise we end up in an inconsistent state | ||
self.update_process_status() | ||
|
||
def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None: | ||
"""Resolve an awaitable. | ||
Precondition: must be an awaitable that was previously inserted. | ||
:param awaitable: the awaitable to resolve | ||
:param value: the value to assign to the awaitable | ||
""" | ||
ctx, key = self.ctx_manager.resolve_nested_context(awaitable.key) | ||
|
||
if awaitable.action == AwaitableAction.ASSIGN: | ||
ctx[key] = value | ||
elif awaitable.action == AwaitableAction.APPEND: | ||
# Find the same awaitable inserted in the context | ||
container = ctx[key] | ||
for index, placeholder in enumerate(container): | ||
if ( | ||
isinstance(placeholder, Awaitable) | ||
and placeholder.pk == awaitable.pk | ||
): | ||
container[index] = value | ||
break | ||
else: | ||
raise AssertionError( | ||
f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`" | ||
) | ||
else: | ||
raise AssertionError(f"Unsupported awaitable action: {awaitable.action}") | ||
|
||
awaitable.resolved = True | ||
# remove awaitabble from the list | ||
self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk] | ||
|
||
if not self.process.has_terminated(): | ||
# the process may be terminated, for example, if the process was killed or excepted | ||
# then we should not try to update it | ||
self.update_process_status() | ||
|
||
def update_process_status(self) -> None: | ||
"""Set the process status with a message accounting the current sub processes that we are waiting for.""" | ||
if self._awaitables: | ||
status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}" | ||
self.process.node.set_process_status(status) | ||
else: | ||
self.process.node.set_process_status(None) | ||
|
||
def action_awaitables(self) -> None: | ||
"""Handle the awaitables that are currently registered with the work chain. | ||
Depending on the class type of the awaitable's target a different callback | ||
function will be bound with the awaitable and the runner will be asked to | ||
call it when the target is completed | ||
""" | ||
for awaitable in self._awaitables: | ||
# if the waitable already has a callback, skip | ||
if awaitable.pk in self.ctx._awaitable_actions: | ||
continue | ||
if awaitable.target == AwaitableTarget.PROCESS: | ||
callback = functools.partial( | ||
self.process.call_soon, self.on_awaitable_finished, awaitable | ||
) | ||
self.runner.call_on_process_finish(awaitable.pk, callback) | ||
self.ctx._awaitable_actions.append(awaitable.pk) | ||
elif awaitable.target == "asyncio.tasks.Task": | ||
# this is a awaitable task, the callback function is already set | ||
self.ctx._awaitable_actions.append(awaitable.pk) | ||
else: | ||
assert f"invalid awaitable target '{awaitable.target}'" | ||
|
||
def on_awaitable_finished(self, awaitable: Awaitable) -> None: | ||
"""Callback function, for when an awaitable process instance is completed. | ||
The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all | ||
awaitables have been dealt with, the work chain process is resumed. | ||
:param awaitable: an Awaitable instance | ||
""" | ||
self.logger.debug(f"Awaitable {awaitable.key} finished.") | ||
|
||
if isinstance(awaitable.pk, int): | ||
self.logger.info( | ||
"received callback that awaitable with key {} and pk {} has terminated".format( | ||
awaitable.key, awaitable.pk | ||
) | ||
) | ||
try: | ||
node = load_node(awaitable.pk) | ||
except (exceptions.MultipleObjectsError, exceptions.NotExistent): | ||
raise ValueError( | ||
f"provided pk<{awaitable.pk}> could not be resolved to a valid Node instance" | ||
) | ||
|
||
if awaitable.outputs: | ||
value = { | ||
entry.link_label: entry.node | ||
for entry in node.base.links.get_outgoing() | ||
} | ||
else: | ||
value = node # type: ignore | ||
else: | ||
# In this case, the pk and key are the same. | ||
self.logger.info( | ||
"received callback that awaitable {} has terminated".format( | ||
awaitable.key | ||
) | ||
) | ||
try: | ||
# if awaitable is cancelled, the result is None | ||
if awaitable.cancelled(): | ||
self.process.task_manager.set_task_state_info( | ||
awaitable.key, "state", "KILLED" | ||
) | ||
# set child tasks state to SKIPPED | ||
self.process.task_manager.set_tasks_state( | ||
self.ctx._connectivity["child_node"][awaitable.key], | ||
"SKIPPED", | ||
) | ||
self.process.report(f"Task: {awaitable.key} cancelled.") | ||
else: | ||
results = awaitable.result() | ||
self.process.task_manager.update_normal_task_state( | ||
awaitable.key, results | ||
) | ||
except Exception as e: | ||
self.logger.error(f"Error in awaitable {awaitable.key}: {e}") | ||
self.process.task_manager.set_task_state_info( | ||
awaitable.key, "state", "FAILED" | ||
) | ||
# set child tasks state to SKIPPED | ||
self.process.task_manager.set_tasks_state( | ||
self.ctx._connectivity["child_node"][awaitable.key], | ||
"SKIPPED", | ||
) | ||
self.process.report(f"Task: {awaitable.key} failed: {e}") | ||
self.process.run_error_handlers(awaitable.key) | ||
value = None | ||
|
||
self.resolve_awaitable(awaitable, value) | ||
|
||
# node finished, update the task state and result | ||
# udpate the task state | ||
self.process.task_manager.update_task_state(awaitable.key) | ||
# try to resume the workgraph, if the workgraph is already resumed | ||
# by other awaitable, this will not work | ||
try: | ||
self.process.resume() | ||
except Exception as e: | ||
print(e) | ||
|
||
def construct_awaitable_function( | ||
self, name: str, awaitable_target: Awaitable | ||
) -> None: | ||
"""Construct the awaitable function.""" | ||
awaitable = Awaitable( | ||
**{ | ||
"pk": name, | ||
"action": AwaitableAction.ASSIGN, | ||
"target": "asyncio.tasks.Task", | ||
"outputs": False, | ||
} | ||
) | ||
awaitable_target.key = name | ||
awaitable_target.pk = name | ||
awaitable_target.action = AwaitableAction.ASSIGN | ||
awaitable_target.add_done_callback(self.on_awaitable_finished) | ||
return awaitable | ||
|
||
def to_context(self, **kwargs: Awaitable | ProcessNode) -> None: | ||
"""Add a dictionary of awaitables to the context. | ||
This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will | ||
assign a certain value to the corresponding key in the context of the work graph. | ||
""" | ||
for key, value in kwargs.items(): | ||
awaitable = construct_awaitable(value) | ||
awaitable.key = key | ||
self.insert_awaitable(awaitable) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from __future__ import annotations | ||
|
||
from aiida.common.extendeddicts import AttributeDict | ||
from typing import Any | ||
from aiida_workgraph.utils import get_nested_dict | ||
import logging | ||
|
||
|
||
class ContextManager: | ||
"""Manages the context for the WorkGraphEngine.""" | ||
|
||
_CONTEXT = "CONTEXT" | ||
|
||
def __init__(self, _context, process, logger: logging.Logger): | ||
self.process = process | ||
self._context = _context | ||
self.logger = logger | ||
|
||
@property | ||
def ctx(self) -> AttributeDict: | ||
"""Access the context.""" | ||
return self._context | ||
|
||
def resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]: | ||
""" | ||
Returns a reference to a sub-dictionary of the context and the last key, | ||
after resolving a potentially segmented key where required sub-dictionaries are created as needed. | ||
:param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary | ||
""" | ||
ctx = self.ctx | ||
ctx_path = key.split(".") | ||
|
||
for index, path in enumerate(ctx_path[:-1]): | ||
try: | ||
ctx = ctx[path] | ||
except KeyError: # see below why this is the only exception we have to catch here | ||
ctx[ | ||
path | ||
] = AttributeDict() # create the sub-dict and update the context | ||
ctx = ctx[path] | ||
continue | ||
|
||
# Notes: | ||
# * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking | ||
# * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables | ||
# (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself | ||
# * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable | ||
# would be an AttributeDict we can append things to it since the order of tasks is maintained. | ||
if type(ctx) != AttributeDict: # pylint: disable=C0123 | ||
raise ValueError( | ||
f"Can not update the context for key `{key}`: " | ||
f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index + 1])}`, expected AttributeDict' | ||
) | ||
|
||
return ctx, ctx_path[-1] | ||
|
||
def update_context_variable(self, value: Any) -> Any: | ||
"""Replace placeholders in the value with actual context values.""" | ||
if isinstance(value, dict): | ||
return {k: self.update_context_variable(v) for k, v in value.items()} | ||
elif ( | ||
isinstance(value, str) | ||
and value.strip().startswith("{{") | ||
and value.strip().endswith("}}") | ||
): | ||
name = value[2:-2].strip() | ||
return get_nested_dict(self.ctx, name) | ||
return value |
Oops, something went wrong.