Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RLlib] Add support for multi-agent off-policy algorithms in the new API stack. #45182

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
baa1398
wip
sven1977 Apr 29, 2024
a1eb1f9
wip
sven1977 Apr 29, 2024
6538b58
fixes
sven1977 Apr 29, 2024
683f515
Merge branch 'master' of https://github.com/ray-project/ray into chan…
sven1977 Apr 30, 2024
366a4b9
wip
sven1977 Apr 30, 2024
a8b2d0c
wip
sven1977 Apr 30, 2024
f76628a
merge
sven1977 May 3, 2024
81421d9
Fixed a bug with 'TERMINATEDS/TRUNCATEDS' in replay buffer sampling t…
simonsays1980 May 3, 2024
bd54d5a
LINTER.
simonsays1980 May 3, 2024
6ee006f
Added docs to new 'sample' method and removed old sample methods.
simonsays1980 May 6, 2024
a345d09
Merge branch 'master' into change_episode_buffers_to_return_episode_l…
simonsays1980 May 6, 2024
b77fd5a
Replaced 'td_error' by 'TD_ERROR_KEY'.
simonsays1980 May 6, 2024
6e11ff6
Needed to define 'TD_ERROR_KEY' in 'replay_buffer.utils' b/c import e…
simonsays1980 May 6, 2024
b39b9a8
Fixed a small bug in test code.
simonsays1980 May 7, 2024
e6cf4f7
Merge branch 'master' into change_episode_buffers_to_return_episode_l…
simonsays1980 May 7, 2024
eebc04d
Interchanged 'new_obs' with our constant 'Columns.NEXT_OBS' for bette…
simonsays1980 May 7, 2024
d12f16f
Added new sampling method in 'MultiAgentEpisodeReplayBuffer' for 'ind…
simonsays1980 May 7, 2024
2247c02
Changed 'truncated/terminated' logic in 'MultiEnv' and 'MultiAgentEpi…
simonsays1980 May 8, 2024
827adda
Switched back to 'pid'.
simonsays1980 May 10, 2024
1e67ccf
Commented out NaN metrics b/c they produced hindreds of warnings.
simonsays1980 May 10, 2024
c748df8
Changed comment.
simonsays1980 May 10, 2024
fc35faa
Little changes here and there and to clean-up sample logic and multi-…
simonsays1980 May 10, 2024
c336ac8
Added suggestions from @sven1977's review.
simonsays1980 May 10, 2024
6409007
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 13, 2024
81c3893
Merged master
simonsays1980 May 13, 2024
c522597
Modified multi-agent buffer tests to correspond to the changes in '_s…
simonsays1980 May 13, 2024
b8fbe19
CHanged 'MultiAGentEpisode' and 'MultiEnv' back to master.
simonsays1980 May 13, 2024
feafb6b
Apply suggestions from code review
sven1977 May 14, 2024
d2f9030
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 14, 2024
2fd7717
Added slots to 'MultiAgentEpisode' which should help reducing memory …
simonsays1980 May 15, 2024
a3416a8
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 15, 2024
2296cfc
Changed multi-agent SAC example such that at a minimum 2 agents are u…
simonsays1980 May 16, 2024
8582ad9
Merge branch 'master' of https://github.com/ray-project/ray
simonsays1980 May 16, 2024
c8d72fa
Merge branch 'master' into change_ma_buffer_to_use_list_of_episodes
simonsays1980 May 16, 2024
ffbf3de
Multiple performance tunings that bring the multi-agent buffer into d…
simonsays1980 May 16, 2024
47888a4
LINTER.
simonsays1980 May 16, 2024
7d6497e
Merge branch 'change_ma_buffer_to_use_list_of_episodes' of github.com…
simonsays1980 May 16, 2024
cccd48d
Merge branch 'master' of https://github.com/ray-project/ray into chan…
sven1977 May 17, 2024
e96b9ce
test BAZEL printout
sven1977 May 17, 2024
9d409dd
Commented out off-policy multi-agent examples that were not learning.
simonsays1980 May 17, 2024
41d0b18
Merge branch 'change_ma_buffer_to_use_list_of_episodes' of github.com…
simonsays1980 May 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .bazelrc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ test:ci --flaky_test_attempts=3
test:ci --nocache_test_results
test:ci --spawn_strategy=local
test:ci --test_output=errors
test:ci --experimental_ui_max_stdouterr_bytes=-1
test:ci --test_verbose_timeout_warnings
test:ci-debug -c dbg
test:ci-debug --copt="-g"
Expand Down
19 changes: 19 additions & 0 deletions rllib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,25 @@ py_test(
args = ["--dir=tuned_examples/sac"]
)

