Skip to content

Commit

Permalink
improve gym speedup (#210)
Browse files Browse the repository at this point in the history
## Description

Improved speedup in gym by: 
 - added `max_episode_steps` into the base config in c++
- moved `max_episode_steps` from c++ default config per env into
`registration.py`
 - moved `trunc` computation logics from python to c++

This PR also includes a big refactor of calling test env. We use
`make_dm` / `make_gym` instead of
`envpool_cls(spec_cls(spec_cls.gen_config(...)))`.

## Motivation and Context

The computation of `trunc` in `gym_envpool.py` is shown as a bottleneck
by `py-spy`.

After fix, envpool's speedup in gym improved from ~1.0x to ~1.3x (using
make_gym instead of make_dm for the test below):

```
Namespace(domain='cheetah', seed=0, task='run', total_step=200000)
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 200000/200000 [00:17<00:00, 11301.86it/s]
FPS(dmc) = 11300.15
100%|███████████████████████████████████████████████████████████████████████████████████████████████| 200000/200000 [00:13<00:00, 14408.97it/s]
FPS(envpool) = 14408.43
EnvPool Speedup: 1.28x
```

Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
  • Loading branch information
wangsiping97 and Trinkle23897 authored Nov 1, 2022
1 parent e384d09 commit 29d8412
Show file tree
Hide file tree
Showing 74 changed files with 597 additions and 1,012 deletions.
199 changes: 100 additions & 99 deletions docs/content/new_env.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,7 @@ state space, and action space. Create a class ``CartPoleEnvFns``:
class CartPoleEnvFns {
public:
static decltype(auto) DefaultConfig() {
return MakeDict("max_episode_steps"_.Bind(200),
"reward_threshold"_.Bind(195.0));
return MakeDict("reward_threshold"_.Bind(195.0));
}
template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
Expand Down Expand Up @@ -114,12 +113,12 @@ available to see on the python side:
>>> import envpool
>>> spec = envpool.make_spec("CartPole-v0")
>>> spec
CartPoleEnvSpec(num_envs=1, batch_size=0, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, max_episode_steps=200, reward_threshold=195.0)
CartPoleEnvSpec(num_envs=1, batch_size=1, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, gym_reset_return_info=False, max_episode_steps=200, reward_threshold=195.0)

>>> # if we change a config value
>>> env = envpool.make_gym("CartPole-v0", reward_threshold=666)
>>> env
CartPoleGymEnvPool(num_envs=1, batch_size=0, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, max_episode_steps=200, reward_threshold=666.0)
CartPoleGymEnvPool(num_envs=1, batch_size=1, num_threads=0, max_num_players=1, thread_affinity_offset=-1, base_path='envpool', seed=42, gym_reset_return_info=True, max_episode_steps=200, reward_threshold=666.0)

>>> # observation space and action space
>>> env.observation_space
Expand Down Expand Up @@ -421,7 +420,8 @@ Miscellaneous
.. code-block:: c++
#ifdef ENVPOOL_TEST
fprintf(stderr, "here");
fprintf(stderr, "here\n");
LOG(INFO) << "another error log print method.";
#endif
Expand Down Expand Up @@ -472,6 +472,48 @@ instantiate ``CartPoleEnvSpec``, ``CartPoleDMEnvPool``, and
]


Register CartPole-v0/1 in EnvPool
---------------------------------

To register a task in EnvPool, you need to call ``register`` function in
``envpool.registration``. Here is ``registration.py``:
::

from envpool.registration import register

register(
task_id="CartPole-v0",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=200,
reward_threshold=195.0,
)

register(
task_id="CartPole-v1",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=500,
reward_threshold=475.0,
)

``task_id``, ``import_path``, ``spec_cls``, ``dm_cls``, and ``gym_cls`` are
required arguments. Other arguments such as ``max_episode_steps`` and
``reward_threshold`` are env-specific. For example, if someone use
``envpool.make("CartPole-v1")``, the ``reward_threshold`` will be set to 475.0
at ``CartPoleEnvPool`` initialization.

Finally, it is crucial to let the top-level module import this file. In
``envpool/entry.py``, add the following line:
::

import envpool.classic_control.registration # noqa: F401


Write Bazel BUILD File
----------------------

Expand Down Expand Up @@ -556,25 +598,25 @@ Let's first take a look at ``BUILD`` file in ``classic_control``:
deps = ["//envpool/python:api"],
)

