diff --git a/src/quaxed/__init__.py b/src/quaxed/__init__.py index 4eaaf9d..b118ce0 100644 --- a/src/quaxed/__init__.py +++ b/src/quaxed/__init__.py @@ -2,22 +2,18 @@ quaxed: Pre-quaxed libraries for multiple dispatch over abstract array types in JAX """ +# pylint: disable=C0415,W0621 -# pylint: disable=redefined-builtin +from __future__ import annotations -__all__ = ["__version__", "lax", "scipy"] +from typing import TYPE_CHECKING -import sys -from typing import Any - -import plum -from jaxtyping import ArrayLike - -from . import _jax, lax, scipy +from . import _jax, lax, numpy, scipy from ._jax import * from ._setup import JAX_VERSION -from ._version import version as __version__ +from ._version import version as __version__ # noqa: F401 +__all__ = ["lax", "numpy", "scipy"] __all__ += _jax.__all__ if JAX_VERSION < (0, 4, 32): @@ -26,15 +22,16 @@ __all__ += ["array_api"] -# Simplify the display of ArrayLike -plum.activate_union_aliases() -plum.set_union_alias(ArrayLike, "ArrayLike") +if TYPE_CHECKING: + from typing import Any def __getattr__(name: str) -> Any: # TODO: fuller annotation """Forward all other attribute accesses to Quaxified JAX.""" - import jax # pylint: disable=C0415,W0621 - from quax import quaxify # pylint: disable=C0415,W0621 + import sys + + import jax + from quax import quaxify # TODO: detect if the attribute is a function or a module. # If it is a function, quaxify it. If it is a module, return a proxy object @@ -45,3 +42,7 @@ def __getattr__(name: str) -> Any: # TODO: fuller annotation setattr(sys.modules[__name__], name, out) return out + + +# Clean up the namespace +del TYPE_CHECKING diff --git a/src/quaxed/_types.py b/src/quaxed/_types.py index ba6161f..ea87295 100644 --- a/src/quaxed/_types.py +++ b/src/quaxed/_types.py @@ -1,13 +1,16 @@ -"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved. - -quaxed: Pre-quaxed libraries for multiple dispatch over abstract array types in JAX -""" +"""Copyright (c) 2023 Nathaniel Starkman. All rights reserved.""" __all__: list[str] = [] from typing import Any, Protocol, runtime_checkable import jax.numpy as jnp +import plum +from jaxtyping import ArrayLike + +# Simplify the display of ArrayLike +plum.activate_union_aliases() +plum.set_union_alias(ArrayLike, "ArrayLike") @runtime_checkable