Skip to content

Commit

Permalink
Exception re-raising
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Aug 15, 2024
1 parent fac0724 commit d82a651
Showing 1 changed file with 19 additions and 6 deletions.
25 changes: 19 additions & 6 deletions allenact/base_abstractions/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
Union,
final,
)
from concurrent.futures import ThreadPoolExecutor, wait
import concurrent.futures as cf
from contextlib import contextmanager

import gym
Expand Down Expand Up @@ -452,7 +452,7 @@ def __init__(

if self.any_parallel:
# Also, a ThreadPoolExecutor to collect all data (possibly) under IO bottlenecks
self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size)
self.executor = cf.ThreadPoolExecutor(max_workers=self.thread_pool_size)

# Also, a mutex to enable underlying task sampler implementations to ensure e.g. only one process
# resets the sampler when called from a ThreadPoolExecutor (next_task must be thread safe, possibly
Expand All @@ -462,7 +462,7 @@ def __init__(
# after step is the one where we also parallelize instantiating new tasks
with self.wrap_with_task_batch_size_0() as true_task_batch_size: # type:ignore
if self.parallel_init:
wait(
self.wait_for_futures_and_raise_errors(
[
self.executor.submit(self.make_new_task, it)
for it in range(1, true_task_batch_size)
Expand All @@ -473,6 +473,19 @@ def __init__(
for it in range(1, true_task_batch_size):
self.make_new_task(it)

@staticmethod
def wait_for_futures_and_raise_errors(
futures: Sequence[cf.Future],
) -> Sequence[Any]:
results = []
cf.wait(futures)
for future in futures:
try:
results.append(future.result()) # This will re-raise any exceptions
except Exception:
raise
return results

@contextmanager
def wrap_with_task_batch_size_0(self):
task_batch_size = self.task_sampler.task_batch_size
Expand Down Expand Up @@ -504,7 +517,7 @@ def obs_extract(it, task):
res[it] = task.get_observations()

if self.parallel_get_observations:
wait(
self.wait_for_futures_and_raise_errors(
[
self.executor.submit(obs_extract, it, task)
for it, task in enumerate(self.tasks)
Expand Down Expand Up @@ -549,7 +562,7 @@ def before_step(it, task):
env_actions[it], intermediates[it] = task._before_env_step(action[it])

if self.parallel_before_step:
wait(
self.wait_for_futures_and_raise_errors(
[
self.executor.submit(before_step, it, task)
for it, task in enumerate(self.tasks)
Expand Down Expand Up @@ -609,7 +622,7 @@ def after_step(it, task):

with self.wrap_with_task_batch_size_0(): # type:ignore
if self.parallel_after_step:
wait(
self.wait_for_futures_and_raise_errors(
[
self.executor.submit(after_step, it, task)
for it, task in enumerate(self.tasks)
Expand Down

0 comments on commit d82a651

Please sign in to comment.