Skip to content

Commit

Permalink
Fix unittests for DataBuffer
Browse files Browse the repository at this point in the history
  • Loading branch information
emailweixu committed Dec 17, 2024
1 parent 230c47f commit 9759814
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions alf/utils/data_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self,
self._enqueued = None

buffer_id = [0]
self._env_ids = torch.arange(self._num_envs, device=device)

def _create_buffer(spec_path, tensor_spec):
buf = tensor_spec.zeros((num_environments, max_length))
Expand Down Expand Up @@ -246,32 +247,38 @@ def _enqueue(self, batch, env_ids=None):
"batch and env_ids do not have same length %s vs. %s" %
(batch_size, env_ids.shape[0]))

# Make sure that there is no duplicate in `env_id`
# torch.unique(env_ids, return_counts=True)[1] is the counts for each unique item
assert torch.unique(
env_ids, return_counts=True)[1].max() == 1, (
"There are duplicated ids in env_ids %s" % env_ids)

current_pos = self._current_pos[env_ids]
if (current_pos == current_pos[0]).all() and (
env_ids == torch.arange(self._num_envs)).all():
pos = self.circular(current_pos[0].item())
current_pos = self._current_pos[0].item()
if (env_ids.numel() == self._num_envs
and (self._current_pos == current_pos).all()
and (env_ids == self._env_ids).all()):
# fast path for the common case
pos = self.circular(current_pos)

def _set(buf, bat):
buf[:, pos] = bat

alf.nest.map_structure(_set, self._buffer, batch)
self._current_pos.fill_(current_pos + 1)
self._current_size.fill_(
min(current_pos + 1, self._max_length))
else:
# Make sure that there is no duplicate in `env_id`
# torch.unique(env_ids, return_counts=True)[1] is the counts for each unique item
assert torch.unique(
env_ids, return_counts=True)[1].max() == 1, (
"There are duplicated ids in env_ids %s" % env_ids)

current_pos = self._current_pos[env_ids]
indices = env_ids * self._max_length + self.circular(
current_pos)
alf.nest.map_structure(
lambda buf, bat: buf.__setitem__(indices, bat.detach()),
self._flattened_buffer, batch)

self._current_pos[env_ids] += 1
current_size = self._current_size[env_ids]
self._current_size[env_ids] = torch.clamp(
current_size + 1, max=self._max_length)
self._current_pos[env_ids] += 1
current_size = self._current_size[env_ids]
self._current_size[env_ids] = torch.clamp(
current_size + 1, max=self._max_length)
# set flags if they exist to unblock potential consumers
if self._enqueued:
self._enqueued.set()
Expand Down

0 comments on commit 9759814

Please sign in to comment.