Skip to content

Commit

Permalink
Fix the check for jax.Array.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 652787597
Change-Id: I7aff382d61475c35cc24b6bbc42d62b10ccebe76
  • Loading branch information
sinopalnikov authored and copybara-github committed Jul 16, 2024
1 parent bea6d6b commit 8d76656
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions acme/utils/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import abc
from typing import Any, Mapping, Optional

import jax
import numpy as np
import tree
from typing_extensions import Protocol
Expand Down Expand Up @@ -68,8 +69,8 @@ def close(self):
def tensor_to_numpy(value: Any):
if hasattr(value, 'numpy'):
return value.numpy() # tf.Tensor (TF2).
if hasattr(value, 'device_buffer'):
return np.asarray(value) # jnp.DeviceArray.
if isinstance(value, jax.Array):
return np.asarray(value)
return value


Expand Down

0 comments on commit 8d76656

Please sign in to comment.