py_library(
name = "classic_control_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)

py_test(
name = "classic_control_test",
srcs = ["classic_control_test.py"],
deps = [
":classic_control",
":classic_control_registration",
requirement("numpy"),
requirement("absl-py"),
],
)

py_library(
name = "classic_control_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)


We have several ways for dependency declaration:

1. use relative path: ``:cartpole`` points to first item (cartpole cc_library);
Expand All @@ -584,6 +626,50 @@ We have several ways for dependency declaration:
runtime dependencies;
4. third-party dependency (not shown above): will explain in the next section.

And don't forget to modify the top-level Bazel BUILD dependency:

.. code-block:: diff
py_library(
name = "entry",
srcs = ["entry.py"],
deps = [
"//envpool/atari:atari_registration",
+ "//envpool/classic_control:classic_control_registration",
],
)
py_library(
name = "envpool",
srcs = ["__init__.py"],
deps = [
":entry",
":registration",
"//envpool/atari",
+ "//envpool/classic_control",
"//envpool/python",
],
)
Also, pay attention to check if ``.so`` file is packed into ``.whl``
successfully. In ``setup.cfg``:

.. code-block:: diff
[options.package_data]
envpool = atari/*.so
atari/roms/*.bin
+ classic_control/*.so
Now you can run ``envpool.make("CartPole-v0")`` by re-installing EnvPool:

.. code-block:: bash
# generate .whl file
make bazel-build
# install .whl
pip install dist/envpool-<version>-*.whl
Testing
~~~~~~~
Expand Down Expand Up @@ -701,91 +787,6 @@ documentation
`Atari BUILD example <https://github.com/sail-sg/envpool/blob/v0.6.1.post1/envpool/atari/BUILD>`_.


Register CartPole-v0/1 in EnvPool
---------------------------------

To register a task in EnvPool, you need to call ``register`` function in
``envpool.registration``. Here is ``registration.py``:
::

from envpool.registration import register

register(
task_id="CartPole-v0",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=200,
reward_threshold=195.0,
)

register(
task_id="CartPole-v1",
import_path="envpool.classic_control",
spec_cls="CartPoleEnvSpec",
dm_cls="CartPoleDMEnvPool",
gym_cls="CartPoleGymEnvPool",
max_episode_steps=500,
reward_threshold=475.0,
)

``task_id``, ``import_path``, ``spec_cls``, ``dm_cls``, and ``gym_cls`` are
required arguments. Other arguments such as ``max_episode_steps`` and
``reward_threshold`` are env-specific. For example, if someone use
``envpool.make("CartPole-v1")``, the ``reward_threshold`` will be set to 475.0
at ``CartPoleEnvPool`` initialization.

Finally, it is crucial to let the top-level module import this file. In
``envpool/entry.py``, add the following line:
::

import envpool.classic_control.registration

And don't forget to modify the Bazel BUILD dependency:

.. code-block:: diff
py_library(
name = "entry",
srcs = ["entry.py"],
deps = [
"//envpool/atari:atari_registration",
+ "//envpool/classic_control:classic_control_registration",
],
)
py_library(
name = "envpool",
srcs = ["__init__.py"],
deps = [
":entry",
":registration",
"//envpool/atari",
+ "//envpool/classic_control",
"//envpool/python",
],
)
Also, pay attention to check if ``.so`` file is packed into ``.whl``
successfully. In ``setup.cfg``:

.. code-block:: diff
[options.package_data]
envpool = atari/*.so
atari/roms/*.bin
+ classic_control/*.so
Now you can run ``envpool.make("CartPole-v0")`` by re-installing EnvPool:

.. code-block:: bash
# generate .whl file
make bazel-build
# install .whl
pip install dist/envpool-<version>-*.whl

Add Unit Test for CartPoleEnv
-----------------------------
Expand Down
4 changes: 3 additions & 1 deletion docs/content/xla_interface.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,10 @@ We can now write the actor loop as:
def actor_step(iter, loop_var):
handle0, states = loop_var
action = policy(states)
# for gym
# for gym < 0.26
handle1, (new_states, rew, done, info) = step(handle0, action)
# for gym >= 0.26
# handle1, (new_states, rew, term, trunc, info) = step(handle0, action)
# for dm
# handle1, new_states = step(handle0, action)
return (handle1, new_states)
Expand Down
19 changes: 11 additions & 8 deletions envpool/atari/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,20 @@ pybind_extension(
],
)

py_library(
name = "atari_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)

py_test(
name = "api_test",
srcs = ["api_test.py"],
deps = [
":atari",
":atari_registration",
"//envpool/python",
requirement("numpy"),
],
Expand All @@ -91,6 +100,7 @@ py_test(
srcs = ["atari_envpool_test.py"],
deps = [
":atari",
":atari_registration",
requirement("numpy"),
requirement("jax"),
requirement("dm-env"),
Expand All @@ -109,18 +119,11 @@ py_test(
deps = [
":atari",
":atari_network",
":atari_registration",
requirement("numpy"),
requirement("absl-py"),
requirement("tianshou"),
requirement("tqdm"),
requirement("torch"),
],
)

py_library(
name = "atari_registration",
srcs = ["registration.py"],
deps = [
"//envpool:registration",
],
)
Loading

0 comments on commit 29d8412

Please sign in to comment.