Skip to content

Commit

Permalink
add einsum, einsum_expression
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Sep 11, 2023
1 parent 1354bf6 commit 76f2b1a
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 58 deletions.
13 changes: 12 additions & 1 deletion cotengra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
get_hypergraph,
)
from .interface import (
contract_expression,
einsum_expression,
einsum,
register_preset,
)

Expand Down Expand Up @@ -139,13 +140,23 @@ def HyperMultiOptimizer(*args, **kwargs):
return HyperOptimizer(*args, multicontraction=True, **kwargs)


contract_expression = einsum_expression
"""Alias for :func:`cotengra.einsum_expression`."""

contract = einsum
"""Alias for :func:`cotengra.einsum`."""


__all__ = (
"auto_hq_optimize",
"auto_optimize",
"contract_expression",
"contract",
"ContractionTree",
"ContractionTreeCompressed",
"ContractionTreeMulti",
"einsum_expression",
"einsum",
"FlowCutterOptimizer",
"get_hyper_space",
"get_hypergraph",
Expand Down
173 changes: 121 additions & 52 deletions cotengra/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def preset_to_optimizer(preset):
_find_path_handlers = {}


def find_path_explicit(inputs, output, size_dict, optimize):
def find_path_explicit_path(inputs, output, size_dict, optimize):
return optimize


Expand Down Expand Up @@ -99,7 +99,7 @@ def find_path(inputs, output, size_dict, optimize="auto", **kwargs):
elif isinstance(optimize, ContractionTree):
fn = _find_path_handlers[cls] = find_path_tree
elif isinstance(optimize, (tuple, list)):
fn = _find_path_handlers[cls] = find_path_explicit
fn = _find_path_handlers[cls] = find_path_explicit_path
else:
fn = _find_path_handlers[cls] = find_path_optimizer

Expand Down Expand Up @@ -209,7 +209,8 @@ def __call__(self, *arrays, **kwargs):


class WithBackend:
"""Wrapper to make any autoray written function take a ``backend`` kwarg.
"""Wrapper to make any autoray written function take a ``backend`` kwarg,
by simply using `autoray.backend_like`.
"""

__slots__ = ("fn",)
Expand All @@ -224,7 +225,64 @@ def __call__(self, *args, backend=None, **kwargs):
return self.fn(*args, **kwargs)


def contract_expression(
def _einsum_expression_with_constants(
eq,
*shapes,
optimize="auto",
constants=None,
implementation=None,
autojit=False,
prefer_einsum=False,
sort_contraction_indices=False,
):
import autoray as ar

constants = set(constants)

variables = []
variables_with_constants = []
shapes_only = []
for i, s in enumerate(shapes):
if i in constants:
variables_with_constants.append(s)
shapes_only.append(ar.shape(s))
else:
# want to generate function as if it were written with autoray
v = ar.lazy.Variable(s, backend="autoray.numpy")
variables.append(v)
variables_with_constants.append(v)
shapes_only.append(s)

# get the full expression, without constants
full_expr = einsum_expression(
eq,
*shapes_only,
optimize=optimize,
constants=None,
implementation=implementation,
# wait to jit until after constants are folded
autojit=False,
prefer_einsum=prefer_einsum,
sort_contraction_indices=sort_contraction_indices,
)

# trace through, and then get function with constants folded
lz_output = full_expr(*variables_with_constants)
fn = lz_output.get_function(variables, fold_constants=True)

# now we can jit
if autojit:
from autoray import autojit as _autojit

fn = _autojit(fn)
else:
# allow for backend kwarg (which will set what autoray.numpy uses)
fn = WithBackend(fn)

return fn


def einsum_expression(
eq,
*shapes,
optimize="auto",
Expand Down Expand Up @@ -253,8 +311,8 @@ def contract_expression(
if marked as constant in ``constants``.
optimize : str, path_like, PathOptimizer, or ContractionTree
The optimization strategy to use. If a ``HyperOptimizer`` or
``ContractionTree`` instance is passed then te expression will make use
of any sliced indices.
``ContractionTree`` instance is passed then the expression will make
use of any sliced indices.
constants : sequence of int, optional
The indices of tensors to treat as constant, the final expression will
take the remaining non-constant tensors as inputs.
Expand Down Expand Up @@ -293,7 +351,7 @@ def contract_expression(
with shapes ``shapes``.
"""
if constants is not None:
fn = _contract_expression_with_constants(
fn = _einsum_expression_with_constants(
eq,
*shapes,
optimize=optimize,
Expand Down Expand Up @@ -379,58 +437,69 @@ def fn(*arrays, backend=None):
return fn


def _contract_expression_with_constants(
_EINSUM_EXPR_CACHE = {}


def einsum(
eq,
*shapes,
*arrays,
optimize="auto",
constants=None,
implementation=None,
autojit=False,
prefer_einsum=False,
sort_contraction_indices=False,
cache_expression=True,
backend=None,
**kwargs,
):
import autoray as ar

constants = set(constants)
"""Perform an einsum contraction, using `cotengra`. By default the path
finding and expression building is cached, so that if the same contraction
is performed multiple times the overhead is negated.
variables = []
variables_with_constants = []
shapes_only = []
for i, s in enumerate(shapes):
if i in constants:
variables_with_constants.append(s)
shapes_only.append(ar.shape(s))
else:
# want to generate function as if it were written with autoray
v = ar.lazy.Variable(s, backend="autoray.numpy")
variables.append(v)
variables_with_constants.append(v)
shapes_only.append(s)
Parameters
----------
eq : str
The equation to use for contraction, for example ``'ab,bc->ac'``.
arrays : sequence[array]
The arrays to contract.
optimize : str, path_like, PathOptimizer, or ContractionTree
The optimization strategy to use. If a ``HyperOptimizer`` or
``ContractionTree`` instance is passed then the contraction will make
use of any sliced indices.
cache_expression : bool, optional
If ``True``, cache the expression used to contract the arrays. This
negates the overhead of pathfinding and building the expression when
a contraction is performed multiple times.
backend : str, optional
If given, the explicit backend to use for the contraction, by default
the backend is dispatched automatically.
kwargs
Passed to :func:`einsum_expression`.
# get the full expression, without constants
full_expr = contract_expression(
eq,
*shapes_only,
optimize=optimize,
constants=None,
implementation=implementation,
# wait to jit until after constants are folded
autojit=False,
prefer_einsum=prefer_einsum,
sort_contraction_indices=sort_contraction_indices,
)
Returns
-------
array
"""
shapes = tuple(map(ar.shape, arrays))

# trace through, and then get function with constants folded
lz_output = full_expr(*variables_with_constants)
fn = lz_output.get_function(variables, fold_constants=True)
if cache_expression and isinstance(optimize, str):
try:
key = (eq, shapes, optimize, frozenset(kwargs.items()))
try:
expr = _EINSUM_EXPR_CACHE[key]
except KeyError:
# missing from cache
expr = _EINSUM_EXPR_CACHE[key] = einsum_expression(
eq, *shapes, optimize=optimize, **kwargs
)
except TypeError:
# unhashale kwargs
import warnings

warnings.warn(
"einsum cache disabled as one of the "
f"arguments is not hashable: {kwargs}"
)

# now we can jit
if autojit:
from autoray import autojit as _autojit
expr = einsum_expression(eq, *shapes, optimize=optimize, **kwargs)

fn = _autojit(fn)
else:
# allow for backend kwarg (which will set what autoray.numpy uses)
fn = WithBackend(fn)
expr = einsum_expression(eq, *shapes, optimize=optimize, **kwargs)

return fn
return expr(*arrays, backend=backend)
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
],
extras_require={
"recommended": [
"cotengrust",
"kahypar",
"networkx",
"opt_einsum",
Expand Down
8 changes: 3 additions & 5 deletions tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,10 +96,8 @@
@pytest.mark.parametrize("eq", test_case_eqs)
def test_basic_equations(eq):
arrays = ctg.utils.make_arrays_from_eq(eq)
shapes = [a.shape for a in arrays]
x = np.einsum(eq, *arrays)
expr = ctg.contract_expression(eq, *shapes)
y = expr(*arrays)
y = ctg.einsum(eq, *arrays)
assert_allclose(x, y)


Expand Down Expand Up @@ -246,7 +244,7 @@ def test_exponent_stripping(autojit):
@pytest.mark.parametrize("constants", [None, True])
@pytest.mark.parametrize("optimize_type", ["path", "tree", "optimizer", "str"])
@pytest.mark.parametrize("sort_contraction_indices", [False, True])
def test_contract_expression(
def test_einsum_expression(
autojit,
constants,
optimize_type,
Expand Down Expand Up @@ -275,7 +273,7 @@ def test_contract_expression(
for c in sorted(constants, reverse=True):
shapes[c] = arrays.pop(c)

expr = ctg.contract_expression(
expr = ctg.einsum_expression(
eq,
*shapes,
optimize=optimize,
Expand Down

0 comments on commit 76f2b1a

Please sign in to comment.