diff --git a/cotengra/__init__.py b/cotengra/__init__.py index 1280843..46090e4 100644 --- a/cotengra/__init__.py +++ b/cotengra/__init__.py @@ -26,7 +26,8 @@ get_hypergraph, ) from .interface import ( - contract_expression, + einsum_expression, + einsum, register_preset, ) @@ -139,13 +140,23 @@ def HyperMultiOptimizer(*args, **kwargs): return HyperOptimizer(*args, multicontraction=True, **kwargs) +contract_expression = einsum_expression +"""Alias for :func:`cotengra.einsum_expression`.""" + +contract = einsum +"""Alias for :func:`cotengra.einsum`.""" + + __all__ = ( "auto_hq_optimize", "auto_optimize", "contract_expression", + "contract", "ContractionTree", "ContractionTreeCompressed", "ContractionTreeMulti", + "einsum_expression", + "einsum", "FlowCutterOptimizer", "get_hyper_space", "get_hypergraph", diff --git a/cotengra/interface.py b/cotengra/interface.py index 0fb9414..f74435c 100644 --- a/cotengra/interface.py +++ b/cotengra/interface.py @@ -49,7 +49,7 @@ def preset_to_optimizer(preset): _find_path_handlers = {} -def find_path_explicit(inputs, output, size_dict, optimize): +def find_path_explicit_path(inputs, output, size_dict, optimize): return optimize @@ -99,7 +99,7 @@ def find_path(inputs, output, size_dict, optimize="auto", **kwargs): elif isinstance(optimize, ContractionTree): fn = _find_path_handlers[cls] = find_path_tree elif isinstance(optimize, (tuple, list)): - fn = _find_path_handlers[cls] = find_path_explicit + fn = _find_path_handlers[cls] = find_path_explicit_path else: fn = _find_path_handlers[cls] = find_path_optimizer @@ -209,7 +209,8 @@ def __call__(self, *arrays, **kwargs): class WithBackend: - """Wrapper to make any autoray written function take a ``backend`` kwarg. + """Wrapper to make any autoray written function take a ``backend`` kwarg, + by simply using `autoray.backend_like`. """ __slots__ = ("fn",) @@ -224,7 +225,64 @@ def __call__(self, *args, backend=None, **kwargs): return self.fn(*args, **kwargs) -def contract_expression( +def _einsum_expression_with_constants( + eq, + *shapes, + optimize="auto", + constants=None, + implementation=None, + autojit=False, + prefer_einsum=False, + sort_contraction_indices=False, +): + import autoray as ar + + constants = set(constants) + + variables = [] + variables_with_constants = [] + shapes_only = [] + for i, s in enumerate(shapes): + if i in constants: + variables_with_constants.append(s) + shapes_only.append(ar.shape(s)) + else: + # want to generate function as if it were written with autoray + v = ar.lazy.Variable(s, backend="autoray.numpy") + variables.append(v) + variables_with_constants.append(v) + shapes_only.append(s) + + # get the full expression, without constants + full_expr = einsum_expression( + eq, + *shapes_only, + optimize=optimize, + constants=None, + implementation=implementation, + # wait to jit until after constants are folded + autojit=False, + prefer_einsum=prefer_einsum, + sort_contraction_indices=sort_contraction_indices, + ) + + # trace through, and then get function with constants folded + lz_output = full_expr(*variables_with_constants) + fn = lz_output.get_function(variables, fold_constants=True) + + # now we can jit + if autojit: + from autoray import autojit as _autojit + + fn = _autojit(fn) + else: + # allow for backend kwarg (which will set what autoray.numpy uses) + fn = WithBackend(fn) + + return fn + + +def einsum_expression( eq, *shapes, optimize="auto", @@ -253,8 +311,8 @@ def contract_expression( if marked as constant in ``constants``. optimize : str, path_like, PathOptimizer, or ContractionTree The optimization strategy to use. If a ``HyperOptimizer`` or - ``ContractionTree`` instance is passed then te expression will make use - of any sliced indices. + ``ContractionTree`` instance is passed then the expression will make + use of any sliced indices. constants : sequence of int, optional The indices of tensors to treat as constant, the final expression will take the remaining non-constant tensors as inputs. @@ -293,7 +351,7 @@ def contract_expression( with shapes ``shapes``. """ if constants is not None: - fn = _contract_expression_with_constants( + fn = _einsum_expression_with_constants( eq, *shapes, optimize=optimize, @@ -379,58 +437,69 @@ def fn(*arrays, backend=None): return fn -def _contract_expression_with_constants( +_EINSUM_EXPR_CACHE = {} + + +def einsum( eq, - *shapes, + *arrays, optimize="auto", - constants=None, - implementation=None, - autojit=False, - prefer_einsum=False, - sort_contraction_indices=False, + cache_expression=True, + backend=None, + **kwargs, ): - import autoray as ar - - constants = set(constants) + """Perform an einsum contraction, using `cotengra`. By default the path + finding and expression building is cached, so that if the same contraction + is performed multiple times the overhead is negated. - variables = [] - variables_with_constants = [] - shapes_only = [] - for i, s in enumerate(shapes): - if i in constants: - variables_with_constants.append(s) - shapes_only.append(ar.shape(s)) - else: - # want to generate function as if it were written with autoray - v = ar.lazy.Variable(s, backend="autoray.numpy") - variables.append(v) - variables_with_constants.append(v) - shapes_only.append(s) + Parameters + ---------- + eq : str + The equation to use for contraction, for example ``'ab,bc->ac'``. + arrays : sequence[array] + The arrays to contract. + optimize : str, path_like, PathOptimizer, or ContractionTree + The optimization strategy to use. If a ``HyperOptimizer`` or + ``ContractionTree`` instance is passed then the contraction will make + use of any sliced indices. + cache_expression : bool, optional + If ``True``, cache the expression used to contract the arrays. This + negates the overhead of pathfinding and building the expression when + a contraction is performed multiple times. + backend : str, optional + If given, the explicit backend to use for the contraction, by default + the backend is dispatched automatically. + kwargs + Passed to :func:`einsum_expression`. - # get the full expression, without constants - full_expr = contract_expression( - eq, - *shapes_only, - optimize=optimize, - constants=None, - implementation=implementation, - # wait to jit until after constants are folded - autojit=False, - prefer_einsum=prefer_einsum, - sort_contraction_indices=sort_contraction_indices, - ) + Returns + ------- + array + """ + shapes = tuple(map(ar.shape, arrays)) - # trace through, and then get function with constants folded - lz_output = full_expr(*variables_with_constants) - fn = lz_output.get_function(variables, fold_constants=True) + if cache_expression and isinstance(optimize, str): + try: + key = (eq, shapes, optimize, frozenset(kwargs.items())) + try: + expr = _EINSUM_EXPR_CACHE[key] + except KeyError: + # missing from cache + expr = _EINSUM_EXPR_CACHE[key] = einsum_expression( + eq, *shapes, optimize=optimize, **kwargs + ) + except TypeError: + # unhashale kwargs + import warnings + + warnings.warn( + "einsum cache disabled as one of the " + f"arguments is not hashable: {kwargs}" + ) - # now we can jit - if autojit: - from autoray import autojit as _autojit + expr = einsum_expression(eq, *shapes, optimize=optimize, **kwargs) - fn = _autojit(fn) else: - # allow for backend kwarg (which will set what autoray.numpy uses) - fn = WithBackend(fn) + expr = einsum_expression(eq, *shapes, optimize=optimize, **kwargs) - return fn + return expr(*arrays, backend=backend) diff --git a/setup.py b/setup.py index 3dee38a..129940f 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ ], extras_require={ "recommended": [ + "cotengrust", "kahypar", "networkx", "opt_einsum", diff --git a/tests/test_compute.py b/tests/test_compute.py index 612655a..b630fe2 100644 --- a/tests/test_compute.py +++ b/tests/test_compute.py @@ -96,10 +96,8 @@ @pytest.mark.parametrize("eq", test_case_eqs) def test_basic_equations(eq): arrays = ctg.utils.make_arrays_from_eq(eq) - shapes = [a.shape for a in arrays] x = np.einsum(eq, *arrays) - expr = ctg.contract_expression(eq, *shapes) - y = expr(*arrays) + y = ctg.einsum(eq, *arrays) assert_allclose(x, y) @@ -246,7 +244,7 @@ def test_exponent_stripping(autojit): @pytest.mark.parametrize("constants", [None, True]) @pytest.mark.parametrize("optimize_type", ["path", "tree", "optimizer", "str"]) @pytest.mark.parametrize("sort_contraction_indices", [False, True]) -def test_contract_expression( +def test_einsum_expression( autojit, constants, optimize_type, @@ -275,7 +273,7 @@ def test_contract_expression( for c in sorted(constants, reverse=True): shapes[c] = arrays.pop(c) - expr = ctg.contract_expression( + expr = ctg.einsum_expression( eq, *shapes, optimize=optimize,