Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Jul 11, 2024
1 parent 3d3b9b1 commit f5afde3
Showing 1 changed file with 31 additions and 9 deletions.
40 changes: 31 additions & 9 deletions allenact/base_abstractions/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,10 +370,10 @@ def set_seed(self, seed: int) -> None:


class BatchedTask(Generic[EnvType]):
"""An abstract class defining a batch of goal directed 'tasks.' Agents interact
with their environment through a task by taking a `step` after which they
receive new observations, rewards, and (potentially) other useful
information.
"""An abstract class defining a batch of goal directed 'tasks.' Agents
interact with their environment through a task by taking a `step` after
which they receive new observations, rewards, and (potentially) other
useful information.
A BatchedTask is a wrapper around a specific Task
and allows for multiple tasks to be simultaneously executed in the same scene.
Expand Down Expand Up @@ -439,7 +439,9 @@ def __init__(
self.tasks.append(self.make_new_task(it))

# Also, a ThreadPoolExecutor to collect all data (possibly) under IO bottlenecks
self.executor = ThreadPoolExecutor(max_workers=min(10, self.task_sampler.task_batch_size))
self.executor = ThreadPoolExecutor(
max_workers=min(10, self.task_sampler.task_batch_size)
)

# Also, a mutex to enable underlying task sampler implementations to ensure e.g. only one process
# resets the sampler when called from a ThreadPoolExecutor (next_task must be thread safe, possibly
Expand Down Expand Up @@ -471,7 +473,12 @@ def get_observations(self, **kwargs) -> List[Any]: # -> Dict[str, Any]:
def obs_extract(it, task):
res[it] = task.get_observations()

wait([self.executor.submit(obs_extract, it, task) for it, task in enumerate(self.tasks)])
wait(
[
self.executor.submit(obs_extract, it, task)
for it, task in enumerate(self.tasks)
]
)
# for it, task in enumerate(self.tasks):
# obs_extract(it, task)

Expand Down Expand Up @@ -547,7 +554,12 @@ def update_after_step(it, current_task):
infos[it] = info

# Ensure completion with wait():
wait([self.executor.submit(update_after_step, it, current_task) for it, current_task in enumerate(self.tasks)])
wait(
[
self.executor.submit(update_after_step, it, current_task)
for it, current_task in enumerate(self.tasks)
]
)
# for it, current_task in enumerate(self.tasks):
# update_after_step(it, current_task)

Expand All @@ -567,7 +579,12 @@ def _step(self, action: Any) -> List[RLStepResult]:
def before_step(it, task):
actions[it], intermediates[it] = task._before_env_step(action[it])

wait([self.executor.submit(before_step, it, task) for it, task in enumerate(self.tasks)])
wait(
[
self.executor.submit(before_step, it, task)
for it, task in enumerate(self.tasks)
]
)
# for it, task in enumerate(self.tasks):
# before_step(it, task)

Expand All @@ -580,7 +597,12 @@ def before_step(it, task):
def after_step(it, task):
srs[it] = task._after_env_step(action[it], actions[it], intermediates[it])

wait([self.executor.submit(after_step, it, task) for it, task in enumerate(self.tasks)])
wait(
[
self.executor.submit(after_step, it, task)
for it, task in enumerate(self.tasks)
]
)
# for it, task in enumerate(self.tasks):
# after_step(it, task)

Expand Down

0 comments on commit f5afde3

Please sign in to comment.