Skip to content

Commit

Permalink
Deprecate several internal utilities in jax.core
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 14, 2024
1 parent df2e9c3 commit 1462a55
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* The `jax.experimental.array_api` module is deprecated, and importing it is no
longer required to use the Array API. `jax.numpy` supports the array API
directly; see {ref}`python-array-api` for more information.
* The internal utilities `jax.core.check_eqn`, `jax.core.check_types`, and
`jax.core.check_valid_jaxtype` are now deprecated, and will be removed in
the future.

## jaxlib 0.4.32

Expand Down
18 changes: 12 additions & 6 deletions jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,7 @@
call_bind_with_continuation as call_bind_with_continuation,
call_impl as call_impl,
call_p as call_p,
check_eqn as check_eqn,
check_jaxpr as check_jaxpr,
check_type as check_type,
check_valid_jaxtype as check_valid_jaxtype,
closed_call_p as closed_call_p,
concrete_aval as concrete_aval,
concrete_or_error as concrete_or_error,
Expand Down Expand Up @@ -110,7 +107,6 @@
new_sublevel as new_sublevel,
no_axis_name as no_axis_name,
no_effects as no_effects,
non_negative_dim as _deprecated_non_negative_dim,
outfeed_primitives as outfeed_primitives,
primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype,
primitive_uses_outfeed as primitive_uses_outfeed,
Expand Down Expand Up @@ -144,6 +140,13 @@

from jax._src import core as _src_core
_deprecations = {
# Added 2024-08-14
"check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn),
"check_type": ("jax.core.check_type is deprecated.", _src_core.check_type),
"check_valid_jaxtype": (
("jax.core.check_valid_jaxtype is deprecated. Instead, you can manually"
" raise an error if core.valid_jaxtype() returns False."),
_src_core.check_valid_jaxtype),
# Added 2024-06-12
"pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval),
"pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn),
Expand Down Expand Up @@ -181,13 +184,16 @@
),
# Added Jan 8, 2024
"non_negative_dim": (
"jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _deprecated_non_negative_dim,
"jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim,
),
}

import typing
if typing.TYPE_CHECKING:
non_negative_dim = _deprecated_non_negative_dim
check_eqn = _src_core.check_eqn
check_type = _src_core.check_type
check_valid_jaxtype = _src_core.check_valid_jaxtype
non_negative_dim = _src_core.non_negative_dim
pp_aval = _src_core.pp_aval
pp_eqn = _src_core.pp_eqn
pp_eqn_rules = _src_core.pp_eqn_rules
Expand Down

0 comments on commit 1462a55

Please sign in to comment.