diff --git a/src/py/flwr/common/constant.py b/src/py/flwr/common/constant.py index 7fddc4a0e110..8aafb68ea17d 100644 --- a/src/py/flwr/common/constant.py +++ b/src/py/flwr/common/constant.py @@ -48,6 +48,7 @@ ) FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9095" EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093" +SIMULATIONIO_API_DEFAULT_ADDRESS = "0.0.0.0:9096" # Constants for ping PING_DEFAULT_INTERVAL = 30 diff --git a/src/py/flwr/server/app.py b/src/py/flwr/server/app.py index cfada7fca933..e931cf550014 100644 --- a/src/py/flwr/server/app.py +++ b/src/py/flwr/server/app.py @@ -47,6 +47,7 @@ ISOLATION_MODE_SUBPROCESS, MISSING_EXTRA_REST, SERVERAPPIO_API_DEFAULT_ADDRESS, + SIMULATIONIO_API_DEFAULT_ADDRESS, TRANSPORT_TYPE_GRPC_ADAPTER, TRANSPORT_TYPE_GRPC_RERE, TRANSPORT_TYPE_REST, @@ -63,6 +64,7 @@ from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server from flwr.superexec.app import load_executor from flwr.superexec.exec_grpc import run_exec_api_grpc +from flwr.superexec.simulation import SimulationEngine from .client_manager import ClientManager from .history import History @@ -79,6 +81,7 @@ from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor from .superlink.linkstate import LinkStateFactory +from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc DATABASE = ":flwr-in-memory-state:" BASE_DIR = get_flwr_dir() / "superlink" / "ffs" @@ -215,6 +218,7 @@ def run_superlink() -> None: # Parse IP addresses serverappio_address, _, _ = _format_address(args.serverappio_api_address) exec_address, _, _ = _format_address(args.exec_api_address) + simulationio_address, _, _ = _format_address(args.simulationio_api_address) # Obtain certificates certificates = _try_obtain_certificates(args) @@ -225,128 +229,148 @@ def run_superlink() -> None: # Initialize FfsFactory ffs_factory = FfsFactory(args.storage_dir) - # Start ServerAppIo API - serverappio_server: grpc.Server = run_serverappio_api_grpc( - address=serverappio_address, + # Start Exec API + executor = load_executor(args) + exec_server: grpc.Server = run_exec_api_grpc( + address=exec_address, state_factory=state_factory, ffs_factory=ffs_factory, + executor=executor, certificates=certificates, + config=parse_config_args( + [args.executor_config] if args.executor_config else args.executor_config + ), ) - grpc_servers = [serverappio_server] + grpc_servers = [exec_server] - # Start Fleet API - bckg_threads = [] - if not args.fleet_api_address: - if args.fleet_api_type in [ - TRANSPORT_TYPE_GRPC_RERE, - TRANSPORT_TYPE_GRPC_ADAPTER, - ]: - args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS - elif args.fleet_api_type == TRANSPORT_TYPE_REST: - args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS - - fleet_address, host, port = _format_address(args.fleet_api_address) - - num_workers = args.fleet_api_num_workers - if num_workers != 1: - log( - WARN, - "The Fleet API currently supports only 1 worker. " - "You have specified %d workers. " - "Support for multiple workers will be added in future releases. " - "Proceeding with a single worker.", - args.fleet_api_num_workers, - ) - num_workers = 1 + # Determine Exec plugin + # If simulation is used, don't start ServerAppIo and Fleet APIs + sim_exec = isinstance(executor, SimulationEngine) - if args.fleet_api_type == TRANSPORT_TYPE_REST: - if ( - importlib.util.find_spec("requests") - and importlib.util.find_spec("starlette") - and importlib.util.find_spec("uvicorn") - ) is None: - sys.exit(MISSING_EXTRA_REST) - - _, ssl_certfile, ssl_keyfile = ( - certificates if certificates is not None else (None, None, None) - ) - - fleet_thread = threading.Thread( - target=_run_fleet_api_rest, - args=( - host, - port, - ssl_keyfile, - ssl_certfile, - state_factory, - ffs_factory, - num_workers, - ), - ) - fleet_thread.start() - bckg_threads.append(fleet_thread) - elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: - maybe_keys = _try_setup_node_authentication(args, certificates) - interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None - if maybe_keys is not None: - ( - node_public_keys, - server_private_key, - server_public_key, - ) = maybe_keys - state = state_factory.state() - state.store_node_public_keys(node_public_keys) - state.store_server_private_public_key( - private_key_to_bytes(server_private_key), - public_key_to_bytes(server_public_key), - ) - log( - INFO, - "Node authentication enabled with %d known public keys", - len(node_public_keys), - ) - interceptors = [AuthenticateServerInterceptor(state)] + bckg_threads = [] - fleet_server = _run_fleet_api_grpc_rere( - address=fleet_address, + if sim_exec: + simulationio_server: grpc.Server = run_simulationio_api_grpc( + address=simulationio_address, state_factory=state_factory, ffs_factory=ffs_factory, certificates=certificates, - interceptors=interceptors, ) - grpc_servers.append(fleet_server) - elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER: - fleet_server = _run_fleet_api_grpc_adapter( - address=fleet_address, + grpc_servers.append(simulationio_server) + + else: + # Start ServerAppIo API + serverappio_server: grpc.Server = run_serverappio_api_grpc( + address=serverappio_address, state_factory=state_factory, ffs_factory=ffs_factory, certificates=certificates, ) - grpc_servers.append(fleet_server) - else: - raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") - - # Start Exec API - exec_server: grpc.Server = run_exec_api_grpc( - address=exec_address, - state_factory=state_factory, - ffs_factory=ffs_factory, - executor=load_executor(args), - certificates=certificates, - config=parse_config_args( - [args.executor_config] if args.executor_config else args.executor_config - ), - ) - grpc_servers.append(exec_server) + grpc_servers.append(serverappio_server) + + # Start Fleet API + if not args.fleet_api_address: + if args.fleet_api_type in [ + TRANSPORT_TYPE_GRPC_RERE, + TRANSPORT_TYPE_GRPC_ADAPTER, + ]: + args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS + elif args.fleet_api_type == TRANSPORT_TYPE_REST: + args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS + + fleet_address, host, port = _format_address(args.fleet_api_address) + + num_workers = args.fleet_api_num_workers + if num_workers != 1: + log( + WARN, + "The Fleet API currently supports only 1 worker. " + "You have specified %d workers. " + "Support for multiple workers will be added in future releases. " + "Proceeding with a single worker.", + args.fleet_api_num_workers, + ) + num_workers = 1 + + if args.fleet_api_type == TRANSPORT_TYPE_REST: + if ( + importlib.util.find_spec("requests") + and importlib.util.find_spec("starlette") + and importlib.util.find_spec("uvicorn") + ) is None: + sys.exit(MISSING_EXTRA_REST) + + _, ssl_certfile, ssl_keyfile = ( + certificates if certificates is not None else (None, None, None) + ) - if args.isolation == ISOLATION_MODE_SUBPROCESS: - # Scheduler thread - scheduler_th = threading.Thread( - target=_flwr_serverapp_scheduler, - args=(state_factory, args.serverappio_api_address, args.ssl_ca_certfile), - ) - scheduler_th.start() - bckg_threads.append(scheduler_th) + fleet_thread = threading.Thread( + target=_run_fleet_api_rest, + args=( + host, + port, + ssl_keyfile, + ssl_certfile, + state_factory, + ffs_factory, + num_workers, + ), + ) + fleet_thread.start() + bckg_threads.append(fleet_thread) + elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE: + maybe_keys = _try_setup_node_authentication(args, certificates) + interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None + if maybe_keys is not None: + ( + node_public_keys, + server_private_key, + server_public_key, + ) = maybe_keys + state = state_factory.state() + state.store_node_public_keys(node_public_keys) + state.store_server_private_public_key( + private_key_to_bytes(server_private_key), + public_key_to_bytes(server_public_key), + ) + log( + INFO, + "Node authentication enabled with %d known public keys", + len(node_public_keys), + ) + interceptors = [AuthenticateServerInterceptor(state)] + + fleet_server = _run_fleet_api_grpc_rere( + address=fleet_address, + state_factory=state_factory, + ffs_factory=ffs_factory, + certificates=certificates, + interceptors=interceptors, + ) + grpc_servers.append(fleet_server) + elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER: + fleet_server = _run_fleet_api_grpc_adapter( + address=fleet_address, + state_factory=state_factory, + ffs_factory=ffs_factory, + certificates=certificates, + ) + grpc_servers.append(fleet_server) + else: + raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}") + + if args.isolation == ISOLATION_MODE_SUBPROCESS: + # Scheduler thread + scheduler_th = threading.Thread( + target=_flwr_serverapp_scheduler, + args=( + state_factory, + args.serverappio_api_address, + args.ssl_ca_certfile, + ), + ) + scheduler_th.start() + bckg_threads.append(scheduler_th) # Graceful shutdown register_exit_handlers( @@ -361,7 +385,7 @@ def run_superlink() -> None: for thread in bckg_threads: if not thread.is_alive(): sys.exit(1) - serverappio_server.wait_for_termination(timeout=1) + exec_server.wait_for_termination(timeout=1) def _flwr_serverapp_scheduler( @@ -657,6 +681,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser: _add_args_serverappio_api(parser=parser) _add_args_fleet_api(parser=parser) _add_args_exec_api(parser=parser) + _add_args_simulationio_api(parser=parser) return parser @@ -790,3 +815,11 @@ def _add_args_exec_api(parser: argparse.ArgumentParser) -> None: "For example:\n\n`--executor-config 'verbose=true " 'root-certificates="certificates/superlink-ca.crt"\'`', ) + + +def _add_args_simulationio_api(parser: argparse.ArgumentParser) -> None: + parser.add_argument( + "--simulationio-api-address", + help="SimulationIo API (gRPC) server address (IPv4, IPv6, or a domain name).", + default=SIMULATIONIO_API_DEFAULT_ADDRESS, + ) diff --git a/src/py/flwr/server/driver/inmemory_driver_test.py b/src/py/flwr/server/driver/inmemory_driver_test.py index c10c57648900..dbf79ae60287 100644 --- a/src/py/flwr/server/driver/inmemory_driver_test.py +++ b/src/py/flwr/server/driver/inmemory_driver_test.py @@ -21,7 +21,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 -from flwr.common import RecordSet +from flwr.common import ConfigsRecord, RecordSet from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL from flwr.common.message import Error from flwr.common.serde import ( @@ -232,7 +232,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None: """Test tasks are deleted in sqlite state once messages are pulled.""" # Prepare state = LinkStateFactory("").state() - run_id = state.create_run("", "", "", {}) + run_id = state.create_run("", "", "", {}, ConfigsRecord()) self.driver = InMemoryDriver(MagicMock(state=lambda: state)) self.driver.init_run(run_id=run_id) msg_ids, node_id = push_messages(self.driver, self.num_nodes) @@ -259,7 +259,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None: # Prepare state_factory = LinkStateFactory(":flwr-in-memory-state:") state = state_factory.state() - run_id = state.create_run("", "", "", {}) + run_id = state.create_run("", "", "", {}, ConfigsRecord()) self.driver = InMemoryDriver(state_factory) self.driver.init_run(run_id=run_id) msg_ids, node_id = push_messages(self.driver, self.num_nodes) diff --git a/src/py/flwr/server/superlink/driver/serverappio_servicer.py b/src/py/flwr/server/superlink/driver/serverappio_servicer.py index 9e4d72adb747..e1820fee0659 100644 --- a/src/py/flwr/server/superlink/driver/serverappio_servicer.py +++ b/src/py/flwr/server/superlink/driver/serverappio_servicer.py @@ -23,6 +23,7 @@ import grpc +from flwr.common import ConfigsRecord from flwr.common.constant import Status from flwr.common.logger import log from flwr.common.serde import ( @@ -112,6 +113,7 @@ def CreateRun( request.fab_version, fab_hash, user_config_from_proto(request.override_config), + ConfigsRecord(), ) return CreateRunResponse(run_id=run_id) diff --git a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py index d44f4eb7e8f9..fe6f0540e280 100644 --- a/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py +++ b/src/py/flwr/server/superlink/fleet/grpc_rere/server_interceptor_test.py @@ -20,6 +20,7 @@ import grpc +from flwr.common import ConfigsRecord from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( compute_hmac, @@ -334,7 +335,7 @@ def test_successful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) ) - run_id = self.state.create_run("", "", "", {}) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) request = GetRunRequest(run_id=run_id) shared_secret = generate_shared_key( self._node_private_key, self._server_public_key @@ -365,7 +366,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None: self.state.create_node( ping_interval=30, public_key=public_key_to_bytes(self._node_public_key) ) - run_id = self.state.create_run("", "", "", {}) + run_id = self.state.create_run("", "", "", {}, ConfigsRecord()) request = GetRunRequest(run_id=run_id) node_private_key, _ = generate_key_pairs() shared_secret = generate_shared_key(node_private_key, self._server_public_key) diff --git a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py index 52194a5a9ac8..085a6a2e29a7 100644 --- a/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/in_memory_linkstate.py @@ -30,6 +30,7 @@ RUN_ID_NUM_BYTES, Status, ) +from flwr.common.record import ConfigsRecord from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 from flwr.server.superlink.linkstate.linkstate import LinkState @@ -69,6 +70,7 @@ def __init__(self) -> None: # Map run_id to RunRecord self.run_ids: dict[int, RunRecord] = {} self.contexts: dict[int, Context] = {} + self.federation_options: dict[int, ConfigsRecord] = {} self.task_ins_store: dict[UUID, TaskIns] = {} self.task_res_store: dict[UUID, TaskRes] = {} @@ -378,12 +380,14 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" return self.public_key_to_node_id.get(node_public_key) + # pylint: disable=too-many-arguments,too-many-positional-arguments def create_run( self, fab_id: Optional[str], fab_version: Optional[str], fab_hash: Optional[str], override_config: UserConfig, + federation_options: ConfigsRecord, ) -> int: """Create a new run for the specified `fab_hash`.""" # Sample a random int64 as run_id @@ -407,6 +411,9 @@ def create_run( pending_at=now().isoformat(), ) self.run_ids[run_id] = run_record + + # Record federation options. Leave empty if not passed + self.federation_options[run_id] = federation_options return run_id log(ERROR, "Unexpected run creation failure.") return 0 @@ -514,6 +521,14 @@ def get_pending_run_id(self) -> Optional[int]: return pending_run_id + def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]: + """Retrieve the federation options for the specified `run_id`.""" + with self.lock: + if run_id not in self.run_ids: + log(ERROR, "`run_id` is invalid") + return None + return self.federation_options[run_id] + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" with self.lock: diff --git a/src/py/flwr/server/superlink/linkstate/linkstate.py b/src/py/flwr/server/superlink/linkstate/linkstate.py index a64fdfa92a06..4144ab89e2be 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate.py @@ -20,6 +20,7 @@ from uuid import UUID from flwr.common import Context +from flwr.common.record import ConfigsRecord from flwr.common.typing import Run, RunStatus, UserConfig from flwr.proto.task_pb2 import TaskIns, TaskRes # pylint: disable=E0611 @@ -152,12 +153,13 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]: """Retrieve stored `node_id` filtered by `node_public_keys`.""" @abc.abstractmethod - def create_run( + def create_run( # pylint: disable=too-many-arguments,too-many-positional-arguments self, fab_id: Optional[str], fab_version: Optional[str], fab_hash: Optional[str], override_config: UserConfig, + federation_options: ConfigsRecord, ) -> int: """Create a new run for the specified `fab_hash`.""" @@ -227,6 +229,21 @@ def get_pending_run_id(self) -> Optional[int]: there is no Run pending. """ + @abc.abstractmethod + def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]: + """Retrieve the federation options for the specified `run_id`. + + Parameters + ---------- + run_id : int + The identifier of the run. + + Returns + ------- + Optional[ConfigsRecord] + The federation options for the run if it exists; None otherwise. + """ + @abc.abstractmethod def store_server_private_public_key( self, private_key: bytes, public_key: bytes diff --git a/src/py/flwr/server/superlink/linkstate/linkstate_test.py b/src/py/flwr/server/superlink/linkstate/linkstate_test.py index 2cdea58a7cb7..16e19c9dc11b 100644 --- a/src/py/flwr/server/superlink/linkstate/linkstate_test.py +++ b/src/py/flwr/server/superlink/linkstate/linkstate_test.py @@ -23,7 +23,7 @@ from unittest.mock import patch from uuid import UUID -from flwr.common import DEFAULT_TTL, Context, RecordSet, now +from flwr.common import DEFAULT_TTL, ConfigsRecord, Context, RecordSet, now from flwr.common.constant import ErrorCode, Status, SubStatus from flwr.common.secure_aggregation.crypto.symmetric_encryption import ( generate_key_pairs, @@ -60,7 +60,9 @@ def test_create_and_get_run(self) -> None: """Test if create_run and get_run work correctly.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) # Execute run = state.get_run(run_id) @@ -75,8 +77,12 @@ def test_get_pending_run_id(self) -> None: """Test if get_pending_run_id works correctly.""" # Prepare state = self.state_factory() - _ = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) - run_id2 = state.create_run(None, None, "fffffff", {"mock_key": "mock_value"}) + _ = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) + run_id2 = state.create_run( + None, None, "fffffff", {"mock_key": "mock_value"}, ConfigsRecord() + ) state.update_run_status(run_id2, RunStatus(Status.STARTING, "", "")) # Execute @@ -95,8 +101,12 @@ def test_get_and_update_run_status(self) -> None: """Test if get_run_status and update_run_status work correctly.""" # Prepare state = self.state_factory() - run_id1 = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) - run_id2 = state.create_run(None, None, "fffffff", {"mock_key": "mock_value"}) + run_id1 = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) + run_id2 = state.create_run( + None, None, "fffffff", {"mock_key": "mock_value"}, ConfigsRecord() + ) state.update_run_status(run_id2, RunStatus(Status.STARTING, "", "")) state.update_run_status(run_id2, RunStatus(Status.RUNNING, "", "")) @@ -113,7 +123,9 @@ def test_status_transition_valid(self) -> None: """Test valid run status transactions.""" # Prepare state = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) # Execute and assert status1 = state.get_run_status({run_id})[run_id] @@ -135,7 +147,9 @@ def test_status_transition_invalid(self) -> None: """Test invalid run status transitions.""" # Prepare state = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {"test_key": "test_value"}) + run_id = state.create_run( + None, None, "9f86d08", {"test_key": "test_value"}, ConfigsRecord() + ) run_statuses = [ RunStatus(Status.PENDING, "", ""), RunStatus(Status.STARTING, "", ""), @@ -194,7 +208,7 @@ def test_store_task_ins_one(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -230,7 +244,7 @@ def test_store_task_ins_invalid_node_id(self) -> None: state = self.state_factory() node_id = state.create_node(1e3) invalid_node_id = 61016 if node_id != 61016 else 61017 - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=invalid_node_id, anonymous=False, run_id=run_id ) @@ -248,7 +262,7 @@ def test_store_and_delete_tasks(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins_0 = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -322,7 +336,7 @@ def test_task_ins_store_anonymous_and_retrieve_anonymous(self) -> None: """ # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -337,7 +351,7 @@ def test_task_ins_store_anonymous_and_fail_retrieving_identitiy(self) -> None: """Store anonymous TaskIns and fail to retrieve it.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) # Execute @@ -352,7 +366,7 @@ def test_task_ins_store_identity_and_fail_retrieving_anonymous(self) -> None: # Prepare state: LinkState = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -369,7 +383,7 @@ def test_task_ins_store_identity_and_retrieve_identity(self) -> None: # Prepare state: LinkState = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -389,7 +403,7 @@ def test_task_ins_store_delivered_and_fail_retrieving(self) -> None: # Prepare state: LinkState = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -434,7 +448,7 @@ def test_task_res_store_and_retrieve_by_task_ins_id(self) -> None: """Store TaskRes retrieve it by task_ins_id.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_ins_id = state.store_task_ins(task_ins) @@ -460,7 +474,7 @@ def test_node_ids_initial_state(self) -> None: """Test retrieving all node_ids and empty initial state.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute retrieved_node_ids = state.get_nodes(run_id) @@ -472,7 +486,7 @@ def test_create_node_and_get_nodes(self) -> None: """Test creating a client node.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_ids = [] # Execute @@ -489,7 +503,7 @@ def test_create_node_public_key(self) -> None: # Prepare state: LinkState = self.state_factory() public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute node_id = state.create_node(ping_interval=10, public_key=public_key) @@ -505,7 +519,7 @@ def test_create_node_public_key_twice(self) -> None: # Prepare state: LinkState = self.state_factory() public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -527,7 +541,7 @@ def test_delete_node(self) -> None: """Test deleting a client node.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = state.create_node(ping_interval=10) # Execute @@ -542,7 +556,7 @@ def test_delete_node_public_key(self) -> None: # Prepare state: LinkState = self.state_factory() public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute @@ -559,7 +573,7 @@ def test_delete_node_public_key_none(self) -> None: # Prepare state: LinkState = self.state_factory() public_key = b"mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = 0 # Execute & Assert @@ -578,7 +592,7 @@ def test_delete_node_wrong_public_key(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id = state.create_node(ping_interval=10, public_key=public_key) # Execute & Assert @@ -597,7 +611,7 @@ def test_get_node_id_wrong_public_key(self) -> None: state: LinkState = self.state_factory() public_key = b"mock" wrong_public_key = b"mock_mock" - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute state.create_node(ping_interval=10, public_key=public_key) @@ -612,7 +626,7 @@ def test_get_nodes_invalid_run_id(self) -> None: """Test retrieving all node_ids with invalid run_id.""" # Prepare state: LinkState = self.state_factory() - state.create_run(None, None, "9f86d08", {}) + state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) invalid_run_id = 61016 state.create_node(ping_interval=10) @@ -626,7 +640,7 @@ def test_num_task_ins(self) -> None: """Test if num_tasks returns correct number of not delivered task_ins.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -644,7 +658,7 @@ def test_num_task_res(self) -> None: """Test if num_tasks returns correct number of not delivered task_res.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins_0 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_ins_1 = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) @@ -757,7 +771,7 @@ def test_acknowledge_ping(self) -> None: """Test if acknowledge_ping works and if get_nodes return online nodes.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_ids = [state.create_node(ping_interval=10) for _ in range(100)] for node_id in node_ids[:70]: state.acknowledge_ping(node_id, ping_interval=30) @@ -776,7 +790,7 @@ def test_node_unavailable_error(self) -> None: """Test if get_task_res return TaskRes containing node unavailable error.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) node_id_0 = state.create_node(ping_interval=90) node_id_1 = state.create_node(ping_interval=30) # Create and store TaskIns @@ -818,7 +832,7 @@ def test_store_task_res_task_ins_expired(self) -> None: """Test behavior of store_task_res when the TaskIns it references is expired.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins(consumer_node_id=0, anonymous=True, run_id=run_id) task_ins.task.created_at = time.time() - task_ins.task.ttl + 0.5 @@ -872,7 +886,7 @@ def test_store_task_res_limit_ttl(self) -> None: # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=0, anonymous=True, run_id=run_id @@ -904,7 +918,7 @@ def test_get_task_ins_not_return_expired(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -924,7 +938,7 @@ def test_get_task_res_not_return_expired(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -955,7 +969,7 @@ def test_get_task_res_returns_empty_for_missing_taskins(self) -> None: does not exist.""" # Prepare state = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins_id = "5b0a3fc2-edba-4525-a89a-04b83420b7c8" task_res = create_task_res( @@ -978,7 +992,7 @@ def test_get_task_res_return_if_not_expired(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -1010,7 +1024,7 @@ def test_store_task_res_fail_if_consumer_producer_id_mismatch(self) -> None: # Prepare state = self.state_factory() node_id = state.create_node(1e3) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) task_ins = create_task_ins( consumer_node_id=node_id, anonymous=False, run_id=run_id ) @@ -1042,7 +1056,7 @@ def test_get_set_serverapp_context(self) -> None: state=RecordSet(), run_config={"test": "test"}, ) - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) # Execute init = state.get_serverapp_context(run_id) @@ -1094,7 +1108,7 @@ def test_add_and_get_serverapp_log(self) -> None: """Test adding and retrieving serverapp logs.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) log_entry_1 = "Log entry 1" log_entry_2 = "Log entry 2" timestamp = now().timestamp() @@ -1114,7 +1128,7 @@ def test_get_serverapp_log_after_timestamp(self) -> None: """Test retrieving serverapp logs after a specific timestamp.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) log_entry_1 = "Log entry 1" log_entry_2 = "Log entry 2" state.add_serverapp_log(run_id, log_entry_1) @@ -1136,7 +1150,7 @@ def test_get_serverapp_log_after_timestamp_no_logs(self) -> None: found.""" # Prepare state: LinkState = self.state_factory() - run_id = state.create_run(None, None, "9f86d08", {}) + run_id = state.create_run(None, None, "9f86d08", {}, ConfigsRecord()) log_entry = "Log entry" state.add_serverapp_log(run_id, log_entry) timestamp = now().timestamp() @@ -1150,6 +1164,31 @@ def test_get_serverapp_log_after_timestamp_no_logs(self) -> None: assert latest == 0 assert retrieved_logs == "" + def test_create_run_with_and_without_federation_options(self) -> None: + """Test that the recording and fetching of federation options works.""" + # Prepare + state = self.state_factory() + # A run w/ federation options + fed_options = ConfigsRecord({"setting-a": 123, "setting-b": [4, 5, 6]}) + run_id = state.create_run( + None, + None, + "fffffff", + {"mock_key": "mock_value"}, + federation_options=fed_options, + ) + state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")) + + # Execute + fed_options_fetched = state.get_federation_options(run_id=run_id) + + # Assert + assert fed_options_fetched == fed_options + + # Generate a run_id that doesn't exist. Then check None is returned + unique_int = next(num for num in range(0, 1) if num not in {run_id}) + assert state.get_federation_options(run_id=unique_int) is None + def create_task_ins( consumer_node_id: int, diff --git a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py index ad73bd4fcce0..77bb3337d7f4 100644 --- a/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py +++ b/src/py/flwr/server/superlink/linkstate/sqlite_linkstate.py @@ -33,6 +33,7 @@ RUN_ID_NUM_BYTES, Status, ) +from flwr.common.record import ConfigsRecord from flwr.common.typing import Run, RunStatus, UserConfig # pylint: disable=E0611 @@ -45,6 +46,8 @@ from .linkstate import LinkState from .utils import ( + configsrecord_from_bytes, + configsrecord_to_bytes, context_from_bytes, context_to_bytes, convert_sint64_to_uint64, @@ -95,7 +98,8 @@ running_at TEXT, finished_at TEXT, sub_status TEXT, - details TEXT + details TEXT, + federation_options BLOB ); """ @@ -810,12 +814,14 @@ def get_node_id(self, node_public_key: bytes) -> Optional[int]: return uint64_node_id return None + # pylint: disable=too-many-arguments,too-many-positional-arguments def create_run( self, fab_id: Optional[str], fab_version: Optional[str], fab_hash: Optional[str], override_config: UserConfig, + federation_options: ConfigsRecord, ) -> int: """Create a new run for the specified `fab_id` and `fab_version`.""" # Sample a random int64 as run_id @@ -830,15 +836,29 @@ def create_run( if self.query(query, (sint64_run_id,))[0]["COUNT(*)"] == 0: query = ( "INSERT INTO run " - "(run_id, fab_id, fab_version, fab_hash, override_config, pending_at, " - "starting_at, running_at, finished_at, sub_status, details)" - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);" + "(run_id, fab_id, fab_version, fab_hash, override_config, " + "federation_options, pending_at, starting_at, running_at, finished_at, " + "sub_status, details) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?);" ) if fab_hash: fab_id, fab_version = "", "" override_config_json = json.dumps(override_config) - data = [sint64_run_id, fab_id, fab_version, fab_hash, override_config_json] - data += [now().isoformat(), "", "", "", "", ""] + data = [ + sint64_run_id, + fab_id, + fab_version, + fab_hash, + override_config_json, + configsrecord_to_bytes(federation_options), + ] + data += [ + now().isoformat(), + "", + "", + "", + "", + "", + ] self.query(query, tuple(data)) return uint64_run_id log(ERROR, "Unexpected run creation failure.") @@ -1003,6 +1023,21 @@ def get_pending_run_id(self) -> Optional[int]: return pending_run_id + def get_federation_options(self, run_id: int) -> Optional[ConfigsRecord]: + """Retrieve the federation options for the specified `run_id`.""" + # Convert the uint64 value to sint64 for SQLite + sint64_run_id = convert_uint64_to_sint64(run_id) + query = "SELECT federation_options FROM run WHERE run_id = ?;" + rows = self.query(query, (sint64_run_id,)) + + # Check if the run_id exists + if not rows: + log(ERROR, "`run_id` is invalid") + return None + + row = rows[0] + return configsrecord_from_bytes(row["federation_options"]) + def acknowledge_ping(self, node_id: int, ping_interval: float) -> bool: """Acknowledge a ping received from a node, serving as a heartbeat.""" sint64_node_id = convert_uint64_to_sint64(node_id) diff --git a/src/py/flwr/server/superlink/linkstate/utils.py b/src/py/flwr/server/superlink/linkstate/utils.py index 4a18e8880c9d..5e8c2be00cba 100644 --- a/src/py/flwr/server/superlink/linkstate/utils.py +++ b/src/py/flwr/server/superlink/linkstate/utils.py @@ -20,12 +20,15 @@ from os import urandom from uuid import uuid4 -from flwr.common import Context, log, serde +from flwr.common import ConfigsRecord, Context, log, serde from flwr.common.constant import ErrorCode, Status, SubStatus from flwr.common.typing import RunStatus from flwr.proto.error_pb2 import Error # pylint: disable=E0611 from flwr.proto.message_pb2 import Context as ProtoContext # pylint: disable=E0611 from flwr.proto.node_pb2 import Node # pylint: disable=E0611 + +# pylint: disable=E0611 +from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord from flwr.proto.task_pb2 import Task, TaskIns, TaskRes # pylint: disable=E0611 NODE_UNAVAILABLE_ERROR_REASON = ( @@ -146,6 +149,18 @@ def context_from_bytes(context_bytes: bytes) -> Context: return serde.context_from_proto(ProtoContext.FromString(context_bytes)) +def configsrecord_to_bytes(configs_record: ConfigsRecord) -> bytes: + """Serialize a `ConfigsRecord` to bytes.""" + return serde.configs_record_to_proto(configs_record).SerializeToString() + + +def configsrecord_from_bytes(configsrecord_bytes: bytes) -> ConfigsRecord: + """Deserialize `ConfigsRecord` from bytes.""" + return serde.configs_record_from_proto( + ProtoConfigsRecord.FromString(configsrecord_bytes) + ) + + def make_node_unavailable_taskres(ref_taskins: TaskIns) -> TaskRes: """Generate a TaskRes with a node unavailable error from a TaskIns.""" current_time = time.time() diff --git a/src/py/flwr/server/superlink/simulation/__init__.py b/src/py/flwr/server/superlink/simulation/__init__.py new file mode 100644 index 000000000000..8485a3c9a3c7 --- /dev/null +++ b/src/py/flwr/server/superlink/simulation/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower SimulationIo service.""" diff --git a/src/py/flwr/server/superlink/simulation/simulationio_grpc.py b/src/py/flwr/server/superlink/simulation/simulationio_grpc.py new file mode 100644 index 000000000000..d1e79306e0b7 --- /dev/null +++ b/src/py/flwr/server/superlink/simulation/simulationio_grpc.py @@ -0,0 +1,65 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SimulationIo gRPC API.""" + + +from logging import INFO +from typing import Optional + +import grpc + +from flwr.common import GRPC_MAX_MESSAGE_LENGTH +from flwr.common.logger import log +from flwr.proto.simulationio_pb2_grpc import ( # pylint: disable=E0611 + add_SimulationIoServicer_to_server, +) +from flwr.server.superlink.ffs.ffs_factory import FfsFactory +from flwr.server.superlink.linkstate import LinkStateFactory + +from ..fleet.grpc_bidi.grpc_server import generic_create_grpc_server +from .simulationio_servicer import SimulationIoServicer + + +def run_simulationio_api_grpc( + address: str, + state_factory: LinkStateFactory, + ffs_factory: FfsFactory, + certificates: Optional[tuple[bytes, bytes, bytes]], +) -> grpc.Server: + """Run SimulationIo API (gRPC, request-response).""" + # Create SimulationIo API gRPC server + simulationio_servicer: grpc.Server = SimulationIoServicer( + state_factory=state_factory, + ffs_factory=ffs_factory, + ) + simulationio_add_servicer_to_server_fn = add_SimulationIoServicer_to_server + simulationio_grpc_server = generic_create_grpc_server( + servicer_and_add_fn=( + simulationio_servicer, + simulationio_add_servicer_to_server_fn, + ), + server_address=address, + max_message_length=GRPC_MAX_MESSAGE_LENGTH, + certificates=certificates, + ) + + log( + INFO, + "Flower Simulation Engine: Starting SimulationIo API on %s", + address, + ) + simulationio_grpc_server.start() + + return simulationio_grpc_server diff --git a/src/py/flwr/server/superlink/simulation/simulationio_servicer.py b/src/py/flwr/server/superlink/simulation/simulationio_servicer.py new file mode 100644 index 000000000000..03bed32e4332 --- /dev/null +++ b/src/py/flwr/server/superlink/simulation/simulationio_servicer.py @@ -0,0 +1,132 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SimulationIo API servicer.""" + +import threading +from logging import DEBUG, INFO + +import grpc +from grpc import ServicerContext + +from flwr.common.constant import Status +from flwr.common.logger import log +from flwr.common.serde import ( + context_from_proto, + context_to_proto, + fab_to_proto, + run_status_from_proto, + run_to_proto, +) +from flwr.common.typing import Fab, RunStatus +from flwr.proto import simulationio_pb2_grpc +from flwr.proto.log_pb2 import ( # pylint: disable=E0611 + PushLogsRequest, + PushLogsResponse, +) +from flwr.proto.run_pb2 import ( # pylint: disable=E0611 + UpdateRunStatusRequest, + UpdateRunStatusResponse, +) +from flwr.proto.simulationio_pb2 import ( # pylint: disable=E0611 + PullSimulationInputsRequest, + PullSimulationInputsResponse, + PushSimulationOutputsRequest, + PushSimulationOutputsResponse, +) +from flwr.server.superlink.ffs.ffs_factory import FfsFactory +from flwr.server.superlink.linkstate import LinkStateFactory + + +class SimulationIoServicer(simulationio_pb2_grpc.SimulationIoServicer): + """SimulationIo API servicer.""" + + def __init__( + self, state_factory: LinkStateFactory, ffs_factory: FfsFactory + ) -> None: + self.state_factory = state_factory + self.ffs_factory = ffs_factory + self.lock = threading.RLock() + + def PullSimulationInputs( + self, request: PullSimulationInputsRequest, context: ServicerContext + ) -> PullSimulationInputsResponse: + """Pull SimultionIo process inputs.""" + log(DEBUG, "SimultionIoServicer.SimultionIoInputs") + # Init access to LinkState and Ffs + state = self.state_factory.state() + ffs = self.ffs_factory.ffs() + + # Lock access to LinkState, preventing obtaining the same pending run_id + with self.lock: + # Attempt getting the run_id of a pending run + run_id = state.get_pending_run_id() + # If there's no pending run, return an empty response + if run_id is None: + return PullSimulationInputsResponse() + + # Retrieve Context, Run and Fab for the run_id + serverapp_ctxt = state.get_serverapp_context(run_id) + run = state.get_run(run_id) + fab = None + if run and run.fab_hash: + if result := ffs.get(run.fab_hash): + fab = Fab(run.fab_hash, result[0]) + if run and fab and serverapp_ctxt: + # Update run status to STARTING + if state.update_run_status(run_id, RunStatus(Status.STARTING, "", "")): + log(INFO, "Starting run %d", run_id) + return PullSimulationInputsResponse( + context=context_to_proto(serverapp_ctxt), + run=run_to_proto(run), + fab=fab_to_proto(fab), + ) + + # Raise an exception if the Run or Fab is not found, + # or if the status cannot be updated to STARTING + raise RuntimeError(f"Failed to start run {run_id}") + + def PushSimulationOutputs( + self, request: PushSimulationOutputsRequest, context: ServicerContext + ) -> PushSimulationOutputsResponse: + """Push Simulation process outputs.""" + log(DEBUG, "SimultionIoServicer.PushSimulationOutputs") + state = self.state_factory.state() + state.set_serverapp_context(request.run_id, context_from_proto(request.context)) + return PushSimulationOutputsResponse() + + def UpdateRunStatus( + self, request: UpdateRunStatusRequest, context: grpc.ServicerContext + ) -> UpdateRunStatusResponse: + """Update the status of a run.""" + log(DEBUG, "SimultionIoServicer.UpdateRunStatus") + state = self.state_factory.state() + + # Update the run status + state.update_run_status( + run_id=request.run_id, new_status=run_status_from_proto(request.run_status) + ) + return UpdateRunStatusResponse() + + def PushLogs( + self, request: PushLogsRequest, context: grpc.ServicerContext + ) -> PushLogsResponse: + """Push logs.""" + log(DEBUG, "ServerAppIoServicer.PushLogs") + state = self.state_factory.state() + + # Add logs to LinkState + merged_logs = "".join(request.logs) + state.add_serverapp_log(request.run_id, merged_logs) + return PushLogsResponse() diff --git a/src/py/flwr/simulation/__init__.py b/src/py/flwr/simulation/__init__.py index a171347b1507..912613cbad9f 100644 --- a/src/py/flwr/simulation/__init__.py +++ b/src/py/flwr/simulation/__init__.py @@ -18,6 +18,7 @@ import importlib from flwr.simulation.run_simulation import run_simulation +from flwr.simulation.simulationio_connection import SimulationIoConnection is_ray_installed = importlib.util.find_spec("ray") is not None @@ -37,6 +38,7 @@ def start_simulation(*args, **kwargs): # type: ignore __all__ = [ + "SimulationIoConnection", "run_simulation", "start_simulation", ] diff --git a/src/py/flwr/simulation/simulationio_connection.py b/src/py/flwr/simulation/simulationio_connection.py new file mode 100644 index 000000000000..a53f0f5ce317 --- /dev/null +++ b/src/py/flwr/simulation/simulationio_connection.py @@ -0,0 +1,86 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Flower SimulationIo connection.""" + + +from logging import DEBUG, WARNING +from typing import Optional, cast + +import grpc + +from flwr.common.constant import SIMULATIONIO_API_DEFAULT_ADDRESS +from flwr.common.grpc import create_channel +from flwr.common.logger import log +from flwr.proto.simulationio_pb2_grpc import SimulationIoStub # pylint: disable=E0611 + + +class SimulationIoConnection: + """`SimulationIoConnection` provides an interface to the SimulationIo API. + + Parameters + ---------- + simulationio_service_address : str (default: "[::]:9094") + The address (URL, IPv6, IPv4) of the SuperLink SimulationIo API service. + root_certificates : Optional[bytes] (default: None) + The PEM-encoded root certificates as a byte string. + If provided, a secure connection using the certificates will be + established to an SSL-enabled Flower server. + """ + + def __init__( # pylint: disable=too-many-arguments + self, + simulationio_service_address: str = SIMULATIONIO_API_DEFAULT_ADDRESS, + root_certificates: Optional[bytes] = None, + ) -> None: + self._addr = simulationio_service_address + self._cert = root_certificates + self._grpc_stub: Optional[SimulationIoStub] = None + self._channel: Optional[grpc.Channel] = None + + @property + def _is_connected(self) -> bool: + """Check if connected to the SimulationIo API server.""" + return self._channel is not None + + @property + def _stub(self) -> SimulationIoStub: + """SimulationIo stub.""" + if not self._is_connected: + self._connect() + return cast(SimulationIoStub, self._grpc_stub) + + def _connect(self) -> None: + """Connect to the SimulationIo API.""" + if self._is_connected: + log(WARNING, "Already connected") + return + self._channel = create_channel( + server_address=self._addr, + insecure=(self._cert is None), + root_certificates=self._cert, + ) + self._grpc_stub = SimulationIoStub(self._channel) + log(DEBUG, "[SimulationIO] Connected to %s", self._addr) + + def _disconnect(self) -> None: + """Disconnect from the SimulationIo API.""" + if not self._is_connected: + log(DEBUG, "Already disconnected") + return + channel: grpc.Channel = self._channel + self._channel = None + self._grpc_stub = None + channel.close() + log(DEBUG, "[SimulationIO] Disconnected") diff --git a/src/py/flwr/superexec/deployment.py b/src/py/flwr/superexec/deployment.py index 5d31bcd5edc4..247c594f9766 100644 --- a/src/py/flwr/superexec/deployment.py +++ b/src/py/flwr/superexec/deployment.py @@ -21,7 +21,7 @@ from typing_extensions import override -from flwr.common import Context, RecordSet +from flwr.common import ConfigsRecord, Context, RecordSet from flwr.common.constant import SERVERAPPIO_API_DEFAULT_ADDRESS, Status, SubStatus from flwr.common.logger import log from flwr.common.typing import Fab, RunStatus, UserConfig @@ -133,7 +133,9 @@ def _create_run( f"FAB ({fab.hash_str}) hash from request doesn't match contents" ) - run_id = self.linkstate.create_run(None, None, fab_hash, override_config) + run_id = self.linkstate.create_run( + None, None, fab_hash, override_config, ConfigsRecord() + ) return run_id def _create_context(self, run_id: int) -> None: