Skip to content

Commit

Permalink
Reuse channel to send data. (#141)
Browse files Browse the repository at this point in the history
Before this PR, we'll create a new channel object when sending data. We reuse that in this PR to make code clean.

Also, we did a benchmark from @fengsp(Thanks for the contribution) for this PR, and the result is not too good.

The baseline(before this PR) is:
```
num calls: 10000
total time (ms) =  79241.70279502869
per task overhead (ms) = 7.924171590805054
```

and the current result after this PR is:
```
num calls: 10000
total time (ms) =  78388.72385025024
per task overhead (ms) = 7.83887369632721
```

---------

Signed-off-by: Qing Wang <kingchin1218@gmail.com>
  • Loading branch information
jovany-wang authored Jul 10, 2023
1 parent 204a0a3 commit 702565d
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 57 deletions.
65 changes: 65 additions & 0 deletions benchmarks/many_tiny_tasks_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright 2023 The RayFed Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import ray
import time
import sys
import fed


@fed.remote
class MyActor:
def run(self):
return None


@fed.remote
class Aggregator:
def aggr(self, val1, val2):
return None


def main(party):
ray.init(address='local')

cluster = {
'alice': {'address': '127.0.0.1:11010'},
'bob': {'address': '127.0.0.1:11011'},
}
fed.init(cluster=cluster, party=party)

actor_alice = MyActor.party("alice").remote()
actor_bob = MyActor.party("bob").remote()
aggregator = Aggregator.party("alice").remote()

start = time.time()
num_calls = 10000
for i in range(num_calls):
val_alice = actor_alice.run.remote()
val_bob = actor_bob.run.remote()
sum_val_obj = aggregator.aggr.remote(val_alice, val_bob)
fed.get(sum_val_obj)
if i % 100 == 0:
print(f"Running {i}th call")
print(f"num calls: {num_calls}")
print("total time (ms) = ", (time.time() - start)*1000)
print("per task overhead (ms) =", (time.time() - start)*1000/num_calls)

fed.shutdown()
ray.shutdown()


if __name__ == "__main__":
assert len(sys.argv) == 2, 'Please run this script with party.'
main(sys.argv[1])
99 changes: 42 additions & 57 deletions fed/proxy/barriers.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,65 +124,30 @@ async def _run_grpc_server(


async def send_data_grpc(
dest,
data,
stub,
upstream_seq_id,
downstream_seq_id,
metadata=None,
tls_config=None,
retry_policy=None,
grpc_options=None
):
grpc_options = get_grpc_options(retry_policy=retry_policy) if \
grpc_options is None else fed_utils.dict2tuple(grpc_options)
tls_enabled = fed_utils.tls_enabled(tls_config)
cluster_config = fed_config.get_cluster_config()
metadata = fed_utils.dict2tuple(metadata)
if tls_enabled:
ca_cert, private_key, cert_chain = fed_utils.load_cert_config(tls_config)
credentials = grpc.ssl_channel_credentials(
certificate_chain=cert_chain,
private_key=private_key,
root_certificates=ca_cert,
)

async with grpc.aio.secure_channel(
dest,
credentials,
options=grpc_options,
) as channel:
stub = fed_pb2_grpc.GrpcServiceStub(channel)
data = cloudpickle.dumps(data)
request = fed_pb2.SendDataRequest(
data=data,
upstream_seq_id=str(upstream_seq_id),
downstream_seq_id=str(downstream_seq_id),
)
# wait for downstream's reply
response = await stub.SendData(
request, metadata=metadata, timeout=cluster_config.cross_silo_timeout)
logger.debug(
f'Received data response from seq_id {downstream_seq_id}, '
f'result: {response.result}.'
)
return response.result
else:
async with grpc.aio.insecure_channel(dest, options=grpc_options) as channel:
stub = fed_pb2_grpc.GrpcServiceStub(channel)
data = cloudpickle.dumps(data)
request = fed_pb2.SendDataRequest(
data=data,
upstream_seq_id=str(upstream_seq_id),
downstream_seq_id=str(downstream_seq_id),
)
# wait for downstream's reply
response = await stub.SendData(
request, metadata=metadata, timeout=cluster_config.cross_silo_timeout)
logger.debug(
f'Received data response from seq_id {downstream_seq_id} '
f'result: {response.result}.'
)
return response.result
data = cloudpickle.dumps(data)
request = fed_pb2.SendDataRequest(
data=data,
upstream_seq_id=str(upstream_seq_id),
downstream_seq_id=str(downstream_seq_id),
)
# Waiting for the reply from downstream.
response = await stub.SendData(
request,
metadata=fed_utils.dict2tuple(metadata),
timeout=cluster_config.cross_silo_timeout,
)
logger.debug(
f'Received data response from seq_id {downstream_seq_id}, '
f'result: {response.result}.'
)
return response.result


@ray.remote
Expand All @@ -207,6 +172,8 @@ def __init__(
self._tls_config = tls_config
self.retry_policy = retry_policy
self._grpc_metadata = fed_config.get_job_config().grpc_metadata
# Mapping the destination party name to the reused client stub.
self._stubs = {}
cluster_config = fed_config.get_cluster_config()
set_max_message_length(cluster_config.cross_silo_messages_max_size)

Expand Down Expand Up @@ -235,15 +202,33 @@ async def send(
dest_addr = self._cluster[dest_party]['address']
dest_party_grpc_config = self.setup_grpc_config(dest_party)
try:
tls_enabled = fed_utils.tls_enabled(self._tls_config)
grpc_options = dest_party_grpc_config['grpc_options']
grpc_options = get_grpc_options(retry_policy=self.retry_policy) if \
grpc_options is None else fed_utils.dict2tuple(grpc_options)

if dest_party not in self._stubs:
if tls_enabled:
ca_cert, private_key, cert_chain = fed_utils.load_cert_config(
self._tls_config)
credentials = grpc.ssl_channel_credentials(
certificate_chain=cert_chain,
private_key=private_key,
root_certificates=ca_cert,
)
channel = grpc.aio.secure_channel(
dest_addr, credentials, options=grpc_options)
else:
channel = grpc.aio.insecure_channel(dest_addr, options=grpc_options)
stub = fed_pb2_grpc.GrpcServiceStub(channel)
self._stubs[dest_party] = stub

response = await send_data_grpc(
dest=dest_addr,
data=data,
stub=self._stubs[dest_party],
upstream_seq_id=upstream_seq_id,
downstream_seq_id=downstream_seq_id,
metadata=dest_party_grpc_config['grpc_metadata'],
tls_config=self._tls_config,
retry_policy=self.retry_policy,
grpc_options=dest_party_grpc_config['grpc_options']
)
except Exception as e:
logger.error(f'Failed to {send_log_msg}, error: {e}')
Expand Down

0 comments on commit 702565d

Please sign in to comment.