From 28a65589f7489d9d0a1e62ca502e7c70e00b7e80 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Wed, 28 Aug 2024 15:35:54 -0700 Subject: [PATCH] `lazy_loader.attach` now only imports the submodule once on first access This is a partial roll-forward of #22998. PiperOrigin-RevId: 668633307 --- jax/_src/lazy_loader.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/jax/_src/lazy_loader.py b/jax/_src/lazy_loader.py index cf6e68e49c81..14822bff3eff 100644 --- a/jax/_src/lazy_loader.py +++ b/jax/_src/lazy_loader.py @@ -16,6 +16,7 @@ from collections.abc import Callable, Sequence import importlib +import sys from typing import Any @@ -26,17 +27,27 @@ def attach(package_name: str, submodules: Sequence[str]) -> tuple[ ]: """Lazily loads submodules of a package. - Example use: - ``` - __getattr__, __dir__, __all__ = lazy_loader.attach(__name__, ["sub1", "sub2"]) - ``` + Returns: + A tuple of ``__getattr__``, ``__dir__`` function and ``__all__`` -- + a list of available global names, which can be used to replace the + corresponding definitions in the package. + + Raises: + RuntimeError: If the ``__name__`` of the caller cannot be determined. """ + owner_name = sys._getframe(1).f_globals.get("__name__") + if owner_name is None: + raise RuntimeError("Cannot determine the ``__name__`` of the caller.") - __all__: list[str] = list(submodules) + __all__ = list(submodules) def __getattr__(name: str) -> Any: if name in submodules: - return importlib.import_module(f"{package_name}.{name}") + value = importlib.import_module(f"{package_name}.{name}") + # Update module-level globals to avoid calling ``__getattr__`` again + # for this ``name``. + setattr(sys.modules[owner_name], name, value) + return value raise AttributeError(f"module '{package_name}' has no attribute '{name}") def __dir__() -> list[str]: