Skip to content

Commit

Permalink
Fix CI
Browse files Browse the repository at this point in the history
  • Loading branch information
jovany-wang committed Jul 21, 2023
1 parent e46fc13 commit 4d5390b
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 66 deletions.
1 change: 0 additions & 1 deletion fed/proxy/grpc/grpc_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def __init__(
self._lock = threading.Lock()

async def start(self):
print(f"==================listen_addr={self._listen_addr}")
port = self._listen_addr[self._listen_addr.index(':') + 1 :]
try:
await _run_grpc_server(
Expand Down
77 changes: 15 additions & 62 deletions tests/test_transport_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,22 @@ async def is_ready(self):

def _test_start_receiver_proxy(
addresses: str,
config: dict,
party: str,
logging_level: str,
expected_metadata: dict,
):
# Create RecevrProxyActor
# Not that this is now a threaded actor.
party_addr = addresses[party]
listen_addr = party_addr
print(f"160======listen_addr={listen_addr}")
if not listen_addr:
listen_addr = party_addr['address']
address = addresses[party]
listening_address = config['cross_silo_message'].get('listening_address', None)
if not listening_address:
listening_address = address

receiver_proxy_actor = TestReceiverProxyActor.options(
name=f"ReceiverProxyActor-{party}", max_concurrency=1000
).remote(
listen_addr=listen_addr,
listen_addr=listening_address,
party=party,
expected_metadata=expected_metadata
)
Expand Down Expand Up @@ -195,70 +195,23 @@ def test_send_grpc_with_meta():
global_context.get_global_context().get_cleanup_manager().start()

SERVER_ADDRESS = "127.0.0.1:12344"
party = 'test_party'
cluster_config = {'test_party': SERVER_ADDRESS}
party_name = 'test_party'
addresses = {party_name: SERVER_ADDRESS}
_test_start_receiver_proxy(
cluster_config, party, logging_level='info',
addresses,
{'cross_silo_message': {}},
party_name,
logging_level='info',
expected_metadata=metadata,
)
_start_sender_proxy(
cluster_config,
party,
addresses,
party_name,
logging_level='info',
proxy_cls=GrpcSenderProxy,
proxy_config=GrpcCrossSiloMessageConfig())
sent_objs = []
sent_obj = send(party, "data", 0, 1)
sent_objs.append(sent_obj)
for result in ray.get(sent_objs):
assert result

global_context.get_global_context().get_cleanup_manager().graceful_stop()
global_context.clear_global_context()
ray.shutdown()


def test_send_grpc_with_party_specific_meta():
compatible_utils.init_ray(address='local')
cluster_config = {
constants.KEY_OF_CLUSTER_ADDRESSES: "",
constants.KEY_OF_CURRENT_PARTY_NAME: "",
constants.KEY_OF_TLS_CONFIG: "",
}
sender_proxy_config = CrossSiloMessageConfig(
http_header={"key": "value"})
job_config = {
constants.KEY_OF_CROSS_SILO_MESSAGE_CONFIG:
sender_proxy_config,
}
compatible_utils._init_internal_kv()
compatible_utils.kv.put(constants.KEY_OF_CLUSTER_CONFIG,
cloudpickle.dumps(cluster_config))
compatible_utils.kv.put(constants.KEY_OF_JOB_CONFIG,
cloudpickle.dumps(job_config))
global_context.get_global_context().get_cleanup_manager().start()

SERVER_ADDRESS = "127.0.0.1:12344"
party = 'test_party'
cluster_parties_config = {
'test_party': {
'cross_silo_message': CrossSiloMessageConfig(
http_header={"token": "test-party-token"},
listening_address=SERVER_ADDRESS,)
}
}
_test_start_receiver_proxy(
cluster_parties_config, party, logging_level='info',
expected_metadata={"key": "value", "token": "test-party-token"},
)
_start_sender_proxy(
cluster_parties_config,
party,
logging_level='info',
proxy_cls=GrpcSenderProxy,
proxy_config=sender_proxy_config)
sent_objs = []
sent_obj = send(party, "data", 0, 1)
sent_obj = send(party_name, "data", 0, 1)
sent_objs.append(sent_obj)
for result in ray.get(sent_objs):
assert result
Expand Down
6 changes: 3 additions & 3 deletions tests/test_transport_proxy_tls.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,18 @@ def test_n_to_1_transport():
NUM_DATA = 10
SERVER_ADDRESS = "127.0.0.1:65422"
party = 'test_party'
cluster_config = {'test_party': {'address': SERVER_ADDRESS}}
addresses = {'test_party': SERVER_ADDRESS}
config = GrpcCrossSiloMessageConfig()
_start_receiver_proxy(
cluster_config,
addresses,
party,
logging_level='info',
tls_config=tls_config,
proxy_cls=GrpcReceiverProxy,
proxy_config=config
)
_start_sender_proxy(
cluster_config,
addresses,
party,
logging_level='info',
tls_config=tls_config,
Expand Down

0 comments on commit 4d5390b

Please sign in to comment.