Skip to content

Commit

Permalink
Merge branch 'main' into feat/upgrade-numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
chongshenng committed Nov 7, 2024
2 parents 60571ca + 11665ee commit 94972d3
Show file tree
Hide file tree
Showing 16 changed files with 624 additions and 164 deletions.
1 change: 1 addition & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
)
FLEET_API_REST_DEFAULT_ADDRESS = "0.0.0.0:9095"
EXEC_API_DEFAULT_ADDRESS = "0.0.0.0:9093"
SIMULATIONIO_API_DEFAULT_ADDRESS = "0.0.0.0:9096"

# Constants for ping
PING_DEFAULT_INTERVAL = 30
Expand Down
247 changes: 140 additions & 107 deletions src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
ISOLATION_MODE_SUBPROCESS,
MISSING_EXTRA_REST,
SERVERAPPIO_API_DEFAULT_ADDRESS,
SIMULATIONIO_API_DEFAULT_ADDRESS,
TRANSPORT_TYPE_GRPC_ADAPTER,
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_REST,
Expand All @@ -63,6 +64,7 @@
from flwr.proto.grpcadapter_pb2_grpc import add_GrpcAdapterServicer_to_server
from flwr.superexec.app import load_executor
from flwr.superexec.exec_grpc import run_exec_api_grpc
from flwr.superexec.simulation import SimulationEngine

from .client_manager import ClientManager
from .history import History
Expand All @@ -79,6 +81,7 @@
from .superlink.fleet.grpc_rere.fleet_servicer import FleetServicer
from .superlink.fleet.grpc_rere.server_interceptor import AuthenticateServerInterceptor
from .superlink.linkstate import LinkStateFactory
from .superlink.simulation.simulationio_grpc import run_simulationio_api_grpc

DATABASE = ":flwr-in-memory-state:"
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"
Expand Down Expand Up @@ -215,6 +218,7 @@ def run_superlink() -> None:
# Parse IP addresses
serverappio_address, _, _ = _format_address(args.serverappio_api_address)
exec_address, _, _ = _format_address(args.exec_api_address)
simulationio_address, _, _ = _format_address(args.simulationio_api_address)

# Obtain certificates
certificates = _try_obtain_certificates(args)
Expand All @@ -225,128 +229,148 @@ def run_superlink() -> None:
# Initialize FfsFactory
ffs_factory = FfsFactory(args.storage_dir)

