diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index c10c57648900..dbf79ae60287 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -21,7 +21,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 -from flwr.common import RecordSet +from flwr.common import ConfigsRecord, RecordSet from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL from flwr.common.message import Error from flwr.common.serde import ( @@ -232,7 +232,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: """Test tasks are deleted in sqlite state once messages are pulled.""" # Prepare state = LinkStateFactory("").state() - run_id = state.create_run("", "", "", {}) + run_id = state.create_run("", "", "", {}, ConfigsRecord()) self.driver = InMemoryDriver(MagicMock(state=lambda: state)) self.driver.init_run(run_id=run_id) msg_ids, node_id = push_messages(self.driver, self.num_nodes) @@ -259,7 +259,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: # Prepare state_factory = LinkStateFactory(":flwr-in-memory-state:") state = state_factory.state() - run_id = state.create_run("", "", "", {}) + run_id = state.create_run("", "", "", {}, ConfigsRecord()) self.driver = InMemoryDriver(state_factory) self.driver.init_run(run_id=run_id) msg_ids, node_id = push_messages(self.driver, self.num_nodes) diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index 9e4d72adb747..e1820fee0659 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -23,6 +23,7 @@ import grpc +from flwr.common import ConfigsRecord from flwr.common.constant import Status from flwr.common.logger import log from flwr.common.serde import ( @@ -112,6 +113,7 @@ def CreateRun( request.fab_version, fab_hash, user_config_from_proto(request.override_config), + ConfigsRecord(), ) return CreateRunResponse(run_id=run_id) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index d44f4eb7e8f9..fe6f0540e280 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -20,6 +20,7 @@ import grpc +from flwr.common import ConfigsRecord from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, @@ -334,7 +335,7 @@ def test_successful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) ) - run_id = self.state.create_run("", "", "", {}) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -365,7 +366,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) ) - run_id = self.state.create_run("", "", "", {}) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) request = GetRunRequest(run_id=run_id) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index 52194a5a9ac8..085a6a2e29a7 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -30,6 +30,7 @@ RUN_ID_NUM_BYTES, Status, ) +from flwr.common.record import ConfigsRecord from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.linkstate.linkstate import LinkState @@ -69,6 +70,7 @@ def __init__(self) -> None: # Map run_id to RunRecord self.run_ids: dict[int, RunRecord] = {} self.contexts: dict[int, Context] = {} + self.federation_options: dict[int, ConfigsRecord] = {} self.task_ins_store: dict[UUID, TaskIns] = {} self.task_res_store: dict[UUID, TaskRes] = {} @@ -378,12 +380,14 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" return self.public_key_to_node_id.get(node_public_key) + # pylint: disable=too-many-arguments,too-many-positional-arguments def create_run( self, fab_id: Optional[str], fab_version: Optional[str], fab_hash: Optional[str], override_config: UserConfig, + federation_options: ConfigsRecord, ) -> int: """Create a new run for the specified `fab_hash`.""" # Sample a random int64 as run_id @@ -407,6 +411,9 @@ def create_run( pending_at=now().isoformat(), ) self.run_ids[run_id] = run_record + + # Record federation options. Leave empty if not passed + self.federation_options[run_id] = federation_options return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -514,6 +521,14 @@ def get_pending_run_id(self) -> Optional[int]: return pending_run_id + def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]: + """Retrieve the federation options for the specified `run_id`.""" + with self.lock: + if run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") + return None + return self.federation_options[run_id] + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" with self.lock: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index a64fdfa92a06..4144ab89e2be 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -20,6 +20,7 @@ from uuid import UUID from flwr.common import Context +from flwr.common.record import ConfigsRecord from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -152,12 +153,13 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" @abc.abstractmethod - def create_run( + def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments self, fab_id: Optional[str], fab_version: Optional[str], fab_hash: Optional[str], override_config: UserConfig, + federation_options: ConfigsRecord, ) -> int: """Create a new run for the specified `fab_hash`.""" @@ -227,6 +229,21 @@ def get_pending_run_id(self) -> Optional[int]: there is no Run pending. """ + @abc.abstractmethod + def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]: + """Retrieve the federation options for the specified `run_id`. + + Parameters + ---------- + run_id : int + The identifier of the run. + + Returns + ------- + Optional[ConfigsRecord] + The federation options for the run if it exists; None otherwise. + """ + @abc.abstractmethod def store_server_private_public_key( self, private_key: bytes, public_key: bytes diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 2cdea58a7cb7..16e19c9dc11b 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -23,7 +23,7 @@ from unittest.mock import patch from uuid import UUID -from flwr.common import DEFAULT_TTL, Context, RecordSet, now +from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, RecordSet, now from flwr.common.constant import ErrorCode, Status, SubStatus from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( generate_key_pairs, @@ -60,7 +60,9 @@ def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) # Execute run = state.get_run(run_id) @@ -75,8 +77,12 @@ def test_get_pending_run_id(self) -> None: """Test if get_pending_run_id works correctly.""" # Prepare state = self.state_factory() - _ = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) - run_id2 = state.create_run(None, None, "fffffff", {"mock_key": "mock_value"}) + _ = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) + run_id2 = state.create_run( + None, None, "fffffff", {"mock_key": "mock_value"}, ConfigsRecord() + ) state.update_run_status(run_id2, RunStatus(Status.STARTING, "", "")) # Execute @@ -95,8 +101,12 @@ def test_get_and_update_run_status(self) -> None: """Test if get_run_status and update_run_status work correctly.""" # Prepare state = self.state_factory() - run_id1 = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) - run_id2 = state.create_run(None, None, "fffffff", {"mock_key": "mock_value"}) + run_id1 = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) + run_id2 = state.create_run( + None, None, "fffffff", {"mock_key": "mock_value"}, ConfigsRecord() + ) state.update_run_status(run_id2, RunStatus(Status.STARTING, "", "")) state.update_run_status(run_id2, RunStatus(Status.RUNNING, "", "")) @@ -113,7 +123,9 @@ def test_status_transition_valid(self) -> None: """Test valid run status transactions.""" # Prepare state = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) # Execute and assert status1 = state.get_run_status({run_id})[run_id] @@ -135,7 +147,9 @@ def test_status_transition_invalid(self) -> None: """Test invalid run status transitions.""" # Prepare state = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) run_statuses = [ RunStatus(Status.PENDING, "", ""), RunStatus(Status.STARTING, "", ""), @@ -194,7 +208,7 @@ def test_store_task_ins_one(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -230,7 +244,7 @@ def test_store_task_ins_invalid_node_id(self) -> None: state = self.state_factory() node_id = state.create_node(1e3) invalid_node_id = 61016 if node_id != 61016 else 61017 - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=invalid_node_id, anonymous=False, run_id=run_id ) @@ -248,7 +262,7 @@ def test_store_and_delete_tasks(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins_0 = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -322,7 +336,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -337,7 +351,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -352,7 +366,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: # Prepare state: LinkState = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -369,7 +383,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: # Prepare state: LinkState = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -389,7 +403,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: # Prepare state: LinkState = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -434,7 +448,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_ins_id = state.store_task_ins(task_ins) @@ -460,7 +474,7 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute retrieved_node_ids = state.get_nodes(run_id) @@ -472,7 +486,7 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_ids = [] # Execute @@ -489,7 +503,7 @@ def test_create_node_public_key(self) -> None: # Prepare state: LinkState = self.state_factory() public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -505,7 +519,7 @@ def test_create_node_public_key_twice(self) -> None: # Prepare state: LinkState = self.state_factory() public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -527,7 +541,7 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = state.create_node(ping_interval=10) # Execute @@ -542,7 +556,7 @@ def test_delete_node_public_key(self) -> None: # Prepare state: LinkState = self.state_factory() public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -559,7 +573,7 @@ def test_delete_node_public_key_none(self) -> None: # Prepare state: LinkState = self.state_factory() public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = 0 # Execute & Assert @@ -578,7 +592,7 @@ def test_delete_node_wrong_public_key(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute & Assert @@ -597,7 +611,7 @@ def test_get_node_id_wrong_public_key(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute state.create_node(ping_interval=10, public_key=public_key) @@ -612,7 +626,7 @@ def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare state: LinkState = self.state_factory() - state.create_run(None, None, "9f86d08", {}) + state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -626,7 +640,7 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -644,7 +658,7 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_ins_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -757,7 +771,7 @@ def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) @@ -776,7 +790,7 @@ def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) # Create and store TaskIns @@ -818,7 +832,7 @@ def test_store_task_res_task_ins_expired(self) -> None: """Test behavior of store_task_res when the TaskIns it references is expired.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_ins.task.created_at = time.time() - task_ins.task.ttl + 0.5 @@ -872,7 +886,7 @@ def test_store_task_res_limit_ttl(self) -> None: # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=0, anonymous=True, run_id=run_id @@ -904,7 +918,7 @@ def test_get_task_ins_not_return_expired(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -924,7 +938,7 @@ def test_get_task_res_not_return_expired(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -955,7 +969,7 @@ def test_get_task_res_returns_empty_for_missing_taskins(self) -> None: does not exist.""" # Prepare state = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins_id = "5b0a3fc2-edba-4525-a89a-04b83420b7c8" task_res = create_task_res( @@ -978,7 +992,7 @@ def test_get_task_res_return_if_not_expired(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -1010,7 +1024,7 @@ def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -1042,7 +1056,7 @@ def test_get_set_serverapp_context(self) -> None: state=RecordSet(), run_config={"test": "test"}, ) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute init = state.get_serverapp_context(run_id) @@ -1094,7 +1108,7 @@ def test_add_and_get_serverapp_log(self) -> None: """Test adding and retrieving serverapp logs.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) log_entry_1 = "Log entry 1" log_entry_2 = "Log entry 2" timestamp = now().timestamp() @@ -1114,7 +1128,7 @@ def test_get_serverapp_log_after_timestamp(self) -> None: """Test retrieving serverapp logs after a specific timestamp.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) log_entry_1 = "Log entry 1" log_entry_2 = "Log entry 2" state.add_serverapp_log(run_id, log_entry_1) @@ -1136,7 +1150,7 @@ def test_get_serverapp_log_after_timestamp_no_logs(self) -> None: found.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) log_entry = "Log entry" state.add_serverapp_log(run_id, log_entry) timestamp = now().timestamp() @@ -1150,6 +1164,31 @@ def test_get_serverapp_log_after_timestamp_no_logs(self) -> None: assert latest == 0 assert retrieved_logs == "" + def test_create_run_with_and_without_federation_options(self) -> None: + """Test that the recording and fetching of federation options works.""" + # Prepare + state = self.state_factory() + # A run w/ federation options + fed_options = ConfigsRecord({"setting-a": 123, "setting-b": [4, 5, 6]}) + run_id = state.create_run( + None, + None, + "fffffff", + {"mock_key": "mock_value"}, + federation_options=fed_options, + ) + state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + + # Execute + fed_options_fetched = state.get_federation_options(run_id=run_id) + + # Assert + assert fed_options_fetched == fed_options + + # Generate a run_id that doesn't exist. Then check None is returned + unique_int = next(num for num in range(0, 1) if num not in {run_id}) + assert state.get_federation_options(run_id=unique_int) is None + def create_task_ins( consumer_node_id: int, diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index ad73bd4fcce0..77bb3337d7f4 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -33,6 +33,7 @@ RUN_ID_NUM_BYTES, Status, ) +from flwr.common.record import ConfigsRecord from flwr.common.typing import Run, RunStatus, UserConfig # pylint: disable=E0611 @@ -45,6 +46,8 @@ from .linkstate import LinkState from .utils import ( + configsrecord_from_bytes, + configsrecord_to_bytes, context_from_bytes, context_to_bytes, convert_sint64_to_uint64, @@ -95,7 +98,8 @@ running_at TEXT, finished_at TEXT, sub_status TEXT, - details TEXT + details TEXT, + federation_options BLOB ); """ @@ -810,12 +814,14 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]: return uint64_node_id return None + # pylint: disable=too-many-arguments,too-many-positional-arguments def create_run( self, fab_id: Optional[str], fab_version: Optional[str], fab_hash: Optional[str], override_config: UserConfig, + federation_options: ConfigsRecord, ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id @@ -830,15 +836,29 @@ def create_run( if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0: query = ( "INSERT INTO run " - "(run_id, fab_id, fab_version, fab_hash, override_config, pending_at, " - "starting_at, running_at, finished_at, sub_status, details)" - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);" + "(run_id, fab_id, fab_version, fab_hash, override_config, " + "federation_options, pending_at, starting_at, running_at, finished_at, " + "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);" ) if fab_hash: fab_id, fab_version = "", "" override_config_json = json.dumps(override_config) - data = [sint64_run_id, fab_id, fab_version, fab_hash, override_config_json] - data += [now().isoformat(), "", "", "", "", ""] + data = [ + sint64_run_id, + fab_id, + fab_version, + fab_hash, + override_config_json, + configsrecord_to_bytes(federation_options), + ] + data += [ + now().isoformat(), + "", + "", + "", + "", + "", + ] self.query(query, tuple(data)) return uint64_run_id log(ERROR, "Unexpected run creation failure.") @@ -1003,6 +1023,21 @@ def get_pending_run_id(self) -> Optional[int]: return pending_run_id + def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]: + """Retrieve the federation options for the specified `run_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_run_id = convert_uint64_to_sint64(run_id) + query = "SELECT federation_options FROM run WHERE run_id = ?;" + rows = self.query(query, (sint64_run_id,)) + + # Check if the run_id exists + if not rows: + log(ERROR, "`run_id` is invalid") + return None + + row = rows[0] + return configsrecord_from_bytes(row["federation_options"]) + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" sint64_node_id = convert_uint64_to_sint64(node_id) diff --git a/src/py/flwr/server/superlink/linkstate/utils.py b/src/py/flwr/server/superlink/linkstate/utils.py index 4a18e8880c9d..5e8c2be00cba 100644 --- a/src/py/flwr/server/superlink/linkstate/utils.py +++ b/src/py/flwr/server/superlink/linkstate/utils.py @@ -20,12 +20,15 @@ from os import urandom from uuid import uuid4 -from flwr.common import Context, log, serde +from flwr.common import ConfigsRecord, Context, log, serde from flwr.common.constant import ErrorCode, Status, SubStatus from flwr.common.typing import RunStatus from flwr.proto.error_pb2 import Error # pylint: disable=E0611 from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 + +# pylint: disable=E0611 +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 NODE_UNAVAILABLE_ERROR_REASON = ( @@ -146,6 +149,18 @@ def context_from_bytes(context_bytes: bytes) -> Context: return serde.context_from_proto(ProtoContext.FromString(context_bytes)) +def configsrecord_to_bytes(configs_record: ConfigsRecord) -> bytes: + """Serialize a `ConfigsRecord` to bytes.""" + return serde.configs_record_to_proto(configs_record).SerializeToString() + + +def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord: + """Deserialize `ConfigsRecord` from bytes.""" + return serde.configs_record_from_proto( + ProtoConfigsRecord.FromString(configsrecord_bytes) + ) + + def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: """Generate a TaskRes with a node unavailable error from a TaskIns.""" current_time = time.time() diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 5d31bcd5edc4..247c594f9766 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -21,7 +21,7 @@ from typing_extensions import override -from flwr.common import Context, RecordSet +from flwr.common import ConfigsRecord, Context, RecordSet from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS, Status, SubStatus from flwr.common.logger import log from flwr.common.typing import Fab, RunStatus, UserConfig @@ -133,7 +133,9 @@ def _create_run( f"FAB ({fab.hash_str}) hash from request doesn't match contents" ) - run_id = self.linkstate.create_run(None, None, fab_hash, override_config) + run_id = self.linkstate.create_run( + None, None, fab_hash, override_config, ConfigsRecord() + ) return run_id def _create_context(self, run_id: int) -> None: