From 97598144544d1364aa5c20b2f33682372779bea0 Mon Sep 17 00:00:00 2001 From: Wei Xu Date: Mon, 16 Dec 2024 16:01:56 -0800 Subject: [PATCH] Fix unittests for DataBuffer --- alf/utils/data_buffer.py | 35 +++++++++++++++++++++-------------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/alf/utils/data_buffer.py b/alf/utils/data_buffer.py index de3746fa7..4cdeaebf3 100644 --- a/alf/utils/data_buffer.py +++ b/alf/utils/data_buffer.py @@ -119,6 +119,7 @@ def __init__(self, self._enqueued = None buffer_id = [0] + self._env_ids = torch.arange(self._num_envs, device=device) def _create_buffer(spec_path, tensor_spec): buf = tensor_spec.zeros((num_environments, max_length)) @@ -246,32 +247,38 @@ def _enqueue(self, batch, env_ids=None): "batch and env_ids do not have same length %s vs. %s" % (batch_size, env_ids.shape[0])) - # Make sure that there is no duplicate in `env_id` - # torch.unique(env_ids, return_counts=True)[1] is the counts for each unique item - assert torch.unique( - env_ids, return_counts=True)[1].max() == 1, ( - "There are duplicated ids in env_ids %s" % env_ids) - - current_pos = self._current_pos[env_ids] - if (current_pos == current_pos[0]).all() and ( - env_ids == torch.arange(self._num_envs)).all(): - pos = self.circular(current_pos[0].item()) + current_pos = self._current_pos[0].item() + if (env_ids.numel() == self._num_envs + and (self._current_pos == current_pos).all() + and (env_ids == self._env_ids).all()): + # fast path for the common case + pos = self.circular(current_pos) def _set(buf, bat): buf[:, pos] = bat alf.nest.map_structure(_set, self._buffer, batch) + self._current_pos.fill_(current_pos + 1) + self._current_size.fill_( + min(current_pos + 1, self._max_length)) else: + # Make sure that there is no duplicate in `env_id` + # torch.unique(env_ids, return_counts=True)[1] is the counts for each unique item + assert torch.unique( + env_ids, return_counts=True)[1].max() == 1, ( + "There are duplicated ids in env_ids %s" % env_ids) + + current_pos = self._current_pos[env_ids] indices = env_ids * self._max_length + self.circular( current_pos) alf.nest.map_structure( lambda buf, bat: buf.__setitem__(indices, bat.detach()), self._flattened_buffer, batch) - self._current_pos[env_ids] += 1 - current_size = self._current_size[env_ids] - self._current_size[env_ids] = torch.clamp( - current_size + 1, max=self._max_length) + self._current_pos[env_ids] += 1 + current_size = self._current_size[env_ids] + self._current_size[env_ids] = torch.clamp( + current_size + 1, max=self._max_length) # set flags if they exist to unblock potential consumers if self._enqueued: self._enqueued.set()