Skip to content

Commit

Permalink
feat(framework) Verify message TTL when storing TaskIns and TaskRes (#…
Browse files Browse the repository at this point in the history
…3596)

Co-authored-by: Heng Pan <pan@flower.ai>
  • Loading branch information
mohammadnaseri and panh99 authored Sep 26, 2024
1 parent 6d37e25 commit 8fa9c56
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/py/flwr/server/utils/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Validators."""


import time
from typing import Union

from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
Expand Down Expand Up @@ -47,6 +48,11 @@ def validate_task_ins_or_res(tasks_ins_res: Union[TaskIns, TaskRes]) -> list[str
# unix timestamp of 27 March 2024 00h:00m:00s UTC
validation_errors.append("`pushed_at` is not a recent timestamp")

# Verify TTL and created_at time
current_time = time.time()
if tasks_ins_res.task.created_at + tasks_ins_res.task.ttl <= current_time:
validation_errors.append("Task TTL has expired")

# TaskIns specific
if isinstance(tasks_ins_res, TaskIns):
# Task producer
Expand Down
18 changes: 18 additions & 0 deletions src/py/flwr/server/utils/validator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ def test_is_valid_task_res(self) -> None:
val_errors = validate_task_ins_or_res(msg)
self.assertTrue(val_errors, (producer_node_id, anonymous, ancestry))

def test_task_ttl_expired(self) -> None:
"""Test validation for expired Task TTL."""
# Prepare an expired TaskIns
expired_task_ins = create_task_ins(0, True)
expired_task_ins.task.created_at = time.time() - 10 # 10 seconds ago
expired_task_ins.task.ttl = 6 # 6 seconds TTL

expired_task_res = create_task_res(0, True, ["1"])
expired_task_res.task.created_at = time.time() - 10 # 10 seconds ago
expired_task_res.task.ttl = 6 # 6 seconds TTL

# Execute & Assert
val_errors_ins = validate_task_ins_or_res(expired_task_ins)
self.assertIn("Task TTL has expired", val_errors_ins)

val_errors_res = validate_task_ins_or_res(expired_task_res)
self.assertIn("Task TTL has expired", val_errors_res)


def create_task_ins(
consumer_node_id: int,
Expand Down

0 comments on commit 8fa9c56

Please sign in to comment.