From 1462a55dc63e58a926ffdaf7e963ec6ae5f1bba0 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 14 Aug 2024 09:15:52 -0700 Subject: [PATCH] Deprecate several internal utilities in jax.core --- CHANGELOG.md | 3 +++ jax/core.py | 18 ++++++++++++------ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2b30b08abb01..baadb4f8c9e7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * The `jax.experimental.array_api` module is deprecated, and importing it is no longer required to use the Array API. `jax.numpy` supports the array API directly; see {ref}`python-array-api` for more information. + * The internal utilities `jax.core.check_eqn`, `jax.core.check_types`, and + `jax.core.check_valid_jaxtype` are now deprecated, and will be removed in + the future. ## jaxlib 0.4.32 diff --git a/jax/core.py b/jax/core.py index 80025e8619f3..1f433d6f5c29 100644 --- a/jax/core.py +++ b/jax/core.py @@ -66,10 +66,7 @@ call_bind_with_continuation as call_bind_with_continuation, call_impl as call_impl, call_p as call_p, - check_eqn as check_eqn, check_jaxpr as check_jaxpr, - check_type as check_type, - check_valid_jaxtype as check_valid_jaxtype, closed_call_p as closed_call_p, concrete_aval as concrete_aval, concrete_or_error as concrete_or_error, @@ -110,7 +107,6 @@ new_sublevel as new_sublevel, no_axis_name as no_axis_name, no_effects as no_effects, - non_negative_dim as _deprecated_non_negative_dim, outfeed_primitives as outfeed_primitives, primal_dtype_to_tangent_dtype as primal_dtype_to_tangent_dtype, primitive_uses_outfeed as primitive_uses_outfeed, @@ -144,6 +140,13 @@ from jax._src import core as _src_core _deprecations = { + # Added 2024-08-14 + "check_eqn": ("jax.core.check_eqn is deprecated.", _src_core.check_eqn), + "check_type": ("jax.core.check_type is deprecated.", _src_core.check_type), + "check_valid_jaxtype": ( + ("jax.core.check_valid_jaxtype is deprecated. Instead, you can manually" + " raise an error if core.valid_jaxtype() returns False."), + _src_core.check_valid_jaxtype), # Added 2024-06-12 "pp_aval": ("jax.core.pp_aval is deprecated.", _src_core.pp_aval), "pp_eqn": ("jax.core.pp_eqn is deprecated.", _src_core.pp_eqn), @@ -181,13 +184,16 @@ ), # Added Jan 8, 2024 "non_negative_dim": ( - "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _deprecated_non_negative_dim, + "jax.core.non_negative_dim is deprecated. Use max_dim(..., 0).", _src_core.non_negative_dim, ), } import typing if typing.TYPE_CHECKING: - non_negative_dim = _deprecated_non_negative_dim + check_eqn = _src_core.check_eqn + check_type = _src_core.check_type + check_valid_jaxtype = _src_core.check_valid_jaxtype + non_negative_dim = _src_core.non_negative_dim pp_aval = _src_core.pp_aval pp_eqn = _src_core.pp_eqn pp_eqn_rules = _src_core.pp_eqn_rules