Skip to content

Commit

Permalink
Refactor engine
Browse files Browse the repository at this point in the history
  • Loading branch information
superstar54 committed Nov 29, 2024
1 parent fe3ede7 commit cf2ec0d
Show file tree
Hide file tree
Showing 4 changed files with 1,153 additions and 1,072 deletions.
231 changes: 231 additions & 0 deletions aiida_workgraph/engine/awaitable_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,231 @@
from __future__ import annotations

import functools
from aiida.orm import ProcessNode
from aiida.engine.processes.workchains.awaitable import (
Awaitable,
AwaitableAction,
AwaitableTarget,
construct_awaitable,
)
from aiida.orm import load_node
from aiida.common import exceptions
from typing import Any, List
import logging


class AwaitableHandler:
"""Handles awaitable objects and their resolutions."""

def __init__(
self, _awaitables, runner, logger: logging.Logger, process, ctx_manager
):
self.runner = runner
self.logger = logger
self.process = process
self.ctx_manager = ctx_manager
self.ctx = ctx_manager.ctx
self._awaitables: List[Awaitable] = _awaitables
self._temp = {"awaitables": {}}
self.ctx._awaitable_actions = []

def insert_awaitable(self, awaitable: Awaitable) -> None:
"""Insert an awaitable that should be terminated before before continuing to the next step.
:param awaitable: the thing to await
"""
ctx, key = self.ctx_manager.resolve_nested_context(awaitable.key)

# Already assign the awaitable itself to the location in the context container where it is supposed to end up
# once it is resolved. This is especially important for the `APPEND` action, since it needs to maintain the
# order, but the awaitables will not necessarily be resolved in the order in which they are added. By using the
# awaitable as a placeholder, in the `_resolve_awaitable`, it can be found and replaced by the resolved value.
if awaitable.action == AwaitableAction.ASSIGN:
ctx[key] = awaitable
elif awaitable.action == AwaitableAction.APPEND:
ctx.setdefault(key, []).append(awaitable)
else:
raise AssertionError(f"Unsupported awaitable action: {awaitable.action}")

self._awaitables.append(
awaitable
) # add only if everything went ok, otherwise we end up in an inconsistent state
self.update_process_status()

def resolve_awaitable(self, awaitable: Awaitable, value: Any) -> None:
"""Resolve an awaitable.
Precondition: must be an awaitable that was previously inserted.
:param awaitable: the awaitable to resolve
:param value: the value to assign to the awaitable
"""
ctx, key = self.ctx_manager.resolve_nested_context(awaitable.key)

if awaitable.action == AwaitableAction.ASSIGN:
ctx[key] = value
elif awaitable.action == AwaitableAction.APPEND:
# Find the same awaitable inserted in the context
container = ctx[key]
for index, placeholder in enumerate(container):
if (
isinstance(placeholder, Awaitable)
and placeholder.pk == awaitable.pk
):
container[index] = value
break
else:
raise AssertionError(
f"Awaitable `{awaitable.pk} was not in `ctx.{awaitable.key}`"
)
else:
raise AssertionError(f"Unsupported awaitable action: {awaitable.action}")

awaitable.resolved = True
# remove awaitabble from the list
self._awaitables = [a for a in self._awaitables if a.pk != awaitable.pk]

if not self.process.has_terminated():
# the process may be terminated, for example, if the process was killed or excepted
# then we should not try to update it
self.update_process_status()

def update_process_status(self) -> None:
"""Set the process status with a message accounting the current sub processes that we are waiting for."""
if self._awaitables:
status = f"Waiting for child processes: {', '.join([str(_.pk) for _ in self._awaitables])}"
self.process.node.set_process_status(status)
else:
self.process.node.set_process_status(None)

def action_awaitables(self) -> None:
"""Handle the awaitables that are currently registered with the work chain.
Depending on the class type of the awaitable's target a different callback
function will be bound with the awaitable and the runner will be asked to
call it when the target is completed
"""
for awaitable in self._awaitables:
# if the waitable already has a callback, skip
if awaitable.pk in self.ctx._awaitable_actions:
continue
if awaitable.target == AwaitableTarget.PROCESS:
callback = functools.partial(
self.process.call_soon, self.on_awaitable_finished, awaitable
)
self.runner.call_on_process_finish(awaitable.pk, callback)
self.ctx._awaitable_actions.append(awaitable.pk)
elif awaitable.target == "asyncio.tasks.Task":
# this is a awaitable task, the callback function is already set
self.ctx._awaitable_actions.append(awaitable.pk)
else:
assert f"invalid awaitable target '{awaitable.target}'"

def on_awaitable_finished(self, awaitable: Awaitable) -> None:
"""Callback function, for when an awaitable process instance is completed.
The awaitable will be effectuated on the context of the work chain and removed from the internal list. If all
awaitables have been dealt with, the work chain process is resumed.
:param awaitable: an Awaitable instance
"""
self.logger.debug(f"Awaitable {awaitable.key} finished.")

