Skip to content

Commit

Permalink
Supports batched sub-environment with batch_size_per_env=1
Browse files Browse the repository at this point in the history
Also made the following improvements:

1. Differentiate array based and tensor based alf environment to
generate informative error message and apply TensorWrapper when necessary.

2. pass num_parallel_environments to env_load_fn for create_environment.
The sub-envrinonments need to know this information if it needs to
divide a big dataset into disjoint parts.
  • Loading branch information
emailweixu committed Sep 21, 2023
1 parent a37049a commit 9e06bb5
Show file tree
Hide file tree
Showing 16 changed files with 178 additions and 37 deletions.
9 changes: 9 additions & 0 deletions alf/environments/alf_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,15 @@ def task_names(self):
"""The name of each tasks."""
return [str(i) for i in range(self.num_tasks)]

@property
def is_tensor_based(self):
"""Whether the environment is tensor-based or not.
Tensor-based environment means that the observations and actions are
represented as tensors. Otherwise, they are represented as numpy arrays.
"""
return False

@property
def batched(self):
"""Whether the environment is batched or not.
Expand Down
4 changes: 4 additions & 0 deletions alf/environments/alf_gym3_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,10 @@ def _image_channel_first_permute_spec(spec):
# finishes the current episode
self._prev_first = [False] * self.batch_size

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True
Expand Down
4 changes: 4 additions & 0 deletions alf/environments/alf_gym_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,10 @@ def gym(self):
"""Return the gym environment. """
return self._gym_env

@property
def is_tensor_based(self):
return False

def _obtain_zero_info(self):
"""Get an env info of zeros only once when the env is created.
This info will be filled in each ``FIRST`` time step as a placeholder.
Expand Down
22 changes: 22 additions & 0 deletions alf/environments/alf_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def __getattr__(self, name):
"attempted to get missing private attribute '{}'".format(name))
return getattr(self._env, name)

@property
def is_tensor_based(self):
return self._env.is_tensor_based

@property
def batched(self):
return self._env.batched
Expand Down Expand Up @@ -818,6 +822,14 @@ def __init__(self, env):
'BatchedTensorWrapper can only be used to wrap non-batched env')
super().__init__(env)

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True

@staticmethod
def _to_batched_tensor(raw):
"""Conver the structured input into batched (batch_size = 1) tensors
Expand Down Expand Up @@ -846,6 +858,10 @@ def __init__(self, env):
'TensorWrapper can only be used to wrap batched env')
super().__init__(env)

@property
def is_tensor_based(self):
return True

@staticmethod
def _to_tensor(raw):
"""Conver the structured input into batched (batch_size = 1) tensors
Expand Down Expand Up @@ -1098,6 +1114,8 @@ def __init__(self, envs: List[AlfEnvironment]):
self._env_info_spec = self._envs[0].env_info_spec()
self._num_tasks = self._envs[0].num_tasks
self._task_names = self._envs[0].task_names
if any(env.is_tensor_based for env in self._envs):
raise ValueError('All environments must be array-based.')
if any(env.action_spec() != self._action_spec for env in self._envs):
raise ValueError(
'All environments must have the same action spec.')
Expand All @@ -1116,6 +1134,10 @@ def __init__(self, envs: List[AlfEnvironment]):
def metadata(self):
return self._envs[0].metadata

@property
def is_tensor_based(self):
return False

