diff --git a/acme/utils/loggers/base.py b/acme/utils/loggers/base.py index 1517a1c27d..071ce0148f 100644 --- a/acme/utils/loggers/base.py +++ b/acme/utils/loggers/base.py @@ -17,6 +17,7 @@ import abc from typing import Any, Mapping, Optional +import jax import numpy as np import tree from typing_extensions import Protocol @@ -68,8 +69,8 @@ def close(self): def tensor_to_numpy(value: Any): if hasattr(value, 'numpy'): return value.numpy() # tf.Tensor (TF2). - if hasattr(value, 'device_buffer'): - return np.asarray(value) # jnp.DeviceArray. + if isinstance(value, jax.Array): + return np.asarray(value) return value