diff --git a/jax_scalify/core/datatype.py b/jax_scalify/core/datatype.py index f61f5c6..062438d 100644 --- a/jax_scalify/core/datatype.py +++ b/jax_scalify/core/datatype.py @@ -204,7 +204,7 @@ def as_scaled_array_base( return ScaledArray(val / scale.astype(val.dtype), scale) # type:ignore # TODO: fix bug when scale is not 1. - raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.") + raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.") # type:ignore # return scaled_array_base(val, scale) diff --git a/jax_scalify/core/typing.py b/jax_scalify/core/typing.py index e26cf95..8291577 100644 --- a/jax_scalify/core/typing.py +++ b/jax_scalify/core/typing.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from typing import Any +from typing import Any, Tuple # import chex import jax @@ -10,12 +10,21 @@ try: from jax import Array - ArrayTypes = (Array, jax.stages.ArgInfo) + ArrayTypes: Tuple[Any, ...] = (Array,) except ImportError: from jaxlib.xla_extension import DeviceArray as Array + # Older version of JAX <0.4 ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer) +try: + from jax.stages import ArgInfo + + # Additional ArgInfo in recent JAX versions. + ArrayTypes = (*ArrayTypes, ArgInfo) +except ImportError: + pass + def get_numpy_api(val: Any) -> Any: """Get the Numpy API corresponding to an array.