From b35c20ce5da955d3de084f11f7a13b0ed4c14f56 Mon Sep 17 00:00:00 2001 From: Jieying Luo Date: Wed, 31 May 2023 13:41:53 -0700 Subject: [PATCH] Use xla_extension_version and remove some dead version check in xla_bridge_test.py. Min jaxlib requires xla_extension_version >= 144. PiperOrigin-RevId: 536810415 --- tests/xla_bridge_test.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index d115e6efb11f..eb0f1b02b67c 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -21,6 +21,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax.interpreters import xla from jax._src.config import config @@ -99,7 +100,7 @@ def test_register_plugin(self): with mock.patch.object( xc, "load_pjrt_plugin_dynamically", autospec=True ) as mock_load_plugin: - if xc._version >= 152: + if xla_extension_version >= 152: with mock.patch.object( xc, "pjrt_plugin_loaded", autospec=True ) as mock_plugin_loaded: @@ -115,18 +116,13 @@ def test_register_plugin(self): self.assertIn("name1", xb._backend_factories) self.assertIn("name2", xb._backend_factories) self.assertEqual(priotiy, 400) - if xc._version >= 152: + if xla_extension_version >= 152: mock_plugin_loaded.assert_called_once_with("name1") else: mock_load_plugin.assert_called_once_with("name1", "path1") - if xc._version >= 134: - mock_make.assert_called_once_with("name1", None) - else: - mock_make.assert_called_once_with("name1") + mock_make.assert_called_once_with("name1", None) def test_register_plugin_with_config(self): - if xc._version < 134: - return test_json_file_path = os.path.join( os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json" ) @@ -137,7 +133,7 @@ def test_register_plugin_with_config(self): with mock.patch.object( xc, "load_pjrt_plugin_dynamically", autospec=True ) as mock_load_plugin: - if xc._version >= 152: + if xla_extension_version >= 152: with mock.patch.object( xc, "pjrt_plugin_loaded", autospec=True ) as mock_plugin_loaded: @@ -147,7 +143,7 @@ def test_register_plugin_with_config(self): self.assertIn("name1", xb._backend_factories) self.assertEqual(priority, 400) - if xc._version >= 152: + if xla_extension_version >= 152: mock_plugin_loaded.assert_called_once_with("name1") else: mock_load_plugin.assert_called_once_with(