Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add PPO StatelessCarePole learning tests (+LSTM) to CI. #46324

Merged
33 changes: 33 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,39 @@ py_test(
srcs = ["tuned_examples/ppo/cartpole_truncated_ppo.py"],
args = ["--as-test", "--enable-new-api-stack"]
)
# StatelessCartPole
py_test(
name = "learning_tests_stateless_cartpole_ppo",
main = "tuned_examples/ppo/stateless_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "learning_tests_discrete", "torch_only"],
size = "large",
srcs = ["tuned_examples/ppo/stateless_cartpole_ppo.py"],
args = ["--as-test", "--enable-new-api-stack"]
)
py_test(
name = "learning_tests_stateless_cartpole_ppo_gpu",
main = "tuned_examples/ppo/stateless_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "gpu"],
size = "large",
srcs = ["tuned_examples/ppo/stateless_cartpole_ppo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=1"]
)
py_test(
name = "learning_tests_stateless_cartpole_ppo_multi_cpu",
main = "tuned_examples/ppo/stateless_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core"],
size = "large",
srcs = ["tuned_examples/ppo/stateless_cartpole_ppo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"]
)
py_test(
name = "learning_tests_stateless_cartpole_ppo_multi_gpu",
main = "tuned_examples/ppo/stateless_cartpole_ppo.py",
tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_discrete", "learning_tests_pytorch_use_all_core", "multi_gpu"],
size = "large",
srcs = ["tuned_examples/ppo/stateless_cartpole_ppo.py"],
args = ["--as-test", "--enable-new-api-stack", "--num-gpus=2"]
)
# Pendulum
py_test(
name = "learning_tests_pendulum_ppo",
Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/impala/impala_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def update_from_episodes(
# algos that actually need (and know how) to do minibatching.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
min_total_mini_batches: int = 0,
num_total_mini_batches: int = 0,
reduce_fn=None, # Deprecated args.
**kwargs,
) -> ResultDict:
Expand Down
15 changes: 8 additions & 7 deletions rllib/core/learner/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def update_from_episodes(
# algos that actually need (and know how) to do minibatching.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
min_total_mini_batches: int = 0,
num_total_mini_batches: int = 0,
# Deprecated args.
reduce_fn=DEPRECATED_VALUE,
) -> ResultDict:
Expand All @@ -991,13 +991,14 @@ def update_from_episodes(
minibatch_size: The size of the minibatch to use for each update.
num_iters: The number of complete passes over all the sub-batches
in the input multi-agent batch.
min_total_mini_batches: The minimum number of mini-batches to loop through
num_total_mini_batches: The total number of mini-batches to loop through
(across all `num_sgd_iter` SGD iterations). It's required to set this
for multi-agent + multi-GPU situations in which the MultiAgentEpisodes
themselves are roughly sharded equally, however, they might contain
SingleAgentEpisodes with very lopsided length distributions. Thus,
without this limit it can happen that one Learner goes through a
different number of mini-batches than other Learners, causing deadlocks.
without this fixed, pre-computed value it can happen that one Learner
goes through a different number of mini-batches than other Learners,
causing a deadlock.

Returns:
A `ResultDict` object produced by a call to `self.metrics.reduce()`. The
Expand All @@ -1021,7 +1022,7 @@ def update_from_episodes(
timesteps=timesteps,
minibatch_size=minibatch_size,
num_iters=num_iters,
min_total_mini_batches=min_total_mini_batches,
num_total_mini_batches=num_total_mini_batches,
)

@OverrideToImplementCustomLogic
Expand Down Expand Up @@ -1231,7 +1232,7 @@ def _update_from_batch_or_episodes(
# algos that actually need (and know how) to do minibatching.
minibatch_size: Optional[int] = None,
num_iters: int = 1,
min_total_mini_batches: int = 0,
num_total_mini_batches: int = 0,
) -> Union[Dict[str, Any], List[Dict[str, Any]]]:

self._check_is_built()
Expand Down Expand Up @@ -1317,7 +1318,7 @@ def _update_from_batch_or_episodes(
batch_iter = partial(
MiniBatchCyclicIterator,
uses_new_env_runners=True,
min_total_mini_batches=min_total_mini_batches,
num_total_mini_batches=num_total_mini_batches,
)
else:
batch_iter = MiniBatchCyclicIterator
Expand Down
77 changes: 34 additions & 43 deletions rllib/core/learner/learner_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _learner_update(
_episodes_shard=None,
_timesteps=None,
_return_state=False,
_min_total_mini_batches=0,
_num_total_mini_batches=0,
**_kwargs,
):
# If the batch shard is an `DataIterator` we have an offline
Expand Down Expand Up @@ -400,7 +400,7 @@ def _learner_update(
timesteps=_timesteps,
minibatch_size=minibatch_size,
num_iters=num_iters,
min_total_mini_batches=_min_total_mini_batches,
num_total_mini_batches=_num_total_mini_batches,
**_kwargs,
)
if _return_state:
Expand Down Expand Up @@ -485,53 +485,41 @@ def _learner_update(
from ray.data.iterator import DataIterator

if isinstance(episodes[0], DataIterator):
min_total_mini_batches = 0
num_total_mini_batches = 0
partials = [
partial(
_learner_update,
_episodes_shard=episodes_shard,
_timesteps=timesteps,
_min_total_mini_batches=min_total_mini_batches,
_num_total_mini_batches=num_total_mini_batches,
)
for episodes_shard in episodes
]
else:
eps_shards = list(
ShardEpisodesIterator(episodes, len(self._workers))
ShardEpisodesIterator(
episodes,
len(self._workers),
len_lookback_buffer=self.config.episode_lookback_horizon,
)
)
# In the multi-agent case AND `minibatch_size` AND num_workers
# > 1, we compute a max iteration counter such that the different
# Learners will not go through a different number of iterations.
min_total_mini_batches = 0
if (
isinstance(episodes[0], MultiAgentEpisode)
and minibatch_size
and len(self._workers) > 1
):
# Find episode w/ the largest single-agent episode in it, then
# compute this single-agent episode's total number of mini
# batches (if we iterated over it num_sgd_iter times with the
# mini batch size).
longest_ts = 0
per_mod_ts = defaultdict(int)
for i, shard in enumerate(eps_shards):
for ma_episode in shard:
for sa_episode in ma_episode.agent_episodes.values():
key = (i, sa_episode.module_id)
per_mod_ts[key] += len(sa_episode)
if per_mod_ts[key] > longest_ts:
longest_ts = per_mod_ts[key]
min_total_mini_batches = self._compute_num_total_mini_batches(
batch_size=longest_ts,
mini_batch_size=minibatch_size,
num_iters=num_iters,
num_total_mini_batches = 0
if minibatch_size and len(self._workers) > 1:
num_total_mini_batches = self._compute_num_total_mini_batches(
episodes,
len(self._workers),
minibatch_size,
num_iters,
)
partials = [
partial(
_learner_update,
_episodes_shard=eps_shard,
_timesteps=timesteps,
_min_total_mini_batches=min_total_mini_batches,
_num_total_mini_batches=num_total_mini_batches,
)
for eps_shard in eps_shards
]
Expand Down Expand Up @@ -946,20 +934,23 @@ def __del__(self):
self.shutdown()

@staticmethod
def _compute_num_total_mini_batches(batch_size, mini_batch_size, num_iters):
num_total_mini_batches = 0
rest_size = 0
for i in range(num_iters):
eaten_batch = -rest_size
while eaten_batch < batch_size:
eaten_batch += mini_batch_size
num_total_mini_batches += 1
rest_size = mini_batch_size - (eaten_batch - batch_size)
if rest_size:
num_total_mini_batches -= 1
if rest_size:
num_total_mini_batches += 1
return num_total_mini_batches
def _compute_num_total_mini_batches(
episodes,
num_shards,
mini_batch_size,
num_iters,
):
# Count total number of timesteps per module ID.
if isinstance(episodes[0], MultiAgentEpisode):
per_mod_ts = defaultdict(int)
for ma_episode in episodes:
for sa_episode in ma_episode.agent_episodes.values():
per_mod_ts[sa_episode.module_id] += len(sa_episode)
max_ts = max(per_mod_ts.values())
else:
max_ts = sum(map(len, episodes))

return int((num_iters * max_ts) / (num_shards * mini_batch_size))

@Deprecated(new="LearnerGroup.update_from_batch(async=False)", error=False)
def update(self, *args, **kwargs):
Expand Down
22 changes: 17 additions & 5 deletions rllib/env/multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1479,7 +1479,12 @@ def get_temporary_timestep_data(self, key: str) -> List[Any]:
except KeyError:
raise KeyError(f"Key {key} not found in temporary timestep data!")

def slice(self, slice_: slice) -> "MultiAgentEpisode":
def slice(
self,
slice_: slice,
*,
len_lookback_buffer: Optional[int] = None,
) -> "MultiAgentEpisode":
"""Returns a slice of this episode with the given slice object.

Works analogous to
Expand Down Expand Up @@ -1544,6 +1549,10 @@ def slice(self, slice_: slice) -> "MultiAgentEpisode":
slice_: The slice object to use for slicing. This should exclude the
lookback buffer, which will be prepended automatically to the returned
slice.
len_lookback_buffer: If not None, forces the returned slice to try to have
this number of timesteps in its lookback buffer (if available). If None
(default), tries to make the returned slice's lookback as large as the
current lookback buffer of this episode (`self`).

Returns:
The new MultiAgentEpisode representing the requested slice.
Expand Down Expand Up @@ -1630,23 +1639,26 @@ def slice(self, slice_: slice) -> "MultiAgentEpisode":
truncateds["__all__"] = all(truncateds.get(aid) for aid in self.agent_episodes)

# Determine all other slice contents.
_lb = len_lookback_buffer if len_lookback_buffer is not None else ref_lookback
if start - _lb < 0 and ref_lookback < (_lb - start):
_lb = ref_lookback + start
observations = self.get_observations(
slice(start - ref_lookback, stop + 1),
slice(start - _lb, stop + 1),
neg_index_as_lookback=True,
return_list=True,
)
actions = self.get_actions(
slice(start - ref_lookback, stop),
slice(start - _lb, stop),
neg_index_as_lookback=True,
return_list=True,
)
rewards = self.get_rewards(
slice(start - ref_lookback, stop),
slice(start - _lb, stop),
neg_index_as_lookback=True,
return_list=True,
)
extra_model_outputs = self.get_extra_model_outputs(
indices=slice(start - ref_lookback, stop),
indices=slice(start - _lb, stop),
neg_index_as_lookback=True,
return_list=True,
)
Expand Down
Loading
Loading