Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle Job-Mismatch request #174

Merged
merged 5 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion fed/grpc/fed.proto
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@ message SendDataRequest {
};

message SendDataResponse {
string result = 1;
int32 code = 1;
string result = 2;
};
147 changes: 132 additions & 15 deletions fed/grpc/pb3/fed_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
# source: fed.proto
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
Expand All @@ -28,12 +27,113 @@



DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3')
DESCRIPTOR = _descriptor.FileDescriptor(
name='fed.proto',
package='',
syntax='proto3',
serialized_options=b'\200\001\001',
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"0\n\x10SendDataResponse\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x05\x12\x0e\n\x06result\x18\x02 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3'
)



_SENDDATAREQUEST = DESCRIPTOR.message_types_by_name['SendDataRequest']
_SENDDATARESPONSE = DESCRIPTOR.message_types_by_name['SendDataResponse']

_SENDDATAREQUEST = _descriptor.Descriptor(
name='SendDataRequest',
full_name='SendDataRequest',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='data', full_name='SendDataRequest.data', index=0,
number=1, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=b"",
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='upstream_seq_id', full_name='SendDataRequest.upstream_seq_id', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='downstream_seq_id', full_name='SendDataRequest.downstream_seq_id', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='job_name', full_name='SendDataRequest.job_name', index=3,
number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=13,
serialized_end=114,
)


_SENDDATARESPONSE = _descriptor.Descriptor(
name='SendDataResponse',
full_name='SendDataResponse',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='code', full_name='SendDataResponse.code', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='result', full_name='SendDataResponse.result', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=116,
serialized_end=164,
)

DESCRIPTOR.message_types_by_name['SendDataRequest'] = _SENDDATAREQUEST
DESCRIPTOR.message_types_by_name['SendDataResponse'] = _SENDDATARESPONSE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)

SendDataRequest = _reflection.GeneratedProtocolMessageType('SendDataRequest', (_message.Message,), {
'DESCRIPTOR' : _SENDDATAREQUEST,
'__module__' : 'fed_pb2'
Expand All @@ -48,15 +148,32 @@
})
_sym_db.RegisterMessage(SendDataResponse)

_GRPCSERVICE = DESCRIPTOR.services_by_name['GrpcService']
if _descriptor._USE_C_DESCRIPTORS == False:

DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b'\200\001\001'
_SENDDATAREQUEST._serialized_start=13
_SENDDATAREQUEST._serialized_end=114
_SENDDATARESPONSE._serialized_start=116
_SENDDATARESPONSE._serialized_end=150
_GRPCSERVICE._serialized_start=152
_GRPCSERVICE._serialized_end=216

DESCRIPTOR._options = None

_GRPCSERVICE = _descriptor.ServiceDescriptor(
name='GrpcService',
full_name='GrpcService',
file=DESCRIPTOR,
index=0,
serialized_options=None,
create_key=_descriptor._internal_create_key,
serialized_start=166,
serialized_end=230,
methods=[
_descriptor.MethodDescriptor(
name='SendData',
full_name='GrpcService.SendData',
index=0,
containing_service=None,
input_type=_SENDDATAREQUEST,
output_type=_SENDDATARESPONSE,
serialized_options=None,
create_key=_descriptor._internal_create_key,
),
])
_sym_db.RegisterServiceDescriptor(_GRPCSERVICE)

DESCRIPTOR.services_by_name['GrpcService'] = _GRPCSERVICE

# @@protoc_insertion_point(module_scope)
8 changes: 4 additions & 4 deletions fed/grpc/pb4/fed_pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@



DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"\"\n\x10SendDataResponse\x12\x0e\n\x06result\x18\x01 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3')
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tfed.proto\"e\n\x0fSendDataRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\x17\n\x0fupstream_seq_id\x18\x02 \x01(\t\x12\x19\n\x11\x64ownstream_seq_id\x18\x03 \x01(\t\x12\x10\n\x08job_name\x18\x04 \x01(\t\"0\n\x10SendDataResponse\x12\x0c\n\x04\x63ode\x18\x01 \x01(\x05\x12\x0e\n\x06result\x18\x02 \x01(\t2@\n\x0bGrpcService\x12\x31\n\x08SendData\x12\x10.SendDataRequest\x1a\x11.SendDataResponse\"\x00\x42\x03\x80\x01\x01\x62\x06proto3')

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
Expand All @@ -39,7 +39,7 @@
_globals['_SENDDATAREQUEST']._serialized_start=13
_globals['_SENDDATAREQUEST']._serialized_end=114
_globals['_SENDDATARESPONSE']._serialized_start=116
_globals['_SENDDATARESPONSE']._serialized_end=150
_globals['_GRPCSERVICE']._serialized_start=152
_globals['_GRPCSERVICE']._serialized_end=216
_globals['_SENDDATARESPONSE']._serialized_end=164
_globals['_GRPCSERVICE']._serialized_start=166
_globals['_GRPCSERVICE']._serialized_end=230
# @@protoc_insertion_point(module_scope)
20 changes: 17 additions & 3 deletions fed/proxy/grpc/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ async def send(
timeout=timeout,
metadata=grpc_metadata,
)
return response
self.handle_response_error(response)
return response.result

def get_grpc_config_by_party(self, dest_party):
"""Overide global config by party specific config
Expand Down Expand Up @@ -167,6 +168,17 @@ async def get_proxy_config(self, dest_party=None):
proxy_config.update({'grpc_options': grpc_options})
return proxy_config

def handle_response_error(self, response):
if response.code == 200:
return

if 400 <= response.code < 500:
# Request error should also be identified as a sending failure,
# though the request was physically sent.
logger.warning(f"Request was successfully sent but got error response, "
f"code: {response.code}, message: {response.result}.")
raise RuntimeError(response.result)


async def send_data_grpc(
data,
Expand All @@ -192,9 +204,10 @@ async def send_data_grpc(
)
logger.debug(
f'Received data response from seq_id {downstream_seq_id}, '
f'code: {response.code}, '
f'result: {response.result}.'
)
return response.result
return response


class GrpcReceiverProxy(ReceiverProxy):
Expand Down Expand Up @@ -287,6 +300,7 @@ async def SendData(self, request, context):
f"The reason may be that the ReceiverProxy is listening "
f"on the same address with that job.")
return fed_pb2.SendDataResponse(
code=417,
result=f"JobName mis-match, expected {self._job_name}, got {job_name}.")
upstream_seq_id = request.upstream_seq_id
downstream_seq_id = request.downstream_seq_id
Expand All @@ -309,7 +323,7 @@ async def SendData(self, request, context):
event = get_from_two_dim_dict(self._events, upstream_seq_id, downstream_seq_id)
event.set()
logger.debug(f"Event set for {upstream_seq_id}")
return fed_pb2.SendDataResponse(result="OK")
return fed_pb2.SendDataResponse(code=200, result="OK")


async def _run_grpc_server(
Expand Down
5 changes: 3 additions & 2 deletions tests/multi-jobs/test_ignore_other_job_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@ async def send(
timeout=timeout,
metadata=grpc_metadata,
)
assert "JobName mis-match" in response
assert response.code == 417
assert "JobName mis-match" in response.result
# So that process can exit
raise RuntimeError()
raise RuntimeError(response.result)


@fed.remote
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transport_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def SendData(self, request, context):
assert v == metadata[k]
event = asyncio.Event()
event.set()
return fed_pb2.SendDataResponse(result="OK")
return fed_pb2.SendDataResponse(code=200, result="OK")


async def _test_run_grpc_server(
Expand Down
Loading