Skip to content

Commit

Permalink
Fix import bug issue with jax.stages.ArgInfo in past JAX versions.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jul 3, 2024
1 parent 1db8b33 commit a2f8760
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
2 changes: 1 addition & 1 deletion jax_scalify/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def as_scaled_array_base(
return ScaledArray(val / scale.astype(val.dtype), scale) # type:ignore

# TODO: fix bug when scale is not 1.
raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.")
raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.") # type:ignore
# return scaled_array_base(val, scale)


Expand Down
13 changes: 11 additions & 2 deletions jax_scalify/core/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any
from typing import Any, Tuple

# import chex
import jax
Expand All @@ -10,12 +10,21 @@
try:
from jax import Array

ArrayTypes = (Array, jax.stages.ArgInfo)
ArrayTypes: Tuple[Any, ...] = (Array,)
except ImportError:
from jaxlib.xla_extension import DeviceArray as Array

# Older version of JAX <0.4
ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer)

try:
from jax.stages import ArgInfo

# Additional ArgInfo in recent JAX versions.
ArrayTypes = (*ArrayTypes, ArgInfo)
except ImportError:
pass


def get_numpy_api(val: Any) -> Any:
"""Get the Numpy API corresponding to an array.
Expand Down

0 comments on commit a2f8760

Please sign in to comment.