# TODO (simon): These tests are not learning, yet.
# py_test(
# name = "learning_tests_multi_agent_pendulum_sac",
# main = "tuned_examples/sac/multi_agent_pendulum_sac.py",
# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous"],
# size = "large",
# srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"],
# args = ["--enable-new-api-stack", "--num-agents=2"]
# )

# py_test(
# name = "learning_tests_multi_agent_pendulum_sac_multi_gpu",
# main = "tuned_examples/sac/multi_agent_pendulum_sac.py",
# tags = ["team:rllib", "exclusive", "learning_tests", "torch_only", "learning_tests_pendulum", "learning_tests_continuous", "multi_gpu"],
# size = "large",
# srcs = ["tuned_examples/sac/multi_agent_pendulum_sac.py"],
# args = ["--enable-new-api-stack", "--num-agents=2", "--num-gpus=2"]
# )

# --------------------------------------------------------------------
# Algorithms (Compilation, Losses, simple functionality tests)
# rllib/algorithms/
Expand Down
5 changes: 3 additions & 2 deletions rllib/algorithms/dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.algorithm_config import AlgorithmConfig, NotProvided
from ray.rllib.algorithms.dqn.dqn_rainbow_learner import TD_ERROR_KEY
from ray.rllib.algorithms.dqn.dqn_tf_policy import DQNTFPolicy
from ray.rllib.algorithms.dqn.dqn_torch_policy import DQNTorchPolicy
from ray.rllib.core.learner import Learner
Expand Down Expand Up @@ -64,6 +63,7 @@
REPLAY_BUFFER_UPDATE_PRIOS_TIMER,
SAMPLE_TIMER,
SYNCH_WORKER_WEIGHTS_TIMER,
TD_ERROR_KEY,
TIMERS,
)
from ray.rllib.utils.deprecation import DEPRECATED_VALUE
Expand Down Expand Up @@ -662,7 +662,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
num_items=self.config.train_batch_size,
n_step=self.config.n_step,
gamma=self.config.gamma,
beta=self.config.replay_buffer_config["beta"],
beta=self.config.replay_buffer_config.get("beta"),
)

# Perform an update on the buffer-sampled train batch.
Expand Down Expand Up @@ -700,6 +700,7 @@ def _training_step_new_api_stack(self, *, with_noise_reset) -> ResultDict:
},
reduce="sum",
)

# TODO (sven): Uncomment this once agent steps are available in the
# Learner stats.
# self.metrics.log_dict(self.metrics.peek(
Expand Down
6 changes: 4 additions & 2 deletions rllib/algorithms/dqn/dqn_rainbow_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
override,
OverrideToImplementCustomLogic_CallToSuperRecommended,
)
from ray.rllib.utils.metrics import LAST_TARGET_UPDATE_TS, NUM_TARGET_UPDATES
from ray.rllib.utils.metrics import (
LAST_TARGET_UPDATE_TS,
NUM_TARGET_UPDATES,
)
from ray.rllib.utils.typing import ModuleID

if TYPE_CHECKING:
Expand All @@ -32,7 +35,6 @@
QF_TARGET_NEXT_PROBS = "qf_target_next_probs"
QF_PREDS = "qf_preds"
QF_PROBS = "qf_probs"
TD_ERROR_KEY = "td_error"
TD_ERROR_MEAN_KEY = "td_error_mean"


Expand Down
2 changes: 1 addition & 1 deletion rllib/algorithms/dqn/torch/dqn_rainbow_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
QF_TARGET_NEXT_PROBS,
QF_PREDS,
QF_PROBS,
TD_ERROR_KEY,
TD_ERROR_MEAN_KEY,
)
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.torch.torch_learner import TorchLearner
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import TD_ERROR_KEY
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import ModuleID, TensorType

