diff --git a/.bazelrc b/.bazelrc index 057739f0a09c..759335df4797 100644 --- a/.bazelrc +++ b/.bazelrc @@ -168,6 +168,7 @@ test:ci --flaky_test_attempts=3 test:ci --nocache_test_results test:ci --spawn_strategy=local test:ci --test_output=errors +test:ci --experimental_ui_max_stdouterr_bytes=-1 test:ci --test_verbose_timeout_warnings test:ci-debug -c dbg test:ci-debug --copt="-g" diff --git a/rllib/BUILD b/rllib/BUILD index c43ca3e2949f..8f141b901c42 100644 --- a/rllib/BUILD +++ b/rllib/BUILD @@ -463,6 +463,25 @@ py_test( args = ["--dir=tuned_examples/sac"] ) +# TODO (simon): These tests are not learning, yet. +# py_test( +# name = "learning_tests_multi_agent_pendulum_sac", +# main = "tuned_examples/sac/multi_agent_pendulum_sac.py", +# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous"], +# size = "large", +# srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"], +# args = ["--enable-new-api-stack", "--num-agents=2"] +# ) + +# py_test( +# name = "learning_tests_multi_agent_pendulum_sac_multi_gpu", +# main = "tuned_examples/sac/multi_agent_pendulum_sac.py", +# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous", "multi_gpu"], +# size = "large", +# srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"], +# args = ["--enable-new-api-stack", "--num-agents=2", "--num-gpus=2"] +# ) + # -------------------------------------------------------------------- # Algorithms (Compilation, Losses, simple functionality tests) # rllib/algorithms/ diff --git a/rllib/algorithms/dqn/dqn.py b/rllib/algorithms/dqn/dqn.py index 58e0a0f29ecc..445257f80336 100644 --- a/rllib/algorithms/dqn/dqn.py +++ b/rllib/algorithms/dqn/dqn.py @@ -16,7 +16,6 @@ from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided -from ray.rllib.algorithms.dqn.dqn_rainbow_learner import TD_ERROR_KEY from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy from ray.rllib.core.learner import Learner @@ -64,6 +63,7 @@ REPLAY_BUFFER_UPDATE_PRIOS_TIMER, SAMPLE_TIMER, SYNCH_WORKER_WEIGHTS_TIMER, + TD_ERROR_KEY, TIMERS, ) from ray.rllib.utils.deprecation import DEPRECATED_VALUE @@ -662,7 +662,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: num_items=self.config.train_batch_size, n_step=self.config.n_step, gamma=self.config.gamma, - beta=self.config.replay_buffer_config["beta"], + beta=self.config.replay_buffer_config.get("beta"), ) # Perform an update on the buffer-sampled train batch. @@ -700,6 +700,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict: }, reduce="sum", ) + # TODO (sven): Uncomment this once agent steps are available in the # Learner stats. # self.metrics.log_dict(self.metrics.peek( diff --git a/rllib/algorithms/dqn/dqn_rainbow_learner.py b/rllib/algorithms/dqn/dqn_rainbow_learner.py index 1aba7f757008..abc73ada8413 100644 --- a/rllib/algorithms/dqn/dqn_rainbow_learner.py +++ b/rllib/algorithms/dqn/dqn_rainbow_learner.py @@ -13,7 +13,10 @@ override, OverrideToImplementCustomLogic_CallToSuperRecommended, ) -from ray.rllib.utils.metrics import LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES +from ray.rllib.utils.metrics import ( + LAST_TARGET_UPDATE_TS, + NUM_TARGET_UPDATES, +) from ray.rllib.utils.typing import ModuleID if TYPE_CHECKING: @@ -32,7 +35,6 @@ QF_TARGET_NEXT_PROBS = "qf_target_next_probs" QF_PREDS = "qf_preds" QF_PROBS = "qf_probs" -TD_ERROR_KEY = "td_error" TD_ERROR_MEAN_KEY = "td_error_mean" diff --git a/rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py b/rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py index 0a887908b114..7ec354ba61f9 100644 --- a/rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py +++ b/rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py @@ -14,13 +14,13 @@ QF_TARGET_NEXT_PROBS, QF_PREDS, QF_PROBS, - TD_ERROR_KEY, TD_ERROR_MEAN_KEY, ) from ray.rllib.core.columns import Columns from ray.rllib.core.learner.torch.torch_learner import TorchLearner from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch +from ray.rllib.utils.metrics import TD_ERROR_KEY from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.utils.typing import ModuleID, TensorType diff --git a/rllib/algorithms/sac/sac.py b/rllib/algorithms/sac/sac.py index c58170cc44e7..d8f8b74a1de5 100644 --- a/rllib/algorithms/sac/sac.py +++ b/rllib/algorithms/sac/sac.py @@ -352,6 +352,7 @@ def validate(self) -> None: ] not in [ "EpisodeReplayBuffer", "PrioritizedEpisodeReplayBuffer", + "MultiAgentEpisodeReplayBuffer", ]: raise ValueError( "When using the new `EnvRunner API` the replay buffer must be of type " diff --git a/rllib/algorithms/sac/sac_learner.py b/rllib/algorithms/sac/sac_learner.py index 94a2e907de96..cdbc44ee749b 100644 --- a/rllib/algorithms/sac/sac_learner.py +++ b/rllib/algorithms/sac/sac_learner.py @@ -20,7 +20,6 @@ QF_TWIN_LOSS_KEY = "qf_twin_loss" QF_TWIN_PREDS = "qf_twin_preds" TD_ERROR_MEAN_KEY = "td_error_mean" -TD_ERROR_KEY = "td_error" class SACLearner(DQNRainbowLearner): diff --git a/rllib/algorithms/sac/torch/sac_torch_learner.py b/rllib/algorithms/sac/torch/sac_torch_learner.py index 0565e950136b..229b8cc4549f 100644 --- a/rllib/algorithms/sac/torch/sac_torch_learner.py +++ b/rllib/algorithms/sac/torch/sac_torch_learner.py @@ -15,17 +15,15 @@ QF_TWIN_LOSS_KEY, QF_TWIN_PREDS, TD_ERROR_MEAN_KEY, - TD_ERROR_KEY, SACLearner, ) -from ray.rllib.core import DEFAULT_MODULE_ID from ray.rllib.core.columns import Columns from ray.rllib.core.learner.learner import ( POLICY_LOSS_KEY, ) from ray.rllib.utils.annotations import override from ray.rllib.utils.framework import try_import_torch -from ray.rllib.utils.metrics import ALL_MODULES +from ray.rllib.utils.metrics import ALL_MODULES, TD_ERROR_KEY from ray.rllib.utils.nested_dict import NestedDict from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType @@ -221,8 +219,6 @@ def compute_loss_for_module( # Note further, we use here the Huber loss instead of the mean squared error # as it improves training performance. critic_loss = torch.mean( - # TODO (simon): Introduce priority weights when episode buffer is ready. - # batch[PRIO_WEIGHTS] * batch["weights"] * torch.nn.HuberLoss(reduction="none", delta=1.0)( q_selected, q_selected_target @@ -303,6 +299,7 @@ def compute_loss_for_module( def compute_gradients( self, loss_per_module: Dict[str, TensorType], **kwargs ) -> ParamDict: + # Set all grads to `None`. for optim in self._optimizer_parameters: optim.zero_grad(set_to_none=True) @@ -317,7 +314,7 @@ def compute_gradients( for component in ( ["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else [] ): - self.metrics.peek(DEFAULT_MODULE_ID, component + "_loss").backward( + self.metrics.peek(module_id, component + "_loss").backward( retain_graph=True ) grads.update( diff --git a/rllib/connectors/common/agent_to_module_mapping.py b/rllib/connectors/common/agent_to_module_mapping.py index c304fa60a174..b54a20bb050f 100644 --- a/rllib/connectors/common/agent_to_module_mapping.py +++ b/rllib/connectors/common/agent_to_module_mapping.py @@ -133,9 +133,6 @@ def __call__( shared_data: Optional[dict] = None, **kwargs, ) -> Any: - # This Connector should only be used in a multi-agent setting. - assert not episodes or isinstance(episodes[0], MultiAgentEpisode) - # Current agent to module mapping function. # agent_to_module_mapping_fn = shared_data.get("agent_to_module_mapping_fn") # Store in shared data, which module IDs map to which episode/agent, such diff --git a/rllib/connectors/common/batch_individual_items.py b/rllib/connectors/common/batch_individual_items.py index e4a4f2ac8d86..9b5460b4cb49 100644 --- a/rllib/connectors/common/batch_individual_items.py +++ b/rllib/connectors/common/batch_individual_items.py @@ -33,7 +33,10 @@ def __call__( # to a batch structure of: # [module_id] -> [col0] -> [list of items] if is_marl_module and column in rl_module: - assert is_multi_agent + # assert is_multi_agent + # TODO (simon, sven): Check, if we need for other cases this check. + # If MA Off-Policy and independent sampling we need to overcome + # this check. module_data = column_data for col, col_data in module_data.copy().items(): if isinstance(col_data, list) and col != Columns.INFOS: diff --git a/rllib/env/multi_agent_env_runner.py b/rllib/env/multi_agent_env_runner.py index e60300c11a4b..89938795866e 100644 --- a/rllib/env/multi_agent_env_runner.py +++ b/rllib/env/multi_agent_env_runner.py @@ -3,7 +3,6 @@ from collections import defaultdict from functools import partial -import numpy as np from typing import DefaultDict, Dict, List, Optional from ray.rllib.algorithms.algorithm_config import AlgorithmConfig @@ -603,9 +602,11 @@ def get_metrics(self) -> ResultDict: module_episode_returns, ) - # If no episodes at all, log NaN stats. - if len(self._done_episodes_for_metrics) == 0: - self._log_episode_metrics(np.nan, np.nan, np.nan) + # TODO (simon): This results in hundreds of warnings in the logs + # b/c reducing over NaNs is not supported. + # # If no episodes at all, log NaN stats. + # if len(self._done_episodes_for_metrics) == 0: + # self._log_episode_metrics(np.nan, np.nan, np.nan) # Log num episodes counter for this iteration. self.metrics.log_value( diff --git a/rllib/env/multi_agent_episode.py b/rllib/env/multi_agent_episode.py index 216aaf5f0f31..a59cec2d7a63 100644 --- a/rllib/env/multi_agent_episode.py +++ b/rllib/env/multi_agent_episode.py @@ -58,6 +58,30 @@ class MultiAgentEpisode: up to here, b/c there is nothing to learn from these "premature" rewards. """ + __slots__ = ( + "id_", + "agent_to_module_mapping_fn", + "_agent_to_module_mapping", + "observation_space", + "action_space", + "env_t_started", + "env_t", + "agent_t_started", + "env_t_to_agent_t", + "_hanging_actions_end", + "_hanging_extra_model_outputs_end", + "_hanging_rewards_end", + "_hanging_actions_begin", + "_hanging_extra_model_outputs_begin", + "_hanging_rewards_begin", + "is_terminated", + "is_truncated", + "agent_episodes", + "_temporary_timestep_data", + "_start_time", + "_last_step_time", + ) + SKIP_ENV_TS_TAG = "S" def __init__( diff --git a/rllib/tuned_examples/sac/multi_agent_pendulum_sac.py b/rllib/tuned_examples/sac/multi_agent_pendulum_sac.py new file mode 100644 index 000000000000..a70ca0b9d62b --- /dev/null +++ b/rllib/tuned_examples/sac/multi_agent_pendulum_sac.py @@ -0,0 +1,79 @@ +from ray.rllib.algorithms.sac import SACConfig +from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum +from ray.rllib.utils.metrics import ( + ENV_RUNNER_RESULTS, + EPISODE_RETURN_MEAN, + NUM_ENV_STEPS_SAMPLED_LIFETIME, +) +from ray.tune.registry import register_env + +from ray.rllib.utils.test_utils import add_rllib_example_script_args + +parser = add_rllib_example_script_args() +# Use `parser` to add your own custom command line options to this script +# and (if needed) use their values to set up `config` below. +args = parser.parse_args() + +register_env( + "multi_agent_pendulum", + lambda _: MultiAgentPendulum({"num_agents": args.num_agents or 2}), +) + +config = ( + SACConfig() + .environment(env="multi_agent_pendulum") + .rl_module( + model_config_dict={ + "fcnet_hiddens": [256, 256], + "fcnet_activation": "relu", + "post_fcnet_hiddens": [], + "post_fcnet_activation": None, + "post_fcnet_weights_initializer": "orthogonal_", + "post_fcnet_weights_initializer_config": {"gain": 0.01}, + } + ) + .api_stack( + enable_rl_module_and_learner=True, + enable_env_runner_and_connector_v2=True, + ) + .env_runners( + rollout_fragment_length=1, + num_env_runners=2, + num_envs_per_env_runner=1, + ) + .training( + initial_alpha=1.001, + lr=3e-4, + target_entropy="auto", + n_step=1, + tau=0.005, + train_batch_size_per_learner=256, + target_network_update_freq=1, + replay_buffer_config={ + "type": "MultiAgentEpisodeReplayBuffer", + "capacity": 100000, + }, + num_steps_sampled_before_learning_starts=256, + ) + .reporting( + metrics_num_episodes_for_smoothing=5, + min_sample_timesteps_per_iteration=1000, + ) +) + +if args.num_agents: + config.multi_agent( + policy_mapping_fn=lambda aid, *arg, **kw: f"p{aid}", + policies={f"p{i}" for i in range(args.num_agents)}, + ) + +stop = { + NUM_ENV_STEPS_SAMPLED_LIFETIME: 500000, + # `episode_return_mean` is the sum of all agents/policies' returns. + f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -400.0 * (args.num_agents or 2), +} + +if __name__ == "__main__": + from ray.rllib.utils.test_utils import run_rllib_example_script_experiment + + run_rllib_example_script_experiment(config, args, stop=stop) diff --git a/rllib/utils/metrics/__init__.py b/rllib/utils/metrics/__init__.py index 164ec20cf405..39e087da9434 100644 --- a/rllib/utils/metrics/__init__.py +++ b/rllib/utils/metrics/__init__.py @@ -91,3 +91,4 @@ # Learner. LEARNER_STATS_KEY = "learner_stats" ALL_MODULES = "__all_modules__" +TD_ERROR_KEY = "td_error" diff --git a/rllib/utils/replay_buffers/multi_agent_episode_replay_buffer.py b/rllib/utils/replay_buffers/multi_agent_episode_replay_buffer.py index c32415ba01a9..9b1af8e86cff 100644 --- a/rllib/utils/replay_buffers/multi_agent_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/multi_agent_episode_replay_buffer.py @@ -3,10 +3,11 @@ from gymnasium.core import ActType, ObsType import numpy as np import scipy -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from ray.rllib.core.columns import Columns from ray.rllib.env.multi_agent_episode import MultiAgentEpisode +from ray.rllib.env.single_agent_episode import SingleAgentEpisode from ray.rllib.utils.replay_buffers.episode_replay_buffer import EpisodeReplayBuffer from ray.rllib.utils import force_list from ray.rllib.utils.annotations import override, DeveloperAPI @@ -130,40 +131,42 @@ def add( """ episodes: List["MultiAgentEpisode"] = force_list(episodes) - new_episode_ids: List[str] = [] - for eps in episodes: - new_episode_ids.append(eps.id_) - self._num_timesteps += eps.env_steps() - self._num_timesteps_added += eps.env_steps() + new_episode_ids: List[str] = {eps.id_ for eps in episodes} + total_env_timesteps = sum([eps.env_steps() for eps in episodes]) + self._num_timesteps += total_env_timesteps + self._num_timesteps_added += total_env_timesteps # Evict old episodes. - eps_evicted: List["MultiAgentEpisode"] = [] - eps_evicted_ids: List[Union[str, int]] = [] - eps_evicted_idxs: List[int] = [] + eps_evicted_ids: Set[Union[str, int]] = set() + eps_evicted_idxs: Set[int] = set() while ( self._num_timesteps > self.capacity and self._num_remaining_episodes(new_episode_ids, eps_evicted_ids) != 1 ): # Evict episode. evicted_episode = self.episodes.popleft() - eps_evicted.append(evicted_episode) - eps_evicted_ids.append(evicted_episode.id_) - eps_evicted_idxs.append(self.episode_id_to_index.pop(evicted_episode.id_)) + eps_evicted_ids.add(evicted_episode.id_) + eps_evicted_idxs.add(self.episode_id_to_index.pop(evicted_episode.id_)) # If this episode has a new chunk in the new episodes added, # we subtract it again. # TODO (sven, simon): Should we just treat such an episode chunk # as a new episode? if evicted_episode.id_ in new_episode_ids: - new_eps_to_evict = episodes[new_episode_ids.index(evicted_episode.id_)] + idx = next( + i + for i, eps in enumerate(episodes) + if eps.id_ == evicted_episode.id_ + ) + new_eps_to_evict = episodes.pop(idx) self._num_timesteps -= new_eps_to_evict.env_steps() self._num_timesteps_added -= new_eps_to_evict.env_steps() - episodes.remove(new_eps_to_evict) # Remove the timesteps of the evicted episode from the counter. self._num_timesteps -= evicted_episode.env_steps() self._num_agent_timesteps -= evicted_episode.agent_steps() self._num_episodes_evicted += 1 # Remove the module timesteps of the evicted episode from the counters. self._evict_module_episodes(evicted_episode) + del evicted_episode # Add agent and module steps. for eps in episodes: @@ -174,41 +177,38 @@ def add( # Remove corresponding indices, if episodes were evicted. if eps_evicted_idxs: - new_indices = [] - # Each index 2-tuple is of the form (ma_episode_idx, timestep) and + # If the episode is not exvicted, we keep the index. + # Note, ach index 2-tuple is of the form (ma_episode_idx, timestep) and # refers to a certain environment timestep in a certain multi-agent # episode. - for idx_tuple in self._indices: - # If episode index is not from an evicted episode, keep it. - if idx_tuple[0] not in eps_evicted_idxs: - new_indices.append(idx_tuple) - # Assign the new list of indices. - self._indices = new_indices + self._indices = [ + idx_tuple + for idx_tuple in self._indices + if idx_tuple[0] not in eps_evicted_idxs + ] # Also remove corresponding module indices. for module_id, module_indices in self._module_to_indices.items(): - new_module_indices = [] # Each index 3-tuple is of the form # (ma_episode_idx, agent_id, timestep) and refers to a certain # agent timestep in a certain multi-agent episode. - for idx_triplet in module_indices: - if idx_triplet[0] not in eps_evicted_idxs: - new_module_indices.append(idx_triplet) - self._module_to_indices[module_id] = new_module_indices + self._module_to_indices[module_id] = [ + idx_triplet + for idx_triplet in module_indices + if idx_triplet[0] not in eps_evicted_idxs + ] for eps in episodes: eps = copy.deepcopy(eps) # If the episode is part of an already existing episode, concatenate. if eps.id_ in self.episode_id_to_index: eps_idx = self.episode_id_to_index[eps.id_] - existing_eps = self.episodes[eps_idx] + existing_eps = self.episodes[eps_idx - self._num_episodes_evicted] existing_len = len(existing_eps) self._indices.extend( [ ( eps_idx, - # Note, we add 1 b/c the first timestep is - # never sampled. - existing_len + i + 1, + existing_len + i, ) for i in range(len(eps)) ] @@ -223,7 +223,7 @@ def add( self.episodes.append(eps) eps_idx = len(self.episodes) - 1 + self._num_episodes_evicted self.episode_id_to_index[eps.id_] = eps_idx - self._indices.extend([(eps_idx, i + 1) for i in range(len(eps))]) + self._indices.extend([(eps_idx, i) for i in range(len(eps))]) # Add new module indices. self._add_new_module_indices(eps, eps_idx, False) @@ -240,6 +240,7 @@ def sample( include_extra_model_outputs: bool = False, replay_mode: str = "independent", modules_to_sample: Optional[List[ModuleID]] = None, + **kwargs, ) -> SampleBatchType: """Samples a batch of multi-agent transitions. @@ -458,46 +459,25 @@ def _sample_independent( gamma: float, include_infos: bool, include_extra_model_outputs: bool, - modules_to_sample: Optional[List[ModuleID]], + modules_to_sample: Optional[Set[ModuleID]], ) -> SampleBatchType: """Samples a batch of independent multi-agent transitions.""" + + actual_n_step = n_step or 1 # Sample the n-step if necessary. - if isinstance(n_step, tuple): - # Use random n-step sampling. - random_n_step = True - else: - actual_n_step = n_step or 1 - random_n_step = False + random_n_step = isinstance(n_step, tuple) - ret = {} + sampled_episodes = [] # TODO (simon): Ensure that the module has data and if not, skip it. # TODO (sven): Should we then error out or skip? I think the Learner # should handle this case when a module has no train data. - for module_id in modules_to_sample or self._module_to_indices.keys(): - # Rows to return. - observations: List[List[ObsType]] = [[] for _ in range(batch_size_B)] - next_observations: List[List[ObsType]] = [[] for _ in range(batch_size_B)] - actions: List[List[ActType]] = [[] for _ in range(batch_size_B)] - rewards: List[List[float]] = [[] for _ in range(batch_size_B)] - is_terminated: List[bool] = [False for _ in range(batch_size_B)] - is_truncated: List[bool] = [False for _ in range(batch_size_B)] - weights: List[float] = [[1.0] for _ in range(batch_size_B)] - n_steps: List[List[int]] = [[] for _ in range(batch_size_B)] - # If `info` should be included, construct also a container for them. - if include_infos: - infos: List[List[Dict[str, Any]]] = [[] for _ in range(batch_size_B)] - # If `extra_model_outputs` should be included, construct a container for - # them. - if include_extra_model_outputs: - extra_model_outputs: List[List[Dict[str, Any]]] = [ - [] for _ in range(batch_size_B) - ] + modules_to_sample = modules_to_sample or set(self._module_to_indices.keys()) + for module_id in modules_to_sample: + module_indices = self._module_to_indices[module_id] B = 0 while B < batch_size_B: # Now sample from the single-agent timesteps. - index_tuple = self._module_to_indices[module_id][ - self.rng.integers(len(self._module_to_indices[module_id])) - ] + index_tuple = module_indices[self.rng.integers(len(module_indices))] # This will be an agent timestep (not env timestep). # TODO (simon, sven): Maybe deprecate sa_episode_idx (_) in the index @@ -507,109 +487,95 @@ def _sample_independent( index_tuple[1], index_tuple[2], ) - # If we cannnot make the n-step, we resample. - if sa_episode_ts - n_step < 0: - continue - # If we use random n-step sampling, draw the n-step for this item. - if random_n_step: - actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) - # If we are at the end of an episode, continue. - # Note, priority sampling got us `o_(t+n)` and we need for the loss - # calculation in addition `o_t`. - # TODO (simon): Maybe introduce a variable `num_retries` until the - # while loop should break when not enough samples have been collected - # to make n-step possible. - if sa_episode_ts - actual_n_step < 0: - continue - else: - n_steps[B] = actual_n_step + # Get the multi-agent episode. ma_episode = self.episodes[ma_episode_idx] # Retrieve the single-agent episode for filtering. sa_episode = ma_episode.agent_episodes[agent_id] - # Ensure that each row contains a tuple of the form: - # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) - # TODO (simon): Implement version for sequence sampling when using RNNs. - sa_eps_observation = sa_episode.get_observations( - slice(sa_episode_ts - actual_n_step, sa_episode_ts + 1) - ) - # Note, the reward that is collected by transitioning from `o_t` to - # `o_(t+1)` is stored in the next transition in `SingleAgentEpisode`. - sa_eps_rewards = sa_episode.get_rewards( - slice(sa_episode_ts - actual_n_step, sa_episode_ts) - ) - observations[B] = sa_eps_observation[0] - next_observations[B] = sa_eps_observation[-1] + + # If we use random n-step sampling, draw the n-step for this item. + if random_n_step: + actual_n_step = int(self.rng.integers(n_step[0], n_step[1])) + # If we cannnot make the n-step, we resample. + if sa_episode_ts + actual_n_step > len(sa_episode): + continue # Note, this will be the reward after executing action - # `a_(episode_ts-n_step+1)`. For `n_step>1` this will be the sum of + # `a_(episode_ts)`. For `n_step>1` this will be the sum of # all rewards that were collected over the last n steps. - rewards[B] = scipy.signal.lfilter( - [1], [1, -gamma], sa_eps_rewards[::-1], axis=0 + sa_raw_rewards = sa_episode.get_rewards( + slice(sa_episode_ts, sa_episode_ts + actual_n_step) + ) + sa_rewards = scipy.signal.lfilter( + [1], [1, -gamma], sa_raw_rewards[::-1], axis=0 )[-1] - # Note, `SingleAgentEpisode` stores the action that followed - # `o_t` with `o_(t+1)`, therefore, we need the next one. - # TODO (simon): This gets the wrong action as long as the getters are - # not fixed. - actions[B] = sa_episode.get_actions(sa_episode_ts - actual_n_step) - if include_infos: - # If infos are included we include the ones from the last timestep - # as usually the info contains additional values about the last - # state. - infos[B] = sa_episode.get_infos(sa_episode_ts) - if include_extra_model_outputs: - # If `extra_model_outputs` are included we include the ones from the - # first timestep as usually the `extra_model_outputs` contain - # additional values from the forward pass that produced the action - # at the first timestep. - # Note, we extract them into single row dictionaries similar to the - # infos, in a connector we can then extract these into single batch - # rows. - extra_model_outputs[B] = { - k: sa_episode.get_extra_model_outputs( - k, sa_episode_ts - actual_n_step - ) - for k in sa_episode.extra_model_outputs.keys() - } - # If the sampled time step is the episode's last time step check, if - # the episode is terminated or truncated. - if sa_episode_ts == sa_episode.t: - is_terminated[B] = sa_episode.is_terminated - is_truncated[B] = sa_episode.is_truncated + + sampled_sa_episode = SingleAgentEpisode( + id_=sa_episode.id_, + # Provide the IDs for the learner connector. + agent_id=sa_episode.agent_id, + module_id=sa_episode.module_id, + multi_agent_episode_id=ma_episode.id_, + # Ensure that each episode contains a tuple of the form: + # (o_t, a_t, sum(r_(t:t+n_step)), o_(t+n_step)) + # Two observations (t and t+n). + observations=[ + sa_episode.get_observations(sa_episode_ts), + sa_episode.get_observations(sa_episode_ts + actual_n_step), + ], + observation_space=sa_episode.observation_space, + infos=( + [ + sa_episode.get_infos(sa_episode_ts), + sa_episode.get_infos(sa_episode_ts + actual_n_step), + ] + if include_infos + else None + ), + actions=[sa_episode.get_actions(sa_episode_ts)], + action_space=sa_episode.action_space, + rewards=[sa_rewards], + # If the sampled single-agent episode is the single-agent episode's + # last time step, check, if the single-agent episode is terminated + # or truncated. + terminated=( + sa_episode_ts + actual_n_step >= len(sa_episode) + and sa_episode.is_terminated + ), + truncated=( + sa_episode_ts + actual_n_step >= len(sa_episode) + and sa_episode.is_truncated + ), + extra_model_outputs={ + "weights": [1.0], + "n_step": [actual_n_step], + **( + { + k: [ + sa_episode.get_extra_model_outputs(k, sa_episode_ts) + ] + for k in sa_episode.extra_model_outputs.keys() + } + if include_extra_model_outputs + else {} + ), + }, + # TODO (sven): Support lookback buffers. + len_lookback_buffer=0, + t_started=sa_episode_ts, + ) + # Append single-agent episode to the list of sampled episodes. + sampled_episodes.append(sampled_sa_episode) # Increase counter. B += 1 # Increase the per module timesteps counter. - self.sampled_timesteps_per_module[module_id] += batch_size_B - ret[module_id] = { - # Note, observation and action spaces could be complex. `batch` - # takes care of these. - Columns.OBS: batch(observations), - Columns.ACTIONS: batch(actions), - Columns.REWARDS: np.array(rewards), - Columns.NEXT_OBS: batch(next_observations), - Columns.TERMINATEDS: np.array(is_terminated), - Columns.TRUNCATEDS: np.array(is_truncated), - "weights": np.array(weights), - "n_step": np.array(n_steps), - } - # Include infos if necessary. - if include_infos: - ret[module_id].update( - { - Columns.INFOS: infos, - } - ) - # Include extra model outputs, if necessary. - if include_extra_model_outputs: - ret[module_id].update( - # These could be complex, too. - batch(extra_model_outputs) - ) + self.sampled_timesteps_per_module[module_id] += B + # Increase the counter for environment timesteps. self.sampled_timesteps += batch_size_B # Return multi-agent dictionary. - return ret + return sampled_episodes def _sample_synchonized( self, @@ -899,7 +865,9 @@ def _add_new_module_indices( sa_episode_in_buffer = False if sa_episode_in_buffer: existing_eps_len = len( - self.episodes[episode_idx].agent_episodes[agent_id] + self.episodes[ + episode_idx - self._num_episodes_evicted + ].agent_episodes[agent_id] ) else: existing_eps_len = 0 @@ -910,7 +878,7 @@ def _add_new_module_indices( # Keep the MAE index for sampling episode_idx, agent_id, - existing_eps_len + i + 1, + existing_eps_len + i, ) for i in range(len(module_eps)) ] diff --git a/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py b/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py index f7e818a659cf..b2a656d75f06 100644 --- a/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/prioritized_episode_replay_buffer.py @@ -217,6 +217,7 @@ def add( # TODO (sven, simon): Should we just treat such an episode chunk # as a new episode? if eps_evicted_ids[-1] in new_episode_ids: + # TODO (simon): Apply the same logic as in the MA-case. len_to_subtract = len( episodes[new_episode_ids.index(eps_evicted_idxs[-1])] ) @@ -288,7 +289,7 @@ def add( for i in range(len(eps)) ] ) - # Increase index. + # Increase index to the new length of `self._indices`. j = len(self._indices) @override(EpisodeReplayBuffer) diff --git a/rllib/utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py b/rllib/utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py index 14a3860c5e6c..3844b5a485c6 100644 --- a/rllib/utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py +++ b/rllib/utils/replay_buffers/tests/test_multi_agent_episode_replay_buffer.py @@ -5,6 +5,7 @@ from ray.rllib.utils.replay_buffers.multi_agent_episode_replay_buffer import ( MultiAgentEpisodeReplayBuffer, ) +from ray.rllib.utils.test_utils import check class TestMultiAgentEpisodeReplayBuffer(unittest.TestCase): @@ -150,59 +151,58 @@ def test_buffer_independent_sample_logic(self): for i in range(1000): sample = buffer.sample(batch_size_B=16, n_step=1) self.assertTrue(buffer.get_sampled_timesteps() == 16 * (i + 1)) - self.assertTrue("module_1" in sample) - self.assertTrue("module_2" in sample) - for module_id in sample: - self.assertTrue(buffer.get_sampled_timesteps(module_id) == 16 * (i + 1)) + module_ids = {eps.module_id for eps in sample} + self.assertTrue("module_1" in module_ids) + self.assertTrue("module_2" in module_ids) + for eps in sample: + # For both modules, we should have 16 x (i + 1) timesteps sampled. + # Note, this must be the same here as the number of timesteps sampled + # altogether, b/c we sample both modules. + self.assertTrue( + buffer.get_sampled_timesteps("module_1") == 16 * (i + 1) + ) + self.assertTrue( + buffer.get_sampled_timesteps("module_2") == 16 * (i + 1) + ) ( obs, - actions, - rewards, + action, + reward, next_obs, is_terminated, is_truncated, - weights, - n_steps, + weight, + n_step, ) = ( - sample[module_id]["obs"], - sample[module_id]["actions"], - sample[module_id]["rewards"], - sample[module_id]["new_obs"], - sample[module_id]["terminateds"], - sample[module_id]["truncateds"], - sample[module_id]["weights"], - sample[module_id]["n_step"], + eps.get_observations(0), + eps.get_actions(-1), + eps.get_rewards(-1), + eps.get_observations(-1), + eps.is_terminated, + eps.is_truncated, + eps.get_extra_model_outputs("weights", -1), + eps.get_extra_model_outputs("n_step", -1), ) # Make sure terminated and truncated are never both True. - assert not np.any(np.logical_and(is_truncated, is_terminated)) - - # All fields have same shape. - assert ( - obs.shape[:2] - == rewards.shape - == actions.shape - == next_obs.shape - == is_truncated.shape - == is_terminated.shape - ) + assert not (is_truncated and is_terminated) # Note, floating point numbers cannot be compared directly. tolerance = 1e-8 # Assert that actions correspond to the observations. - self.assertTrue(np.all(actions - obs < tolerance)) + check(obs, action, atol=tolerance) # Assert that next observations are correctly one step after # observations. - self.assertTrue(np.all(next_obs - obs - 1 < tolerance)) + check(next_obs, obs + 1, atol=tolerance) # Assert that the reward comes from the next observation. - self.assertTrue(np.all(rewards * 10 - next_obs < tolerance)) + check(reward * 10, next_obs, atol=tolerance) # Furthermore, assert that the importance sampling weights are # one for `beta=0.0`. - self.assertTrue(np.all(weights - 1.0 < tolerance)) + check(weight, 1.0, atol=tolerance) # Assert that all n-steps are 1.0 as passed into `sample`. - self.assertTrue(np.all(n_steps - 1.0 < tolerance)) + check(n_step, 1.0, atol=tolerance) def test_buffer_synchronized_sample_logic(self): """Samples synchronized from the multi-agent buffer.""" diff --git a/rllib/utils/replay_buffers/utils.py b/rllib/utils/replay_buffers/utils.py index 3b1bb6b6924f..c825c64fc3ff 100644 --- a/rllib/utils/replay_buffers/utils.py +++ b/rllib/utils/replay_buffers/utils.py @@ -8,7 +8,7 @@ from ray.rllib.utils.annotations import OldAPIStack from ray.rllib.utils.deprecation import DEPRECATED_VALUE from ray.rllib.utils.from_config import from_config -from ray.rllib.utils.metrics import ALL_MODULES +from ray.rllib.utils.metrics import ALL_MODULES, TD_ERROR_KEY from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY from ray.rllib.utils.replay_buffers import ( EpisodeReplayBuffer, @@ -30,9 +30,6 @@ logger = logging.getLogger(__name__) -# TODO (simon): Move all regular keys to the metric constants file. -TD_ERROR_KEY = "td_error" - @DeveloperAPI def update_priorities_in_episode_replay_buffer(