Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] consolidate TDs in ParallelEnv without buffers #2231

Merged
merged 4 commits into from
Jun 28, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1434,8 +1434,13 @@ def _step_and_maybe_reset_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:

for i, _data in enumerate(tensordict.unbind(0)):
self.parent_channels[i].send(("step_and_maybe_reset", _data))
td = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
for i in range(td.shape[0]):
# We send the same td multiple times as it is in shared mem and we just need to index it
# in each process.
# If we don't do this, we need to unbind it but then the custom pickler will require
# some extra metadata to be collected.
self.parent_channels[i].send(("step_and_maybe_reset", (td, i)))

results = [None] * self.num_workers

Expand Down Expand Up @@ -1556,8 +1561,11 @@ def step_and_maybe_reset(
def _step_no_buffers(
self, tensordict: TensorDictBase
) -> Tuple[TensorDictBase, TensorDictBase]:
for i, data in enumerate(tensordict.unbind(0)):
self.parent_channels[i].send(("step", data))
data = tensordict.consolidate(share_memory=True, inplace=True, num_threads=1)
for i, local_data in enumerate(data.unbind(0)):
self.parent_channels[i].send(("step", local_data))
# for i in range(data.shape[0]):
# self.parent_channels[i].send(("step", (data, i)))
out_tds = []
for i, channel in enumerate(self.parent_channels):
self._events[i].wait()
Expand Down Expand Up @@ -1663,17 +1671,24 @@ def _reset_no_buffers(
reset_kwargs_list,
needs_resetting,
) -> Tuple[TensorDictBase, TensorDictBase]:
tdunbound = (
tensordict.unbind(0)
if is_tensor_collection(tensordict)
else [None] * self.num_workers
)
if is_tensor_collection(tensordict):
# tensordict = tensordict.consolidate(share_memory=True, num_threads=1)
tensordict = tensordict.consolidate(
share_memory=True, num_threads=1
).unbind(0)
else:
tensordict = [None] * self.num_workers
out_tds = [None] * self.num_workers
for i, (data, reset_kwargs) in enumerate(zip(tdunbound, reset_kwargs_list)):
for i, (local_data, reset_kwargs) in enumerate(
zip(tensordict, reset_kwargs_list)
):
if not needs_resetting[i]:
out_tds[i] = tdunbound[i].exclude(*self.reset_keys)
localtd = local_data
if localtd is not None:
localtd = localtd.exclude(*self.reset_keys)
out_tds[i] = localtd
continue
self.parent_channels[i].send(("reset", (data, reset_kwargs)))
self.parent_channels[i].send(("reset", (local_data, reset_kwargs)))

for i, channel in enumerate(self.parent_channels):
if not needs_resetting[i]:
Expand Down Expand Up @@ -1995,10 +2010,11 @@ def look_for_cuda(tensor, has_cuda=has_cuda):
i = 0
next_shared_tensordict = shared_tensordict.get("next")
root_shared_tensordict = shared_tensordict.exclude("next")
if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
raise RuntimeError(
"tensordict must be placed in shared memory (share_memory_() or memmap_())"
)
# TODO: restore this
# if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()):
# raise RuntimeError(
# "tensordict must be placed in shared memory (share_memory_() or memmap_())"
# )
shared_tensordict = shared_tensordict.clone(False).unlock_()

initialized = True
Expand Down Expand Up @@ -2243,6 +2259,8 @@ def _run_worker_pipe_direct(
raise RuntimeError("call 'init' before resetting")
# we use 'data' to pass the keys that we need to pass to reset,
# because passing the entire buffer may have unwanted consequences
# data, idx, reset_kwargs = data
# data = data[idx]
data, reset_kwargs = data
if data is not None:
data._fast_apply(
Expand All @@ -2256,25 +2274,33 @@ def _run_worker_pipe_direct(
event.record()
event.synchronize()
mp_event.set()
child_pipe.send(cur_td)
child_pipe.send(
cur_td.consolidate(share_memory=True, inplace=True, num_threads=1)
)
del cur_td

elif cmd == "step":
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# data, idx = data
# data = data[idx]
next_td = env._step(data)
if event is not None:
event.record()
event.synchronize()
mp_event.set()
child_pipe.send(next_td)
child_pipe.send(
next_td.consolidate(share_memory=True, inplace=True, num_threads=1)
)
del next_td

elif cmd == "step_and_maybe_reset":
if not initialized:
raise RuntimeError("called 'init' before step")
i += 1
# data, idx = data
# data = data[idx]
data._fast_apply(
lambda x: x.clone() if x.device.type == "cuda" else x, out=data
)
Expand Down
Loading