Skip to content

Commit

Permalink
feat(framework) Verify the TaskIns TTL when saving TaskRes (#3609)
Browse files Browse the repository at this point in the history
Co-authored-by: Heng Pan <pan@flower.ai>
  • Loading branch information
mohammadnaseri and panh99 authored Sep 26, 2024
1 parent 8fa9c56 commit 83cd4ba
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 6 deletions.
17 changes: 17 additions & 0 deletions src/py/flwr/server/superlink/state/in_memory_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,23 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
log(ERROR, errors)
return None

with self.lock:
# Check if the TaskIns it is replying to exists and is valid
task_ins_id = task_res.task.ancestry[0]
task_ins = self.task_ins_store.get(UUID(task_ins_id))

if task_ins is None:
log(ERROR, "TaskIns with task_id %s does not exist.", task_ins_id)
return None

if task_ins.task.created_at + task_ins.task.ttl <= time.time():
log(
ERROR,
"Failed to store TaskRes: TaskIns with task_id %s has expired.",
task_ins_id,
)
return None

# Validate run_id
if task_res.run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
Expand Down
40 changes: 39 additions & 1 deletion src/py/flwr/server/superlink/state/sqlite_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,18 @@ def store_task_res(self, task_res: TaskRes) -> Optional[UUID]:
# Create task_id
task_id = uuid4()

# Store TaskIns
task_ins_id = task_res.task.ancestry[0]
task_ins = self.get_valid_task_ins(task_ins_id)
if task_ins is None:
log(
ERROR,
"Failed to store TaskRes: "
"TaskIns with task_id %s does not exist or has expired.",
task_ins_id,
)
return None

# Store TaskRes
task_res.task_id = str(task_id)
data = (task_res_to_dict(task_res),)

Expand Down Expand Up @@ -810,6 +821,33 @@ def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
log(ERROR, "`node_id` does not exist.")
return False

def get_valid_task_ins(self, task_id: str) -> Optional[dict[str, Any]]:
"""Check if the TaskIns exists and is valid (not expired).
Return TaskIns if valid.
"""
query = """
SELECT *
FROM task_ins
WHERE task_id = :task_id
"""
data = {"task_id": task_id}
rows = self.query(query, data)
if not rows:
# TaskIns does not exist
return None

task_ins = rows[0]
created_at = task_ins["created_at"]
ttl = task_ins["ttl"]
current_time = time.time()

# Check if TaskIns is expired
if ttl is not None and created_at + ttl <= current_time:
return None

return task_ins


def dict_factory(
cursor: sqlite3.Cursor,
Expand Down
53 changes: 48 additions & 5 deletions src/py/flwr/server/superlink/state/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from abc import abstractmethod
from datetime import datetime, timezone
from unittest.mock import patch
from uuid import uuid4

from flwr.common import DEFAULT_TTL
from flwr.common.constant import ErrorCode
Expand Down Expand Up @@ -302,7 +301,10 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None:
# Prepare
state: State = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})
task_ins_id = uuid4()

task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id)
task_ins_id = state.store_task_ins(task_ins)

task_res = create_task_res(
producer_node_id=0,
anonymous=True,
Expand All @@ -312,7 +314,9 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None:

# Execute
task_res_uuid = state.store_task_res(task_res)
task_res_list = state.get_task_res(task_ids={task_ins_id}, limit=None)

if task_ins_id is not None:
task_res_list = state.get_task_res(task_ids={task_ins_id}, limit=None)

# Assert
retrieved_task_res = task_res_list[0]
Expand Down Expand Up @@ -507,11 +511,23 @@ def test_num_task_res(self) -> None:
# Prepare
state: State = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})

task_ins_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id)
task_ins_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id)
task_ins_id_0 = state.store_task_ins(task_ins_0)
task_ins_id_1 = state.store_task_ins(task_ins_1)

task_0 = create_task_res(
producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id
producer_node_id=0,
anonymous=True,
ancestry=[str(task_ins_id_0)],
run_id=run_id,
)
task_1 = create_task_res(
producer_node_id=0, anonymous=True, ancestry=["1"], run_id=run_id
producer_node_id=0,
anonymous=True,
ancestry=[str(task_ins_id_1)],
run_id=run_id,
)

# Store two tasks
Expand Down Expand Up @@ -664,6 +680,33 @@ def test_node_unavailable_error(self) -> None:
assert err_taskres.task.HasField("error")
assert err_taskres.task.error.code == ErrorCode.NODE_UNAVAILABLE

def test_store_task_res_task_ins_expired(self) -> None:
"""Test behavior of store_task_res when the TaskIns it references is expired."""
# Prepare
state: State = self.state_factory()
run_id = state.create_run(None, None, "9f86d08", {})

task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id)
task_ins.task.created_at = time.time() - task_ins.task.ttl + 0.5
task_ins_id = state.store_task_ins(task_ins)

with patch(
"time.time",
side_effect=lambda: task_ins.task.created_at + task_ins.task.ttl + 0.1,
): # Expired by 0.1 seconds
task = create_task_res(
producer_node_id=0,
anonymous=True,
ancestry=[str(task_ins_id)],
run_id=run_id,
)

# Execute
result = state.store_task_res(task)

# Assert
assert result is None


def create_task_ins(
consumer_node_id: int,
Expand Down

0 comments on commit 83cd4ba

Please sign in to comment.