Skip to content

Commit

Permalink
Add transaction mode to worker class
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed May 23, 2023
1 parent 03c4d1c commit 185ad3b
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 9 deletions.
57 changes: 52 additions & 5 deletions src/blueapi/worker/reworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class RunEngineWorker(Worker[Task]):
_ctx: BlueskyContext
_stop_timeout: float

_transaction_lock: RLock
_pending_transaction: Optional[ActiveTask]

_state: WorkerState
_errors: List[str]
_warnings: List[str]
Expand All @@ -70,6 +73,9 @@ def __init__(
self._ctx = ctx
self._stop_timeout = stop_timeout

self._transaction_lock = RLock()
self._pending_transaction = None

self._state = WorkerState.from_bluesky_state(ctx.run_engine.state)
self._errors = []
self._warnings = []
Expand All @@ -85,8 +91,49 @@ def __init__(
self._stopping = Event()
self._stopped = Event()

def submit_task(self, name: str, task: Task) -> None:
active_task = ActiveTask(name, task)
def begin_transaction(self, task: Task) -> str:
task_id: str = str(uuid.uuid4())
with self._transaction_lock:
if self._pending_transaction is not None:
raise WorkerBusyError("There is already a transaction in progress")
self._pending_transaction = ActiveTask(task_id, task)
return task_id

def clear_transaction(self) -> str:
with self._transaction_lock:
if self._pending_transaction is None:
raise Exception("No transaction to clear")

task_id = self._pending_transaction.task_id
self._pending_transaction = None
return task_id

def commit_transaction(self, task_id: str) -> None:
with self._transaction_lock:
if self._pending_transaction is None:
raise Exception("No transaction to commit")

pending_id = self._pending_transaction.task_id
if pending_id == task_id:
self._submit_active_task(self._pending_transaction)
else:
raise KeyError(
"Not committing the transaction requested, asked to commit"
f"{task_id} when {pending_id} is in progress"
)

def get_pending(self) -> Optional[Task]:
with self._transaction_lock:
if self._pending_transaction is None:
return None
else:
return self._pending_transaction.task

def submit_task(self, task_id: str, task: Task) -> None:
active_task = ActiveTask(task_id, task)
self._submit_active_task(active_task)

def _submit_active_task(self, active_task: ActiveTask) -> None:
LOGGER.info(f"Submitting: {active_task}")
try:
self._task_queue.put_nowait(active_task)
Expand Down Expand Up @@ -196,11 +243,11 @@ def _report_status(
warnings = self._warnings
if self._current is not None:
task_status = TaskStatus(
task_name=self._current.name,
task_name=self._current.task_id,
task_complete=self._current.is_complete,
task_failed=self._current.is_error or bool(errors),
)
correlation_id = self._current.name
correlation_id = self._current.task_id
else:
task_status = None
correlation_id = None
Expand All @@ -215,7 +262,7 @@ def _report_status(

def _on_document(self, name: str, document: Mapping[str, Any]) -> None:
if self._current is not None:
correlation_id = self._current.name
correlation_id = self._current.task_id
self._data_events.publish(
DataEvent(name=name, doc=document), correlation_id
)
Expand Down
2 changes: 1 addition & 1 deletion src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _lookup_params(

@dataclass
class ActiveTask:
name: str
task_id: str
task: Task
is_complete: bool = False
is_error: bool = False
49 changes: 46 additions & 3 deletions src/blueapi/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from typing import Generic, Optional, TypeVar

from blueapi.core import DataEvent, EventStream

Expand All @@ -15,13 +15,56 @@ class Worker(ABC, Generic[T]):
"""

@abstractmethod
def submit_task(self, __name: str, __task: T) -> None:
def begin_transaction(self, __task: T) -> str:
"""
Begin a new transaction, lock the worker with a pending task,
do not allow new transactions until this one is run or cleared.
Args:
__task: The task to run if this transaction is committed
Returns:
str: An ID for the task
"""

@abstractmethod
def clear_transaction(self) -> str:
"""
Clear any existing transaction. Raise an error if
unable.
Returns:
str: The ID of the task cleared
"""

@abstractmethod
def commit_transaction(self, __task_id: str) -> None:
"""
Commit the pending transaction and run the
embedded task
Args:
__task_id: The ID of the task to run, must match
the pending transaction
"""

@abstractmethod
def get_pending(self) -> Optional[T]:
"""_summary_
Returns:
Optional[Task]: _description_
"""

@abstractmethod
def submit_task(self, __task_id: str, __task: T) -> None:
"""
Submit a task to be run
Args:
__name (str): A unique name to identify this task
__name (str): name of the plan to be run
__task (T): The task to run
__correlation_id (str): unique identifier of the task
"""

@abstractmethod
Expand Down

0 comments on commit 185ad3b

Please sign in to comment.