Skip to content

Commit

Permalink
Fix JAX 0.4.31 compatibility, with sharding argument in `convert_el…
Browse files Browse the repository at this point in the history
…ement_type`. (#134)
  • Loading branch information
balancap authored Sep 9, 2024
1 parent 32f81b9 commit b9b5c57
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion jax_scalify/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@
scalify,
)
from .pow2 import Pow2RoundMode, pow2_decompose, pow2_round, pow2_round_down, pow2_round_up # noqa: F401
from .typing import Array, ArrayTypes, get_numpy_api # noqa: F401
from .typing import Array, ArrayTypes, Sharding, get_numpy_api # noqa: F401
from .utils import safe_div, safe_reciprocal # noqa: F401
2 changes: 2 additions & 0 deletions jax_scalify/core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,13 @@
# Type aliasing. To be compatible with JAX 0.3 as well.
try:
from jax import Array
from jax.sharding import Sharding

ArrayTypes: Tuple[Any, ...] = (Array,)
except ImportError:
from jaxlib.xla_extension import DeviceArray as Array

Sharding = Any
# Older version of JAX <0.4
ArrayTypes = (Array, jax.interpreters.partial_eval.DynamicJaxprTracer)

Expand Down
5 changes: 4 additions & 1 deletion jax_scalify/lax/scaled_ops_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
DTypeLike,
ScaledArray,
Shape,
Sharding,
as_scaled_array,
get_scale_dtype,
is_static_anyscale,
Expand Down Expand Up @@ -76,7 +77,9 @@ def scaled_broadcast_in_dim(A: ScaledArray, shape: Shape, broadcast_dimensions:


@core.register_scaled_lax_op
def scaled_convert_element_type(A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False) -> ScaledArray:
def scaled_convert_element_type(
A: ScaledArray, new_dtype: DTypeLike, weak_type: bool = False, sharding: Sharding | None = None
) -> ScaledArray:
# NOTE: by default, no rescaling done before casting.
# Choice of adding an optional rescaling op before is up to the user (and which strategy to use).
# NOTE bis: scale not casted as well by default!
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
]
dependencies = [
"chex>=0.1.6",
"jax>=0.3.16,<0.4.31",
"jax>=0.3.16",
"jaxlib>=0.3.15",
"ml_dtypes",
"numpy>=1.22.4"
Expand Down

0 comments on commit b9b5c57

Please sign in to comment.