Skip to content

Commit

Permalink
Fix import bug issue with jax.stages.ArgInfo in past JAX versions.
Browse files Browse the repository at this point in the history
  • Loading branch information
balancap committed Jul 3, 2024
1 parent 1db8b33 commit 3254881
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ jobs:
access_token: ${{ github.token }}
if: ${{github.ref != 'refs/head/main'}}
- uses: actions/checkout@v3
- name: Set up Python 3.8
- name: Set up Python 3.10
uses: actions/setup-python@v4
with:
python-version: 3.8
python-version: 3.10
- uses: pre-commit/action@v3.0.0

unit_tests:
Expand Down
2 changes: 1 addition & 1 deletion jax_scalify/core/datatype.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def as_scaled_array_base(
return ScaledArray(val / scale.astype(val.dtype), scale) # type:ignore

# TODO: fix bug when scale is not 1.
raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.")
raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.") # type:ignore
# return scaled_array_base(val, scale)


Expand Down
13 changes: 11 additions & 2 deletions jax_scalify/core/typing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any
from typing import Any, Tuple

# import chex
import jax
Expand All @@ -10,12 +10,21 @@
try:
from jax import Array

ArrayTypes = (Array, jax.stages.ArgInfo)
ArrayTypes: Tuple[Any, ...] = (Array,)
except ImportError:
from jaxlib.xla_extension import DeviceArray as Array

# Older version of JAX <0.4
ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer)

try:
from jax.stages import ArgInfo

# Additional ArgInfo in recent JAX versions.
ArrayTypes = (*ArrayTypes, ArgInfo)
except ImportError:
pass


def get_numpy_api(val: Any) -> Any:
"""Get the Numpy API corresponding to an array.
Expand Down

0 comments on commit 3254881

Please sign in to comment.