Skip to content

Commit

Permalink
Add channel_first option in procgen (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
Trinkle23897 authored Jan 8, 2023
1 parent b3b2362 commit 92118d7
Show file tree
Hide file tree
Showing 10 changed files with 111 additions and 34 deletions.
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ docker-ci-launch: docker-ci
docker run --network=host -v /home/ubuntu:/home/github-action --shm-size=4gb -it $(PROJECT_NAME):$(DOCKER_TAG) bash

docker-dev: docker-ci
docker run --network=host -v /:/host --shm-size=4gb -it $(PROJECT_NAME):$(DOCKER_TAG) bash
docker run --network=host -v /:/host -v $(shell pwd):/app -v $(HOME)/.cache:/root/.cache --shm-size=4gb -it $(PROJECT_NAME):$(DOCKER_TAG) zsh

# for mainland China
docker-dev-cn:
docker build --network=host -t $(PROJECT_NAME):$(DOCKER_TAG) -f docker/dev-cn.dockerfile .
echo successfully build docker image with tag $(PROJECT_NAME):$(DOCKER_TAG)
docker run --network=host -v /:/host --shm-size=4gb -it $(PROJECT_NAME):$(DOCKER_TAG) bash
docker run --network=host -v /:/host -v $(shell pwd):/app -v $(HOME)/.cache:/root/.cache --shm-size=4gb -it $(PROJECT_NAME):$(DOCKER_TAG) zsh

docker-release:
docker build --network=host -t $(PROJECT_NAME)-release:$(DOCKER_TAG) -f docker/release.dockerfile .
Expand All @@ -177,7 +177,7 @@ docker-release-push: docker-release
docker push $(DOCKER_USER)/$(PROJECT_NAME)-release:$(DOCKER_TAG)

docker-release-launch: docker-release
docker run --network=host -v /:/host --shm-size=4gb -it $(PROJECT_NAME)-release:$(DOCKER_TAG) bash
docker run --network=host -v /:/host -v $(shell pwd):/app -v $(HOME)/.cache:/root/.cache --shm-size=4gb -it $(PROJECT_NAME)-release:$(DOCKER_TAG) zsh

pypi-wheel: auditwheel-install bazel-release
ls dist/*.whl -Art | tail -n 1 | xargs auditwheel repair --plat manylinux_2_24_x86_64
Expand Down
16 changes: 13 additions & 3 deletions docker/dev-cn.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,17 @@ ARG HOME=/root
ARG PATH=$PATH:$HOME/go/bin

RUN apt-get update \
&& apt-get install -y python3-pip python3-dev golang-1.18 clang-format-11 git wget swig tmux clang-tidy vim qtdeclarative5-dev \
&& apt-get install -y python3-pip python3-dev golang-1.18 git wget curl zsh tmux vim \
&& rm -rf /var/lib/apt/lists/*
RUN ln -s /usr/bin/python3 /usr/bin/python
RUN ln -sf /usr/lib/go-1.18/bin/go /usr/bin/go
RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)"
WORKDIR $HOME
RUN git clone https://github.com/gpakosz/.tmux.git
RUN ln -s -f .tmux/.tmux.conf
RUN cp .tmux/.tmux.conf.local .
RUN echo "set-option -g default-shell /bin/zsh" >> .tmux.conf.local
RUN echo "set-option -g history-limit 10000" >> .tmux.conf.local
RUN go env -w GOPROXY=https://goproxy.cn

RUN wget https://mirrors.huaweicloud.com/bazel/6.0.0/bazel-6.0.0-linux-x86_64
Expand All @@ -20,7 +27,10 @@ RUN mv bazel-6.0.0-linux-x86_64 $HOME/go/bin/bazel
RUN go install github.com/bazelbuild/buildtools/buildifier@latest
RUN $HOME/go/bin/bazel version

RUN useradd -ms /bin/bash github-action
RUN useradd -ms /bin/zsh github-action

RUN apt-get update \
&& apt-get install -y clang-format clang-tidy swig qtdeclarative5-dev \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /app
COPY . .
17 changes: 14 additions & 3 deletions docker/dev.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,27 @@ ARG HOME=/root
ARG PATH=$PATH:$HOME/go/bin

RUN apt-get update \
&& apt-get install -y python3-pip python3-dev golang-1.18 clang-format-11 git wget swig tmux clang-tidy vim qtdeclarative5-dev \
&& apt-get install -y python3-pip python3-dev golang-1.18 git wget curl zsh tmux vim \
&& rm -rf /var/lib/apt/lists/*
RUN ln -s /usr/bin/python3 /usr/bin/python
RUN ln -sf /usr/lib/go-1.18/bin/go /usr/bin/go
RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)"
WORKDIR $HOME
RUN git clone https://github.com/gpakosz/.tmux.git
RUN ln -s -f .tmux/.tmux.conf
RUN cp .tmux/.tmux.conf.local .
RUN echo "set-option -g default-shell /bin/zsh" >> .tmux.conf.local
RUN echo "set-option -g history-limit 10000" >> .tmux.conf.local
RUN echo "export PATH=$PATH:$HOME/go/bin" >> .zshrc

RUN go install github.com/bazelbuild/bazelisk@latest && ln -sf $HOME/go/bin/bazelisk $HOME/go/bin/bazel
RUN go install github.com/bazelbuild/buildtools/buildifier@latest
RUN $HOME/go/bin/bazel version

RUN useradd -ms /bin/bash github-action
RUN useradd -ms /bin/zsh github-action

RUN apt-get update \
&& apt-get install -y clang-format clang-tidy swig qtdeclarative5-dev \
&& rm -rf /var/lib/apt/lists/*

WORKDIR /app
COPY . .
30 changes: 22 additions & 8 deletions docker/release.dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@ ENV LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH

WORKDIR $HOME

# install base dependencies
# setup env

RUN apt-get update && apt-get install -y software-properties-common && add-apt-repository ppa:ubuntu-toolchain-r/test

RUN apt-get update \
&& apt-get install -y git curl wget gcc-9 g++-9 build-essential swig make \
zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev libncursesw5-dev libffi-dev liblzma-dev \
llvm xz-utils tk-dev libxml2-dev libxmlsec1-dev qtdeclarative5-dev \
&& apt-get install -y git curl wget zsh gcc-9 g++-9 build-essential make tmux \
&& rm -rf /var/lib/apt/lists/*
# use self-compiled openssl instead of system provided (1.0.2)
RUN apt-get remove -y libssl-dev

RUN curl https://pyenv.run | sh

RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-9 60 --slave /usr/bin/g++ g++ /usr/bin/g++-9
RUN sh -c "$(curl -fsSL https://raw.githubusercontent.com/ohmyzsh/ohmyzsh/master/tools/install.sh)"
WORKDIR $HOME
RUN git clone https://github.com/gpakosz/.tmux.git
RUN ln -s -f .tmux/.tmux.conf
RUN cp .tmux/.tmux.conf.local .
RUN echo "set-option -g default-shell /bin/zsh" >> .tmux.conf.local
RUN echo "set-option -g history-limit 10000" >> .tmux.conf.local

RUN curl https://pyenv.run | sh

# install go from source

Expand All @@ -35,6 +38,15 @@ RUN go install github.com/bazelbuild/bazelisk@latest && ln -sf $HOME/go/bin/baze

RUN bazel version

# install base dependencies

RUN apt-get update \
&& apt-get install -y swig zlib1g-dev libbz2-dev libreadline-dev libsqlite3-dev libncursesw5-dev libffi-dev liblzma-dev \
llvm xz-utils tk-dev libxml2-dev libxmlsec1-dev qtdeclarative5-dev \
&& rm -rf /var/lib/apt/lists/*
# use self-compiled openssl instead of system provided (1.0.2)
RUN apt-get remove -y libssl-dev

# install newest openssl (for py3.10 and py3.11)

RUN wget https://www.openssl.org/source/openssl-1.1.1s.tar.gz
Expand Down Expand Up @@ -63,3 +75,5 @@ COPY . .

RUN bazel build //envpool/utils:image_process_test --config=release
RUN bazel build //envpool/vizdoom/bin:vizdoom_bin --config=release

WORKDIR /app
5 changes: 4 additions & 1 deletion docs/env/procgen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Options
* ``seed (int)``: the environment seed, default to ``42``;
* ``max_episode_steps (int)``: the maximum number of steps for one episode,
each procgen game has different timeout value;
* ``channel_first (bool)``: whether to transpose the observation image to
``(3, 64, 64)``, default to ``True``;
* ``env_name (str)``: one of 16 procgen env name;
* ``num_levels (int)``: default to ``0``;
* ``start_level (int)``: default to ``0``;
Expand All @@ -41,7 +43,8 @@ original version of procgen (PRs for fixing this issue are highly welcome).
Observation Space
-----------------

The observation image size is ``(64, 64, 3)``.
The observation image shape is ``(3, 64, 64)`` when ``channel_first`` is
``True``, ``(64, 64, 3)`` when ``channel_first`` is ``False``.


Action Space
Expand Down
2 changes: 1 addition & 1 deletion envpool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
register,
)

__version__ = "0.8.0"
__version__ = "0.8.1"
__all__ = [
"register",
"make",
Expand Down
43 changes: 32 additions & 11 deletions envpool/procgen/procgen_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,26 @@ class ProcgenEnvFns {
public:
static decltype(auto) DefaultConfig() {
return MakeDict(
"env_name"_.Bind(std::string("bigfish")), "num_levels"_.Bind(0),
"start_level"_.Bind(0), "use_sequential_levels"_.Bind(false),
"center_agent"_.Bind(true), "use_backgrounds"_.Bind(true),
"use_monochrome_assets"_.Bind(false), "restrict_themes"_.Bind(false),
"use_generated_assets"_.Bind(false), "paint_vel_info"_.Bind(false),
"use_easy_jump"_.Bind(false), "distribution_mode"_.Bind(1));
"env_name"_.Bind(std::string("bigfish")), "channel_first"_.Bind(true),
"num_levels"_.Bind(0), "start_level"_.Bind(0),
"use_sequential_levels"_.Bind(false), "center_agent"_.Bind(true),
"use_backgrounds"_.Bind(true), "use_monochrome_assets"_.Bind(false),
"restrict_themes"_.Bind(false), "use_generated_assets"_.Bind(false),
"paint_vel_info"_.Bind(false), "use_easy_jump"_.Bind(false),
"distribution_mode"_.Bind(1));
}

template <typename Config>
static decltype(auto) StateSpec(const Config& conf) {
// The observation is RGB 64 x 64 x 3
return MakeDict("obs"_.Bind(Spec<uint8_t>({kRes, kRes, 3}, {0, 255})),
"info:prev_level_seed"_.Bind(Spec<int>({-1})),
"info:prev_level_complete"_.Bind(Spec<int>({-1})),
"info:level_seed"_.Bind(Spec<int>({-1})));
return MakeDict(
"obs"_.Bind(Spec<uint8_t>(conf["channel_first"_]
? std::vector<int>{3, kRes, kRes}
: std::vector<int>{kRes, kRes, 3},
{0, 255})),
"info:prev_level_seed"_.Bind(Spec<int>({-1})),
"info:prev_level_complete"_.Bind(Spec<int>({-1})),
"info:level_seed"_.Bind(Spec<int>({-1})));
}

template <typename Config>
Expand All @@ -91,6 +96,7 @@ class ProcgenEnv : public Env<ProcgenEnvSpec> {
protected:
std::shared_ptr<Game> game_;
std::string env_name_;
bool channel_first_;
// buffer used by game
FrameSpec obs_spec_;
Array obs_;
Expand All @@ -102,6 +108,7 @@ class ProcgenEnv : public Env<ProcgenEnvSpec> {
ProcgenEnv(const Spec& spec, int env_id)
: Env<ProcgenEnvSpec>(spec, env_id),
env_name_(spec.config["env_name"_]),
channel_first_(spec.config["channel_first"_]),
obs_spec_({kRes, kRes, 3}),
obs_(obs_spec_) {
/* Initialize the single game we are holding in this EnvPool environment
Expand Down Expand Up @@ -177,7 +184,21 @@ class ProcgenEnv : public Env<ProcgenEnvSpec> {
private:
void WriteObs() {
State state = Allocate();
state["obs"_].Assign(obs_);
if (channel_first_) {
// convert from HWC to CHW
auto* data = static_cast<uint8_t*>(state["obs"_].Data());
auto* buffer = static_cast<uint8_t*>(obs_.Data());
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < kRes; ++j) {
for (int k = 0; k < kRes; ++k) {
data[i * kRes * kRes + j * kRes + k] =
buffer[j * kRes * 3 + k * 3 + i];
}
}
}
} else {
state["obs"_].Assign(obs_);
}
state["reward"_] = reward_;
state["info:prev_level_seed"_] = prev_level_seed_;
state["info:prev_level_complete"_] = prev_level_complete_;
Expand Down
2 changes: 1 addition & 1 deletion envpool/procgen/procgen_env_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TEST(PRocgenEnvTest, BasicStep) {
auto state_vec = envpool.Recv();
ProcgenState state(&state_vec);
EXPECT_EQ(state["obs"_].Shape(),
std::vector<std::size_t>({batch, 64, 64, 3}));
std::vector<std::size_t>({batch, 3, 64, 64}));
uint8_t* data = static_cast<uint8_t*>(state["obs"_].Data());
int index = 0;
for (std::size_t j = 0; j < batch; ++j) {
Expand Down
22 changes: 20 additions & 2 deletions envpool/procgen/procgen_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for Procgen environments alignment & deterministic check."""
"""Unit tests for Procgen environments."""

# import cv2
import numpy as np
Expand Down Expand Up @@ -49,7 +49,7 @@ def test_deterministic(self) -> None:
def test_align(self) -> None:
task_id = "CoinrunHard-v0"
seed = 0
env = make_gym(task_id, seed=seed)
env = make_gym(task_id, seed=seed, channel_first=False)
env.action_space.seed(seed)
done = [False]
cnt = sum_reward = sum_obs = 0
Expand All @@ -71,6 +71,24 @@ def test_align(self) -> None:
pixel_mean = (sum_obs / cnt).mean(axis=0).mean(axis=0) # type: ignore
np.testing.assert_allclose(pixel_mean, pixel_mean_ref)

def test_channel_first(
self,
task_id: str = "CoinrunHard-v0",
seed: int = 0,
total: int = 1000,
) -> None:
env1 = make_gym(task_id, seed=seed, channel_first=True)
env2 = make_gym(task_id, seed=seed, channel_first=False)
self.assertEqual(env1.observation_space.shape, (3, 64, 64))
self.assertEqual(env2.observation_space.shape, (64, 64, 3))
for _ in range(total):
act = env1.action_space.sample()
obs1 = env1.step(np.array([act]))[0][0]
obs2 = env2.step(np.array([act]))[0][0]
self.assertEqual(obs1.shape, (3, 64, 64))
self.assertEqual(obs2.shape, (64, 64, 3))
np.testing.assert_allclose(obs1, obs2.transpose(2, 0, 1))


if __name__ == "__main__":
absltest.main()
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = envpool
version = 0.8.0
version = 0.8.1
author = "EnvPool Contributors"
author_email = "sail@sea.com"
description = "C++-based high-performance parallel environment execution engine (vectorized env) for general RL environments."
Expand Down

0 comments on commit 92118d7

Please sign in to comment.