if isinstance(awaitable.pk, int):
self.logger.info(
"received callback that awaitable with key {} and pk {} has terminated".format(
awaitable.key, awaitable.pk
)
)
try:
node = load_node(awaitable.pk)
except (exceptions.MultipleObjectsError, exceptions.NotExistent):
raise ValueError(
f"provided pk<{awaitable.pk}> could not be resolved to a valid Node instance"
)

if awaitable.outputs:
value = {
entry.link_label: entry.node
for entry in node.base.links.get_outgoing()
}
else:
value = node # type: ignore
else:
# In this case, the pk and key are the same.
self.logger.info(
"received callback that awaitable {} has terminated".format(
awaitable.key
)
)
try:
# if awaitable is cancelled, the result is None
if awaitable.cancelled():
self.process.task_manager.set_task_state_info(
awaitable.key, "state", "KILLED"
)
# set child tasks state to SKIPPED
self.process.task_manager.set_tasks_state(
self.ctx._connectivity["child_node"][awaitable.key],
"SKIPPED",
)
self.process.report(f"Task: {awaitable.key} cancelled.")
else:
results = awaitable.result()
self.process.task_manager.update_normal_task_state(
awaitable.key, results
)
except Exception as e:
self.logger.error(f"Error in awaitable {awaitable.key}: {e}")
self.process.task_manager.set_task_state_info(
awaitable.key, "state", "FAILED"
)
# set child tasks state to SKIPPED
self.process.task_manager.set_tasks_state(
self.ctx._connectivity["child_node"][awaitable.key],
"SKIPPED",
)
self.process.report(f"Task: {awaitable.key} failed: {e}")
self.process.run_error_handlers(awaitable.key)
value = None

self.resolve_awaitable(awaitable, value)

# node finished, update the task state and result
# udpate the task state
self.process.task_manager.update_task_state(awaitable.key)
# try to resume the workgraph, if the workgraph is already resumed
# by other awaitable, this will not work
try:
self.process.resume()
except Exception as e:
print(e)

def construct_awaitable_function(
self, name: str, awaitable_target: Awaitable
) -> None:
"""Construct the awaitable function."""
awaitable = Awaitable(
**{
"pk": name,
"action": AwaitableAction.ASSIGN,
"target": "asyncio.tasks.Task",
"outputs": False,
}
)
awaitable_target.key = name
awaitable_target.pk = name
awaitable_target.action = AwaitableAction.ASSIGN
awaitable_target.add_done_callback(self.on_awaitable_finished)
return awaitable

def to_context(self, **kwargs: Awaitable | ProcessNode) -> None:
"""Add a dictionary of awaitables to the context.
This is a convenience method that provides syntactic sugar, for a user to add multiple intersteps that will
assign a certain value to the corresponding key in the context of the work graph.
"""
for key, value in kwargs.items():
awaitable = construct_awaitable(value)
awaitable.key = key
self.insert_awaitable(awaitable)
69 changes: 69 additions & 0 deletions aiida_workgraph/engine/context_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from aiida.common.extendeddicts import AttributeDict
from typing import Any
from aiida_workgraph.utils import get_nested_dict
import logging


class ContextManager:
"""Manages the context for the WorkGraphEngine."""

_CONTEXT = "CONTEXT"

def __init__(self, _context, process, logger: logging.Logger):
self.process = process
self._context = _context
self.logger = logger

@property
def ctx(self) -> AttributeDict:
"""Access the context."""
return self._context

def resolve_nested_context(self, key: str) -> tuple[AttributeDict, str]:
"""
Returns a reference to a sub-dictionary of the context and the last key,
after resolving a potentially segmented key where required sub-dictionaries are created as needed.
:param key: A key into the context, where words before a dot are interpreted as a key for a sub-dictionary
"""
ctx = self.ctx
ctx_path = key.split(".")

for index, path in enumerate(ctx_path[:-1]):
try:
ctx = ctx[path]
except KeyError: # see below why this is the only exception we have to catch here
ctx[
path
] = AttributeDict() # create the sub-dict and update the context
ctx = ctx[path]
continue

# Notes:
# * the first ctx (self.ctx) is guaranteed to be an AttributeDict, hence the post-"dereference" checking
# * the values can be many different things: on insertion they are either AtrributeDict, List or Awaitables
# (subclasses of AttributeDict) but after resolution of an Awaitable this will be the value itself
# * assumption: a resolved value is never a plain AttributeDict, on the other hand if a resolved Awaitable
# would be an AttributeDict we can append things to it since the order of tasks is maintained.
if type(ctx) != AttributeDict: # pylint: disable=C0123
raise ValueError(
f"Can not update the context for key `{key}`: "
f' found instance of `{type(ctx)}` at `{".".join(ctx_path[:index + 1])}`, expected AttributeDict'
)

return ctx, ctx_path[-1]

def update_context_variable(self, value: Any) -> Any:
"""Replace placeholders in the value with actual context values."""
if isinstance(value, dict):
return {k: self.update_context_variable(v) for k, v in value.items()}
elif (
isinstance(value, str)
and value.strip().startswith("{{")
and value.strip().endswith("}}")
):
name = value[2:-2].strip()
return get_nested_dict(self.ctx, name)
return value
Loading

0 comments on commit cf2ec0d

Please sign in to comment.