@property
def batched(self):
return True
Expand Down
14 changes: 9 additions & 5 deletions alf/environments/fast_parallel_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,10 @@ def __init__(
time_step_with_env_info_spec = self._time_step_spec._replace(
env_info=self._env_info_spec)
batch_size_per_env = self._envs[0].batch_size
if batch_size_per_env == 1:
assert not self._envs[
0].batched, "Does not support batched environment for if batch_size is 1"
batched = batch_size_per_env > 1
batched = self._envs[0].batched
if any(env.is_tensor_based for env in self._envs):
raise ValueError(
'All environments must be array-based environments.')
if any(env.action_spec() != self._action_spec for env in self._envs):
raise ValueError(
'All environments must have the same action spec.')
Expand All @@ -146,7 +146,7 @@ def __init__(
raise ValueError('All environments must have the same batched.')
self._closed = False
self._penv = _penv.ParallelEnvironment(
num_envs, num_spare_envs_for_reload, batch_size_per_env,
num_envs, num_spare_envs_for_reload, batch_size_per_env, batched,
self._action_spec, time_step_with_env_info_spec, name)

@property
Expand All @@ -172,6 +172,10 @@ def start(self):
env.wait_start()
logging.info('All processes started.')

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True
Expand Down
53 changes: 38 additions & 15 deletions alf/environments/parallel_environment.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,26 +259,30 @@ class EnvironmentBase {
std::string name_;
int num_envs_;
int batch_size_per_env_;
bool batched_;

protected:
SharedDataBuffer action_buffer_;
SharedDataBuffer timestep_buffer_;

EnvironmentBase(int num_envs,
int batch_size_per_env,
bool batched,
const py::object& action_spec,
const py::object& timestep_spec,
const std::string& name);
};

EnvironmentBase::EnvironmentBase(int num_envs,
int batch_size_per_env,
bool batched,
const py::object& action_spec,
const py::object& timestep_spec,
const std::string& name)
: name_(name),
num_envs_(num_envs),
batch_size_per_env_(batch_size_per_env),
batched_(batched || batch_size_per_env > 1),
action_buffer_(
action_spec, num_envs * batch_size_per_env, name + ".action"),
timestep_buffer_(
Expand All @@ -299,6 +303,7 @@ class ParallelEnvironment : public EnvironmentBase {
ParallelEnvironment(int num_envs,
int num_spare_envs,
int batch_size_per_env,
bool batched,
py::object action_spec,
py::object timestep_spec,
const std::string& name);
Expand All @@ -318,11 +323,16 @@ struct Job {
ParallelEnvironment::ParallelEnvironment(int num_envs,
int num_spare_envs,
int batch_size_per_env,
bool batched,
py::object action_spec,
py::object timestep_spec,
const std::string& name)
: EnvironmentBase(
num_envs, batch_size_per_env, action_spec, timestep_spec, name),
: EnvironmentBase(num_envs,
batch_size_per_env,
batched,
action_spec,
timestep_spec,
name),
num_spare_envs_(num_spare_envs),
original_env_ids_(num_envs),
reseted_(num_envs, 0),
Expand Down Expand Up @@ -443,6 +453,7 @@ class ProcessEnvironment : public EnvironmentBase {
int env_id,
int num_envs,
int batch_size_per_env,
bool batched,
const py::object& action_spec,
const py::object& timestep_spec,
const std::string& name);
Expand All @@ -461,11 +472,16 @@ ProcessEnvironment::ProcessEnvironment(py::object env,
int env_id,
int num_envs,
int batch_size_per_env,
bool batched,
const py::object& action_spec,
const py::object& timestep_spec,
const std::string& name)
: EnvironmentBase(
num_envs, batch_size_per_env, action_spec, timestep_spec, name),
: EnvironmentBase(num_envs,
batch_size_per_env,
batched,
action_spec,
timestep_spec,
name),
name_(name),
env_id_(env_id),
env_(env),
Expand Down Expand Up @@ -536,8 +552,7 @@ void ProcessEnvironment::Step(int env_id) {
just_reseted_ = false;
} else {
auto action = action_buffer_.ViewAsNestedArray(
env_id * batch_size_per_env_,
(batch_size_per_env_ > 1) ? batch_size_per_env_ : 0);
env_id * batch_size_per_env_, batched_ ? batch_size_per_env_ : 0);
timestep = py_step_(action);
}
SendTimestep(timestep, env_id);
Expand All @@ -549,7 +564,7 @@ void ProcessEnvironment::Reset() {
}

void ProcessEnvironment::SendTimestep(const py::object& timestep, int env_id) {
if (batch_size_per_env_ == 1) {
if (!batched_) {
timestep_buffer_.WriteSlice(timestep, env_id);
env_id_buf_[env_id] = env_id;
} else {
Expand Down Expand Up @@ -590,14 +605,20 @@ void ProcessEnvironment::Worker() {

PYBIND11_MODULE(_penv, m) {
py::class_<ParallelEnvironment>(m, "ParallelEnvironment")
.def(
py::init<int, int, int, py::object, py::object, const std::string&>(),
py::arg("num_envs"),
py::arg("num_spare_envs"),
py::arg("batch_size_per_env"),
py::arg("action_spec"),
py::arg("timestep_spec"),
py::arg("name"))
.def(py::init<int,
int,
int,
int,
py::object,
py::object,
const std::string&>(),
py::arg("num_envs"),
py::arg("num_spare_envs"),
py::arg("batch_size_per_env"),
py::arg("batched"),
py::arg("action_spec"),
py::arg("timestep_spec"),
py::arg("name"))
.def("step",
&ParallelEnvironment::Step,
R"pbdoc(
Expand All @@ -613,6 +634,7 @@ PYBIND11_MODULE(_penv, m) {
int,
int,
int,
int,
py::object,
py::object,
const std::string&>(),
Expand All @@ -621,6 +643,7 @@ PYBIND11_MODULE(_penv, m) {
py::arg("env_id"),
py::arg("num_envs"),
py::arg("batch_size_per_env"),
py::arg("batched"),
py::arg("action_spec"),
py::arg("timestep_spec"),
py::arg("name"))
Expand Down
9 changes: 8 additions & 1 deletion alf/environments/parallel_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def __init__(self,
self._task_names = self._envs[0].task_names
self._time_step_with_env_info_spec = self._time_step_spec._replace(
env_info=self._env_info_spec)
if any(env.is_tensor_based for env in self._envs):
raise ValueError(
'All environments must be array-based environments.')
if any(env.action_spec() != self._action_spec for env in self._envs):
raise ValueError(
'All environments must have the same action spec.')
Expand Down Expand Up @@ -130,6 +133,10 @@ def start(self):
env.wait_start()
logging.info('All processes started.')

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True
Expand All @@ -153,7 +160,7 @@ def call_envs_with_same_args(self, func_name, *args, **kwargs):
func_name (str): name of the function to call
*args: args to pass to the function
**kwargs: kwargs to pass to the function
return:
list: list of results from each environment
"""
Expand Down
2 changes: 1 addition & 1 deletion alf/environments/process_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _worker(conn: multiprocessing.connection,
penv = _penv.ProcessEnvironment(
env, partial(process_call, conn, env, flatten,
action_spec), env_id, num_envs, env.batch_size,
env.action_spec(),
env.batched, env.action_spec(),
env.time_step_spec()._replace(env_info=env.env_info_spec()),
name)
conn.send(_MessageType.READY) # Ready.
Expand Down
4 changes: 4 additions & 0 deletions alf/environments/random_alf_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ def observation_spec(self):
def action_spec(self):
return self._action_spec

@property
def is_tensor_based(self):
return self._use_tensor_time_step

@property
def batch_size(self):
return self._batch_size if self.batched else 1
Expand Down
4 changes: 4 additions & 0 deletions alf/environments/suite_carla.py
Original file line number Diff line number Diff line change
Expand Up @@ -1681,6 +1681,10 @@ def _clear(self):
self._other_vehicles.clear()
self._walkers.clear()

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True
Expand Down
4 changes: 4 additions & 0 deletions alf/environments/suite_go.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,10 @@ def __init__(self,
logging.info("P : pass")
logging.info("SPACE : refresh display")

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True
Expand Down
4 changes: 4 additions & 0 deletions alf/environments/suite_metadrive.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def __init__(self, metadrive_env: metadrive.MetaDriveEnv, env_id: int = 0):

self._current_observation = None

@property
def is_tensor_based(self):
return False

@property
def batched(self):
# TODO(breakds): Add support for multiple algorithm controlled agents in
Expand Down
4 changes: 4 additions & 0 deletions alf/environments/suite_tic_tac_toe.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ def __init__(self, batch_size):
self._player_0 = torch.tensor(-1.)
self._player_1 = torch.tensor(1.)

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True
Expand Down
4 changes: 4 additions & 0 deletions alf/environments/suite_unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def _create_action_spec(act_type):

self.reset()

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True
Expand Down
5 changes: 5 additions & 0 deletions alf/environments/thread_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ def __init__(self, env_constructor):
super().__init__()
self._pool = mp_threads.Pool(1)
self._env = self._pool.apply(env_constructor)
assert not self._env.batched
self._closed = False

@property
def is_tensor_based(self):
return True

@property
def batched(self):
return True
Expand Down
Loading

0 comments on commit 9e06bb5

Please sign in to comment.