Skip to content

Commit

Permalink
Update task.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jordis-ai2 committed Jul 12, 2024
1 parent fd600c2 commit c217391
Showing 1 changed file with 53 additions and 57 deletions.
110 changes: 53 additions & 57 deletions allenact/base_abstractions/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from allenact.base_abstractions.misc import RLStepResult
from allenact.base_abstractions.sensor import Sensor, SensorSuite
from allenact.utils.misc_utils import deprecated
from allenact.utils.system import get_logger

COMPLETE_TASK_METRICS_KEY = "__AFTER_TASK_METRICS__"
COMPLETE_TASK_CALLBACK_KEY = "__AFTER_TASK_CALLBACK__"
Expand Down Expand Up @@ -409,24 +408,13 @@ def __init__(
parallel_before_step: bool = False,
parallel_after_step: bool = True,
parallel_get_observations: bool = True,
**kwargs,
max_thread_pool_size: int = 10,
**task_kwargs: Any,
) -> None:
assert hasattr(
task_sampler, "task_batch_size"
), "BatchedTask requires task_sampler to contain a `task_batch_size`"

# Keep a reference to the task sampler
self.task_sampler = task_sampler

self.callback_sensor_suite = callback_sensor_suite

self.env = env

self.parallel_before_step = parallel_before_step
self.parallel_after_step = parallel_after_step
self.parallel_get_observations = parallel_get_observations
self.any_parallel = parallel_before_step or parallel_after_step or parallel_get_observations

