Skip to content

Commit

Permalink
Merge pull request #257 from huangshiyu13/main
Browse files Browse the repository at this point in the history
update _worker test
  • Loading branch information
huangshiyu13 authored Oct 23, 2023
2 parents f96e77f + 8889302 commit 7e70569
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
6 changes: 4 additions & 2 deletions openrl/envs/vec_env/async_venv.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,7 +734,7 @@ def _worker(
index: int,
env_fn: callable,
pipe: Connection,
parent_pipe: Connection,
parent_pipe: Optional[Connection],
shared_memory: bool,
error_queue: Queue,
auto_reset: bool = True,
Expand All @@ -758,10 +758,12 @@ def prepare_obs(observation):
observation = None
return observation

parent_pipe.close()
if parent_pipe is not None:
parent_pipe.close()
try:
while True:
command, data = pipe.recv()
print(command)

if command == "reset":
result = env.reset(**data)
Expand Down
51 changes: 50 additions & 1 deletion tests/test_env/test_vec_env/test_async_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

""""""

import multiprocessing as mp
import os
import sys

import pytest
from gymnasium.wrappers import EnvCompatibility

from openrl.envs.toy_envs import make_toy_envs
from openrl.envs.vec_env.async_venv import AsyncVectorEnv
from openrl.envs.vec_env.async_venv import AsyncVectorEnv, _worker


class CustomEnvCompatibility(EnvCompatibility):
Expand Down Expand Up @@ -62,5 +63,53 @@ def test_async_env():
env.exec_func(assert_env_name, indices=None, env_name=env_name_new)


def main_control(parent_pipe, child_pipe):
child_pipe.close()

parent_pipe.send(("reset", {"seed": 0}))
result, success = parent_pipe.recv()
assert success, result

parent_pipe.send(("step", [0]))
result, success = parent_pipe.recv()
assert success, result

parent_pipe.send(("_call", ("render", [], {})))
result, success = parent_pipe.recv()
assert success, result

parent_pipe.send(("_setattr", ("metadata", {"name": "IdentityEnvNew"})))
result, success = parent_pipe.recv()
assert success, result

parent_pipe.send(
("_func_exec", (assert_env_name, None, [], {"env_name": "IdentityEnvNew"}))
)
result, success = parent_pipe.recv()
assert success, result

parent_pipe.send(("close", None))
result, success = parent_pipe.recv()
assert success, result


@pytest.mark.unittest
def test_worker():
for auto_reset in [True, False]:
ctx = mp.get_context(None)
parent_pipe, child_pipe = ctx.Pipe()

error_queue = ctx.Queue()

process = ctx.Process(
target=main_control,
name="test",
args=(parent_pipe, child_pipe),
)
process.daemon = True
process.start()
_worker(0, init_envs()[0], child_pipe, None, False, error_queue, auto_reset)


if __name__ == "__main__":
sys.exit(pytest.main(["-sv", os.path.basename(__file__)]))

0 comments on commit 7e70569

Please sign in to comment.