diff --git a/tasktiger/executor.py b/tasktiger/executor.py index fc5f08e..293f4ea 100644 --- a/tasktiger/executor.py +++ b/tasktiger/executor.py @@ -63,6 +63,25 @@ def __init__(self, worker: "Worker"): self.connection = worker.connection self.config = worker.config + def heartbeat( + self, + queue: str, + task_ids: Collection[str], + log: BoundLogger, + locks: Collection[Lock], + queue_lock: Optional[Semaphore], + ) -> None: + self.worker.heartbeat(queue, task_ids) + for lock in locks: + try: + lock.reacquire() + except LockError: + log.warning("could not reacquire lock", lock=lock.name) + if queue_lock: + acquired, current_locks = queue_lock.renew() + if not acquired: + log.debug("queue lock renew failure") + def execute( self, queue: str, @@ -351,18 +370,7 @@ def check_child_exit() -> Optional[int]: break try: - self.worker.heartbeat(queue, all_task_ids) - for lock in locks: - try: - lock.reacquire() - except LockError: - log.warning( - "could not reacquire lock", lock=lock.name - ) - if queue_lock: - acquired, current_locks = queue_lock.renew() - if not acquired: - log.debug("queue lock renew failure") + self.heartbeat(queue, all_task_ids, log, locks, queue_lock) except OSError as e: # EINTR happens if the task completed. Since we're just # renewing locks/heartbeat it's okay if we get interrupted. @@ -386,6 +394,19 @@ class SyncExecutor(Executor): exit_worker_on_job_timeout = True + def _periodic_heartbeat( + self, + queue: str, + task_ids: Collection[str], + log: BoundLogger, + locks: Collection[Lock], + queue_lock: Optional[Semaphore], + stop_event: threading.Event, + ) -> None: + while not stop_event.is_set(): + stop_event.wait(self.config["ACTIVE_TASK_UPDATE_TIMER"]) + self.heartbeat(queue, task_ids, log, locks, queue_lock) + def execute( self, queue: str, @@ -394,5 +415,27 @@ def execute( locks: Collection[Lock], queue_lock: Optional[Semaphore], ) -> bool: + # Run heartbeat thread. + all_task_ids = {task.id for task in tasks} + stop_event = threading.Event() + heartbeat_thread = threading.Thread( + target=self._periodic_heartbeat, + kwargs={ + "queue": queue, + "task_ids": all_task_ids, + "log": log, + "locks": locks, + "queue_lock": queue_lock, + "stop_event": stop_event, + }, + ) + heartbeat_thread.start() + # Run the tasks. - return self.execute_tasks(tasks, log) + result = self.execute_tasks(tasks, log) + + # Stop the heartbeat thread. + stop_event.set() + heartbeat_thread.join() + + return result diff --git a/tests/test_workers.py b/tests/test_workers.py index bd21e8a..4288812 100644 --- a/tests/test_workers.py +++ b/tests/test_workers.py @@ -8,13 +8,16 @@ from freezefrog import FreezeTime from tasktiger import Task, Worker +from tasktiger._internal import ACTIVE from tasktiger.executor import SyncExecutor +from .config import DELAY from .tasks import ( exception_task, long_task_killed, long_task_ok, simple_task, + sleep_task, wait_for_long_task, ) from .test_base import BaseTestCase @@ -35,10 +38,12 @@ def test_max_workers(self): # Start two workers and wait until they start processing. worker1 = Process( - target=external_worker, kwargs={"max_workers_per_queue": 2} + target=external_worker, + kwargs={"worker_kwargs": {"max_workers_per_queue": 2}}, ) worker2 = Process( - target=external_worker, kwargs={"max_workers_per_queue": 2} + target=external_worker, + kwargs={"worker_kwargs": {"max_workers_per_queue": 2}}, ) worker1.start() worker2.start() @@ -181,3 +186,28 @@ def test_handles_timeout(self, tiger, ensure_queues): with pytest.raises(SystemExit): worker.run(once=True, force_once=True) ensure_queues(error={"default": 1}) + + def test_heartbeat(self, tiger): + task = Task(tiger, sleep_task) + task.delay() + + # Start a worker and wait until it starts processing. + worker = Process( + target=external_worker, + kwargs={ + "patch_config": {"ACTIVE_TASK_UPDATE_TIMER": DELAY / 2}, + "worker_kwargs": {"executor_class": SyncExecutor}, + }, + ) + worker.start() + + time.sleep(DELAY / 2) + + key = tiger._key(ACTIVE, "default") + conn = tiger.connection + heartbeat_1 = conn.zscore(key, task.id) + time.sleep(DELAY / 2) + heartbeat_2 = conn.zscore(key, task.id) + assert heartbeat_2 > heartbeat_1 > 0 + + worker.kill() diff --git a/tests/utils.py b/tests/utils.py index 582250c..50f8b17 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -70,7 +70,7 @@ def get_tiger(): return tiger -def external_worker(n=None, patch_config=None, max_workers_per_queue=None): +def external_worker(n=None, patch_config=None, worker_kwargs=None): """ Runs a worker. To be used with multiprocessing.Pool.map. """ @@ -79,12 +79,11 @@ def external_worker(n=None, patch_config=None, max_workers_per_queue=None): if patch_config: tiger.config.update(patch_config) - worker = Worker(tiger) + if worker_kwargs is None: + worker_kwargs = {} - if max_workers_per_queue is not None: - worker.max_workers_per_queue = max_workers_per_queue + Worker(tiger, **worker_kwargs).run(once=True, force_once=True) - worker.run(once=True, force_once=True) tiger.connection.close()