Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for Python array API standard #197

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 10 additions & 22 deletions docs/getting_started/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

`opt_einsum` is quite agnostic to the type of n-dimensional arrays (tensors)
it uses, since finding the contraction path only relies on getting the shape
attribute of each array supplied.
It can perform the underlying tensor contractions with various
libraries. In fact, any library that provides a `numpy.tensordot` and
`numpy.transpose` implementation can perform most normal contractions.
While more special functionality such as axes reduction is reliant on a
`numpy.einsum` implementation.
attribute of each array supplied. It can perform the underlying tensor contractions with many libraries, including any that support the Python array API standard.
Furthermore, any library that provides a `numpy.tensordot` and
`numpy.transpose` implementation can perform most normal contractions, while more special functionality such as axes reduction is reliant on an `numpy.einsum` implementation.

The following is a brief overview of libraries which have been tested with
`opt_einsum`:

Expand All @@ -26,15 +24,6 @@ The following is a brief overview of libraries which have been tested with
- [jax](https://github.com/google/jax): compiled GPU tensor expressions
including `autograd`-like functionality

`opt_einsum` is agnostic to the type of n-dimensional arrays (tensors)
it uses, since finding the contraction path only relies on getting the shape
attribute of each array supplied.
It can perform the underlying tensor contractions with various
libraries. In fact, any library that provides a `numpy.tensordot` and
`~numpy.transpose` implementation can perform most normal contractions.
While more special functionality such as axes reduction is reliant on a
`numpy.einsum` implementation.

!!! note
For a contraction to be possible without using a backend einsum, it must
satisfy the following rule: in the full expression (*including* output
Expand All @@ -44,9 +33,8 @@ While more special functionality such as axes reduction is reliant on a

## Backend agnostic contractions

The automatic backend detection will be detected based on the first supplied
array (default), this can be overridden by specifying the correct `backend`
argument for the type of arrays supplied when calling
By default, backend will be automatically detected based on the first supplied
array. This can be overridden by specifying the desired `backend` as a keyword argument when calling
[`opt_einsum.contract`](../api_reference.md##opt_einsumcontract). For example, if you had a library installed
called `'foo'` which provided an `numpy.ndarray` like object with a
`.shape` attribute as well as `foo.tensordot` and `foo.transpose` then
Expand All @@ -56,9 +44,9 @@ you could contract them with something like:
contract(einsum_str, *foo_arrays, backend='foo')
```

Behind the scenes `opt_einsum` will find the contraction path, perform
pairwise contractions using e.g. `foo.tensordot` and finally return the canonical
type those functions return.
Behind the scenes `opt_einsum` will find the contraction path and perform
pairwise contractions (using e.g. `foo.tensordot`). The return type is backend-dependent; for example, it's up to backends to decide whether `foo.tensordot` performs type promotion.


### Dask

Expand Down Expand Up @@ -197,7 +185,7 @@ Currently `opt_einsum` can handle this automatically for:
- [jax](https://github.com/google/jax)

all of which offer GPU support. Since `tensorflow` and `theano` both require
compiling the expression, this functionality is encapsulated in generating a
compiling the expression. This functionality is encapsulated in generating a
[`opt_einsum.ContractExpression`](../api_reference.md#opt_einsumcontractcontractexpression) using
[`opt_einsum.contract_expression`](../api_reference.md#opt_einsumcontract_expression), which can then be called using numpy
arrays whilst specifying `backend='tensorflow'` etc.
Expand Down
3 changes: 1 addition & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ Optimized einsum can significantly reduce the overall execution time of einsum-l
expressions by optimizing the expression's contraction order and dispatching
many operations to canonical BLAS, cuBLAS, or other specialized routines.
Optimized einsum is agnostic to the backend and can handle NumPy, Dask,
PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as
potentially any library which conforms to a standard API.
PyTorch, Tensorflow, CuPy, Sparse, Theano, JAX, and Autograd arrays as well as any library which conforms to the Python array API standard. Other libraries may work so long as some key methods are available as attributes of their array objects.

## Features

Expand Down
3 changes: 3 additions & 0 deletions opt_einsum/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from .tensorflow import to_tensorflow
from .theano import to_theano
from .torch import to_torch
from .array_api import to_array_api, discover_array_apis

__all__ = [
"get_func",
Expand All @@ -20,4 +21,6 @@
"to_theano",
"to_cupy",
"to_torch",
"to_array_api",
"discover_array_apis",
]
75 changes: 75 additions & 0 deletions opt_einsum/backends/array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Required functions for optimized contractions of arrays using array API-compliant backends.
"""
import sys
from typing import Callable
from types import ModuleType

import numpy as np

from ..sharing import to_backend_cache_wrap


def discover_array_apis():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible to type this function?

"""Discover array API backends."""
if sys.version_info >= (3, 8):
from importlib.metadata import entry_points

if sys.version_info >= (3, 10):
eps = entry_points(group="array_api")
else:
# Deprecated - will raise warning in Python versions >= 3.10
eps = entry_points().get("array_api", [])
return [ep.load() for ep in eps]
else:
# importlib.metadata was introduced in Python 3.8, so it isn't available here. Unable to discover any array APIs.
return []
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Numpy does not officially support below 3.8: https://numpy.org/neps/nep-0029-deprecation_policy.html

It would be worth considering dropping Python 3.7 as well. @jcmgray what do you think?



def make_to_array_function(array_api: ModuleType) -> Callable:
"""Make a ``to_[array_api]`` function for the given array API."""

@to_backend_cache_wrap
def to_array(array): # pragma: no cover
if isinstance(array, np.ndarray):
return array_api.asarray(array)
return array

return to_array


def make_build_expression_function(array_api: ModuleType) -> Callable:
"""Make a ``build_expression`` function for the given array API."""
_to_array_api = to_array_api[array_api.__name__]

def build_expression(_, expr): # pragma: no cover
"""Build an array API function based on ``arrays`` and ``expr``."""

def array_api_contract(*arrays):
return expr._contract([_to_array_api(x) for x in arrays], backend=array_api.__name__)
Copy link
Author

@IsaacBreen IsaacBreen Jul 18, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcmgray How about this?


return array_api_contract

return build_expression


def make_evaluate_constants_function(array_api: ModuleType) -> Callable:
_to_array_api = to_array_api[array_api.__name__]

def evaluate_constants(const_arrays, expr): # pragma: no cover
"""Convert constant arguments to cupy arrays, and perform any possible constant contractions."""
return expr(
*[_to_array_api(x) for x in const_arrays],
backend=array_api.__name__,
evaluate_constants=True,
)

return evaluate_constants


_array_apis = discover_array_apis()
to_array_api = {api.__name__: make_to_array_function(api) for api in _array_apis}
build_expression = {api.__name__: make_build_expression_function(api) for api in _array_apis}
evaluate_constants = {api.__name__: make_evaluate_constants_function(api) for api in _array_apis}

__all__ = ["discover_array_apis", "to_array_api", "build_expression", "evaluate_constants"]
3 changes: 3 additions & 0 deletions opt_einsum/backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from . import tensorflow as _tensorflow
from . import theano as _theano
from . import torch as _torch
from . import array_api as _array_api

__all__ = [
"get_func",
Expand Down Expand Up @@ -122,6 +123,7 @@ def has_tensordot(backend: str) -> bool:
"cupy": _cupy.build_expression,
"torch": _torch.build_expression,
"jax": _jax.build_expression,
**_array_api.build_expression,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any way to check if Jax adds an array interface so that we can catch it and test before hand?

}

EVAL_CONSTS_BACKENDS = {
Expand All @@ -130,6 +132,7 @@ def has_tensordot(backend: str) -> bool:
"cupy": _cupy.evaluate_constants,
"torch": _torch.evaluate_constants,
"jax": _jax.evaluate_constants,
**_array_api.evaluate_constants,
}


Expand Down
16 changes: 13 additions & 3 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,8 +411,13 @@ def _einsum(*operands, **kwargs):


def _default_transpose(x: ArrayType, axes: Tuple[int, ...]) -> ArrayType:
# most libraries implement a method version
return x.transpose(axes)
# Many libraries implement a method version, but the array API-conforming arrys do not (as of 2021.12).
if hasattr(x, "transpose"):
return x.transpose(axes)
elif hasattr(x, "__array_namespace__"):
return x.__array_namespace__().permute_dims(x, axes)
else:
raise NotImplementedError(f"No implementation for transpose or equivalent found for {x}")


@sharing.transpose_cache_wrap
Expand Down Expand Up @@ -549,7 +554,12 @@ def _infer_backend_class_cached(cls: type) -> str:


def infer_backend(x: Any) -> str:
return _infer_backend_class_cached(x.__class__)
if hasattr(x, "__array_namespace__"):
# Having an ``__array_namespace__`` is a 'guarantee' from the developers of the given array's module that
# it conforms to the Python array API. Use this as a backend, if available.
return x.__array_namespace__().__name__
else:
return _infer_backend_class_cached(x.__class__)


def parse_backend(arrays: Sequence[ArrayType], backend: Optional[str]) -> str:
Expand Down
68 changes: 68 additions & 0 deletions opt_einsum/tests/test_backends.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from types import ModuleType
import numpy as np
from pkg_resources import EntryPoint
import pytest

from opt_einsum import backends, contract, contract_expression, helpers, sharing
Expand Down Expand Up @@ -56,6 +58,12 @@
except ImportError:
found_autograd = False


@pytest.fixture(params=backends.discover_array_apis())
def array_api(request) -> ModuleType:
return request.param


tests = [
"ab,bc->ca",
"abc,bcd,dea",
Expand Down Expand Up @@ -466,3 +474,63 @@ def test_object_arrays_backend(string):
obj_opt = expr(*obj_views, backend="object")
assert obj_opt.dtype == object
assert np.allclose(ein, obj_opt.astype(float))


@pytest.mark.parametrize("string", tests)
def test_array_api(array_api: ModuleType, string): # pragma: no cover
array_api_qname: str = array_api.__name__

views = helpers.build_views(string)
ein = contract(string, *views, optimize=False, use_blas=False)
shps = [v.shape for v in views]

expr = contract_expression(string, *shps, optimize=True)

opt = expr(*views, backend=array_api.__name__)
assert np.allclose(ein, opt)

# test non-conversion mode
array_api_views = [backends.to_array_api[array_api_qname](view) for view in views]
array_api_opt = expr(*array_api_views)
assert all(type(array_api_opt) is type(view) for view in array_api_views)
assert np.allclose(ein, np.asarray(array_api_opt))


@pytest.mark.parametrize("constants", [{0, 1}, {0, 2}, {1, 2}])
def test_array_api_with_constants(array_api: ModuleType, constants): # pragma: no cover
array_api_qname: str = array_api.__name__

eq = "ij,jk,kl->li"
shapes = (2, 3), (3, 4), (4, 5)
(non_const,) = {0, 1, 2} - constants
ops = [np.random.rand(*shp) if i in constants else shp for i, shp in enumerate(shapes)]
var = np.random.rand(*shapes[non_const])
res_exp = contract(eq, *(ops[i] if i in constants else var for i in range(3)))

expr = contract_expression(eq, *ops, constants=constants)

# check array API
res_got = expr(var, backend=array_api_qname)
# check array API versions of constants exist
assert all(
array is None or infer_backend(array) == array_api_qname for array in expr._evaluated_constants[array_api_qname]
)
assert np.allclose(res_exp, res_got)

# check can call with numpy still
res_got2 = expr(var, backend="numpy")
assert np.allclose(res_exp, res_got2)

# check array API call returns an array object belonging to the same backend
# NOTE: the array API standard does not require that the returned array is the same type as the input array,
# only that the returned array also obeys the array API standard. Indeed, the standard does not stipulate
# even a *name* for the array type.
#
# For this reason, we won't check the type of the returned array, but only that it is has an
# ``__array_namespace__`` attribute and hence claims to comply with standard.
#
# In future versions, if einexpr uses newer array API features, we will also need to check that the
# returned array complies with the appropriate version of the standard.
res_got3 = expr(array_api.asarray(var))
assert hasattr(res_got3, "__array_namespace__")
assert np.allclose(res_exp, res_got3)