From 325488179b3faa35fa3fe786ec046a0a2e332867 Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Wed, 3 Jul 2024 10:09:53 +0100 Subject: [PATCH] Fix import bug issue with `jax.stages.ArgInfo` in past JAX versions. --- .github/workflows/tests.yaml | 4 ++-- jax_scalify/core/datatype.py | 2 +- jax_scalify/core/typing.py | 13 +++++++++++-- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 464a072..13e1422 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -24,10 +24,10 @@ jobs: access_token: ${{ github.token }} if: ${{github.ref != 'refs/head/main'}} - uses: actions/checkout@v3 - - name: Set up Python 3.8 + - name: Set up Python 3.10 uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: 3.10 - uses: pre-commit/action@v3.0.0 unit_tests: diff --git a/jax_scalify/core/datatype.py b/jax_scalify/core/datatype.py index f61f5c6..062438d 100644 --- a/jax_scalify/core/datatype.py +++ b/jax_scalify/core/datatype.py @@ -204,7 +204,7 @@ def as_scaled_array_base( return ScaledArray(val / scale.astype(val.dtype), scale) # type:ignore # TODO: fix bug when scale is not 1. - raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.") + raise NotImplementedError(f"Constructing `ScaledArray` from {val} and {scale} not supported.") # type:ignore # return scaled_array_base(val, scale) diff --git a/jax_scalify/core/typing.py b/jax_scalify/core/typing.py index e26cf95..8291577 100644 --- a/jax_scalify/core/typing.py +++ b/jax_scalify/core/typing.py @@ -1,5 +1,5 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. -from typing import Any +from typing import Any, Tuple # import chex import jax @@ -10,12 +10,21 @@ try: from jax import Array - ArrayTypes = (Array, jax.stages.ArgInfo) + ArrayTypes: Tuple[Any, ...] = (Array,) except ImportError: from jaxlib.xla_extension import DeviceArray as Array + # Older version of JAX <0.4 ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer) +try: + from jax.stages import ArgInfo + + # Additional ArgInfo in recent JAX versions. + ArrayTypes = (*ArrayTypes, ArgInfo) +except ImportError: + pass + def get_numpy_api(val: Any) -> Any: """Get the Numpy API corresponding to an array.