From 968237080fcb4c5297435f9b146f0a661519ab9c Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Wed, 31 May 2023 18:28:09 +0000 Subject: [PATCH] Add importlib_metadata to project requirements. This is necessary to ensure we can correctly detect PJRT plugins via entry_points without compatibility errors. Prior to this change, there was conditional logic to handle if importlib_metadata wasn't installed at all. However, it doesn't handle the case where importlib_metadata is installed by not high enough version to support Python 3.10 compat. This change gets rid of that logic and just ensures the right version is installed. All of this logic can be removed if/when jax requires Python version >= 3.10 This also removes an unnecessary `requests` dep for the [tpu] install. --- jax/_src/xla_bridge.py | 18 ++++++------------ setup.py | 8 +++++--- 2 files changed, 11 insertions(+), 15 deletions(-) diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 18fcb354aedb..4c11d15e9d12 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -360,20 +360,14 @@ def discover_pjrt_plugins() -> None: if sys.version_info < (3, 10): # Use the backport library because it provides a forward-compatible # implementation. - try: - from importlib_metadata import entry_points - except ModuleNotFoundError: - logger.debug( - "No importlib_metadata found (for Python < 3.10): " - "Plugins advertised from entrypoints will not be found.") - entry_points = None + from importlib_metadata import entry_points else: from importlib.metadata import entry_points - if entry_points: - for entry_point in entry_points(group="jax_plugins"): - logger.debug("Discovered entry-point based JAX plugin: %s", - entry_point.value) - plugin_modules.add(entry_point.value) + + for entry_point in entry_points(group="jax_plugins"): + logger.debug("Discovered entry-point based JAX plugin: %s", + entry_point.value) + plugin_modules.add(entry_point.value) # Now load and initialize them all. for plugin_module_name in plugin_modules: diff --git a/setup.py b/setup.py index 75cf89dba95e..0bcb23b78bf7 100644 --- a/setup.py +++ b/setup.py @@ -67,6 +67,10 @@ def generate_proto(source): 'numpy>=1.21', 'opt_einsum', 'scipy>=1.7', + # Required by xla_bridge.discover_pjrt_plugins for forwards compat with + # Python versions < 3.10. Can be dropped when 3.10 is the minimum + # required Python version. + 'importlib_metadata>=4.6;python_version<"3.10"', ], extras_require={ # Minimum jaxlib version; used in testing. @@ -82,9 +86,7 @@ def generate_proto(source): # Cloud TPU VM jaxlib can be installed via: # $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html 'tpu': [f'jaxlib=={_current_jaxlib_version}', - f'libtpu-nightly=={_libtpu_version}', - # Required by cloud_tpu_init.py - 'requests'], + f'libtpu-nightly=={_libtpu_version}'], # $ pip install jax[australis] 'australis': ['protobuf>=3.13,<4'],