Skip to content

Commit

Permalink
feat(framework) Add new RPCs to Control service (#4241)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Sep 23, 2024
1 parent 247cada commit b6babe9
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 12 deletions.
7 changes: 7 additions & 0 deletions src/proto/flwr/proto/control.proto
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,11 @@ import "flwr/proto/run.proto";
service Control {
// Request to create a new run
rpc CreateRun(CreateRunRequest) returns (CreateRunResponse) {}

// Get the status of a given run
rpc GetRunStatus(GetRunStatusRequest) returns (GetRunStatusResponse) {}

// Update the status of a given run
rpc UpdateRunStatus(UpdateRunStatusRequest)
returns (UpdateRunStatusResponse) {}
}
20 changes: 20 additions & 0 deletions src/proto/flwr/proto/run.proto
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ message Run {
string fab_hash = 5;
}

message RunStatus {
// "starting", "running", "finished"
string status = 1;
// "completed", "failed", "stopped" or "" (non-finished)
string sub_status = 2;
// failure details
string details = 3;
}

// CreateRun
message CreateRunRequest {
string fab_id = 1;
Expand All @@ -40,3 +49,14 @@ message CreateRunResponse { uint64 run_id = 1; }
// GetRun
message GetRunRequest { uint64 run_id = 1; }
message GetRunResponse { Run run = 1; }

// UpdateRunStatus
message UpdateRunStatusRequest {
uint64 run_id = 1;
RunStatus run_status = 2;
}
message UpdateRunStatusResponse {}

// GetRunStatus
message GetRunStatusRequest { repeated uint64 run_ids = 1; }
message GetRunStatusResponse { map<uint64, RunStatus> run_status_dict = 1; }
6 changes: 3 additions & 3 deletions src/py/flwr/proto/control_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

68 changes: 68 additions & 0 deletions src/py/flwr/proto/control_pb2_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def __init__(self, channel):
request_serializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString,
)
self.GetRunStatus = channel.unary_unary(
'/flwr.proto.Control/GetRunStatus',
request_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
)
self.UpdateRunStatus = channel.unary_unary(
'/flwr.proto.Control/UpdateRunStatus',
request_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
response_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
)


class ControlServicer(object):
Expand All @@ -31,6 +41,20 @@ def CreateRun(self, request, context):
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def GetRunStatus(self, request, context):
"""Get the status of a given run
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')

def UpdateRunStatus(self, request, context):
"""Update the status of a given run
"""
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
context.set_details('Method not implemented!')
raise NotImplementedError('Method not implemented!')


def add_ControlServicer_to_server(servicer, server):
rpc_method_handlers = {
Expand All @@ -39,6 +63,16 @@ def add_ControlServicer_to_server(servicer, server):
request_deserializer=flwr_dot_proto_dot_run__pb2.CreateRunRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.CreateRunResponse.SerializeToString,
),
'GetRunStatus': grpc.unary_unary_rpc_method_handler(
servicer.GetRunStatus,
request_deserializer=flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.SerializeToString,
),
'UpdateRunStatus': grpc.unary_unary_rpc_method_handler(
servicer.UpdateRunStatus,
request_deserializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.FromString,
response_serializer=flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.SerializeToString,
),
}
generic_handler = grpc.method_handlers_generic_handler(
'flwr.proto.Control', rpc_method_handlers)
Expand All @@ -65,3 +99,37 @@ def CreateRun(request,
flwr_dot_proto_dot_run__pb2.CreateRunResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def GetRunStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/GetRunStatus',
flwr_dot_proto_dot_run__pb2.GetRunStatusRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.GetRunStatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)

@staticmethod
def UpdateRunStatus(request,
target,
options=(),
channel_credentials=None,
call_credentials=None,
insecure=False,
compression=None,
wait_for_ready=None,
timeout=None,
metadata=None):
return grpc.experimental.unary_unary(request, target, '/flwr.proto.Control/UpdateRunStatus',
flwr_dot_proto_dot_run__pb2.UpdateRunStatusRequest.SerializeToString,
flwr_dot_proto_dot_run__pb2.UpdateRunStatusResponse.FromString,
options, channel_credentials,
insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
26 changes: 26 additions & 0 deletions src/py/flwr/proto/control_pb2_grpc.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,16 @@ class ControlStub:
flwr.proto.run_pb2.CreateRunResponse]
"""Request to create a new run"""

GetRunStatus: grpc.UnaryUnaryMultiCallable[
flwr.proto.run_pb2.GetRunStatusRequest,
flwr.proto.run_pb2.GetRunStatusResponse]
"""Get the status of a given run"""

UpdateRunStatus: grpc.UnaryUnaryMultiCallable[
flwr.proto.run_pb2.UpdateRunStatusRequest,
flwr.proto.run_pb2.UpdateRunStatusResponse]
"""Update the status of a given run"""


class ControlServicer(metaclass=abc.ABCMeta):
@abc.abstractmethod
Expand All @@ -23,5 +33,21 @@ class ControlServicer(metaclass=abc.ABCMeta):
"""Request to create a new run"""
pass

@abc.abstractmethod
def GetRunStatus(self,
request: flwr.proto.run_pb2.GetRunStatusRequest,
context: grpc.ServicerContext,
) -> flwr.proto.run_pb2.GetRunStatusResponse:
"""Get the status of a given run"""
pass

@abc.abstractmethod
def UpdateRunStatus(self,
request: flwr.proto.run_pb2.UpdateRunStatusRequest,
context: grpc.ServicerContext,
) -> flwr.proto.run_pb2.UpdateRunStatusResponse:
"""Update the status of a given run"""
pass


def add_ControlServicer_to_server(servicer: ControlServicer, server: grpc.Server) -> None: ...
32 changes: 23 additions & 9 deletions src/py/flwr/proto/run_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

86 changes: 86 additions & 0 deletions src/py/flwr/proto/run_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,29 @@ class Run(google.protobuf.message.Message):
def ClearField(self, field_name: typing_extensions.Literal["fab_hash",b"fab_hash","fab_id",b"fab_id","fab_version",b"fab_version","override_config",b"override_config","run_id",b"run_id"]) -> None: ...
global___Run = Run

class RunStatus(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
STATUS_FIELD_NUMBER: builtins.int
SUB_STATUS_FIELD_NUMBER: builtins.int
DETAILS_FIELD_NUMBER: builtins.int
status: typing.Text
""""starting", "running", "finished" """

sub_status: typing.Text
""""completed", "failed", "stopped" or "" (non-finished)"""

details: typing.Text
"""failure details"""

def __init__(self,
*,
status: typing.Text = ...,
sub_status: typing.Text = ...,
details: typing.Text = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["details",b"details","status",b"status","sub_status",b"sub_status"]) -> None: ...
global___RunStatus = RunStatus

class CreateRunRequest(google.protobuf.message.Message):
"""CreateRun"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
Expand Down Expand Up @@ -126,3 +149,66 @@ class GetRunResponse(google.protobuf.message.Message):
def HasField(self, field_name: typing_extensions.Literal["run",b"run"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["run",b"run"]) -> None: ...
global___GetRunResponse = GetRunResponse

class UpdateRunStatusRequest(google.protobuf.message.Message):
"""UpdateRunStatus"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_ID_FIELD_NUMBER: builtins.int
RUN_STATUS_FIELD_NUMBER: builtins.int
run_id: builtins.int
@property
def run_status(self) -> global___RunStatus: ...
def __init__(self,
*,
run_id: builtins.int = ...,
run_status: typing.Optional[global___RunStatus] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["run_status",b"run_status"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["run_id",b"run_id","run_status",b"run_status"]) -> None: ...
global___UpdateRunStatusRequest = UpdateRunStatusRequest

class UpdateRunStatusResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
def __init__(self,
) -> None: ...
global___UpdateRunStatusResponse = UpdateRunStatusResponse

class GetRunStatusRequest(google.protobuf.message.Message):
"""GetRunStatus"""
DESCRIPTOR: google.protobuf.descriptor.Descriptor
RUN_IDS_FIELD_NUMBER: builtins.int
@property
def run_ids(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.int]: ...
def __init__(self,
*,
run_ids: typing.Optional[typing.Iterable[builtins.int]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_ids",b"run_ids"]) -> None: ...
global___GetRunStatusRequest = GetRunStatusRequest

class GetRunStatusResponse(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
class RunStatusDictEntry(google.protobuf.message.Message):
DESCRIPTOR: google.protobuf.descriptor.Descriptor
KEY_FIELD_NUMBER: builtins.int
VALUE_FIELD_NUMBER: builtins.int
key: builtins.int
@property
def value(self) -> global___RunStatus: ...
def __init__(self,
*,
key: builtins.int = ...,
value: typing.Optional[global___RunStatus] = ...,
) -> None: ...
def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ...
def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ...

RUN_STATUS_DICT_FIELD_NUMBER: builtins.int
@property
def run_status_dict(self) -> google.protobuf.internal.containers.MessageMap[builtins.int, global___RunStatus]: ...
def __init__(self,
*,
run_status_dict: typing.Optional[typing.Mapping[builtins.int, global___RunStatus]] = ...,
) -> None: ...
def ClearField(self, field_name: typing_extensions.Literal["run_status_dict",b"run_status_dict"]) -> None: ...
global___GetRunStatusResponse = GetRunStatusResponse

0 comments on commit b6babe9

Please sign in to comment.