Expand Down
1 change: 1 addition & 0 deletions rllib/algorithms/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def validate(self) -> None:
] not in [
"EpisodeReplayBuffer",
"PrioritizedEpisodeReplayBuffer",
"MultiAgentEpisodeReplayBuffer",
]:
raise ValueError(
"When using the new `EnvRunner API` the replay buffer must be of type "
Expand Down
1 change: 0 additions & 1 deletion rllib/algorithms/sac/sac_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
QF_TWIN_LOSS_KEY = "qf_twin_loss"
QF_TWIN_PREDS = "qf_twin_preds"
TD_ERROR_MEAN_KEY = "td_error_mean"
TD_ERROR_KEY = "td_error"


class SACLearner(DQNRainbowLearner):
Expand Down
9 changes: 3 additions & 6 deletions rllib/algorithms/sac/torch/sac_torch_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,15 @@
QF_TWIN_LOSS_KEY,
QF_TWIN_PREDS,
TD_ERROR_MEAN_KEY,
TD_ERROR_KEY,
SACLearner,
)
from ray.rllib.core import DEFAULT_MODULE_ID
from ray.rllib.core.columns import Columns
from ray.rllib.core.learner.learner import (
POLICY_LOSS_KEY,
)
from ray.rllib.utils.annotations import override
from ray.rllib.utils.framework import try_import_torch
from ray.rllib.utils.metrics import ALL_MODULES
from ray.rllib.utils.metrics import ALL_MODULES, TD_ERROR_KEY
from ray.rllib.utils.nested_dict import NestedDict
from ray.rllib.utils.typing import ModuleID, ParamDict, TensorType

