From e7dc2c33ceab2abb4b3e724d6144e8887eb8b153 Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Sun, 17 Jul 2022 14:44:38 +0800 Subject: [PATCH 01/12] Add compatibility for Python array API backends --- opt_einsum/backends/__init__.py | 3 ++ opt_einsum/backends/array_api.py | 62 ++++++++++++++++++++++++++++++++ opt_einsum/backends/dispatch.py | 4 ++- opt_einsum/contract.py | 16 +++++++-- 4 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 opt_einsum/backends/array_api.py diff --git a/opt_einsum/backends/__init__.py b/opt_einsum/backends/__init__.py index a9b85795..48e367a8 100644 --- a/opt_einsum/backends/__init__.py +++ b/opt_einsum/backends/__init__.py @@ -8,6 +8,7 @@ from .tensorflow import to_tensorflow from .theano import to_theano from .torch import to_torch +from .array_api import to_array_api, discover_array_api_eps __all__ = [ "get_func", @@ -20,4 +21,6 @@ "to_theano", "to_cupy", "to_torch", + "to_array_api", + "discover_array_api_eps", ] diff --git a/opt_einsum/backends/array_api.py b/opt_einsum/backends/array_api.py new file mode 100644 index 00000000..e73c78cc --- /dev/null +++ b/opt_einsum/backends/array_api.py @@ -0,0 +1,62 @@ +""" +Required functions for optimized contractions of arrays using array API-compliant backends. +""" +import sys +from importlib.metadata import entry_points +from typing import Callable +from types import ModuleType + +import numpy as np + +from ..sharing import to_backend_cache_wrap + + +def discover_array_api_eps(): + """Discover array API backends and return their entry points.""" + if sys.version_info >= (3, 10): + return entry_points(group='array_api') + else: + # Deprecated - will raise warning in Python versions >= 3.10 + return entry_points().get('array_api', []) + +def make_to_array(array_api: ModuleType) -> Callable: + """Make a ``to_[array_api]`` function for the given array API.""" + @to_backend_cache_wrap + def to_array(array): # pragma: no cover + if isinstance(array, np.ndarray): + return array_api.asarray(array) + return array + return to_array + +def make_build_expression(array_api: ModuleType) -> Callable: + """Make a ``build_expression`` function for the given array API.""" + def build_expression(_, expr): # pragma: no cover + """Build an array API function based on ``arrays`` and ``expr``.""" + def array_api_contract(*arrays): + return expr._contract([make_to_array(array_api)(x) for x in arrays], backend=array_api.__name__) + return array_api_contract + return build_expression + +def make_evaluate_constants(array_api: ModuleType) -> Callable: + def evaluate_constants(const_arrays, expr): # pragma: no cover + """Convert constant arguments to cupy arrays, and perform any possible constant contractions. + """ + return expr(*[make_to_array(array_api)(x) for x in const_arrays], backend="cupy", evaluate_constants=True) + return evaluate_constants + +to_array_api = {} +build_expression = {} +evaluate_constants = {} + +for ep in discover_array_api_eps(): + _array_api = ep.load() + to_array_api[ep.value] = make_to_array(_array_api) + build_expression[ep.value] = make_build_expression(_array_api) + evaluate_constants[ep.value] = make_evaluate_constants(_array_api) + +__all__ = [ + 'discover_array_api_eps', + "to_array_api", + "build_expression", + "evaluate_constants" +] \ No newline at end of file diff --git a/opt_einsum/backends/dispatch.py b/opt_einsum/backends/dispatch.py index 7c7d7e8c..a181bb35 100644 --- a/opt_einsum/backends/dispatch.py +++ b/opt_einsum/backends/dispatch.py @@ -15,6 +15,7 @@ from . import tensorflow as _tensorflow from . import theano as _theano from . import torch as _torch +from . import array_api as _array_api __all__ = [ "get_func", @@ -122,6 +123,7 @@ def has_tensordot(backend: str) -> bool: "cupy": _cupy.build_expression, "torch": _torch.build_expression, "jax": _jax.build_expression, + **_array_api.build_expression, } EVAL_CONSTS_BACKENDS = { @@ -130,9 +132,9 @@ def has_tensordot(backend: str) -> bool: "cupy": _cupy.evaluate_constants, "torch": _torch.evaluate_constants, "jax": _jax.evaluate_constants, + **_array_api.evaluate_constants, } - def build_expression(backend, arrays, expr): """Build an expression, based on ``expr`` and initial arrays ``arrays``, that evaluates using backend ``backend``. diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index 0df94d03..d5ae9c33 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -410,8 +410,13 @@ def _einsum(*operands, **kwargs): def _default_transpose(x: ArrayType, axes: Tuple[int, ...]) -> ArrayType: - # most libraries implement a method version - return x.transpose(axes) + # Many libraries implement a method version, but the array API-conforming arrys do not (as of 2021.12). + if hasattr(x, "transpose"): + return x.transpose(axes) + elif hasattr(x, '__array_namespace__'): + return x.__array_namespace__().permute_dims(x, axes) + else: + raise NotImplementedError(f"No implementation for transpose or equivalent found for {x}") @sharing.transpose_cache_wrap @@ -543,7 +548,12 @@ def contract(*operands_: Any, **kwargs: Any) -> ArrayType: def infer_backend(x: Any) -> str: - return x.__class__.__module__.split(".")[0] + if hasattr(x, "__array_namespace__"): + # Having an ``__array_namespace__`` is a 'guarantee' from the developers of the given array's module that + # it conforms to the Python array API. Use this as a backend, if available. + return x.__array_namespace__().__name__ + else: + return x.__class__.__module__.split(".")[0] def parse_backend(arrays: Sequence[ArrayType], backend: str) -> str: From 13a6054b3670983aef4fa8fad1cc316ef8ae2508 Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Sun, 17 Jul 2022 14:47:05 +0800 Subject: [PATCH 02/12] Add tests for array API backends --- opt_einsum/tests/test_backends.py | 58 +++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index 87d2ba22..d0f6d6e7 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -1,4 +1,8 @@ +from types import ModuleType +import sys +import importlib import numpy as np +from pkg_resources import EntryPoint import pytest from opt_einsum import backends, contract, contract_expression, helpers, sharing @@ -56,6 +60,12 @@ except ImportError: found_autograd = False + +@pytest.fixture(params=backends.discover_array_api_eps()) +def array_api_ep(request) -> importlib.metadata.EntryPoint: + return request.param + + tests = [ "ab,bc->ca", "abc,bcd,dea", @@ -465,3 +475,51 @@ def test_object_arrays_backend(string): obj_opt = expr(*obj_views, backend="object") assert obj_opt.dtype == object assert np.allclose(ein, obj_opt.astype(float)) + +@pytest.mark.parametrize("string", tests) +def test_array_api(array_api_ep: EntryPoint, string): # pragma: no cover + array_api: ModuleType = array_api_ep.load() + array_api_qname: str = array_api_ep.value + + views = helpers.build_views(string) + ein = contract(string, *views, optimize=False, use_blas=False) + shps = [v.shape for v in views] + + expr = contract_expression(string, *shps, optimize=True) + + opt = expr(*views, backend=array_api.__name__) + assert np.allclose(ein, opt) + + # test non-conversion mode + array_api_views = [backends.to_array_api[array_api_qname](view) for view in views] + array_api_opt = expr(*array_api_views) + assert all(type(array_api_opt) is type(view) for view in array_api_views) + assert np.allclose(ein, np.asarray(array_api_opt)) + + +# @pytest.mark.parametrize("array_api", array_apis) +# @pytest.mark.parametrize("string", tests) +# def test_array_api_with_constants(constants: importlib.metadata.EntryPoint): # pragma: no cover +# eq = "ij,jk,kl->li" +# shapes = (2, 3), (3, 4), (4, 5) +# (non_const,) = {0, 1, 2} - constants +# ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] +# var = np.random.rand(*shapes[non_const]) +# res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) + +# expr = contract_expression(eq, *ops, constants=constants) + +# # check cupy +# res_got = expr(var, backend="cupy") +# # check cupy versions of constants exist +# assert all(array is None or infer_backend(array) == "cupy" for array in expr._evaluated_constants["cupy"]) +# assert np.allclose(res_exp, res_got) + +# # check can call with numpy still +# res_got2 = expr(var, backend="numpy") +# assert np.allclose(res_exp, res_got2) + +# # check cupy call returns cupy still +# res_got3 = expr(cupy.asarray(var)) +# assert isinstance(res_got3, cupy.ndarray) +# assert np.allclose(res_exp, res_got3.get()) From 0fb031247fb75833c57dfa9a8fb1558d619106d5 Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Sun, 17 Jul 2022 14:47:27 +0800 Subject: [PATCH 03/12] Add (failing) sharing tests for array API backends --- opt_einsum/tests/test_sharing.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/opt_einsum/tests/test_sharing.py b/opt_einsum/tests/test_sharing.py index 0df72765..44fa7d82 100644 --- a/opt_einsum/tests/test_sharing.py +++ b/opt_einsum/tests/test_sharing.py @@ -6,7 +6,7 @@ import pytest from opt_einsum import contract, contract_expression, contract_path, get_symbol, helpers, shared_intermediates -from opt_einsum.backends import to_cupy, to_torch +from opt_einsum.backends import to_cupy, to_torch, to_array_api, discover_array_api_eps from opt_einsum.contract import _einsum from opt_einsum.parser import parse_einsum_input from opt_einsum.sharing import count_cached_ops, currently_sharing, get_sharing_cache @@ -25,7 +25,9 @@ except ImportError: torch_if_found = pytest.param("torch", marks=[pytest.mark.skip(reason="PyTorch not installed.")]) # type: ignore -backends = ["numpy", torch_if_found, cupy_if_found] +array_api_qnames = [ep.value for ep in discover_array_api_eps()] + +backends = ["numpy", torch_if_found, cupy_if_found, *array_api_qnames] equations = [ "ab,bc->ca", "abc,bcd,dea", @@ -41,6 +43,7 @@ "numpy": lambda x: x, "torch": to_torch, "cupy": to_cupy, + **to_array_api, } @@ -55,7 +58,7 @@ def test_sharing_value(eq, backend): with shared_intermediates(): actual = expr(*views, backend=backend) - assert (actual == expected).all() + assert np.all(actual == expected) @pytest.mark.parametrize("backend", backends) From 381f15a78b49c93959f748d3c7cc2e352c195b61 Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Sun, 17 Jul 2022 14:51:59 +0800 Subject: [PATCH 04/12] Revert failing sharing tests for array API backends --- opt_einsum/tests/test_sharing.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/opt_einsum/tests/test_sharing.py b/opt_einsum/tests/test_sharing.py index 44fa7d82..0df72765 100644 --- a/opt_einsum/tests/test_sharing.py +++ b/opt_einsum/tests/test_sharing.py @@ -6,7 +6,7 @@ import pytest from opt_einsum import contract, contract_expression, contract_path, get_symbol, helpers, shared_intermediates -from opt_einsum.backends import to_cupy, to_torch, to_array_api, discover_array_api_eps +from opt_einsum.backends import to_cupy, to_torch from opt_einsum.contract import _einsum from opt_einsum.parser import parse_einsum_input from opt_einsum.sharing import count_cached_ops, currently_sharing, get_sharing_cache @@ -25,9 +25,7 @@ except ImportError: torch_if_found = pytest.param("torch", marks=[pytest.mark.skip(reason="PyTorch not installed.")]) # type: ignore -array_api_qnames = [ep.value for ep in discover_array_api_eps()] - -backends = ["numpy", torch_if_found, cupy_if_found, *array_api_qnames] +backends = ["numpy", torch_if_found, cupy_if_found] equations = [ "ab,bc->ca", "abc,bcd,dea", @@ -43,7 +41,6 @@ "numpy": lambda x: x, "torch": to_torch, "cupy": to_cupy, - **to_array_api, } @@ -58,7 +55,7 @@ def test_sharing_value(eq, backend): with shared_intermediates(): actual = expr(*views, backend=backend) - assert np.all(actual == expected) + assert (actual == expected).all() @pytest.mark.parametrize("backend", backends) From 50c3693bb7323b35dde9b3b6d752a44cf0d2c8ed Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Sun, 17 Jul 2022 15:40:46 +0800 Subject: [PATCH 05/12] Fix and renaming some functions for the array API backend --- opt_einsum/backends/__init__.py | 4 ++-- opt_einsum/backends/array_api.py | 33 +++++++++++++++---------------- opt_einsum/tests/test_backends.py | 12 +++++------ 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/opt_einsum/backends/__init__.py b/opt_einsum/backends/__init__.py index 48e367a8..bf97617c 100644 --- a/opt_einsum/backends/__init__.py +++ b/opt_einsum/backends/__init__.py @@ -8,7 +8,7 @@ from .tensorflow import to_tensorflow from .theano import to_theano from .torch import to_torch -from .array_api import to_array_api, discover_array_api_eps +from .array_api import to_array_api, discover_array_apis __all__ = [ "get_func", @@ -22,5 +22,5 @@ "to_cupy", "to_torch", "to_array_api", - "discover_array_api_eps", + "discover_array_apis", ] diff --git a/opt_einsum/backends/array_api.py b/opt_einsum/backends/array_api.py index e73c78cc..28484286 100644 --- a/opt_einsum/backends/array_api.py +++ b/opt_einsum/backends/array_api.py @@ -11,13 +11,14 @@ from ..sharing import to_backend_cache_wrap -def discover_array_api_eps(): - """Discover array API backends and return their entry points.""" +def discover_array_apis(): + """Discover array API backends.""" if sys.version_info >= (3, 10): - return entry_points(group='array_api') + eps = entry_points(group='array_api') else: # Deprecated - will raise warning in Python versions >= 3.10 - return entry_points().get('array_api', []) + eps = entry_points().get('array_api', []) + return [ep.load() for ep in eps] def make_to_array(array_api: ModuleType) -> Callable: """Make a ``to_[array_api]`` function for the given array API.""" @@ -39,23 +40,21 @@ def array_api_contract(*arrays): def make_evaluate_constants(array_api: ModuleType) -> Callable: def evaluate_constants(const_arrays, expr): # pragma: no cover - """Convert constant arguments to cupy arrays, and perform any possible constant contractions. - """ - return expr(*[make_to_array(array_api)(x) for x in const_arrays], backend="cupy", evaluate_constants=True) + """Convert constant arguments to cupy arrays, and perform any possible constant contractions.""" + return expr( + *[make_to_array(array_api)(x) for x in const_arrays], + backend=array_api.__name__, + evaluate_constants=True + ) return evaluate_constants -to_array_api = {} -build_expression = {} -evaluate_constants = {} - -for ep in discover_array_api_eps(): - _array_api = ep.load() - to_array_api[ep.value] = make_to_array(_array_api) - build_expression[ep.value] = make_build_expression(_array_api) - evaluate_constants[ep.value] = make_evaluate_constants(_array_api) +_array_apis = discover_array_apis() +to_array_api = {api.__name__: make_to_array(api) for api in _array_apis} +build_expression = {api.__name__: make_build_expression(api) for api in _array_apis} +evaluate_constants = {api.__name__: make_evaluate_constants(api) for api in _array_apis} __all__ = [ - 'discover_array_api_eps', + 'discover_array_apis', "to_array_api", "build_expression", "evaluate_constants" diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index d0f6d6e7..ee1aecff 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -1,6 +1,4 @@ from types import ModuleType -import sys -import importlib import numpy as np from pkg_resources import EntryPoint import pytest @@ -61,8 +59,8 @@ found_autograd = False -@pytest.fixture(params=backends.discover_array_api_eps()) -def array_api_ep(request) -> importlib.metadata.EntryPoint: +@pytest.fixture(params=backends.discover_array_apis()) +def array_api(request) -> ModuleType: return request.param @@ -476,10 +474,10 @@ def test_object_arrays_backend(string): assert obj_opt.dtype == object assert np.allclose(ein, obj_opt.astype(float)) + @pytest.mark.parametrize("string", tests) -def test_array_api(array_api_ep: EntryPoint, string): # pragma: no cover - array_api: ModuleType = array_api_ep.load() - array_api_qname: str = array_api_ep.value +def test_array_api(array_api: ModuleType, string): # pragma: no cover + array_api_qname: str = array_api.__name__ views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) From 576c129d13d1a0a10e9dc6da393839a3f88b95fa Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Sun, 17 Jul 2022 15:41:11 +0800 Subject: [PATCH 06/12] Add constants test for array API backends --- opt_einsum/tests/test_backends.py | 62 ++++++++++++++++++------------- 1 file changed, 36 insertions(+), 26 deletions(-) diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index ee1aecff..180b999e 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -495,29 +495,39 @@ def test_array_api(array_api: ModuleType, string): # pragma: no cover assert np.allclose(ein, np.asarray(array_api_opt)) -# @pytest.mark.parametrize("array_api", array_apis) -# @pytest.mark.parametrize("string", tests) -# def test_array_api_with_constants(constants: importlib.metadata.EntryPoint): # pragma: no cover -# eq = "ij,jk,kl->li" -# shapes = (2, 3), (3, 4), (4, 5) -# (non_const,) = {0, 1, 2} - constants -# ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] -# var = np.random.rand(*shapes[non_const]) -# res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) - -# expr = contract_expression(eq, *ops, constants=constants) - -# # check cupy -# res_got = expr(var, backend="cupy") -# # check cupy versions of constants exist -# assert all(array is None or infer_backend(array) == "cupy" for array in expr._evaluated_constants["cupy"]) -# assert np.allclose(res_exp, res_got) - -# # check can call with numpy still -# res_got2 = expr(var, backend="numpy") -# assert np.allclose(res_exp, res_got2) - -# # check cupy call returns cupy still -# res_got3 = expr(cupy.asarray(var)) -# assert isinstance(res_got3, cupy.ndarray) -# assert np.allclose(res_exp, res_got3.get()) +@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}]) +def test_array_api_with_constants(array_api: ModuleType, constants): # pragma: no cover + array_api_qname: str = array_api.__name__ + + eq = "ij,jk,kl->li" + shapes = (2, 3), (3, 4), (4, 5) + (non_const,) = {0, 1, 2} - constants + ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)] + var = np.random.rand(*shapes[non_const]) + res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3))) + + expr = contract_expression(eq, *ops, constants=constants) + + # check array API + res_got = expr(var, backend=array_api_qname) + # check array API versions of constants exist + assert all(array is None or infer_backend(array) == array_api_qname for array in expr._evaluated_constants[array_api_qname]) + assert np.allclose(res_exp, res_got) + + # check can call with numpy still + res_got2 = expr(var, backend="numpy") + assert np.allclose(res_exp, res_got2) + + # check array API call returns an array object belonging to the same backend + # NOTE: the array API standard does not require that the returned array is the same type as the input array, + # only that the returned array also obeys the array API standard. Indeed, the standard does not stipulate + # even a *name* for the array type. + # + # For this reason, we won't check the type of the returned array, but only that it is has an + # ``__array_namespace__`` attribute and hence claims to comply with standard. + # + # In future versions, if einexpr uses newer array API features, we will also need to check that the + # returned array complies with the appropriate version of the standard. + res_got3 = expr(array_api.asarray(var)) + assert hasattr(res_got3, "__array_namespace__") + assert np.allclose(res_exp, res_got3) \ No newline at end of file From db785c9ce470ad217c1d5d781e33c5911f14cbe7 Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Sun, 17 Jul 2022 16:16:59 +0800 Subject: [PATCH 07/12] Update docs for array API standard (with some minor grammatical/clarity changes) --- docs/getting_started/backends.md | 32 ++++++++++---------------------- docs/index.md | 3 +-- 2 files changed, 11 insertions(+), 24 deletions(-) diff --git a/docs/getting_started/backends.md b/docs/getting_started/backends.md index 224e8162..69ab445f 100644 --- a/docs/getting_started/backends.md +++ b/docs/getting_started/backends.md @@ -2,12 +2,10 @@ `opt_einsum` is quite agnostic to the type of n-dimensional arrays (tensors) it uses, since finding the contraction path only relies on getting the shape -attribute of each array supplied. -It can perform the underlying tensor contractions with various -libraries. In fact, any library that provides a `numpy.tensordot` and -`numpy.transpose` implementation can perform most normal contractions. -While more special functionality such as axes reduction is reliant on a -`numpy.einsum` implementation. +attribute of each array supplied. It can perform the underlying tensor contractions with many libraries, including any that support the Python array API standard. +Furthermore, any library that provides a `numpy.tensordot` and +`numpy.transpose` implementation can perform most normal contractions, while more special functionality such as axes reduction is reliant on an `numpy.einsum` implementation. + The following is a brief overview of libraries which have been tested with `opt_einsum`: @@ -26,15 +24,6 @@ The following is a brief overview of libraries which have been tested with - [jax](https://github.com/google/jax): compiled GPU tensor expressions including `autograd`-like functionality -`opt_einsum` is agnostic to the type of n-dimensional arrays (tensors) -it uses, since finding the contraction path only relies on getting the shape -attribute of each array supplied. -It can perform the underlying tensor contractions with various -libraries. In fact, any library that provides a `numpy.tensordot` and -`~numpy.transpose` implementation can perform most normal contractions. -While more special functionality such as axes reduction is reliant on a -`numpy.einsum` implementation. - !!! note For a contraction to be possible without using a backend einsum, it must satisfy the following rule: in the full expression (*including* output @@ -44,9 +33,8 @@ While more special functionality such as axes reduction is reliant on a ## Backend agnostic contractions -The automatic backend detection will be detected based on the first supplied -array (default), this can be overridden by specifying the correct `backend` -argument for the type of arrays supplied when calling +By default, backend will be automatically detected based on the first supplied +array. This can be overridden by specifying the desired `backend` as a keyword argument when calling [`opt_einsum.contract`](../api_reference.md##opt_einsumcontract). For example, if you had a library installed called `'foo'` which provided an `numpy.ndarray` like object with a `.shape` attribute as well as `foo.tensordot` and `foo.transpose` then @@ -56,9 +44,9 @@ you could contract them with something like: contract(einsum_str, *foo_arrays, backend='foo') ``` -Behind the scenes `opt_einsum` will find the contraction path, perform -pairwise contractions using e.g. `foo.tensordot` and finally return the canonical -type those functions return. +Behind the scenes `opt_einsum` will find the contraction path and perform +pairwise contractions (using e.g. `foo.tensordot`). The return type is backend-dependent; for example, it's up to backends to decide whether `foo.tensordot` performs type promotion. + ### Dask @@ -197,7 +185,7 @@ Currently `opt_einsum` can handle this automatically for: - [jax](https://github.com/google/jax) all of which offer GPU support. Since `tensorflow` and `theano` both require -compiling the expression, this functionality is encapsulated in generating a +compiling the expression. This functionality is encapsulated in generating a [`opt_einsum.ContractExpression`](../api_reference.md#opt_einsumcontractcontractexpression) using [`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression), which can then be called using numpy arrays whilst specifying `backend='tensorflow'` etc. diff --git a/docs/index.md b/docs/index.md index 358b888d..c0781f19 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,8 +4,7 @@ Optimized einsum can significantly reduce the overall execution time of einsum-l expressions by optimizing the expression's contraction order and dispatching many operations to canonical BLAS, cuBLAS, or other specialized routines. Optimized einsum is agnostic to the backend and can handle NumPy, Dask, -PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as -potentially any library which conforms to a standard API. +PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as any library which conforms to the Python array API standard. Other libraries may work if some key methods are available as attributes of their array objects. ## Features From 5de3b6e6cd347440630f247d70adf99036e5f6d9 Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Sun, 17 Jul 2022 16:20:37 +0800 Subject: [PATCH 08/12] doc fix --- docs/index.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/index.md b/docs/index.md index c0781f19..fb08c08c 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,7 +4,7 @@ Optimized einsum can significantly reduce the overall execution time of einsum-l expressions by optimizing the expression's contraction order and dispatching many operations to canonical BLAS, cuBLAS, or other specialized routines. Optimized einsum is agnostic to the backend and can handle NumPy, Dask, -PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as any library which conforms to the Python array API standard. Other libraries may work if some key methods are available as attributes of their array objects. +PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as any library which conforms to the Python array API standard. Other libraries may work so long as some key methods are available as attributes of their array objects. ## Features From 75f9d9092c0a47a693331c46daf546a2e7082226 Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Mon, 18 Jul 2022 17:37:32 +0800 Subject: [PATCH 09/12] Only attempt to discover array API backends in Python versions >=3.8 (since earlier versions don't have importlib.metadata) --- opt_einsum/backends/array_api.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/opt_einsum/backends/array_api.py b/opt_einsum/backends/array_api.py index 28484286..7df635e8 100644 --- a/opt_einsum/backends/array_api.py +++ b/opt_einsum/backends/array_api.py @@ -2,7 +2,6 @@ Required functions for optimized contractions of arrays using array API-compliant backends. """ import sys -from importlib.metadata import entry_points from typing import Callable from types import ModuleType @@ -13,12 +12,19 @@ def discover_array_apis(): """Discover array API backends.""" - if sys.version_info >= (3, 10): - eps = entry_points(group='array_api') + if sys.version_info >= (3, 8): + from importlib.metadata import entry_points + + if sys.version_info >= (3, 10): + eps = entry_points(group="array_api") + else: + # Deprecated - will raise warning in Python versions >= 3.10 + eps = entry_points().get("array_api", []) + return [ep.load() for ep in eps] else: - # Deprecated - will raise warning in Python versions >= 3.10 - eps = entry_points().get('array_api', []) - return [ep.load() for ep in eps] + # importlib.metadata was introduced in Python 3.8, so it isn't available here. Unable to discover any array APIs. + return [] + def make_to_array(array_api: ModuleType) -> Callable: """Make a ``to_[array_api]`` function for the given array API.""" From 200402e549a9d8a952eb5133cf575ce216318d5a Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Mon, 18 Jul 2022 17:38:10 +0800 Subject: [PATCH 10/12] Fix formatting and rename some functions for clarity --- opt_einsum/backends/array_api.py | 39 ++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/opt_einsum/backends/array_api.py b/opt_einsum/backends/array_api.py index 7df635e8..f5889f22 100644 --- a/opt_einsum/backends/array_api.py +++ b/opt_einsum/backends/array_api.py @@ -26,42 +26,47 @@ def discover_array_apis(): return [] -def make_to_array(array_api: ModuleType) -> Callable: +def make_to_array_function(array_api: ModuleType) -> Callable: """Make a ``to_[array_api]`` function for the given array API.""" + @to_backend_cache_wrap def to_array(array): # pragma: no cover if isinstance(array, np.ndarray): return array_api.asarray(array) return array + return to_array -def make_build_expression(array_api: ModuleType) -> Callable: + +def make_build_expression_function(array_api: ModuleType) -> Callable: """Make a ``build_expression`` function for the given array API.""" + def build_expression(_, expr): # pragma: no cover """Build an array API function based on ``arrays`` and ``expr``.""" + def array_api_contract(*arrays): - return expr._contract([make_to_array(array_api)(x) for x in arrays], backend=array_api.__name__) + return expr._contract([make_to_array_function(array_api)(x) for x in arrays], backend=array_api.__name__) + return array_api_contract + return build_expression -def make_evaluate_constants(array_api: ModuleType) -> Callable: + +def make_evaluate_constants_function(array_api: ModuleType) -> Callable: def evaluate_constants(const_arrays, expr): # pragma: no cover """Convert constant arguments to cupy arrays, and perform any possible constant contractions.""" return expr( - *[make_to_array(array_api)(x) for x in const_arrays], - backend=array_api.__name__, - evaluate_constants=True + *[make_to_array_function(array_api)(x) for x in const_arrays], + backend=array_api.__name__, + evaluate_constants=True, ) + return evaluate_constants + _array_apis = discover_array_apis() -to_array_api = {api.__name__: make_to_array(api) for api in _array_apis} -build_expression = {api.__name__: make_build_expression(api) for api in _array_apis} -evaluate_constants = {api.__name__: make_evaluate_constants(api) for api in _array_apis} - -__all__ = [ - 'discover_array_apis', - "to_array_api", - "build_expression", - "evaluate_constants" -] \ No newline at end of file +to_array_api = {api.__name__: make_to_array_function(api) for api in _array_apis} +build_expression = {api.__name__: make_build_expression_function(api) for api in _array_apis} +evaluate_constants = {api.__name__: make_evaluate_constants_function(api) for api in _array_apis} + +__all__ = ["discover_array_apis", "to_array_api", "build_expression", "evaluate_constants"] From 714b52053c5d5b65cf88c39d269989eaf31d9f12 Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Mon, 18 Jul 2022 17:38:32 +0800 Subject: [PATCH 11/12] Fix formatting --- opt_einsum/backends/dispatch.py | 1 + opt_einsum/contract.py | 2 +- opt_einsum/tests/test_backends.py | 14 ++++++++------ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/opt_einsum/backends/dispatch.py b/opt_einsum/backends/dispatch.py index a181bb35..86d13b1f 100644 --- a/opt_einsum/backends/dispatch.py +++ b/opt_einsum/backends/dispatch.py @@ -135,6 +135,7 @@ def has_tensordot(backend: str) -> bool: **_array_api.evaluate_constants, } + def build_expression(backend, arrays, expr): """Build an expression, based on ``expr`` and initial arrays ``arrays``, that evaluates using backend ``backend``. diff --git a/opt_einsum/contract.py b/opt_einsum/contract.py index 13252fc5..63c88a55 100644 --- a/opt_einsum/contract.py +++ b/opt_einsum/contract.py @@ -414,7 +414,7 @@ def _default_transpose(x: ArrayType, axes: Tuple[int, ...]) -> ArrayType: # Many libraries implement a method version, but the array API-conforming arrys do not (as of 2021.12). if hasattr(x, "transpose"): return x.transpose(axes) - elif hasattr(x, '__array_namespace__'): + elif hasattr(x, "__array_namespace__"): return x.__array_namespace__().permute_dims(x, axes) else: raise NotImplementedError(f"No implementation for transpose or equivalent found for {x}") diff --git a/opt_einsum/tests/test_backends.py b/opt_einsum/tests/test_backends.py index 9af0b743..1ceb852f 100644 --- a/opt_einsum/tests/test_backends.py +++ b/opt_einsum/tests/test_backends.py @@ -479,7 +479,7 @@ def test_object_arrays_backend(string): @pytest.mark.parametrize("string", tests) def test_array_api(array_api: ModuleType, string): # pragma: no cover array_api_qname: str = array_api.__name__ - + views = helpers.build_views(string) ein = contract(string, *views, optimize=False, use_blas=False) shps = [v.shape for v in views] @@ -512,7 +512,9 @@ def test_array_api_with_constants(array_api: ModuleType, constants): # pragma: # check array API res_got = expr(var, backend=array_api_qname) # check array API versions of constants exist - assert all(array is None or infer_backend(array) == array_api_qname for array in expr._evaluated_constants[array_api_qname]) + assert all( + array is None or infer_backend(array) == array_api_qname for array in expr._evaluated_constants[array_api_qname] + ) assert np.allclose(res_exp, res_got) # check can call with numpy still @@ -523,12 +525,12 @@ def test_array_api_with_constants(array_api: ModuleType, constants): # pragma: # NOTE: the array API standard does not require that the returned array is the same type as the input array, # only that the returned array also obeys the array API standard. Indeed, the standard does not stipulate # even a *name* for the array type. - # - # For this reason, we won't check the type of the returned array, but only that it is has an + # + # For this reason, we won't check the type of the returned array, but only that it is has an # ``__array_namespace__`` attribute and hence claims to comply with standard. - # + # # In future versions, if einexpr uses newer array API features, we will also need to check that the # returned array complies with the appropriate version of the standard. res_got3 = expr(array_api.asarray(var)) assert hasattr(res_got3, "__array_namespace__") - assert np.allclose(res_exp, res_got3) \ No newline at end of file + assert np.allclose(res_exp, res_got3) From 13ec28045510688279672259094a6c54af171eab Mon Sep 17 00:00:00 2001 From: IsaacBreen <57783927+IsaacBreen@users.noreply.github.com> Date: Mon, 18 Jul 2022 19:39:15 +0800 Subject: [PATCH 12/12] Set _to_array_api in outer scope --- opt_einsum/backends/array_api.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/opt_einsum/backends/array_api.py b/opt_einsum/backends/array_api.py index f5889f22..9de2a24f 100644 --- a/opt_einsum/backends/array_api.py +++ b/opt_einsum/backends/array_api.py @@ -40,12 +40,13 @@ def to_array(array): # pragma: no cover def make_build_expression_function(array_api: ModuleType) -> Callable: """Make a ``build_expression`` function for the given array API.""" + _to_array_api = to_array_api[array_api.__name__] def build_expression(_, expr): # pragma: no cover """Build an array API function based on ``arrays`` and ``expr``.""" def array_api_contract(*arrays): - return expr._contract([make_to_array_function(array_api)(x) for x in arrays], backend=array_api.__name__) + return expr._contract([_to_array_api(x) for x in arrays], backend=array_api.__name__) return array_api_contract @@ -53,10 +54,12 @@ def array_api_contract(*arrays): def make_evaluate_constants_function(array_api: ModuleType) -> Callable: + _to_array_api = to_array_api[array_api.__name__] + def evaluate_constants(const_arrays, expr): # pragma: no cover """Convert constant arguments to cupy arrays, and perform any possible constant contractions.""" return expr( - *[make_to_array_function(array_api)(x) for x in const_arrays], + *[_to_array_api(x) for x in const_arrays], backend=array_api.__name__, evaluate_constants=True, )