From 3a1420dfa2dba7ef0fc4c8cda32c5db83215c69e Mon Sep 17 00:00:00 2001 From: Changwan Ryu Date: Thu, 25 Apr 2024 23:32:19 -0700 Subject: [PATCH] Acme: Increase thread pool size to prevent a hang between actors and inference servers According to the documentation [1] when the thread pool size is smaller than the batch size, it is possible to hang when the batched handler is waiting to collect the next example but all the threads are busy synchronously waiting for the results. PiperOrigin-RevId: 628306415 Change-Id: I9c48a689d0e667577f361495524c8fd2b980653e --- acme/jax/experiments/make_distributed_experiment.py | 11 ++++++++--- acme/jax/inference_server.py | 5 +++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/acme/jax/experiments/make_distributed_experiment.py b/acme/jax/experiments/make_distributed_experiment.py index 8830bd2393..b04708bdcd 100644 --- a/acme/jax/experiments/make_distributed_experiment.py +++ b/acme/jax/experiments/make_distributed_experiment.py @@ -345,6 +345,13 @@ def build_actor( if inference_server_config is not None: num_inference_nodes = num_tasks_per_inference_server * num_inference_servers num_actors_per_server = math.ceil(num_actors / num_inference_nodes) + thread_pool_size = ( + 2 * max( + inference_server_config.batch_size, + num_actors_per_server, + ) + ) + inference_nodes = [] for i in range(num_inference_servers): with program.group(f'inference_server_{i}'): @@ -355,9 +362,7 @@ def build_actor( build_inference_server, inference_server_config, learner, - courier_kwargs={ - 'thread_pool_size': num_actors_per_server, - }, + courier_kwargs={'thread_pool_size': thread_pool_size}, ) ) ) diff --git a/acme/jax/inference_server.py b/acme/jax/inference_server.py index e87071e6f3..a7d331aa9b 100644 --- a/acme/jax/inference_server.py +++ b/acme/jax/inference_server.py @@ -137,9 +137,10 @@ def dereference_params_and_call_handler(*args, **kwargs): return handler(*args_with_dereferenced_params, **kwargs_with_dereferenced_params) + max_parallelism = 2 * max(len(self._devices), self._config.batch_size) return lp.batched_handler( batch_size=self._config.batch_size, timeout=self._config.timeout, pad_batch=True, - max_parallelism=2 * len(self._devices))( - dereference_params_and_call_handler) + max_parallelism=max_parallelism, + )(dereference_params_and_call_handler)