From 8d76656138cf1b2fbfba70808cbb8a0a753d79df Mon Sep 17 00:00:00 2001 From: Danila Sinopalnikov Date: Tue, 16 Jul 2024 04:37:10 -0700 Subject: [PATCH] Fix the check for jax.Array. PiperOrigin-RevId: 652787597 Change-Id: I7aff382d61475c35cc24b6bbc42d62b10ccebe76 --- acme/utils/loggers/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/acme/utils/loggers/base.py b/acme/utils/loggers/base.py index 1517a1c27d..071ce0148f 100644 --- a/acme/utils/loggers/base.py +++ b/acme/utils/loggers/base.py @@ -17,6 +17,7 @@ import abc from typing import Any, Mapping, Optional +import jax import numpy as np import tree from typing_extensions import Protocol @@ -68,8 +69,8 @@ def close(self): def tensor_to_numpy(value: Any): if hasattr(value, 'numpy'): return value.numpy() # tf.Tensor (TF2). - if hasattr(value, 'device_buffer'): - return np.asarray(value) # jnp.DeviceArray. + if isinstance(value, jax.Array): + return np.asarray(value) return value