diff --git a/acme/jax/experiments/make_distributed_experiment.py b/acme/jax/experiments/make_distributed_experiment.py index 8830bd2393..b04708bdcd 100644 --- a/acme/jax/experiments/make_distributed_experiment.py +++ b/acme/jax/experiments/make_distributed_experiment.py @@ -345,6 +345,13 @@ def build_actor( if inference_server_config is not None: num_inference_nodes = num_tasks_per_inference_server * num_inference_servers num_actors_per_server = math.ceil(num_actors / num_inference_nodes) + thread_pool_size = ( + 2 * max( + inference_server_config.batch_size, + num_actors_per_server, + ) + ) + inference_nodes = [] for i in range(num_inference_servers): with program.group(f'inference_server_{i}'): @@ -355,9 +362,7 @@ def build_actor( build_inference_server, inference_server_config, learner, - courier_kwargs={ - 'thread_pool_size': num_actors_per_server, - }, + courier_kwargs={'thread_pool_size': thread_pool_size}, ) ) ) diff --git a/acme/jax/inference_server.py b/acme/jax/inference_server.py index e87071e6f3..a7d331aa9b 100644 --- a/acme/jax/inference_server.py +++ b/acme/jax/inference_server.py @@ -137,9 +137,10 @@ def dereference_params_and_call_handler(*args, **kwargs): return handler(*args_with_dereferenced_params, **kwargs_with_dereferenced_params) + max_parallelism = 2 * max(len(self._devices), self._config.batch_size) return lp.batched_handler( batch_size=self._config.batch_size, timeout=self._config.timeout, pad_batch=True, - max_parallelism=2 * len(self._devices))( - dereference_params_and_call_handler) + max_parallelism=max_parallelism, + )(dereference_params_and_call_handler)