Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync TaskTiger worker heartbeat #331

Merged
merged 1 commit into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 56 additions & 13 deletions tasktiger/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,25 @@ def __init__(self, worker: "Worker"):
self.connection = worker.connection
self.config = worker.config

def heartbeat(
self,
queue: str,
task_ids: Collection[str],
log: BoundLogger,
locks: Collection[Lock],
queue_lock: Optional[Semaphore],
) -> None:
self.worker.heartbeat(queue, task_ids)
for lock in locks:
try:
lock.reacquire()
except LockError:
log.warning("could not reacquire lock", lock=lock.name)
if queue_lock:
acquired, current_locks = queue_lock.renew()
if not acquired:
log.debug("queue lock renew failure")

def execute(
self,
queue: str,
Expand Down Expand Up @@ -351,18 +370,7 @@ def check_child_exit() -> Optional[int]:
break

try:
self.worker.heartbeat(queue, all_task_ids)
for lock in locks:
try:
lock.reacquire()
except LockError:
log.warning(
"could not reacquire lock", lock=lock.name
)
if queue_lock:
acquired, current_locks = queue_lock.renew()
if not acquired:
log.debug("queue lock renew failure")
self.heartbeat(queue, all_task_ids, log, locks, queue_lock)
except OSError as e:
# EINTR happens if the task completed. Since we're just
# renewing locks/heartbeat it's okay if we get interrupted.
Expand All @@ -386,6 +394,19 @@ class SyncExecutor(Executor):

exit_worker_on_job_timeout = True

def _periodic_heartbeat(
self,
queue: str,
task_ids: Collection[str],
log: BoundLogger,
locks: Collection[Lock],
queue_lock: Optional[Semaphore],
stop_event: threading.Event,
) -> None:
while not stop_event.is_set():
stop_event.wait(self.config["ACTIVE_TASK_UPDATE_TIMER"])
self.heartbeat(queue, task_ids, log, locks, queue_lock)

def execute(
self,
queue: str,
Expand All @@ -394,5 +415,27 @@ def execute(
locks: Collection[Lock],
queue_lock: Optional[Semaphore],
) -> bool:
# Run heartbeat thread.
all_task_ids = {task.id for task in tasks}
stop_event = threading.Event()
heartbeat_thread = threading.Thread(
target=self._periodic_heartbeat,
kwargs={
"queue": queue,
"task_ids": all_task_ids,
"log": log,
"locks": locks,
"queue_lock": queue_lock,
"stop_event": stop_event,
},
)
heartbeat_thread.start()

# Run the tasks.
return self.execute_tasks(tasks, log)
result = self.execute_tasks(tasks, log)

# Stop the heartbeat thread.
stop_event.set()
heartbeat_thread.join()

return result
34 changes: 32 additions & 2 deletions tests/test_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
from freezefrog import FreezeTime

from tasktiger import Task, Worker
from tasktiger._internal import ACTIVE
from tasktiger.executor import SyncExecutor

from .config import DELAY
from .tasks import (
exception_task,
long_task_killed,
long_task_ok,
simple_task,
sleep_task,
wait_for_long_task,
)
from .test_base import BaseTestCase
Expand All @@ -35,10 +38,12 @@ def test_max_workers(self):

# Start two workers and wait until they start processing.
worker1 = Process(
target=external_worker, kwargs={"max_workers_per_queue": 2}
target=external_worker,
kwargs={"worker_kwargs": {"max_workers_per_queue": 2}},
)
worker2 = Process(
target=external_worker, kwargs={"max_workers_per_queue": 2}
target=external_worker,
kwargs={"worker_kwargs": {"max_workers_per_queue": 2}},
)
worker1.start()
worker2.start()
Expand Down Expand Up @@ -181,3 +186,28 @@ def test_handles_timeout(self, tiger, ensure_queues):
with pytest.raises(SystemExit):
worker.run(once=True, force_once=True)
ensure_queues(error={"default": 1})

def test_heartbeat(self, tiger):
task = Task(tiger, sleep_task)
task.delay()

# Start a worker and wait until it starts processing.
worker = Process(
target=external_worker,
kwargs={
"patch_config": {"ACTIVE_TASK_UPDATE_TIMER": DELAY / 2},
"worker_kwargs": {"executor_class": SyncExecutor},
},
)
worker.start()

time.sleep(DELAY / 2)

key = tiger._key(ACTIVE, "default")
conn = tiger.connection
heartbeat_1 = conn.zscore(key, task.id)
time.sleep(DELAY / 2)
heartbeat_2 = conn.zscore(key, task.id)
assert heartbeat_2 > heartbeat_1 > 0

worker.kill()
9 changes: 4 additions & 5 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def get_tiger():
return tiger


def external_worker(n=None, patch_config=None, max_workers_per_queue=None):
def external_worker(n=None, patch_config=None, worker_kwargs=None):
"""
Runs a worker. To be used with multiprocessing.Pool.map.
"""
Expand All @@ -79,12 +79,11 @@ def external_worker(n=None, patch_config=None, max_workers_per_queue=None):
if patch_config:
tiger.config.update(patch_config)

worker = Worker(tiger)
if worker_kwargs is None:
worker_kwargs = {}

if max_workers_per_queue is not None:
worker.max_workers_per_queue = max_workers_per_queue
Worker(tiger, **worker_kwargs).run(once=True, force_once=True)

worker.run(once=True, force_once=True)
tiger.connection.close()


Expand Down
Loading