Skip to content

Commit

Permalink
parallel collection
Browse files Browse the repository at this point in the history
  • Loading branch information
matteobettini committed Dec 15, 2024
1 parent 271edee commit 61c0a22
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 18 deletions.
7 changes: 5 additions & 2 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ share_policy_params: True
prefer_continuous_actions: True
# If False collection is done using a collector (under no grad). If True, collection is done with gradients.
collect_with_grad: False
# In case of non-vectorized environments, weather to run collection of multiple processes
# If this is used, there will be n_envs_per_worker processes, collecting frames_per_batch/n_envs_per_worker each
parallel_collection: False

# Discount factor
gamma: 0.9
Expand Down Expand Up @@ -51,7 +54,7 @@ max_n_frames: 3_000_000
on_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially.
# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection.
on_policy_n_envs_per_worker: 10
# This is the number of times collected_frames_per_batch will be split into minibatches and trained
on_policy_n_minibatch_iters: 45
Expand All @@ -63,7 +66,7 @@ on_policy_minibatch_size: 400
off_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially.
# Otherwise batching will be simulated and each env will be run sequentially or parallely depending on parallel_collection.
off_policy_n_envs_per_worker: 10
# This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over.
off_policy_n_optimizer_steps: 1000
Expand Down
2 changes: 1 addition & 1 deletion benchmarl/environments/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def _type_check_task_config(
else:
if warn_on_missing_dataclass:
warnings.warn(
"TaskConfig python dataclass not foud, task is being loaded without type checks"
"TaskConfig python dataclass not found, task is being loaded without type checks"
)
return config

Expand Down
9 changes: 5 additions & 4 deletions benchmarl/environments/magent/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

from torchrl.data import Composite
Expand Down Expand Up @@ -31,17 +31,18 @@ def get_env_fun(
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)

return lambda: PettingZooWrapper(
env=self.__get_env(),
env=self.__get_env(config),
return_state=True,
seed=seed,
done_on_any=False,
use_mask=False,
device=device,
)

def __get_env(self) -> EnvBase:
def __get_env(self, config) -> EnvBase:
try:
from magent2.environments import (
adversarial_pursuit_v4,
Expand All @@ -66,7 +67,7 @@ def __get_env(self) -> EnvBase:
}
if self.name not in envs:
raise Exception(f"{self.name} is not an environment of MAgent2")
return envs[self.name].parallel_env(**self.config, render_mode="rgb_array")
return envs[self.name].parallel_env(**config, render_mode="rgb_array")

def supports_continuous_actions(self) -> bool:
return False
Expand Down
6 changes: 4 additions & 2 deletions benchmarl/environments/meltingpot/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

import torch
Expand Down Expand Up @@ -84,11 +84,13 @@ def get_env_fun(
) -> Callable[[], EnvBase]:
from torchrl.envs.libs.meltingpot import MeltingpotEnv

config = copy.deepcopy(self.config)

return lambda: MeltingpotEnv(
substrate=self.name.lower(),
categorical_actions=True,
device=device,
**self.config,
**config,
)

def supports_continuous_actions(self) -> bool:
Expand Down
7 changes: 4 additions & 3 deletions benchmarl/environments/pettingzoo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

from torchrl.data import Composite
Expand Down Expand Up @@ -35,17 +36,17 @@ def get_env_fun(
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)
if self.supports_continuous_actions() and self.supports_discrete_actions():
self.config.update({"continuous_actions": continuous_actions})

config.update({"continuous_actions": continuous_actions})
return lambda: PettingZooEnv(
categorical_actions=True,
device=device,
seed=seed,
parallel=True,
return_state=self.has_state(),
render_mode="rgb_array",
**self.config
**config
)

def supports_continuous_actions(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions benchmarl/environments/smacv2/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

import torch
Expand Down Expand Up @@ -42,8 +42,9 @@ def get_env_fun(
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)
return lambda: SMACv2Env(
categorical_actions=True, seed=seed, device=device, **self.config
categorical_actions=True, seed=seed, device=device, **config
)

def supports_continuous_actions(self) -> bool:
Expand Down
5 changes: 3 additions & 2 deletions benchmarl/environments/vmas/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import copy
from typing import Callable, Dict, List, Optional

from torchrl.data import Composite
Expand Down Expand Up @@ -52,6 +52,7 @@ def get_env_fun(
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:
config = copy.deepcopy(self.config)
return lambda: VmasEnv(
scenario=self.name.lower(),
num_envs=num_envs,
Expand All @@ -60,7 +61,7 @@ def get_env_fun(
device=device,
categorical_actions=True,
clamp_actions=True,
**self.config,
**config,
)

def supports_continuous_actions(self) -> bool:
Expand Down
9 changes: 7 additions & 2 deletions benchmarl/experiment/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@
import torch
from tensordict import TensorDictBase
from tensordict.nn import TensorDictSequential

from torchrl.collectors import SyncDataCollector
from torchrl.envs import SerialEnv, TransformedEnv
from torchrl.envs import ParallelEnv, SerialEnv, TransformedEnv
from torchrl.envs.transforms import Compose
from torchrl.envs.utils import ExplorationType, set_exploration_type, step_mdp
from torchrl.record.loggers import generate_exp_name
Expand Down Expand Up @@ -58,6 +59,7 @@ class ExperimentConfig:
share_policy_params: bool = MISSING
prefer_continuous_actions: bool = MISSING
collect_with_grad: bool = MISSING
parallel_collection: bool = MISSING

gamma: float = MISSING
lr: float = MISSING
Expand Down Expand Up @@ -435,8 +437,11 @@ def _setup_task(self):
transforms_training = Compose(*transforms_training)

if test_env.batch_size == ():
env_class = (
SerialEnv if not self.config.parallel_collection else ParallelEnv
)
self.env_func = lambda: TransformedEnv(
SerialEnv(self.config.n_envs_per_worker(self.on_policy), env_func),
env_class(self.config.n_envs_per_worker(self.on_policy), env_func),
transforms_training.clone(),
)
else:
Expand Down
1 change: 1 addition & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def experiment_config(tmp_path) -> ExperimentConfig:
experiment_config.on_policy_n_envs_per_worker = (
experiment_config.off_policy_n_envs_per_worker
) = 2
experiment_config.parallel_collection = False
experiment_config.off_policy_n_optimizer_steps = 2
experiment_config.off_policy_train_batch_size = 3
experiment_config.off_policy_memory_size = 200
Expand Down
3 changes: 3 additions & 0 deletions test/test_pettingzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,16 @@ def test_all_algos(

@pytest.mark.parametrize("algo_config", [IppoConfig, MasacConfig])
@pytest.mark.parametrize("task", list(PettingZooTask))
@pytest.mark.parametrize("parallel_collection", [True, False])
def test_all_tasks(
self,
algo_config: AlgorithmConfig,
task: Task,
parallel_collection,
experiment_config,
mlp_sequence_config,
):
experiment_config.parallel_collection = parallel_collection
task = task.get_from_yaml()
experiment = Experiment(
algorithm_config=algo_config.get_from_yaml(),
Expand Down

0 comments on commit 61c0a22

Please sign in to comment.