Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Jul 30, 2024
1 parent 44a1e8a commit 01be43f
Show file tree
Hide file tree
Showing 13 changed files with 350 additions and 146 deletions.
114 changes: 108 additions & 6 deletions benchmarl/algorithms/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,27 @@
import pathlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type

from tensordict import TensorDictBase
from tensordict.nn import TensorDictModule, TensorDictSequential
from torchrl.data import (
CompositeSpec,
DiscreteTensorSpec,
LazyTensorStorage,
OneHotDiscreteTensorSpec,
ReplayBuffer,
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import RandomSampler, SamplerWithoutReplacement
from torchrl.envs import Compose, Transform
from torchrl.envs import (
Compose,
EnvBase,
InitTracker,
TensorDictPrimer,
Transform,
TransformedEnv,
)
from torchrl.objectives import LossModule
from torchrl.objectives.utils import HardUpdate, SoftUpdate, TargetNetUpdater

Expand Down Expand Up @@ -51,6 +59,16 @@ def __init__(self, experiment):
self.action_spec = experiment.action_spec
self.state_spec = experiment.state_spec
self.action_mask_spec = experiment.action_mask_spec
self.has_independent_critic = (
experiment.algorithm_config.has_independent_critic()
)
self.has_centralized_critic = (
experiment.algorithm_config.has_centralized_critic()
)
self.has_critic = experiment.algorithm_config.has_critic
self.has_rnn = self.model_config.is_rnn or (
self.critic_model_config.is_rnn and self.has_critic
)

# Cached values that will be instantiated only once and then remain fixed
self._losses_and_updaters = {}
Expand Down Expand Up @@ -142,10 +160,7 @@ def get_replay_buffer(
"""
memory_size = self.experiment_config.replay_buffer_memory_size(self.on_policy)
sampling_size = self.experiment_config.train_minibatch_size(self.on_policy)
if (
self.experiment.model_config.is_rnn
or self.experiment.critic_model_config.is_rnn
):
if self.has_rnn:
sequence_length = -(
-self.experiment_config.collected_frames_per_batch(self.on_policy)
// self.experiment_config.n_envs_per_worker(self.on_policy)
Expand Down Expand Up @@ -229,6 +244,69 @@ def get_parameters(self, group: str) -> Dict[str, Iterable]:
loss=self.get_loss_and_updater(group)[0],
)

def process_env_fun(
self,
env_fun: Callable[[], EnvBase],
) -> Callable[[], EnvBase]:
"""
This function can be used to wrap env_fun
Args:
env_fun (callable): a function that takes no args and creates an enviornment
Returns: a function that takes no args and creates an enviornment
"""
if self.has_rnn:

def model_fun():
env = env_fun()

spec_actor = self.model_config.get_model_state_spec()
spec_actor = CompositeSpec(
{
group: CompositeSpec(
spec_actor.expand(len(agents), *spec_actor.shape),
shape=(len(agents),),
)
for group, agents in self.group_map.items()
}
)
# if self.has_critic and self.critic_model_config.is_rnn and False:
# spec_critic = self.critic_model_config.get_model_state_spec()
# if (
# self.has_independent_critic
# or not self.critic_model_config.share_param_critic
# ):
# spec_critic = CompositeSpec(
# {
# group: CompositeSpec(
# spec_critic.expand(len(agents), *spec_critic.shape),
# shape=(len(agents),),
# )
# for group, agents in self.group_map.items()
# }
# )
# spec_actor.update(spec_critic)

env = TransformedEnv(
env,
Compose(
*(
[InitTracker(init_key="is_init")]
+ (
[TensorDictPrimer(spec_actor)]
if len(spec_actor.keys(True, True)) > 0
else []
)
)
),
)
return env

return model_fun

return env_fun

###############################
# Abstract methods to implement
###############################
Expand Down Expand Up @@ -410,3 +488,27 @@ def supports_discrete_actions() -> bool:
If the algorithm supports discrete actions
"""
raise NotImplementedError

@staticmethod
def has_independent_critic() -> bool:
"""
If the algorithm uses an independent critic
"""
return False

@staticmethod
def has_centralized_critic() -> bool:
"""
If the algorithm uses a centralized critic
"""
return False

def has_critic(self) -> bool:
"""
If the algorithm uses a critic
"""
if self.has_centralized_critic() and self.has_independent_critic():
raise ValueError(
"Algorithm can either have a centralized critic or an indpendent one"
)
return self.has_centralized_critic() or self.has_independent_critic()
18 changes: 12 additions & 6 deletions benchmarl/algorithms/iddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def _get_policy_for_loss(
in_keys=[(group, "param")],
out_keys=[(group, "action")],
distribution_class=TanhDelta if self.use_tanh_mapping else Delta,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_mapping
else {},
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_mapping
else {}
),
return_log_prob=False,
safe=not self.use_tanh_mapping,
)
Expand Down Expand Up @@ -249,3 +251,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def has_independent_critic() -> bool:
return True
4 changes: 4 additions & 0 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return True

@staticmethod
def has_independent_critic() -> bool:
return True
24 changes: 15 additions & 9 deletions benchmarl/algorithms/isac.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,17 @@ def _get_policy_for_loss(
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=IndependentNormal
if not self.use_tanh_normal
else TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {},
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
Expand Down Expand Up @@ -387,3 +389,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def has_independent_critic() -> bool:
return True
18 changes: 12 additions & 6 deletions benchmarl/algorithms/maddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,14 @@ def _get_policy_for_loss(
in_keys=[(group, "param")],
out_keys=[(group, "action")],
distribution_class=TanhDelta if self.use_tanh_mapping else Delta,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_mapping
else {},
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_mapping
else {}
),
return_log_prob=False,
safe=not self.use_tanh_mapping,
)
Expand Down Expand Up @@ -299,3 +301,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def has_centralized_critic() -> bool:
return True
4 changes: 4 additions & 0 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return True

@staticmethod
def has_centralized_critic() -> bool:
return True
24 changes: 15 additions & 9 deletions benchmarl/algorithms/masac.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,17 @@ def _get_policy_for_loss(
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=IndependentNormal
if not self.use_tanh_normal
else TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {},
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
Expand Down Expand Up @@ -461,3 +463,7 @@ def supports_discrete_actions() -> bool:
@staticmethod
def on_policy() -> bool:
return False

@staticmethod
def has_centralized_critic() -> bool:
return True
2 changes: 1 addition & 1 deletion benchmarl/conf/model/layers/gru.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
name: gru

hidden_size: 128
compile: True
compile: False

mlp_num_cells: [256, 256]
mlp_layer_class: torch.nn.Linear
Expand Down
39 changes: 21 additions & 18 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,8 +322,13 @@ def __init__(
self.task = task
self.model_config = model_config
self.critic_model_config = (
critic_model_config if critic_model_config is not None else model_config
critic_model_config
if critic_model_config is not None
else copy.deepcopy(model_config)
)
self.model_config.is_critic = False
self.critic_model_config.is_critic = True

self.algorithm_config = algorithm_config
self.seed = seed

Expand Down Expand Up @@ -377,23 +382,17 @@ def _set_action_type(self):
)

def _setup_task(self):
test_env = self.model_config.process_env_fun(
self.task.get_env_fun(
num_envs=self.config.evaluation_episodes,
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
),
self.task,
test_env = self.task.get_env_fun(
num_envs=self.config.evaluation_episodes,
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)()
env_func = self.model_config.process_env_fun(
self.task.get_env_fun(
num_envs=self.config.n_envs_per_worker(self.on_policy),
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
),
self.task,
env_func = self.task.get_env_fun(
num_envs=self.config.n_envs_per_worker(self.on_policy),
continuous_actions=self.continuous_actions,
seed=self.seed,
device=self.config.sampling_device,
)

transforms_env = self.task.get_env_transforms(test_env)
Expand Down Expand Up @@ -429,6 +428,10 @@ def _setup_task(self):

def _setup_algorithm(self):
self.algorithm = self.algorithm_config.get_algorithm(experiment=self)

self.test_env = self.algorithm.process_env_fun(lambda: self.test_env)()
self.env_func = self.algorithm.process_env_fun(self.env_func)

self.replay_buffers = {
group: self.algorithm.get_replay_buffer(
group=group,
Expand Down Expand Up @@ -612,7 +615,7 @@ def _collection_loop(self):
for group in self.train_group_map.keys():
group_batch = batch.exclude(*self._get_excluded_keys(group))
group_batch = self.algorithm.process_batch(group, group_batch)
if not (self.model_config.is_rnn or self.critic_model_config.is_rnn):
if not self.algorithm.has_rnn:
group_batch = group_batch.reshape(-1)
self.replay_buffers[group].extend(group_batch)

Expand Down
2 changes: 2 additions & 0 deletions benchmarl/models/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ def __init__(
share_params=kwargs.pop("share_params"),
device=kwargs.pop("device"),
action_spec=kwargs.pop("action_spec"),
model_index=kwargs.pop("model_index"),
is_critic=kwargs.pop("is_critic"),
)

self.x = self.input_spec[self.image_in_keys[0]].shape[-3]
Expand Down
Loading

0 comments on commit 01be43f

Please sign in to comment.