Skip to content

Commit

Permalink
Merge pull request #16204 from skye:importlib_metadata_version
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 536823622
  • Loading branch information
jax authors committed May 31, 2023
2 parents b35c20c + 9682370 commit 525ba49
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 15 deletions.
18 changes: 6 additions & 12 deletions jax/_src/xla_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,20 +360,14 @@ def discover_pjrt_plugins() -> None:
if sys.version_info < (3, 10):
# Use the backport library because it provides a forward-compatible
# implementation.
try:
from importlib_metadata import entry_points
except ModuleNotFoundError:
logger.debug(
"No importlib_metadata found (for Python < 3.10): "
"Plugins advertised from entrypoints will not be found.")
entry_points = None
from importlib_metadata import entry_points
else:
from importlib.metadata import entry_points
if entry_points:
for entry_point in entry_points(group="jax_plugins"):
logger.debug("Discovered entry-point based JAX plugin: %s",
entry_point.value)
plugin_modules.add(entry_point.value)

for entry_point in entry_points(group="jax_plugins"):
logger.debug("Discovered entry-point based JAX plugin: %s",
entry_point.value)
plugin_modules.add(entry_point.value)

# Now load and initialize them all.
for plugin_module_name in plugin_modules:
Expand Down
8 changes: 5 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ def generate_proto(source):
'numpy>=1.21',
'opt_einsum',
'scipy>=1.7',
# Required by xla_bridge.discover_pjrt_plugins for forwards compat with
# Python versions < 3.10. Can be dropped when 3.10 is the minimum
# required Python version.
'importlib_metadata>=4.6;python_version<"3.10"',
],
extras_require={
# Minimum jaxlib version; used in testing.
Expand All @@ -82,9 +86,7 @@ def generate_proto(source):
# Cloud TPU VM jaxlib can be installed via:
# $ pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
'tpu': [f'jaxlib=={_current_jaxlib_version}',
f'libtpu-nightly=={_libtpu_version}',
# Required by cloud_tpu_init.py
'requests'],
f'libtpu-nightly=={_libtpu_version}'],

# $ pip install jax[australis]
'australis': ['protobuf>=3.13,<4'],
Expand Down

0 comments on commit 525ba49

Please sign in to comment.