From d68a0993a48e3eafce2849294cae12623c471fe3 Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Thu, 26 Sep 2024 07:14:53 -0700 Subject: [PATCH] Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeError class. Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API. PiperOrigin-RevId: 679136569 --- CHANGELOG.md | 4 ++++ docs/errors.rst | 1 + jax/errors.py | 5 +++++ jax/lib/xla_client.py | 30 ++++++++++++++++++++---------- tests/errors_test.py | 16 ++++++++++++---- tests/package_structure_test.py | 2 +- 6 files changed, 43 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 063324f0179c..a424e645144c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. * New Functionality * This release includes wheels for Python 3.13. Free-threading mode is not yet supported. + * `jax.errors.JaxRuntimeError` has been added as a public alias for the + formerly private `XlaRuntimeError` type. * Breaking changes * `jax_pmap_no_rank_reduction` flag is set to `True` by default. @@ -32,6 +34,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. in an error. * Internal pretty-printing tools `jax.core.pp_*` have been removed, after being deprecated in JAX v0.4.30. + * `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use + `jax.errors.JaxRuntimeError` instead. * Deletion: * `jax.xla_computation` is deleted. It's been 3 months since it's deprecation diff --git a/docs/errors.rst b/docs/errors.rst index 96e14ed8d817..9965d6698bd4 100644 --- a/docs/errors.rst +++ b/docs/errors.rst @@ -9,6 +9,7 @@ along with representative examples of how one might fix them. .. currentmodule:: jax.errors .. autoclass:: ConcretizationTypeError .. autoclass:: KeyReuseError +.. autoclass:: JaxRuntimeError .. autoclass:: NonConcreteBooleanIndexError .. autoclass:: TracerArrayConversionError .. autoclass:: TracerBoolConversionError diff --git a/jax/errors.py b/jax/errors.py index 2a811661d1ae..6da7b717cb5f 100644 --- a/jax/errors.py +++ b/jax/errors.py @@ -26,4 +26,9 @@ UnexpectedTracerError as UnexpectedTracerError, KeyReuseError as KeyReuseError, ) + +from jax._src.lib import xla_client as _xc +JaxRuntimeError = _xc.XlaRuntimeError +del _xc + from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback diff --git a/jax/lib/xla_client.py b/jax/lib/xla_client.py index 7422e9fcc56d..9c57d82e837b 100644 --- a/jax/lib/xla_client.py +++ b/jax/lib/xla_client.py @@ -37,26 +37,36 @@ Traceback = _xc.Traceback XlaBuilder = _xc.XlaBuilder XlaComputation = _xc.XlaComputation -XlaRuntimeError = _xc.XlaRuntimeError _deprecations = { - # Added Aug 5 2024 - "_xla" : ( - "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", - _xc._xla - ), - "bfloat16" : ( - "jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.", - _xc.bfloat16 - ), + # Added Aug 5 2024 + "_xla": ( + "jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.", + _xc._xla, + ), + "bfloat16": ( + "jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.", + _xc.bfloat16, + ), + # Added Sep 26 2024 + "XlaRuntimeError": ( + ( + "jax.lib.xla_client.XlaRuntimeError is deprecated; use" + " jax.errors.JaxRuntimeError." + ), + _xc.XlaRuntimeError, + ), } import typing as _typing + if _typing.TYPE_CHECKING: _xla = _xc._xla bfloat16 = _xc.bfloat16 + XlaRuntimeError = _xc.XlaRuntimeError else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr + __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr del _typing diff --git a/tests/errors_test.py b/tests/errors_test.py index fa2dec95f0fa..7dfc4e51a6de 100644 --- a/tests/errors_test.py +++ b/tests/errors_test.py @@ -394,11 +394,19 @@ def test_grad_norm(self): class CustomErrorsTest(jtu.JaxTestCase): + @jtu.sample_product( - errorclass=[ - errorclass for errorclass in dir(jax.errors) - if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError'] - ], + errorclass=[ + errorclass + for errorclass in dir(jax.errors) + if errorclass.endswith('Error') + and errorclass + not in [ + 'JaxIndexError', + 'JAXTypeError', + 'JaxRuntimeError', + ] + ], ) def testErrorsURL(self, errorclass): class FakeTracer(core.Tracer): diff --git a/tests/package_structure_test.py b/tests/package_structure_test.py index e9944ec084af..71d48c2b121c 100644 --- a/tests/package_structure_test.py +++ b/tests/package_structure_test.py @@ -31,7 +31,7 @@ class PackageStructureTest(jtu.JaxTestCase): @parameterized.parameters([ # TODO(jakevdp): expand test to other public modules. - _mod("jax.errors"), + _mod("jax.errors", exclude=["JaxRuntimeError"]), _mod("jax.nn.initializers"), _mod( "jax.tree_util",