Skip to content

Commit

Permalink
feat(framework) Introduce federation_options mapping in LinkState (
Browse files Browse the repository at this point in the history
…#4438)

Co-authored-by: Heng Pan <pan@flower.ai>
  • Loading branch information
jafermarq and panh99 authored Nov 6, 2024
1 parent 4473cb8 commit 11665ee
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 57 deletions.
6 changes: 3 additions & 3 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from unittest.mock import MagicMock, patch
from uuid import uuid4

from flwr.common import RecordSet
from flwr.common import ConfigsRecord, RecordSet
from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL
from flwr.common.message import Error
from flwr.common.serde import (
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
"""Test tasks are deleted in sqlite state once messages are pulled."""
# Prepare
state = LinkStateFactory("").state()
run_id = state.create_run("", "", "", {})
run_id = state.create_run("", "", "", {}, ConfigsRecord())
self.driver = InMemoryDriver(MagicMock(state=lambda: state))
self.driver.init_run(run_id=run_id)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
Expand All @@ -259,7 +259,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
# Prepare
state_factory = LinkStateFactory(":flwr-in-memory-state:")
state = state_factory.state()
run_id = state.create_run("", "", "", {})
run_id = state.create_run("", "", "", {}, ConfigsRecord())
self.driver = InMemoryDriver(state_factory)
self.driver.init_run(run_id=run_id)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/superlink/driver/serverappio_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import grpc

from flwr.common import ConfigsRecord
from flwr.common.constant import Status
from flwr.common.logger import log
from flwr.common.serde import (
Expand Down Expand Up @@ -112,6 +113,7 @@ def CreateRun(
request.fab_version,
fab_hash,
user_config_from_proto(request.override_config),
ConfigsRecord(),
)
return CreateRunResponse(run_id=run_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import grpc

from flwr.common import ConfigsRecord
from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
compute_hmac,
Expand Down Expand Up @@ -334,7 +335,7 @@ def test_successful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
run_id = self.state.create_run("", "", "", {})
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
request = GetRunRequest(run_id=run_id)
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -365,7 +366,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
run_id = self.state.create_run("", "", "", {})
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
request = GetRunRequest(run_id=run_id)
node_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(node_private_key, self._server_public_key)
Expand Down
15 changes: 15 additions & 0 deletions src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RUN_ID_NUM_BYTES,
Status,
)
from flwr.common.record import ConfigsRecord
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611
from flwr.server.superlink.linkstate.linkstate import LinkState
Expand Down Expand Up @@ -69,6 +70,7 @@ def __init__(self) -> None:
# Map run_id to RunRecord
self.run_ids: dict[int, RunRecord] = {}
self.contexts: dict[int, Context] = {}
self.federation_options: dict[int, ConfigsRecord] = {}
self.task_ins_store: dict[UUID, TaskIns] = {}
self.task_res_store: dict[UUID, TaskRes] = {}

Expand Down Expand Up @@ -378,12 +380,14 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `node_public_keys`."""
return self.public_key_to_node_id.get(node_public_key)

# pylint: disable=too-many-arguments,too-many-positional-arguments
def create_run(
self,
fab_id: Optional[str],
fab_version: Optional[str],
fab_hash: Optional[str],
override_config: UserConfig,
federation_options: ConfigsRecord,
) -> int:
"""Create a new run for the specified `fab_hash`."""
# Sample a random int64 as run_id
Expand All @@ -407,6 +411,9 @@ def create_run(
pending_at=now().isoformat(),
)
self.run_ids[run_id] = run_record

# Record federation options. Leave empty if not passed
self.federation_options[run_id] = federation_options
return run_id
log(ERROR, "Unexpected run creation failure.")
return 0
Expand Down Expand Up @@ -514,6 +521,14 @@ def get_pending_run_id(self) -> Optional[int]:

return pending_run_id

def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
"""Retrieve the federation options for the specified `run_id`."""
with self.lock:
if run_id not in self.run_ids:
log(ERROR, "`run_id` is invalid")
return None
return self.federation_options[run_id]

def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool:
"""Acknowledge a ping received from a node, serving as a heartbeat."""
with self.lock:
Expand Down
19 changes: 18 additions & 1 deletion src/py/flwr/server/superlink/linkstate/linkstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from uuid import UUID

from flwr.common import Context
from flwr.common.record import ConfigsRecord
from flwr.common.typing import Run, RunStatus, UserConfig
from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611

Expand Down Expand Up @@ -152,12 +153,13 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]:
"""Retrieve stored `node_id` filtered by `node_public_keys`."""

@abc.abstractmethod
def create_run(
def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments
self,
fab_id: Optional[str],
fab_version: Optional[str],
fab_hash: Optional[str],
override_config: UserConfig,
federation_options: ConfigsRecord,
) -> int:
"""Create a new run for the specified `fab_hash`."""

Expand Down Expand Up @@ -227,6 +229,21 @@ def get_pending_run_id(self) -> Optional[int]:
there is no Run pending.
"""

@abc.abstractmethod
def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]:
"""Retrieve the federation options for the specified `run_id`.
Parameters
----------
run_id : int
The identifier of the run.
Returns
-------
Optional[ConfigsRecord]
The federation options for the run if it exists; None otherwise.
"""

@abc.abstractmethod
def store_server_private_public_key(
self, private_key: bytes, public_key: bytes
Expand Down
Loading

0 comments on commit 11665ee

Please sign in to comment.