# Start ServerAppIo API
serverappio_server: grpc.Server = run_serverappio_api_grpc(
address=serverappio_address,
# Start Exec API
executor = load_executor(args)
exec_server: grpc.Server = run_exec_api_grpc(
address=exec_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
executor=executor,
certificates=certificates,
config=parse_config_args(
[args.executor_config] if args.executor_config else args.executor_config
),
)
grpc_servers = [serverappio_server]
grpc_servers = [exec_server]

# Start Fleet API
bckg_threads = []
if not args.fleet_api_address:
if args.fleet_api_type in [
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_GRPC_ADAPTER,
]:
args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS

fleet_address, host, port = _format_address(args.fleet_api_address)

num_workers = args.fleet_api_num_workers
if num_workers != 1:
log(
WARN,
"The Fleet API currently supports only 1 worker. "
"You have specified %d workers. "
"Support for multiple workers will be added in future releases. "
"Proceeding with a single worker.",
args.fleet_api_num_workers,
)
num_workers = 1
# Determine Exec plugin
# If simulation is used, don't start ServerAppIo and Fleet APIs
sim_exec = isinstance(executor, SimulationEngine)

if args.fleet_api_type == TRANSPORT_TYPE_REST:
if (
importlib.util.find_spec("requests")
and importlib.util.find_spec("starlette")
and importlib.util.find_spec("uvicorn")
) is None:
sys.exit(MISSING_EXTRA_REST)

_, ssl_certfile, ssl_keyfile = (
certificates if certificates is not None else (None, None, None)
)

fleet_thread = threading.Thread(
target=_run_fleet_api_rest,
args=(
host,
port,
ssl_keyfile,
ssl_certfile,
state_factory,
ffs_factory,
num_workers,
),
)
fleet_thread.start()
bckg_threads.append(fleet_thread)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
maybe_keys = _try_setup_node_authentication(args, certificates)
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
if maybe_keys is not None:
(
node_public_keys,
server_private_key,
server_public_key,
) = maybe_keys
state = state_factory.state()
state.store_node_public_keys(node_public_keys)
state.store_server_private_public_key(
private_key_to_bytes(server_private_key),
public_key_to_bytes(server_public_key),
)
log(
INFO,
"Node authentication enabled with %d known public keys",
len(node_public_keys),
)
interceptors = [AuthenticateServerInterceptor(state)]
bckg_threads = []

fleet_server = _run_fleet_api_grpc_rere(
address=fleet_address,
if sim_exec:
simulationio_server: grpc.Server = run_simulationio_api_grpc(
address=simulationio_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
interceptors=interceptors,
)
grpc_servers.append(fleet_server)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
fleet_server = _run_fleet_api_grpc_adapter(
address=fleet_address,
grpc_servers.append(simulationio_server)

else:
# Start ServerAppIo API
serverappio_server: grpc.Server = run_serverappio_api_grpc(
address=serverappio_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
)
grpc_servers.append(fleet_server)
else:
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")

# Start Exec API
exec_server: grpc.Server = run_exec_api_grpc(
address=exec_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
executor=load_executor(args),
certificates=certificates,
config=parse_config_args(
[args.executor_config] if args.executor_config else args.executor_config
),
)
grpc_servers.append(exec_server)
grpc_servers.append(serverappio_server)

# Start Fleet API
if not args.fleet_api_address:
if args.fleet_api_type in [
TRANSPORT_TYPE_GRPC_RERE,
TRANSPORT_TYPE_GRPC_ADAPTER,
]:
args.fleet_api_address = FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
elif args.fleet_api_type == TRANSPORT_TYPE_REST:
args.fleet_api_address = FLEET_API_REST_DEFAULT_ADDRESS

fleet_address, host, port = _format_address(args.fleet_api_address)

num_workers = args.fleet_api_num_workers
if num_workers != 1:
log(
WARN,
"The Fleet API currently supports only 1 worker. "
"You have specified %d workers. "
"Support for multiple workers will be added in future releases. "
"Proceeding with a single worker.",
args.fleet_api_num_workers,
)
num_workers = 1

if args.fleet_api_type == TRANSPORT_TYPE_REST:
if (
importlib.util.find_spec("requests")
and importlib.util.find_spec("starlette")
and importlib.util.find_spec("uvicorn")
) is None:
sys.exit(MISSING_EXTRA_REST)

_, ssl_certfile, ssl_keyfile = (
certificates if certificates is not None else (None, None, None)
)

if args.isolation == ISOLATION_MODE_SUBPROCESS:
# Scheduler thread
scheduler_th = threading.Thread(
target=_flwr_serverapp_scheduler,
args=(state_factory, args.serverappio_api_address, args.ssl_ca_certfile),
)
scheduler_th.start()
bckg_threads.append(scheduler_th)
fleet_thread = threading.Thread(
target=_run_fleet_api_rest,
args=(
host,
port,
ssl_keyfile,
ssl_certfile,
state_factory,
ffs_factory,
num_workers,
),
)
fleet_thread.start()
bckg_threads.append(fleet_thread)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_RERE:
maybe_keys = _try_setup_node_authentication(args, certificates)
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
if maybe_keys is not None:
(
node_public_keys,
server_private_key,
server_public_key,
) = maybe_keys
state = state_factory.state()
state.store_node_public_keys(node_public_keys)
state.store_server_private_public_key(
private_key_to_bytes(server_private_key),
public_key_to_bytes(server_public_key),
)
log(
INFO,
"Node authentication enabled with %d known public keys",
len(node_public_keys),
)
interceptors = [AuthenticateServerInterceptor(state)]

fleet_server = _run_fleet_api_grpc_rere(
address=fleet_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
interceptors=interceptors,
)
grpc_servers.append(fleet_server)
elif args.fleet_api_type == TRANSPORT_TYPE_GRPC_ADAPTER:
fleet_server = _run_fleet_api_grpc_adapter(
address=fleet_address,
state_factory=state_factory,
ffs_factory=ffs_factory,
certificates=certificates,
)
grpc_servers.append(fleet_server)
else:
raise ValueError(f"Unknown fleet_api_type: {args.fleet_api_type}")

if args.isolation == ISOLATION_MODE_SUBPROCESS:
# Scheduler thread
scheduler_th = threading.Thread(
target=_flwr_serverapp_scheduler,
args=(
state_factory,
args.serverappio_api_address,
args.ssl_ca_certfile,
),
)
scheduler_th.start()
bckg_threads.append(scheduler_th)

# Graceful shutdown
register_exit_handlers(
Expand All @@ -361,7 +385,7 @@ def run_superlink() -> None:
for thread in bckg_threads:
if not thread.is_alive():
sys.exit(1)
serverappio_server.wait_for_termination(timeout=1)
exec_server.wait_for_termination(timeout=1)


def _flwr_serverapp_scheduler(
Expand Down Expand Up @@ -657,6 +681,7 @@ def _parse_args_run_superlink() -> argparse.ArgumentParser:
_add_args_serverappio_api(parser=parser)
_add_args_fleet_api(parser=parser)
_add_args_exec_api(parser=parser)
_add_args_simulationio_api(parser=parser)

return parser

Expand Down Expand Up @@ -790,3 +815,11 @@ def _add_args_exec_api(parser: argparse.ArgumentParser) -> None:
"For example:\n\n`--executor-config 'verbose=true "
'root-certificates="certificates/superlink-ca.crt"\'`',
)


def _add_args_simulationio_api(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--simulationio-api-address",
help="SimulationIo API (gRPC) server address (IPv4, IPv6, or a domain name).",
default=SIMULATIONIO_API_DEFAULT_ADDRESS,
)
6 changes: 3 additions & 3 deletions src/py/flwr/server/driver/inmemory_driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from unittest.mock import MagicMock, patch
from uuid import uuid4

from flwr.common import RecordSet
from flwr.common import ConfigsRecord, RecordSet
from flwr.common.constant import NODE_ID_NUM_BYTES, PING_MAX_INTERVAL
from flwr.common.message import Error
from flwr.common.serde import (
Expand Down Expand Up @@ -232,7 +232,7 @@ def test_task_store_consistency_after_push_pull_sqlitestate(self) -> None:
"""Test tasks are deleted in sqlite state once messages are pulled."""
# Prepare
state = LinkStateFactory("").state()
run_id = state.create_run("", "", "", {})
run_id = state.create_run("", "", "", {}, ConfigsRecord())
self.driver = InMemoryDriver(MagicMock(state=lambda: state))
self.driver.init_run(run_id=run_id)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
Expand All @@ -259,7 +259,7 @@ def test_task_store_consistency_after_push_pull_inmemory_state(self) -> None:
# Prepare
state_factory = LinkStateFactory(":flwr-in-memory-state:")
state = state_factory.state()
run_id = state.create_run("", "", "", {})
run_id = state.create_run("", "", "", {}, ConfigsRecord())
self.driver = InMemoryDriver(state_factory)
self.driver.init_run(run_id=run_id)
msg_ids, node_id = push_messages(self.driver, self.num_nodes)
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/server/superlink/driver/serverappio_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import grpc

from flwr.common import ConfigsRecord
from flwr.common.constant import Status
from flwr.common.logger import log
from flwr.common.serde import (
Expand Down Expand Up @@ -112,6 +113,7 @@ def CreateRun(
request.fab_version,
fab_hash,
user_config_from_proto(request.override_config),
ConfigsRecord(),
)
return CreateRunResponse(run_id=run_id)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import grpc

from flwr.common import ConfigsRecord
from flwr.common.constant import FLEET_API_GRPC_RERE_DEFAULT_ADDRESS
from flwr.common.secure_aggregation.crypto.symmetric_encryption import (
compute_hmac,
Expand Down Expand Up @@ -334,7 +335,7 @@ def test_successful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
run_id = self.state.create_run("", "", "", {})
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
request = GetRunRequest(run_id=run_id)
shared_secret = generate_shared_key(
self._node_private_key, self._server_public_key
Expand Down Expand Up @@ -365,7 +366,7 @@ def test_unsuccessful_get_run_with_metadata(self) -> None:
self.state.create_node(
ping_interval=30, public_key=public_key_to_bytes(self._node_public_key)
)
run_id = self.state.create_run("", "", "", {})
run_id = self.state.create_run("", "", "", {}, ConfigsRecord())
request = GetRunRequest(run_id=run_id)
node_private_key, _ = generate_key_pairs()
shared_secret = generate_shared_key(node_private_key, self._server_public_key)
Expand Down
Loading

0 comments on commit 94972d3

Please sign in to comment.