Skip to content

Commit

Permalink
Acme: Increase thread pool size to prevent a hang between actors and …
Browse files Browse the repository at this point in the history
…inference servers

According to the documentation [1]

when the thread pool size is smaller than the batch size, it is possible to hang when the batched handler is waiting to collect the next
example but all the threads are busy synchronously waiting for the results.

PiperOrigin-RevId: 628306415
Change-Id: I9c48a689d0e667577f361495524c8fd2b980653e
  • Loading branch information
galmacky authored and copybara-github committed Apr 26, 2024
1 parent aa42e1c commit 3a1420d
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
11 changes: 8 additions & 3 deletions acme/jax/experiments/make_distributed_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,13 @@ def build_actor(
if inference_server_config is not None:
num_inference_nodes = num_tasks_per_inference_server * num_inference_servers
num_actors_per_server = math.ceil(num_actors / num_inference_nodes)
thread_pool_size = (
2 * max(
inference_server_config.batch_size,
num_actors_per_server,
)
)

inference_nodes = []
for i in range(num_inference_servers):
with program.group(f'inference_server_{i}'):
Expand All @@ -355,9 +362,7 @@ def build_actor(
build_inference_server,
inference_server_config,
learner,
courier_kwargs={
'thread_pool_size': num_actors_per_server,
},
courier_kwargs={'thread_pool_size': thread_pool_size},
)
)
)
Expand Down
5 changes: 3 additions & 2 deletions acme/jax/inference_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,10 @@ def dereference_params_and_call_handler(*args, **kwargs):
return handler(*args_with_dereferenced_params,
**kwargs_with_dereferenced_params)

max_parallelism = 2 * max(len(self._devices), self._config.batch_size)
return lp.batched_handler(
batch_size=self._config.batch_size,
timeout=self._config.timeout,
pad_batch=True,
max_parallelism=2 * len(self._devices))(
dereference_params_and_call_handler)
max_parallelism=max_parallelism,
)(dereference_params_and_call_handler)

0 comments on commit 3a1420d

Please sign in to comment.