Skip to content

Commit

Permalink
Use xla_extension_version and remove some dead version check in xla_b…
Browse files Browse the repository at this point in the history
…ridge_test.py.

Min jaxlib requires xla_extension_version >= 144.

PiperOrigin-RevId: 536810415
  • Loading branch information
Jieying Luo authored and jax authors committed May 31, 2023
1 parent 727c121 commit b35c20c
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions tests/xla_bridge_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax.interpreters import xla

from jax._src.config import config
Expand Down Expand Up @@ -99,7 +100,7 @@ def test_register_plugin(self):
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
if xc._version >= 152:
if xla_extension_version >= 152:
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
Expand All @@ -115,18 +116,13 @@ def test_register_plugin(self):
self.assertIn("name1", xb._backend_factories)
self.assertIn("name2", xb._backend_factories)
self.assertEqual(priotiy, 400)
if xc._version >= 152:
if xla_extension_version >= 152:
mock_plugin_loaded.assert_called_once_with("name1")
else:
mock_load_plugin.assert_called_once_with("name1", "path1")
if xc._version >= 134:
mock_make.assert_called_once_with("name1", None)
else:
mock_make.assert_called_once_with("name1")
mock_make.assert_called_once_with("name1", None)

def test_register_plugin_with_config(self):
if xc._version < 134:
return
test_json_file_path = os.path.join(
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
)
Expand All @@ -137,7 +133,7 @@ def test_register_plugin_with_config(self):
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
if xc._version >= 152:
if xla_extension_version >= 152:
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
Expand All @@ -147,7 +143,7 @@ def test_register_plugin_with_config(self):

self.assertIn("name1", xb._backend_factories)
self.assertEqual(priority, 400)
if xc._version >= 152:
if xla_extension_version >= 152:
mock_plugin_loaded.assert_called_once_with("name1")
else:
mock_load_plugin.assert_called_once_with(
Expand Down

0 comments on commit b35c20c

Please sign in to comment.