# Instantiate the first actual task from the currently sampled info
self.tasks = [
task_class(
Expand All @@ -435,21 +423,35 @@ def __init__(
task_info=task_info,
max_steps=max_steps,
batch_index=0,
**kwargs,
**task_kwargs,
)
]
self.tasks[0].batch_index = 0

# If task_batch_size greater than 0, instantiate the rest of tasks
if self.task_sampler.task_batch_size > 0:
if task_sampler.task_batch_size > 0:
# Keep a reference to the task sampler
self.task_sampler = task_sampler

self.callback_sensor_suite = callback_sensor_suite
self.env = env

self.parallel_before_step = parallel_before_step
self.parallel_after_step = parallel_after_step
self.parallel_get_observations = parallel_get_observations
self.any_parallel = (
parallel_before_step or parallel_after_step or parallel_get_observations
)
self.thread_pool_size = min(
max_thread_pool_size, self.task_sampler.task_batch_size
)

# If task_batch_size greater than 0, instantiate the rest of tasks
for it in range(1, self.task_sampler.task_batch_size):
self.tasks.append(self.make_new_task(it))

if self.any_parallel:
# Also, a ThreadPoolExecutor to collect all data (possibly) under IO bottlenecks
self.executor = ThreadPoolExecutor(
max_workers=min(10, self.task_sampler.task_batch_size)
)
self.executor = ThreadPoolExecutor(max_workers=self.thread_pool_size)

# Also, a mutex to enable underlying task sampler implementations to ensure e.g. only one process
# resets the sampler when called from a ThreadPoolExecutor (next_task must be thread safe, possibly
Expand Down Expand Up @@ -519,23 +521,12 @@ def render(self, mode: str = "rgb", *args, **kwargs) -> np.ndarray:
raise NotImplementedError()

def step(self, action: Any) -> RLStepResult:
rewards, dones, infos = self._step(action=action)

return RLStepResult(
observation=self.get_observations(),
reward=rewards, # type:ignore
done=dones, # type:ignore
info=infos, # type:ignore
)

@final
def _step(self, action: Any) -> List[RLStepResult]:
# Prepare all actions
actions = [None] * len(self.tasks)
env_actions = [None] * len(self.tasks)
intermediates = [None] * len(self.tasks)

def before_step(it, task):
actions[it], intermediates[it] = task._before_env_step(action[it])
env_actions[it], intermediates[it] = task._before_env_step(action[it])

if self.parallel_before_step:
wait(
Expand All @@ -549,42 +540,44 @@ def before_step(it, task):
before_step(it, task)

# Step over all tasks
self.env.step(actions)
self.env.step(env_actions)

# Prepare all results (excluding observations)
rewards = [None] * len(self.tasks)
dones = [None] * len(self.tasks)
infos = [None] * len(self.tasks)

def after_step(it, current_task):
sr = current_task._after_env_step(action[it], actions[it], intermediates[it])
def after_step(it, task):
sr = task._after_env_step(action[it], env_actions[it], intermediates[it])

assert sr.observation is None, "step result observation is to be added by the BatchedTask"

info = sr.info or {}

# If reward is Sequence, it's assumed to follow the same order imposed by spaces' flatten operation
if isinstance(sr.reward, Sequence):
if isinstance(current_task._total_reward, Sequence):
if isinstance(task._total_reward, Sequence):
for it, rew in enumerate(sr.reward):
current_task._total_reward[it] += float(rew)
task._total_reward[it] += float(rew)
else:
current_task._total_reward = [float(r) for r in sr.reward]
task._total_reward = [float(r) for r in sr.reward]
else:
current_task._total_reward += float(sr.reward) # type:ignore
task._total_reward += float(sr.reward) # type:ignore

current_task._increment_num_steps_taken()
task._increment_num_steps_taken()

done = sr.done

if current_task.is_done():
if task.is_done():
done = True

metrics = current_task.metrics()
metrics = task.metrics()
if metrics is not None and len(metrics) != 0:
info[COMPLETE_TASK_METRICS_KEY] = metrics

if self.callback_sensor_suite is not None:
task_callback_data = self.callback_sensor_suite.get_observations(
env=current_task.env, task=current_task
env=task.env, task=task
)
info[COMPLETE_TASK_CALLBACK_KEY] = task_callback_data

Expand All @@ -605,37 +598,40 @@ def after_step(it, current_task):
for it, task in enumerate(self.tasks):
after_step(it, task)

return rewards, dones, infos
return RLStepResult(
observation=self.get_observations(),
reward=rewards, # type:ignore
done=dones, # type:ignore
info=infos, # type:ignore
)

@final
def _step(self, action: Any) -> Tuple[List, List, List]:
raise RuntimeError("Unexpected call to `_step` in BatchedTask")

def reached_max_steps(self) -> bool:
get_logger().warning("Unexpected call to `reached_max_steps` in BatchedTask")
return False
raise RuntimeError("Unexpected call to `reached_max_steps` in BatchedTask")

def reached_terminal_state(self) -> bool:
get_logger().warning("Unexpected call to `reached_terminal_state` in BatchedTask")
return False
raise RuntimeError("Unexpected call to `reached_terminal_state` in BatchedTask")

def is_done(self) -> bool:
return False

def num_steps_taken(self) -> int:
get_logger().warning("Unexpected call to `num_steps_taken` in BatchedTask")
return -1
raise RuntimeError("Unexpected call to `num_steps_taken` in BatchedTask")

def close(self) -> None:
if self.any_parallel:
self.executor.shutdown(cancel_futures=True)
self.tasks[0].close()

def metrics(self) -> Dict[str, Any]:
get_logger().warning("Unexpected call to `metrics` in BatchedTask")
return {}
raise RuntimeError("Unexpected call to `metrics` in BatchedTask")

def query_expert(self, **kwargs) -> Tuple[Any, bool]:
get_logger().warning("Unexpected call to `query_expert` in BatchedTask")
return None, False
raise RuntimeError("Unexpected call to `query_expert` in BatchedTask")

@property
def cumulative_reward(self) -> float:
get_logger().warning("Unexpected call to `cumulative_reward` in BatchedTask")
return 0.0
raise RuntimeError("Unexpected call to `cumulative_reward` in BatchedTask")

0 comments on commit c217391

Please sign in to comment.