Skip to content

Commit

Permalink
TensorWrapper for converting numpy based batched alf environment to T…
Browse files Browse the repository at this point in the history
…ensor based alf environment (#1535)

Sometimes, the batched environment is written in numpy so that it can be
used with parallel environment, which interfaces training using Tensor.
However, play does not use parallel environment, but it still requires
Tensor, so we need to have a mechanism to convert the numpy environment
to Tensor environment.
  • Loading branch information
emailweixu authored Sep 14, 2023
1 parent 806336e commit dd8e338
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions alf/environments/alf_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,32 @@ def _reset(self):
return BatchedTensorWrapper._to_batched_tensor(super()._reset())


class TensorWrapper(AlfEnvironmentBaseWrapper):
"""Wrapper that converts numpy-based I/O to tensors.
"""

def __init__(self, env):
assert env.batched, (
'TensorWrapper can only be used to wrap batched env')
super().__init__(env)

@staticmethod
def _to_tensor(raw):
"""Conver the structured input into batched (batch_size = 1) tensors
of the same structure.
"""
return nest.map_structure(
lambda x: (torch.as_tensor(x) if isinstance(
x, (np.ndarray, np.number, float, int)) else x), raw)

def _step(self, action):
numpy_action = nest.map_structure(lambda x: x.cpu().numpy(), action)
return TensorWrapper._to_tensor(super()._step(numpy_action))

def _reset(self):
return TensorWrapper._to_tensor(super()._reset())


@alf.configurable
class DiscreteActionWrapper(AlfEnvironmentBaseWrapper):
"""Discretize each continuous action dim into several evenly distributed
Expand Down

0 comments on commit dd8e338

Please sign in to comment.