diff --git a/alf/environments/alf_environment.py b/alf/environments/alf_environment.py index fe9e11779..0b8dcf598 100644 --- a/alf/environments/alf_environment.py +++ b/alf/environments/alf_environment.py @@ -77,6 +77,15 @@ def task_names(self): """The name of each tasks.""" return [str(i) for i in range(self.num_tasks)] + @property + def is_tensor_based(self): + """Whether the environment is tensor-based or not. + + Tensor-based environment means that the observations and actions are + represented as tensors. Otherwise, they are represented as numpy arrays. + """ + return False + @property def batched(self): """Whether the environment is batched or not. diff --git a/alf/environments/alf_gym3_wrapper.py b/alf/environments/alf_gym3_wrapper.py index 619249a97..1f5a94158 100644 --- a/alf/environments/alf_gym3_wrapper.py +++ b/alf/environments/alf_gym3_wrapper.py @@ -264,6 +264,10 @@ def _image_channel_first_permute_spec(spec): # finishes the current episode self._prev_first = [False] * self.batch_size + @property + def is_tensor_based(self): + return True + @property def batched(self): return True diff --git a/alf/environments/alf_gym_wrapper.py b/alf/environments/alf_gym_wrapper.py index 55716d1db..267b87750 100644 --- a/alf/environments/alf_gym_wrapper.py +++ b/alf/environments/alf_gym_wrapper.py @@ -167,6 +167,10 @@ def gym(self): """Return the gym environment. """ return self._gym_env + @property + def is_tensor_based(self): + return False + def _obtain_zero_info(self): """Get an env info of zeros only once when the env is created. This info will be filled in each ``FIRST`` time step as a placeholder. diff --git a/alf/environments/alf_wrappers.py b/alf/environments/alf_wrappers.py index 6fb189770..b57cf3947 100644 --- a/alf/environments/alf_wrappers.py +++ b/alf/environments/alf_wrappers.py @@ -62,6 +62,10 @@ def __getattr__(self, name): "attempted to get missing private attribute '{}'".format(name)) return getattr(self._env, name) + @property + def is_tensor_based(self): + return self._env.is_tensor_based + @property def batched(self): return self._env.batched @@ -818,6 +822,14 @@ def __init__(self, env): 'BatchedTensorWrapper can only be used to wrap non-batched env') super().__init__(env) + @property + def is_tensor_based(self): + return True + + @property + def batched(self): + return True + @staticmethod def _to_batched_tensor(raw): """Conver the structured input into batched (batch_size = 1) tensors @@ -846,6 +858,10 @@ def __init__(self, env): 'TensorWrapper can only be used to wrap batched env') super().__init__(env) + @property + def is_tensor_based(self): + return True + @staticmethod def _to_tensor(raw): """Conver the structured input into batched (batch_size = 1) tensors @@ -1098,6 +1114,8 @@ def __init__(self, envs: List[AlfEnvironment]): self._env_info_spec = self._envs[0].env_info_spec() self._num_tasks = self._envs[0].num_tasks self._task_names = self._envs[0].task_names + if any(env.is_tensor_based for env in self._envs): + raise ValueError('All environments must be array-based.') if any(env.action_spec() != self._action_spec for env in self._envs): raise ValueError( 'All environments must have the same action spec.') @@ -1116,6 +1134,10 @@ def __init__(self, envs: List[AlfEnvironment]): def metadata(self): return self._envs[0].metadata + @property + def is_tensor_based(self): + return False + @property def batched(self): return True diff --git a/alf/environments/fast_parallel_environment.py b/alf/environments/fast_parallel_environment.py index 7b85e5355..6a4498301 100644 --- a/alf/environments/fast_parallel_environment.py +++ b/alf/environments/fast_parallel_environment.py @@ -125,10 +125,10 @@ def __init__( time_step_with_env_info_spec = self._time_step_spec._replace( env_info=self._env_info_spec) batch_size_per_env = self._envs[0].batch_size - if batch_size_per_env == 1: - assert not self._envs[ - 0].batched, "Does not support batched environment for if batch_size is 1" - batched = batch_size_per_env > 1 + batched = self._envs[0].batched + if any(env.is_tensor_based for env in self._envs): + raise ValueError( + 'All environments must be array-based environments.') if any(env.action_spec() != self._action_spec for env in self._envs): raise ValueError( 'All environments must have the same action spec.') @@ -146,7 +146,7 @@ def __init__( raise ValueError('All environments must have the same batched.') self._closed = False self._penv = _penv.ParallelEnvironment( - num_envs, num_spare_envs_for_reload, batch_size_per_env, + num_envs, num_spare_envs_for_reload, batch_size_per_env, batched, self._action_spec, time_step_with_env_info_spec, name) @property @@ -172,6 +172,10 @@ def start(self): env.wait_start() logging.info('All processes started.') + @property + def is_tensor_based(self): + return True + @property def batched(self): return True diff --git a/alf/environments/parallel_environment.cpp b/alf/environments/parallel_environment.cpp index df6af6083..7418307f3 100644 --- a/alf/environments/parallel_environment.cpp +++ b/alf/environments/parallel_environment.cpp @@ -259,6 +259,7 @@ class EnvironmentBase { std::string name_; int num_envs_; int batch_size_per_env_; + bool batched_; protected: SharedDataBuffer action_buffer_; @@ -266,6 +267,7 @@ class EnvironmentBase { EnvironmentBase(int num_envs, int batch_size_per_env, + bool batched, const py::object& action_spec, const py::object& timestep_spec, const std::string& name); @@ -273,12 +275,14 @@ class EnvironmentBase { EnvironmentBase::EnvironmentBase(int num_envs, int batch_size_per_env, + bool batched, const py::object& action_spec, const py::object& timestep_spec, const std::string& name) : name_(name), num_envs_(num_envs), batch_size_per_env_(batch_size_per_env), + batched_(batched || batch_size_per_env > 1), action_buffer_( action_spec, num_envs * batch_size_per_env, name + ".action"), timestep_buffer_( @@ -299,6 +303,7 @@ class ParallelEnvironment : public EnvironmentBase { ParallelEnvironment(int num_envs, int num_spare_envs, int batch_size_per_env, + bool batched, py::object action_spec, py::object timestep_spec, const std::string& name); @@ -318,11 +323,16 @@ struct Job { ParallelEnvironment::ParallelEnvironment(int num_envs, int num_spare_envs, int batch_size_per_env, + bool batched, py::object action_spec, py::object timestep_spec, const std::string& name) - : EnvironmentBase( - num_envs, batch_size_per_env, action_spec, timestep_spec, name), + : EnvironmentBase(num_envs, + batch_size_per_env, + batched, + action_spec, + timestep_spec, + name), num_spare_envs_(num_spare_envs), original_env_ids_(num_envs), reseted_(num_envs, 0), @@ -443,6 +453,7 @@ class ProcessEnvironment : public EnvironmentBase { int env_id, int num_envs, int batch_size_per_env, + bool batched, const py::object& action_spec, const py::object& timestep_spec, const std::string& name); @@ -461,11 +472,16 @@ ProcessEnvironment::ProcessEnvironment(py::object env, int env_id, int num_envs, int batch_size_per_env, + bool batched, const py::object& action_spec, const py::object& timestep_spec, const std::string& name) - : EnvironmentBase( - num_envs, batch_size_per_env, action_spec, timestep_spec, name), + : EnvironmentBase(num_envs, + batch_size_per_env, + batched, + action_spec, + timestep_spec, + name), name_(name), env_id_(env_id), env_(env), @@ -536,8 +552,7 @@ void ProcessEnvironment::Step(int env_id) { just_reseted_ = false; } else { auto action = action_buffer_.ViewAsNestedArray( - env_id * batch_size_per_env_, - (batch_size_per_env_ > 1) ? batch_size_per_env_ : 0); + env_id * batch_size_per_env_, batched_ ? batch_size_per_env_ : 0); timestep = py_step_(action); } SendTimestep(timestep, env_id); @@ -549,7 +564,7 @@ void ProcessEnvironment::Reset() { } void ProcessEnvironment::SendTimestep(const py::object& timestep, int env_id) { - if (batch_size_per_env_ == 1) { + if (!batched_) { timestep_buffer_.WriteSlice(timestep, env_id); env_id_buf_[env_id] = env_id; } else { @@ -590,14 +605,20 @@ void ProcessEnvironment::Worker() { PYBIND11_MODULE(_penv, m) { py::class_(m, "ParallelEnvironment") - .def( - py::init(), - py::arg("num_envs"), - py::arg("num_spare_envs"), - py::arg("batch_size_per_env"), - py::arg("action_spec"), - py::arg("timestep_spec"), - py::arg("name")) + .def(py::init(), + py::arg("num_envs"), + py::arg("num_spare_envs"), + py::arg("batch_size_per_env"), + py::arg("batched"), + py::arg("action_spec"), + py::arg("timestep_spec"), + py::arg("name")) .def("step", &ParallelEnvironment::Step, R"pbdoc( @@ -613,6 +634,7 @@ PYBIND11_MODULE(_penv, m) { int, int, int, + int, py::object, py::object, const std::string&>(), @@ -621,6 +643,7 @@ PYBIND11_MODULE(_penv, m) { py::arg("env_id"), py::arg("num_envs"), py::arg("batch_size_per_env"), + py::arg("batched"), py::arg("action_spec"), py::arg("timestep_spec"), py::arg("name")) diff --git a/alf/environments/parallel_environment.py b/alf/environments/parallel_environment.py index 1fae9122d..6d127b25d 100644 --- a/alf/environments/parallel_environment.py +++ b/alf/environments/parallel_environment.py @@ -97,6 +97,9 @@ def __init__(self, self._task_names = self._envs[0].task_names self._time_step_with_env_info_spec = self._time_step_spec._replace( env_info=self._env_info_spec) + if any(env.is_tensor_based for env in self._envs): + raise ValueError( + 'All environments must be array-based environments.') if any(env.action_spec() != self._action_spec for env in self._envs): raise ValueError( 'All environments must have the same action spec.') @@ -130,6 +133,10 @@ def start(self): env.wait_start() logging.info('All processes started.') + @property + def is_tensor_based(self): + return True + @property def batched(self): return True @@ -153,7 +160,7 @@ def call_envs_with_same_args(self, func_name, *args, **kwargs): func_name (str): name of the function to call *args: args to pass to the function **kwargs: kwargs to pass to the function - + return: list: list of results from each environment """ diff --git a/alf/environments/process_environment.py b/alf/environments/process_environment.py index 865bac56b..8f226c4a1 100644 --- a/alf/environments/process_environment.py +++ b/alf/environments/process_environment.py @@ -85,7 +85,7 @@ def _worker(conn: multiprocessing.connection, penv = _penv.ProcessEnvironment( env, partial(process_call, conn, env, flatten, action_spec), env_id, num_envs, env.batch_size, - env.action_spec(), + env.batched, env.action_spec(), env.time_step_spec()._replace(env_info=env.env_info_spec()), name) conn.send(_MessageType.READY) # Ready. diff --git a/alf/environments/random_alf_environment.py b/alf/environments/random_alf_environment.py index 300b2182c..b4961fd7c 100644 --- a/alf/environments/random_alf_environment.py +++ b/alf/environments/random_alf_environment.py @@ -124,6 +124,10 @@ def observation_spec(self): def action_spec(self): return self._action_spec + @property + def is_tensor_based(self): + return self._use_tensor_time_step + @property def batch_size(self): return self._batch_size if self.batched else 1 diff --git a/alf/environments/suite_carla.py b/alf/environments/suite_carla.py index e947b554b..2a7c3696c 100644 --- a/alf/environments/suite_carla.py +++ b/alf/environments/suite_carla.py @@ -1681,6 +1681,10 @@ def _clear(self): self._other_vehicles.clear() self._walkers.clear() + @property + def is_tensor_based(self): + return True + @property def batched(self): return True diff --git a/alf/environments/suite_go.py b/alf/environments/suite_go.py index 743a82380..bedd99989 100644 --- a/alf/environments/suite_go.py +++ b/alf/environments/suite_go.py @@ -574,6 +574,10 @@ def __init__(self, logging.info("P : pass") logging.info("SPACE : refresh display") + @property + def is_tensor_based(self): + return True + @property def batched(self): return True diff --git a/alf/environments/suite_metadrive.py b/alf/environments/suite_metadrive.py index 20e71b60a..5b2ecb4b6 100644 --- a/alf/environments/suite_metadrive.py +++ b/alf/environments/suite_metadrive.py @@ -104,6 +104,10 @@ def __init__(self, metadrive_env: metadrive.MetaDriveEnv, env_id: int = 0): self._current_observation = None + @property + def is_tensor_based(self): + return False + @property def batched(self): # TODO(breakds): Add support for multiple algorithm controlled agents in diff --git a/alf/environments/suite_tic_tac_toe.py b/alf/environments/suite_tic_tac_toe.py index f477a40f7..8cec72ecf 100644 --- a/alf/environments/suite_tic_tac_toe.py +++ b/alf/environments/suite_tic_tac_toe.py @@ -51,6 +51,10 @@ def __init__(self, batch_size): self._player_0 = torch.tensor(-1.) self._player_1 = torch.tensor(1.) + @property + def is_tensor_based(self): + return True + @property def batched(self): return True diff --git a/alf/environments/suite_unittest.py b/alf/environments/suite_unittest.py index 52d1213c1..4e921bf93 100644 --- a/alf/environments/suite_unittest.py +++ b/alf/environments/suite_unittest.py @@ -76,6 +76,10 @@ def _create_action_spec(act_type): self.reset() + @property + def is_tensor_based(self): + return True + @property def batched(self): return True diff --git a/alf/environments/thread_environment.py b/alf/environments/thread_environment.py index e9a44bdcf..edcaae733 100644 --- a/alf/environments/thread_environment.py +++ b/alf/environments/thread_environment.py @@ -48,8 +48,13 @@ def __init__(self, env_constructor): super().__init__() self._pool = mp_threads.Pool(1) self._env = self._pool.apply(env_constructor) + assert not self._env.batched self._closed = False + @property + def is_tensor_based(self): + return True + @property def batched(self): return True diff --git a/alf/environments/utils.py b/alf/environments/utils.py index 11f92108f..c0ae39a6b 100644 --- a/alf/environments/utils.py +++ b/alf/environments/utils.py @@ -13,6 +13,7 @@ # limitations under the License. import functools +import inspect import numpy as np import random @@ -56,6 +57,13 @@ def check_and_update(self, wrap_with_process): self.update(wrap_with_process) +def _get_wrapped_fn(fn): + """Get the function that is wrapped by ``functools.partial``""" + while isinstance(fn, functools.partial): + fn = fn.func + return fn + + def _env_constructor(env_load_fn, env_name, batch_size_per_env, seed, env_id): # We need to set random seed before env_load_fn because some environment # perform use random numbers in its constructor, so we need to randomize @@ -67,7 +75,8 @@ def _env_constructor(env_load_fn, env_name, batch_size_per_env, seed, env_id): # # NOTE: here it ASSUMES that the created batched environment will take the # following env IDs: env_id, env_id + 1, ... ,env_id + batch_size - 1 - if hasattr(env_load_fn, "batched") and env_load_fn.batched: + batched = getattr(_get_wrapped_fn(env_load_fn), 'batched', False) + if batched: return env_load_fn( env_name, env_id=env_id, batch_size=batch_size_per_env) if batch_size_per_env == 1: @@ -85,7 +94,8 @@ def create_environment(env_name='CartPole-v0', eval_env_load_fn=None, for_evaluation=False, num_parallel_environments=30, - batch_size_per_env=1, + batch_size_per_env=None, + eval_batch_size_per_env=None, nonparallel=False, flatten=True, start_serially=True, @@ -103,9 +113,13 @@ def create_environment(env_name='CartPole-v0', consists of the environments listed in ``env_name``. env_load_fn (Callable) : callable that create an environment If env_load_fn has attribute ``batched`` and it is True, - ``evn_load_fn(env_name, batch_size=num_parallel_environments)`` - will be used to create the batched environment. Otherwise, a - ``ParallAlfEnvironment`` will be created. + ``evn_load_fn(env_name, env_id=env_id, batch_size=batch_size_per_env)`` + will be used to create the batched environment. Otherwise, + ``env_load_fn(env_name, env_id)`` will be used to create the environment. + env_id is the index of the environment in the batch in the range of + ``[0, num_parallel_enviroments / batch_size_per_env)``. + And if "num_parallel_environments" is in the signature of ``env_load_fn``, + num_parallel_environments will be provided as a keyword argument. eval_env_load_fn (Callable) : callable that create an environment for evaluation. If None, use ``env_load_fn``. This argument is useful for cases when the evaluation environment is different from the @@ -115,7 +129,7 @@ def create_environment(env_name='CartPole-v0', will be used for creating the environment if provided. Otherwise, ``env_load_fn`` will be used. num_parallel_environments (int): num of parallel environments - batch_size_per_env (int): if >1, will create + batch_size_per_env (Optional[int]): if >1, will create ``num_parallel_environments/batch_size_per_env`` ``ProcessEnvironment``. Each of these ``ProcessEnvironment`` holds ``batch_size_per_env`` environments. If each underlying environment @@ -127,9 +141,15 @@ def create_environment(env_name='CartPole-v0', ``batch_size_per_env>1`` is to reduce the number of processes being used, or to take advantages of the batched nature of the underlying environment. + If None, it will be `num_parallel_envrironments` if ``env_load_fn`` + is batched and 1 otherwise. + eval_batch_size_per_env (int): if provided, it will be used as the + batch size for evaluation environment. Otherwise, use + ``batch_size_per_env``. num_spare_envs (int): num of spare parallel envs for speed up reset. nonparallel (bool): force to create a single env in the current process. Used for correctly exposing game gin confs to tensorboard. + If True, ``num_parallel_environments`` will be ignored and set to 1. start_serially (bool): start environments serially or in parallel. flatten (bool): whether to use flatten action and time_steps during communication to reduce overhead. @@ -154,7 +174,23 @@ def create_environment(env_name='CartPole-v0', if for_evaluation: # for creating an evaluation environment, use ``eval_env_load_fn`` if # provided and fall back to ``env_load_fn`` otherwise - env_load_fn = eval_env_load_fn if eval_env_load_fn else env_load_fn + env_load_fn = eval_env_load_fn or env_load_fn + batch_size_per_env = eval_batch_size_per_env or batch_size_per_env + + # env_load_fn may be a functools.partial, so we need to get the wrapped + # function to get its attributes + batched = getattr(_get_wrapped_fn(env_load_fn), 'batched', False) + no_thread_env = getattr( + _get_wrapped_fn(env_load_fn), 'no_thread_env', False) + + if nonparallel: + num_parallel_environments = 1 + + if batch_size_per_env is None: + if batched: + batch_size_per_env = num_parallel_environments + else: + batch_size_per_env assert num_parallel_environments % batch_size_per_env == 0, ( f"num_parallel_environments ({num_parallel_environments}) cannot be" @@ -163,20 +199,23 @@ def create_environment(env_name='CartPole-v0', if batch_size_per_env > 1: assert num_spare_envs == 0, "Do not support spare environments for batch_size_per_env > 1" assert parallel_environment_ctor == fast_parallel_environment.FastParallelEnvironment + + if 'num_parallel_environments' in inspect.signature( + env_load_fn).parameters: + env_load_fn = functools.partial( + env_load_fn, num_parallel_environments=num_parallel_environments) + if isinstance(env_name, (list, tuple)): env_load_fn = functools.partial(alf_wrappers.MultitaskWrapper.load, env_load_fn) - if hasattr(env_load_fn, - 'batched') and env_load_fn.batched and batch_size_per_env == 1: - if nonparallel: - alf_env = env_load_fn(env_name, batch_size=1) - else: - alf_env = env_load_fn( - env_name, batch_size=num_parallel_environments) + if batched and batch_size_per_env == num_parallel_environments: + alf_env = env_load_fn(env_name, batch_size=num_parallel_environments) + if not alf_env.is_tensor_based: + alf_env = alf_wrappers.TensorWrapper(alf_env) elif nonparallel: # Each time we can only create one unwrapped env at most - if getattr(env_load_fn, 'no_thread_env', False): + if no_thread_env: # In this case the environment is marked as "not compatible with # thread environment", and we will create it in the main thread. # BatchedTensorWrapper is applied to make sure the I/O is batched