diff --git a/alf/algorithms/rl_algorithm_test.py b/alf/algorithms/rl_algorithm_test.py index b690098c4..96c1088ca 100644 --- a/alf/algorithms/rl_algorithm_test.py +++ b/alf/algorithms/rl_algorithm_test.py @@ -102,6 +102,10 @@ def __init__(self, batch_size, obs_shape=(2, ), reward_dim=1): shape=(), dtype='int64', minimum=0, maximum=2) self.reset() + @property + def is_tensor_based(self): + return True + def observation_spec(self): return self._observation_spec