Expand Down Expand Up @@ -221,8 +219,6 @@ def compute_loss_for_module(
# Note further, we use here the Huber loss instead of the mean squared error
# as it improves training performance.
critic_loss = torch.mean(
# TODO (simon): Introduce priority weights when episode buffer is ready.
# batch[PRIO_WEIGHTS] *
batch["weights"]
* torch.nn.HuberLoss(reduction="none", delta=1.0)(
q_selected, q_selected_target
Expand Down Expand Up @@ -303,6 +299,7 @@ def compute_loss_for_module(
def compute_gradients(
self, loss_per_module: Dict[str, TensorType], **kwargs
) -> ParamDict:
# Set all grads to `None`.
for optim in self._optimizer_parameters:
optim.zero_grad(set_to_none=True)

Expand All @@ -317,7 +314,7 @@ def compute_gradients(
for component in (
["qf", "policy", "alpha"] + ["qf_twin"] if config.twin_q else []
):
self.metrics.peek(DEFAULT_MODULE_ID, component + "_loss").backward(
self.metrics.peek(module_id, component + "_loss").backward(
simonsays1980 marked this conversation as resolved.
Show resolved Hide resolved
retain_graph=True
)
grads.update(
Expand Down
3 changes: 0 additions & 3 deletions rllib/connectors/common/agent_to_module_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,6 @@ def __call__(
shared_data: Optional[dict] = None,
**kwargs,
) -> Any:
# This Connector should only be used in a multi-agent setting.
assert not episodes or isinstance(episodes[0], MultiAgentEpisode)

# Current agent to module mapping function.
# agent_to_module_mapping_fn = shared_data.get("agent_to_module_mapping_fn")
# Store in shared data, which module IDs map to which episode/agent, such
Expand Down
5 changes: 4 additions & 1 deletion rllib/connectors/common/batch_individual_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ def __call__(
# to a batch structure of:
# [module_id] -> [col0] -> [list of items]
if is_marl_module and column in rl_module:
assert is_multi_agent
# assert is_multi_agent
# TODO (simon, sven): Check, if we need for other cases this check.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. There are still some "weird" assumptions left in some connectors' logic.
We should comb these out and make the logic when to go into what loop with SA- or MAEps more clear.

Some of this stuff has to do with the fact that EnvRunners can either have a SingleAgentRLModule OR a MultiAgentRLModule, but Learners always(!) have a MultiAgentModule. Maybe we should have Learners that operate on SingleAgentRLModules for simplicity and more transparency. It shouldn't be too hard to fix that on the Learner side.

# If MA Off-Policy and independent sampling we need to overcome
# this check.
module_data = column_data
for col, col_data in module_data.copy().items():
if isinstance(col_data, list) and col != Columns.INFOS:
Expand Down
9 changes: 5 additions & 4 deletions rllib/env/multi_agent_env_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from collections import defaultdict
from functools import partial
import numpy as np
from typing import DefaultDict, Dict, List, Optional

from ray.rllib.algorithms.algorithm_config import AlgorithmConfig
Expand Down Expand Up @@ -603,9 +602,11 @@ def get_metrics(self) -> ResultDict:
module_episode_returns,
)

# If no episodes at all, log NaN stats.
if len(self._done_episodes_for_metrics) == 0:
self._log_episode_metrics(np.nan, np.nan, np.nan)
# TODO (simon): This results in hundreds of warnings in the logs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to see. This might lead to Tune errors in the sense that at the beginning, if no episode is done yet, Tune will complain that none of the stop criteria (e.g. num_env_steps_sampled_lifetime) can be found in the result dict.

# b/c reducing over NaNs is not supported.
# # If no episodes at all, log NaN stats.
# if len(self._done_episodes_for_metrics) == 0:
# self._log_episode_metrics(np.nan, np.nan, np.nan)

# Log num episodes counter for this iteration.
self.metrics.log_value(
Expand Down
24 changes: 24 additions & 0 deletions rllib/env/multi_agent_episode.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,30 @@ class MultiAgentEpisode:
up to here, b/c there is nothing to learn from these "premature" rewards.
"""

__slots__ = (
"id_",
"agent_to_module_mapping_fn",
"_agent_to_module_mapping",
"observation_space",
"action_space",
"env_t_started",
"env_t",
"agent_t_started",
"env_t_to_agent_t",
"_hanging_actions_end",
"_hanging_extra_model_outputs_end",
"_hanging_rewards_end",
"_hanging_actions_begin",
"_hanging_extra_model_outputs_begin",
"_hanging_rewards_begin",
"is_terminated",
"is_truncated",
"agent_episodes",
"_temporary_timestep_data",
"_start_time",
"_last_step_time",
)

SKIP_ENV_TS_TAG = "S"

def __init__(
Expand Down
79 changes: 79 additions & 0 deletions rllib/tuned_examples/sac/multi_agent_pendulum_sac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from ray.rllib.algorithms.sac import SACConfig
from ray.rllib.examples.envs.classes.multi_agent import MultiAgentPendulum
from ray.rllib.utils.metrics import (
ENV_RUNNER_RESULTS,
EPISODE_RETURN_MEAN,
NUM_ENV_STEPS_SAMPLED_LIFETIME,
)
from ray.tune.registry import register_env

from ray.rllib.utils.test_utils import add_rllib_example_script_args

parser = add_rllib_example_script_args()
# Use `parser` to add your own custom command line options to this script
# and (if needed) use their values to set up `config` below.
args = parser.parse_args()

register_env(
"multi_agent_pendulum",
lambda _: MultiAgentPendulum({"num_agents": args.num_agents or 2}),
)

config = (
SACConfig()
.environment(env="multi_agent_pendulum")
.rl_module(
model_config_dict={
"fcnet_hiddens": [256, 256],
"fcnet_activation": "relu",
"post_fcnet_hiddens": [],
"post_fcnet_activation": None,
"post_fcnet_weights_initializer": "orthogonal_",
"post_fcnet_weights_initializer_config": {"gain": 0.01},
}
)
.api_stack(
enable_rl_module_and_learner=True,
enable_env_runner_and_connector_v2=True,
)
.env_runners(
rollout_fragment_length=1,
num_env_runners=2,
num_envs_per_env_runner=1,
)
.training(
initial_alpha=1.001,
lr=3e-4,
target_entropy="auto",
n_step=1,
tau=0.005,
train_batch_size_per_learner=256,
target_network_update_freq=1,
replay_buffer_config={
"type": "MultiAgentEpisodeReplayBuffer",
"capacity": 100000,
},
num_steps_sampled_before_learning_starts=256,
)
.reporting(
metrics_num_episodes_for_smoothing=5,
min_sample_timesteps_per_iteration=1000,
)
)

if args.num_agents:
config.multi_agent(
policy_mapping_fn=lambda aid, *arg, **kw: f"p{aid}",
policies={f"p{i}" for i in range(args.num_agents)},
)

stop = {
NUM_ENV_STEPS_SAMPLED_LIFETIME: 500000,
# `episode_return_mean` is the sum of all agents/policies' returns.
f"{ENV_RUNNER_RESULTS}/{EPISODE_RETURN_MEAN}": -400.0 * (args.num_agents or 2),
}

if __name__ == "__main__":
from ray.rllib.utils.test_utils import run_rllib_example_script_experiment

run_rllib_example_script_experiment(config, args, stop=stop)
1 change: 1 addition & 0 deletions rllib/utils/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,4 @@
# Learner.
LEARNER_STATS_KEY = "learner_stats"
ALL_MODULES = "__all_modules__"
TD_ERROR_KEY = "td_error"
Loading
Loading