diff --git a/CHANGELOG.md b/CHANGELOG.md index 063324f0179c..a424e645144c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * New Functionality * This release includes wheels for Python 3.13. Free-threading mode is not yet supported. + * `jax.errors.JaxRuntimeError` has been added as a public alias for the + formerly private `XlaRuntimeError` type. * Breaking changes * `jax_pmap_no_rank_reduction` flag is set to `True` by default. @@ -32,6 +34,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. in an error. * Internal pretty-printing tools `jax.core.pp_*` have been removed, after being deprecated in JAX v0.4.30. + * `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use + `jax.errors.JaxRuntimeError` instead. * Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation diff --git a/docs/errors.rst b/docs/errors.rst index 96e14ed8d817..9965d6698bd4 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -9,6 +9,7 @@ along with representative examples of how one might fix them. .. currentmodule:: jax.errors .. autoclass:: ConcretizationTypeError .. autoclass:: KeyReuseError +.. autoclass:: JaxRuntimeError .. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerBoolConversionError diff --git a/jax/errors.py b/jax/errors.py index 2a811661d1ae..6da7b717cb5f 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -26,4 +26,9 @@ UnexpectedTracerError as UnexpectedTracerError, KeyReuseError as KeyReuseError, ) + +from jax._src.lib import xla_client as _xc +JaxRuntimeError = _xc.XlaRuntimeError +del _xc + from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 7422e9fcc56d..9466b72ca452 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -37,26 +37,35 @@ Traceback = _xc.Traceback XlaBuilder = _xc.XlaBuilder XlaComputation = _xc.XlaComputation -XlaRuntimeError = _xc.XlaRuntimeError _deprecations = { - # Added Aug 5 2024 - "_xla" : ( - "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", - _xc._xla - ), - "bfloat16" : ( - "jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.", - _xc.bfloat16 - ), + # Added Aug 5 2024 + "_xla": ( + "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", + _xc._xla, + ), + "bfloat16": ( + "jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.", + _xc.bfloat16, + ), + "XlaRuntimeError": ( + ( + "jax.lib.xla_client.XlaRuntimeError is deprecated; use" + " jax.errors.JaxRuntimeError." + ), + _xc.XlaRuntimeError, + ), } import typing as _typing + if _typing.TYPE_CHECKING: _xla = _xc._xla bfloat16 = _xc.bfloat16 + XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing