Skip to content

Commit

Permalink
[Feature] PPO minibatch advantage (#100)
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini authored Jun 17, 2024
1 parent d5b0f51 commit 11c55c2
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 32 deletions.
54 changes: 39 additions & 15 deletions benchmarl/algorithms/ippo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class Ippo(Algorithm):
choices: "softplus", "exp", "relu", "biased_softplus_1";
use_tanh_normal (bool): if ``True``, use TanhNormal as the continuyous action distribution with support bound
to the action domain. Otherwise, an IndependentNormal is used.
minibatch_advantage (bool): if ``True``, advantage computation is perfomend on minibatches of size
``experiment.config.on_policy_minibatch_size`` instead of the full
``experiment.config.on_policy_collected_frames_per_batch``, this helps not exploding memory usage
"""

Expand All @@ -49,6 +52,7 @@ def __init__(
lmbda: float,
scale_mapping: str,
use_tanh_normal: bool,
minibatch_advantage: bool,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -61,6 +65,7 @@ def __init__(
self.lmbda = lmbda
self.scale_mapping = scale_mapping
self.use_tanh_normal = use_tanh_normal
self.minibatch_advantage = minibatch_advantage

#############################
# Overridden abstract methods
Expand Down Expand Up @@ -148,15 +153,17 @@ def _get_policy_for_loss(
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=IndependentNormal
if not self.use_tanh_normal
else TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {},
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
Expand Down Expand Up @@ -221,14 +228,30 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)),
)

with torch.no_grad():
loss = self.get_loss_and_updater(group)[0]
loss.value_estimator(
batch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
loss = self.get_loss_and_updater(group)[0]
if self.minibatch_advantage:
increment = -(
-self.experiment.config.train_minibatch_size(self.on_policy)
// batch.shape[1]
)
else:
increment = batch.batch_size[0] + 1
last_start_index = 0
start_index = increment
minibatches = []
while last_start_index < batch.shape[0]:
minimbatch = batch[last_start_index:start_index]
minibatches.append(minimbatch)
with torch.no_grad():
loss.value_estimator(
minimbatch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
)
last_start_index = start_index
start_index += increment

batch = torch.cat(minibatches, dim=0)
return batch

def process_loss_vals(
Expand Down Expand Up @@ -285,6 +308,7 @@ class IppoConfig(AlgorithmConfig):
lmbda: float = MISSING
scale_mapping: str = MISSING
use_tanh_normal: bool = MISSING
minibatch_advantage: bool = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
Expand Down
54 changes: 39 additions & 15 deletions benchmarl/algorithms/mappo.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ class Mappo(Algorithm):
choices: "softplus", "exp", "relu", "biased_softplus_1";
use_tanh_normal (bool): if ``True``, use TanhNormal as the continuyous action distribution with support bound
to the action domain. Otherwise, an IndependentNormal is used.
minibatch_advantage (bool): if ``True``, advantage computation is perfomend on minibatches of size
``experiment.config.on_policy_minibatch_size`` instead of the full
``experiment.config.on_policy_collected_frames_per_batch``, this helps not exploding memory usage
"""

Expand All @@ -53,6 +56,7 @@ def __init__(
lmbda: float,
scale_mapping: str,
use_tanh_normal: bool,
minibatch_advantage: bool,
**kwargs
):
super().__init__(**kwargs)
Expand All @@ -65,6 +69,7 @@ def __init__(
self.lmbda = lmbda
self.scale_mapping = scale_mapping
self.use_tanh_normal = use_tanh_normal
self.minibatch_advantage = minibatch_advantage

#############################
# Overridden abstract methods
Expand Down Expand Up @@ -152,15 +157,17 @@ def _get_policy_for_loss(
spec=self.action_spec[group, "action"],
in_keys=[(group, "loc"), (group, "scale")],
out_keys=[(group, "action")],
distribution_class=IndependentNormal
if not self.use_tanh_normal
else TanhNormal,
distribution_kwargs={
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {},
distribution_class=(
IndependentNormal if not self.use_tanh_normal else TanhNormal
),
distribution_kwargs=(
{
"min": self.action_spec[(group, "action")].space.low,
"max": self.action_spec[(group, "action")].space.high,
}
if self.use_tanh_normal
else {}
),
return_log_prob=True,
log_prob_key=(group, "log_prob"),
)
Expand Down Expand Up @@ -225,14 +232,30 @@ def process_batch(self, group: str, batch: TensorDictBase) -> TensorDictBase:
batch.get(("next", "reward")).unsqueeze(-1).expand((*group_shape, 1)),
)

with torch.no_grad():
loss = self.get_loss_and_updater(group)[0]
loss.value_estimator(
batch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
loss = self.get_loss_and_updater(group)[0]
if self.minibatch_advantage:
increment = -(
-self.experiment.config.train_minibatch_size(self.on_policy)
// batch.shape[1]
)
else:
increment = batch.batch_size[0] + 1
last_start_index = 0
start_index = increment
minibatches = []
while last_start_index < batch.shape[0]:
minimbatch = batch[last_start_index:start_index]
minibatches.append(minimbatch)
with torch.no_grad():
loss.value_estimator(
minimbatch,
params=loss.critic_network_params,
target_params=loss.target_critic_network_params,
)
last_start_index = start_index
start_index += increment

batch = torch.cat(minibatches, dim=0)
return batch

def process_loss_vals(
Expand Down Expand Up @@ -321,6 +344,7 @@ class MappoConfig(AlgorithmConfig):
lmbda: float = MISSING
scale_mapping: str = MISSING
use_tanh_normal: bool = MISSING
minibatch_advantage: bool = MISSING

@staticmethod
def associated_class() -> Type[Algorithm]:
Expand Down
1 change: 1 addition & 0 deletions benchmarl/conf/algorithm/ippo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ loss_critic_type: "l2"
lmbda: 0.9
scale_mapping: "biased_softplus_1.0"
use_tanh_normal: True
minibatch_advantage: False
1 change: 1 addition & 0 deletions benchmarl/conf/algorithm/mappo.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,4 @@ loss_critic_type: "l2"
lmbda: 0.9
scale_mapping: "biased_softplus_1.0"
use_tanh_normal: True
minibatch_advantage: False
6 changes: 4 additions & 2 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,8 +606,10 @@ def _collection_loop(self):
training_tds = []
for _ in range(self.config.n_optimizer_steps(self.on_policy)):
for _ in range(
self.config.train_batch_size(self.on_policy)
// self.config.train_minibatch_size(self.on_policy)
-(
-self.config.train_batch_size(self.on_policy)
// self.config.train_minibatch_size(self.on_policy)
)
):
training_tds.append(self._optimizer_loop(group))
training_td = torch.stack(training_tds)
Expand Down

0 comments on commit 11c55c2

Please sign in to comment.