diff --git a/docs/getting_started/backends.md b/docs/getting_started/backends.md index 224e8162..69ab445f 100644 --- a/docs/getting_started/backends.md +++ b/docs/getting_started/backends.md @@ -2,12 +2,10 @@ `opt_einsum` is quite agnostic to the type of n-dimensional arrays (tensors) it uses, since finding the contraction path only relies on getting the shape -attribute of each array supplied. -It can perform the underlying tensor contractions with various -libraries. In fact, any library that provides a `numpy.tensordot` and -`numpy.transpose` implementation can perform most normal contractions. -While more special functionality such as axes reduction is reliant on a -`numpy.einsum` implementation. +attribute of each array supplied. It can perform the underlying tensor contractions with many libraries, including any that support the Python array API standard. +Furthermore, any library that provides a `numpy.tensordot` and +`numpy.transpose` implementation can perform most normal contractions, while more special functionality such as axes reduction is reliant on an `numpy.einsum` implementation. + The following is a brief overview of libraries which have been tested with `opt_einsum`: @@ -26,15 +24,6 @@ The following is a brief overview of libraries which have been tested with - [jax](https://github.com/google/jax): compiled GPU tensor expressions including `autograd`-like functionality -`opt_einsum` is agnostic to the type of n-dimensional arrays (tensors) -it uses, since finding the contraction path only relies on getting the shape -attribute of each array supplied. -It can perform the underlying tensor contractions with various -libraries. In fact, any library that provides a `numpy.tensordot` and -`~numpy.transpose` implementation can perform most normal contractions. -While more special functionality such as axes reduction is reliant on a -`numpy.einsum` implementation. - !!! note For a contraction to be possible without using a backend einsum, it must satisfy the following rule: in the full expression (*including* output @@ -44,9 +33,8 @@ While more special functionality such as axes reduction is reliant on a ## Backend agnostic contractions -The automatic backend detection will be detected based on the first supplied -array (default), this can be overridden by specifying the correct `backend` -argument for the type of arrays supplied when calling +By default, backend will be automatically detected based on the first supplied +array. This can be overridden by specifying the desired `backend` as a keyword argument when calling [`opt_einsum.contract`](../api_reference.md##opt_einsumcontract). For example, if you had a library installed called `'foo'` which provided an `numpy.ndarray` like object with a `.shape` attribute as well as `foo.tensordot` and `foo.transpose` then @@ -56,9 +44,9 @@ you could contract them with something like: contract(einsum_str, *foo_arrays, backend='foo') ``` -Behind the scenes `opt_einsum` will find the contraction path, perform -pairwise contractions using e.g. `foo.tensordot` and finally return the canonical -type those functions return. +Behind the scenes `opt_einsum` will find the contraction path and perform +pairwise contractions (using e.g. `foo.tensordot`). The return type is backend-dependent; for example, it's up to backends to decide whether `foo.tensordot` performs type promotion. + ### Dask @@ -197,7 +185,7 @@ Currently `opt_einsum` can handle this automatically for: - [jax](https://github.com/google/jax) all of which offer GPU support. Since `tensorflow` and `theano` both require -compiling the expression, this functionality is encapsulated in generating a +compiling the expression. This functionality is encapsulated in generating a [`opt_einsum.ContractExpression`](../api_reference.md#opt_einsumcontractcontractexpression) using [`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression), which can then be called using numpy arrays whilst specifying `backend='tensorflow'` etc. diff --git a/docs/index.md b/docs/index.md index 358b888d..fb08c08c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,8 +4,7 @@ Optimized einsum can significantly reduce the overall execution time of einsum-l expressions by optimizing the expression's contraction order and dispatching many operations to canonical BLAS, cuBLAS, or other specialized routines. Optimized einsum is agnostic to the backend and can handle NumPy, Dask, -PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as -potentially any library which conforms to a standard API. +PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as any library which conforms to the Python array API standard. Other libraries may work so long as some key methods are available as attributes of their array objects. ## Features diff --git a/opt_einsum/backends/__init__.py b/opt_einsum/backends/__init__.py index a9b85795..bf97617c 100644 --- a/opt_einsum/backends/__init__.py +++ b/opt_einsum/backends/__init__.py @@ -8,6 +8,7 @@ from .tensorflow import to_tensorflow from .theano import to_theano from .torch import to_torch +from .array_api import to_array_api, discover_array_apis __all__ = [ "get_func", @@ -20,4 +21,6 @@ "to_theano", "to_cupy", "to_torch", + "to_array_api", + "discover_array_apis", ] diff --git a/opt_einsum/backends/array_api.py b/opt_einsum/backends/array_api.py new file mode 100644 index 00000000..9de2a24f --- /dev/null +++ b/opt_einsum/backends/array_api.py @@ -0,0 +1,75 @@ +""" +Required functions for optimized contractions of arrays using array API-compliant backends. +""" +import sys +from typing import Callable +from types import ModuleType + +import numpy as np + +from ..sharing import to_backend_cache_wrap + + +def discover_array_apis(): + """Discover array API backends.""" + if sys.version_info >= (3, 8): + from importlib.metadata import entry_points + + if sys.version_info >= (3, 10): + eps = entry_points(group="array_api") + else: + # Deprecated - will raise warning in Python versions >= 3.10 + eps = entry_points().get("array_api", []) + return [ep.load() for ep in eps] + else: + # importlib.metadata was introduced in Python 3.8, so it isn't available here. Unable to discover any array APIs. + return [] + + +def make_to_array_function(array_api: ModuleType) -> Callable: + """Make a ``to_[array_api]`` function for the given array API.""" + + @to_backend_cache_wrap + def to_array(array): # pragma: no cover + if isinstance(array, np.ndarray): + return array_api.asarray(array) + return array + + return to_array + + +def make_build_expression_function(array_api: ModuleType) -> Callable: + """Make a ``build_expression`` function for the given array API.""" + _to_array_api = to_array_api[array_api.__name__] + + def build_expression(_, expr): # pragma: no cover + """Build an array API function based on ``arrays`` and ``expr``.""" + + def array_api_contract(*arrays): + return expr._contract([_to_array_api(x) for x in arrays], backend=array_api.__name__) + + return array_api_contract + + return build_expression + + +def make_evaluate_constants_function(array_api: ModuleType) -> Callable: + _to_array_api = to_array_api[array_api.__name__] + + def evaluate_constants(const_arrays, expr): # pragma: no cover + """Convert constant arguments to cupy arrays, and perform any possible constant contractions.""" + return expr( + *[_to_array_api(x) for x in const_arrays], + backend=array_api.__name__, + evaluate_constants=True, + ) + + return evaluate_constants + + +_array_apis = discover_array_apis() +to_array_api = {api.__name__: make_to_array_function(api) for api in _array_apis} +build_expression = {api.__name__: make_build_expression_function(api) for api in _array_apis} +evaluate_constants = {api.__name__: make_evaluate_constants_function(api) for api in _array_apis} + +__all__ = ["discover_array_apis", "to_array_api", "build_expression", "evaluate_constants"] diff --git a/opt_einsum/backends/dispatch.py b/opt_einsum/backends/dispatch.py index 7c7d7e8c..86d13b1f 100644 --- a/opt_einsum/backends/dispatch.py +++ b/opt_einsum/backends/dispatch.py @@ -15,6 +15,7 @@ from . import tensorflow as _tensorflow from . import theano as _theano from . import torch as _torch +from . import array_api as _array_api __all__ = [ "get_func", @@ -122,6 +123,7 @@ def has_tensordot(backend: str) -> bool: "cupy": _cupy.build_expression, "torch": _torch.build_expression, "jax": _jax.build_expression, + **_array_api.build_expression, } EVAL_CONSTS_BACKENDS = { @@ -130,6 +132,7 @@ def has_tensordot(backend: str) -> bool: "cupy": _cupy.evaluate_constants, "torch": _torch.evaluate_constants, "jax": _jax.evaluate_constants, + **_array_api.evaluate_constants, } diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index 0f703795..63c88a55 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -411,8 +411,13 @@ def _einsum(*operands, **kwargs): def _default_transpose(x: ArrayType, axes: Tuple[int, ...]) -> ArrayType: - # most libraries implement a method version - return x.transpose(axes) + # Many libraries implement a method version, but the array API-conforming arrys do not (as of 2021.12). + if hasattr(x, "transpose"): + return x.transpose(axes) + elif hasattr(x, "__array_namespace__"): + return x.__array_namespace__().permute_dims(x, axes) + else: + raise NotImplementedError(f"No implementation for transpose or equivalent found for {x}") @sharing.transpose_cache_wrap @@ -549,7 +554,12 @@ def _infer_backend_class_cached(cls: type) -> str: def infer_backend(x: Any) -> str: - return _infer_backend_class_cached(x.__class__) + if hasattr(x, "__array_namespace__"): + # Having an ``__array_namespace__`` is a 'guarantee' from the developers of the given array's module that + # it conforms to the Python array API. Use this as a backend, if available. + return x.__array_namespace__().__name__ + else: + return _infer_backend_class_cached(x.__class__) def parse_backend(arrays: Sequence[ArrayType], backend: Optional[str]) -> str: diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index 7c81c671..1ceb852f 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -1,4 +1,6 @@ +from types import ModuleType import numpy as np +from pkg_resources import EntryPoint import pytest from opt_einsum import backends, contract, contract_expression, helpers, sharing @@ -56,6 +58,12 @@ except ImportError: found_autograd = False + +@pytest.fixture(params=backends.discover_array_apis()) +def array_api(request) -> ModuleType: + return request.param + + tests = [ "ab,bc->ca", "abc,bcd,dea", @@ -466,3 +474,63 @@ def test_object_arrays_backend(string): obj_opt = expr(*obj_views, backend="object") assert obj_opt.dtype == object assert np.allclose(ein, obj_opt.astype(float)) + + +@pytest.mark.parametrize("string", tests) +def test_array_api(array_api: ModuleType, string): # pragma: no cover + array_api_qname: str = array_api.__name__ + + views = helpers.build_views(string) + ein = contract(string, *views, optimize=False, use_blas=False) + shps = [v.shape for v in views] + + expr = contract_expression(string, *shps, optimize=True) + + opt = expr(*views, backend=array_api.__name__) + assert np.allclose(ein, opt) + + # test non-conversion mode + array_api_views = [backends.to_array_api[array_api_qname](view) for view in views] + array_api_opt = expr(*array_api_views) + assert all(type(array_api_opt) is type(view) for view in array_api_views) + assert np.allclose(ein, np.asarray(array_api_opt)) + + +@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) +def test_array_api_with_constants(array_api: ModuleType, constants): # pragma: no cover + array_api_qname: str = array_api.__name__ + + eq = "ij,jk,kl->li" + shapes = (2, 3), (3, 4), (4, 5) + (non_const,) = {0, 1, 2} - constants + ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] + var = np.random.rand(*shapes[non_const]) + res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) + + expr = contract_expression(eq, *ops, constants=constants) + + # check array API + res_got = expr(var, backend=array_api_qname) + # check array API versions of constants exist + assert all( + array is None or infer_backend(array) == array_api_qname for array in expr._evaluated_constants[array_api_qname] + ) + assert np.allclose(res_exp, res_got) + + # check can call with numpy still + res_got2 = expr(var, backend="numpy") + assert np.allclose(res_exp, res_got2) + + # check array API call returns an array object belonging to the same backend + # NOTE: the array API standard does not require that the returned array is the same type as the input array, + # only that the returned array also obeys the array API standard. Indeed, the standard does not stipulate + # even a *name* for the array type. + # + # For this reason, we won't check the type of the returned array, but only that it is has an + # ``__array_namespace__`` attribute and hence claims to comply with standard. + # + # In future versions, if einexpr uses newer array API features, we will also need to check that the + # returned array complies with the appropriate version of the standard. + res_got3 = expr(array_api.asarray(var)) + assert hasattr(res_got3, "__array_namespace__") + assert np.allclose(res_exp, res_got3)