Skip to content

Commit

Permalink
lazy_loader.attach now only imports the submodule once on first access
Browse files Browse the repository at this point in the history
This is a partial roll-forward of #22998.

PiperOrigin-RevId: 668460501
  • Loading branch information
superbobry authored and jax authors committed Aug 28, 2024
1 parent 78d5b75 commit 1f3954d
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions jax/_src/lazy_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from collections.abc import Callable, Sequence
import importlib
import sys
from typing import Any


Expand All @@ -26,17 +27,27 @@ def attach(package_name: str, submodules: Sequence[str]) -> tuple[
]:
"""Lazily loads submodules of a package.
Example use:
```
__getattr__, __dir__, __all__ = lazy_loader.attach(__name__, ["sub1", "sub2"])
```
Returns:
A tuple of ``__getattr__``, ``__dir__`` function and ``__all__`` --
a list of available global names, which can be used to replace the
corresponding definitions in the package.
Raises:
RuntimeError: If the ``__name__`` of the caller cannot be determined.
"""
owner_name = sys._getframe(1).f_globals.get("__name__")
if owner_name is None:
raise RuntimeError("Cannot determine the ``__name__`` of the caller.")

__all__: list[str] = list(submodules)
__all__ = list(submodules)

def __getattr__(name: str) -> Any:
if name in submodules:
return importlib.import_module(f"{package_name}.{name}")
value = importlib.import_module(f"{package_name}.{name}")
# Update module-level globals to avoid calling ``__getattr__`` again
# for this ``name``.
setattr(sys.modules[owner_name], name, value)
return value
raise AttributeError(f"module '{package_name}' has no attribute '{name}")

def __dir__() -> list[str]:
Expand Down

0 comments on commit 1f3954d

Please sign in to comment.