Skip to content

Commit

Permalink
feat(framework) Add ExecUserAuthInterceptor to Exec API (#4630)
Browse files Browse the repository at this point in the history
Co-authored-by: Heng Pan <pan@flower.ai>
Co-authored-by: Javier <jafermarq@users.noreply.github.com>
Co-authored-by: Daniel J. Beutel <daniel@flower.ai>
  • Loading branch information
4 people authored Dec 10, 2024
1 parent 1b43ac7 commit c955c27
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 8 deletions.
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,16 @@ tomli = "^2.0.1"
tomli-w = "^1.0.0"
pathspec = "^0.12.1"
rich = "^13.5.0"
pyyaml = "^6.0.2"
requests = "^2.31.0"
# Optional dependencies (Simulation Engine)
ray = { version = "==2.10.0", optional = true, python = ">=3.9,<3.12" }
# Optional dependencies (REST transport layer)
requests = { version = "^2.31.0", optional = true }
starlette = { version = "^0.31.0", optional = true }
uvicorn = { version = "^0.23.0", extras = ["standard"], optional = true }

[tool.poetry.extras]
simulation = ["ray"]
rest = ["requests", "starlette", "uvicorn"]
rest = ["starlette", "uvicorn"]

[tool.poetry.group.dev.dependencies]
types-dataclasses = "==0.6.6"
Expand Down Expand Up @@ -131,6 +131,7 @@ mdformat-gfm = "==0.3.6"
mdformat-frontmatter = "==2.0.1"
mdformat-beautysh = "==0.1.1"
twine = "==5.1.1"
types-PyYAML = "^6.0.2"
pyroma = "==4.2"
check-wheel-contents = "==0.4.0"
GitPython = "==3.1.32"
Expand Down
2 changes: 2 additions & 0 deletions src/py/flwr/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,8 @@
# Retry configurations
MAX_RETRY_DELAY = 20 # Maximum delay duration between two consecutive retries.

AUTH_TYPE = "auth_type"


class MessageType:
"""Message type."""
Expand Down
53 changes: 52 additions & 1 deletion src/py/flwr/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from logging import DEBUG, INFO, WARN
from pathlib import Path
from time import sleep
from typing import Optional
from typing import Any, Optional

import grpc
import yaml
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives.asymmetric import ec
from cryptography.hazmat.primitives.serialization import (
Expand All @@ -37,8 +38,10 @@
from flwr.common import GRPC_MAX_MESSAGE_LENGTH, EventType, event
from flwr.common.address import parse_address
from flwr.common.args import try_obtain_server_certificates
from flwr.common.auth_plugin import ExecAuthPlugin
from flwr.common.config import get_flwr_dir, parse_config_args
from flwr.common.constant import (
AUTH_TYPE,
CLIENT_OCTET,
EXEC_API_DEFAULT_SERVER_ADDRESS,
FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
Expand Down Expand Up @@ -88,6 +91,15 @@
BASE_DIR = get_flwr_dir() / "superlink" / "ffs"


try:
from flwr.ee import get_exec_auth_plugins
except ImportError:

def get_exec_auth_plugins() -> dict[str, type[ExecAuthPlugin]]:
"""Return all Exec API authentication plugins."""
raise NotImplementedError("No authentication plugins are currently supported.")


def start_server( # pylint: disable=too-many-arguments,too-many-locals
*,
server_address: str = FLEET_API_GRPC_BIDI_DEFAULT_ADDRESS,
Expand Down Expand Up @@ -246,6 +258,12 @@ def run_superlink() -> None:
# Obtain certificates
certificates = try_obtain_server_certificates(args, args.fleet_api_type)

user_auth_config = _try_obtain_user_auth_config(args)
auth_plugin: Optional[ExecAuthPlugin] = None
# user_auth_config is None only if the args.user_auth_config is not provided
if user_auth_config is not None:
auth_plugin = _try_obtain_exec_auth_plugin(user_auth_config)

# Initialize StateFactory
state_factory = LinkStateFactory(args.database)

Expand All @@ -263,6 +281,7 @@ def run_superlink() -> None:
config=parse_config_args(
[args.executor_config] if args.executor_config else args.executor_config
),
auth_plugin=auth_plugin,
)
grpc_servers = [exec_server]

Expand Down Expand Up @@ -559,6 +578,32 @@ def _try_setup_node_authentication(
)


def _try_obtain_user_auth_config(args: argparse.Namespace) -> Optional[dict[str, Any]]:
if args.user_auth_config is not None:
with open(args.user_auth_config, encoding="utf-8") as file:
config: dict[str, Any] = yaml.safe_load(file)
return config
return None


def _try_obtain_exec_auth_plugin(config: dict[str, Any]) -> Optional[ExecAuthPlugin]:
auth_config: dict[str, Any] = config.get("authentication", {})
auth_type: str = auth_config.get(AUTH_TYPE, "")
try:
all_plugins: dict[str, type[ExecAuthPlugin]] = get_exec_auth_plugins()
auth_plugin_class = all_plugins[auth_type]
return auth_plugin_class(config=auth_config)
except KeyError:
if auth_type != "":
sys.exit(
f'Authentication type "{auth_type}" is not supported. '
"Please provide a valid authentication type in the configuration."
)
sys.exit("No authentication type is provided in the configuration.")
except NotImplementedError:
sys.exit("No authentication plugins are currently supported.")


def _run_fleet_api_grpc_rere(
address: str,
state_factory: LinkStateFactory,
Expand Down Expand Up @@ -746,6 +791,12 @@ def _add_args_common(parser: argparse.ArgumentParser) -> None:
type=str,
help="The SuperLink's public key (as a path str) to enable authentication.",
)
parser.add_argument(
"--user-auth-config",
help="The path to the user authentication configuration YAML file.",
type=str,
default=None,
)


def _add_args_serverappio_api(parser: argparse.ArgumentParser) -> None:
Expand Down
19 changes: 18 additions & 1 deletion src/py/flwr/superexec/exec_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,21 @@
# ==============================================================================
"""SuperExec gRPC API."""

from collections.abc import Sequence
from logging import INFO
from typing import Optional

import grpc

from flwr.common import GRPC_MAX_MESSAGE_LENGTH
from flwr.common.auth_plugin import ExecAuthPlugin
from flwr.common.logger import log
from flwr.common.typing import UserConfig
from flwr.proto.exec_pb2_grpc import add_ExecServicer_to_server
from flwr.server.superlink.ffs.ffs_factory import FfsFactory
from flwr.server.superlink.fleet.grpc_bidi.grpc_server import generic_create_grpc_server
from flwr.server.superlink.linkstate import LinkStateFactory
from flwr.superexec.exec_user_auth_interceptor import ExecUserAuthInterceptor

from .exec_servicer import ExecServicer
from .executor import Executor
Expand All @@ -39,6 +42,7 @@ def run_exec_api_grpc(
ffs_factory: FfsFactory,
certificates: Optional[tuple[bytes, bytes, bytes]],
config: UserConfig,
auth_plugin: Optional[ExecAuthPlugin] = None,
) -> grpc.Server:
"""Run Exec API (gRPC, request-response)."""
executor.set_config(config)
Expand All @@ -47,16 +51,29 @@ def run_exec_api_grpc(
linkstate_factory=state_factory,
ffs_factory=ffs_factory,
executor=executor,
auth_plugin=auth_plugin,
)
interceptors: Optional[Sequence[grpc.ServerInterceptor]] = None
if auth_plugin is not None:
interceptors = [ExecUserAuthInterceptor(auth_plugin)]
exec_add_servicer_to_server_fn = add_ExecServicer_to_server
exec_grpc_server = generic_create_grpc_server(
servicer_and_add_fn=(exec_servicer, exec_add_servicer_to_server_fn),
server_address=address,
max_message_length=GRPC_MAX_MESSAGE_LENGTH,
certificates=certificates,
interceptors=interceptors,
)

log(INFO, "Flower Deployment Engine: Starting Exec API on %s", address)
if auth_plugin is None:
log(INFO, "Flower Deployment Engine: Starting Exec API on %s", address)
else:
log(
INFO,
"Flower Deployment Engine: Starting Exec API with user "
"authentication on %s",
address,
)
exec_grpc_server.start()

return exec_grpc_server
25 changes: 22 additions & 3 deletions src/py/flwr/superexec/exec_servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
import time
from collections.abc import Generator
from logging import ERROR, INFO
from typing import Any
from typing import Any, Optional

import grpc

from flwr.common import now
from flwr.common.auth_plugin import ExecAuthPlugin
from flwr.common.constant import LOG_STREAM_INTERVAL, Status, SubStatus
from flwr.common.logger import log
from flwr.common.serde import (
Expand Down Expand Up @@ -60,11 +61,13 @@ def __init__(
linkstate_factory: LinkStateFactory,
ffs_factory: FfsFactory,
executor: Executor,
auth_plugin: Optional[ExecAuthPlugin] = None,
) -> None:
self.linkstate_factory = linkstate_factory
self.ffs_factory = ffs_factory
self.executor = executor
self.executor.initialize(linkstate_factory, ffs_factory)
self.auth_plugin = auth_plugin

def StartRun(
self, request: StartRunRequest, context: grpc.ServicerContext
Expand Down Expand Up @@ -164,14 +167,30 @@ def GetLoginDetails(
) -> GetLoginDetailsResponse:
"""Start login."""
log(INFO, "ExecServicer.GetLoginDetails")
return GetLoginDetailsResponse(login_details={})
if self.auth_plugin is None:
context.abort(
grpc.StatusCode.UNIMPLEMENTED,
"ExecServicer initialized without user authentication",
)
raise grpc.RpcError() # This line is unreachable
return GetLoginDetailsResponse(
login_details=self.auth_plugin.get_login_details()
)

def GetAuthTokens(
self, request: GetAuthTokensRequest, context: grpc.ServicerContext
) -> GetAuthTokensResponse:
"""Get auth token."""
log(INFO, "ExecServicer.GetAuthTokens")
return GetAuthTokensResponse(auth_tokens={})
if self.auth_plugin is None:
context.abort(
grpc.StatusCode.UNIMPLEMENTED,
"ExecServicer initialized without user authentication",
)
raise grpc.RpcError() # This line is unreachable
return GetAuthTokensResponse(
auth_tokens=self.auth_plugin.get_auth_tokens(dict(request.auth_details))
)


def _create_list_runs_response(run_ids: set[int], state: LinkState) -> ListRunsResponse:
Expand Down
101 changes: 101 additions & 0 deletions src/py/flwr/superexec/exec_user_auth_interceptor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Copyright 2024 Flower Labs GmbH. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Flower Exec API interceptor."""


from typing import Any, Callable, Union

import grpc

from flwr.common.auth_plugin import ExecAuthPlugin
from flwr.proto.exec_pb2 import ( # pylint: disable=E0611
GetAuthTokensRequest,
GetAuthTokensResponse,
GetLoginDetailsRequest,
GetLoginDetailsResponse,
StartRunRequest,
StartRunResponse,
StreamLogsRequest,
StreamLogsResponse,
)

Request = Union[
StartRunRequest,
StreamLogsRequest,
GetLoginDetailsRequest,
GetAuthTokensRequest,
]

Response = Union[
StartRunResponse, StreamLogsResponse, GetLoginDetailsResponse, GetAuthTokensResponse
]


class ExecUserAuthInterceptor(grpc.ServerInterceptor): # type: ignore
"""Exec API interceptor for user authentication."""

def __init__(
self,
auth_plugin: ExecAuthPlugin,
):
self.auth_plugin = auth_plugin

def intercept_service(
self,
continuation: Callable[[Any], Any],
handler_call_details: grpc.HandlerCallDetails,
) -> grpc.RpcMethodHandler:
"""Flower server interceptor authentication logic.
Intercept all unary-unary/unary-stream calls from users and authenticate users
by validating auth metadata sent by the user. Continue RPC call if user is
authenticated, else, terminate RPC call by setting context to abort.
"""
# One of the method handlers in
# `flwr.superexec.exec_servicer.ExecServicer`
method_handler: grpc.RpcMethodHandler = continuation(handler_call_details)
return self._generic_auth_unary_method_handler(method_handler)

def _generic_auth_unary_method_handler(
self, method_handler: grpc.RpcMethodHandler
) -> grpc.RpcMethodHandler:
def _generic_method_handler(
request: Request,
context: grpc.ServicerContext,
) -> Response:
call = method_handler.unary_unary or method_handler.unary_stream
metadata = context.invocation_metadata()
if isinstance(
request, (GetLoginDetailsRequest, GetAuthTokensRequest)
) or self.auth_plugin.validate_tokens_in_metadata(metadata):
return call(request, context) # type: ignore

tokens = self.auth_plugin.refresh_tokens(context.invocation_metadata())
if tokens is not None:
context.send_initial_metadata(tokens)
return call(request, context) # type: ignore

context.abort(grpc.StatusCode.UNAUTHENTICATED, "Access denied")
raise grpc.RpcError() # This line is unreachable

if method_handler.unary_unary:
message_handler = grpc.unary_unary_rpc_method_handler
else:
message_handler = grpc.unary_stream_rpc_method_handler
return message_handler(
_generic_method_handler,
request_deserializer=method_handler.request_deserializer,
response_serializer=method_handler.response_serializer,
)

0 comments on commit c955c27

Please sign in to comment.