Skip to content

Commit

Permalink
Add jax.errors.JaxRuntimeError as a public alias for the XlaRuntimeEr…
Browse files Browse the repository at this point in the history
…ror class.

Deprecate jax.lib.xla_client.XlaRuntimeError, which is not a public API.

PiperOrigin-RevId: 679136569
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Sep 26, 2024
1 parent 5cef547 commit d68a099
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 15 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* New Functionality
* This release includes wheels for Python 3.13. Free-threading mode is not yet
supported.
* `jax.errors.JaxRuntimeError` has been added as a public alias for the
formerly private `XlaRuntimeError` type.

* Breaking changes
* `jax_pmap_no_rank_reduction` flag is set to `True` by default.
Expand All @@ -32,6 +34,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
in an error.
* Internal pretty-printing tools `jax.core.pp_*` have been removed, after
being deprecated in JAX v0.4.30.
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
`jax.errors.JaxRuntimeError` instead.

* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
Expand Down
1 change: 1 addition & 0 deletions docs/errors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ along with representative examples of how one might fix them.
.. currentmodule:: jax.errors
.. autoclass:: ConcretizationTypeError
.. autoclass:: KeyReuseError
.. autoclass:: JaxRuntimeError
.. autoclass:: NonConcreteBooleanIndexError
.. autoclass:: TracerArrayConversionError
.. autoclass:: TracerBoolConversionError
Expand Down
5 changes: 5 additions & 0 deletions jax/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,9 @@
UnexpectedTracerError as UnexpectedTracerError,
KeyReuseError as KeyReuseError,
)

from jax._src.lib import xla_client as _xc
JaxRuntimeError = _xc.XlaRuntimeError
del _xc

from jax._src.traceback_util import SimplifiedTraceback as SimplifiedTraceback
30 changes: 20 additions & 10 deletions jax/lib/xla_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,36 @@
Traceback = _xc.Traceback
XlaBuilder = _xc.XlaBuilder
XlaComputation = _xc.XlaComputation
XlaRuntimeError = _xc.XlaRuntimeError

_deprecations = {
# Added Aug 5 2024
"_xla" : (
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
_xc._xla
),
"bfloat16" : (
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
_xc.bfloat16
),
# Added Aug 5 2024
"_xla": (
"jax.lib.xla_client._xla is deprecated; use jax.lib.xla_extension.",
_xc._xla,
),
"bfloat16": (
"jax.lib.xla_client.bfloat16 is deprecated; use ml_dtypes.bfloat16.",
_xc.bfloat16,
),
# Added Sep 26 2024
"XlaRuntimeError": (
(
"jax.lib.xla_client.XlaRuntimeError is deprecated; use"
" jax.errors.JaxRuntimeError."
),
_xc.XlaRuntimeError,
),
}

import typing as _typing

if _typing.TYPE_CHECKING:
_xla = _xc._xla
bfloat16 = _xc.bfloat16
XlaRuntimeError = _xc.XlaRuntimeError
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr

__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del _typing
Expand Down
16 changes: 12 additions & 4 deletions tests/errors_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,19 @@ def test_grad_norm(self):


class CustomErrorsTest(jtu.JaxTestCase):

@jtu.sample_product(
errorclass=[
errorclass for errorclass in dir(jax.errors)
if errorclass.endswith('Error') and errorclass not in ['JaxIndexError', 'JAXTypeError']
],
errorclass=[
errorclass
for errorclass in dir(jax.errors)
if errorclass.endswith('Error')
and errorclass
not in [
'JaxIndexError',
'JAXTypeError',
'JaxRuntimeError',
]
],
)
def testErrorsURL(self, errorclass):
class FakeTracer(core.Tracer):
Expand Down
2 changes: 1 addition & 1 deletion tests/package_structure_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class PackageStructureTest(jtu.JaxTestCase):

@parameterized.parameters([
# TODO(jakevdp): expand test to other public modules.
_mod("jax.errors"),
_mod("jax.errors", exclude=["JaxRuntimeError"]),
_mod("jax.nn.initializers"),
_mod(
"jax.tree_util",
Expand Down

0 comments on commit d68a099

Please sign in to comment.