Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes: 'dp' and memory_limit + tensordot axes order #154

Merged
merged 6 commits into from
Nov 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 14 additions & 9 deletions opt_einsum/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
from collections import namedtuple
from decimal import Decimal

import numpy as np

from . import backends, blas, helpers, parser, paths, sharing

__all__ = ["contract_path", "contract", "format_const_einsum_str", "ContractExpression", "shape_only"]
Expand Down Expand Up @@ -563,14 +561,21 @@ def _core_contract(operands, contraction_list, backend='auto', evaluate_constant

tensor_result = "".join(s for s in input_left + input_right if s not in idx_rm)

# Find indices to contract over
left_pos, right_pos = [], []
for s in idx_rm:
left_pos.append(input_left.find(s))
right_pos.append(input_right.find(s))
if idx_rm:
# Find indices to contract over
left_pos, right_pos = [], []
for s in idx_rm:
left_pos.append(input_left.find(s))
right_pos.append(input_right.find(s))

# Construct the axes tuples in a canonical order
axes = tuple(zip(*sorted(zip(left_pos, right_pos))))
else:
# Ensure axes is always pair of tuples
axes = ((), ())

# Contract!
new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend)
new_view = _tensordot(*tmp_operands, axes=axes, backend=backend)

# Build a new view if needed
if (tensor_result != results_index) or handle_out:
Expand Down Expand Up @@ -757,7 +762,7 @@ def __call__(self, *arrays, **kwargs):
try:
# Check if the backend requires special preparation / calling
# but also ignore non-numpy arrays -> assume user wants same type back
if backends.has_backend(backend) and all(isinstance(x, np.ndarray) for x in arrays):
if backends.has_backend(backend) and all(infer_backend(x) == 'numpy' for x in arrays):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleaner, nice.

return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants)

return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
Expand Down
5 changes: 5 additions & 0 deletions opt_einsum/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import heapq
import itertools
import random
import operator
from collections import Counter, OrderedDict, defaultdict

import numpy as np
Expand Down Expand Up @@ -961,6 +962,7 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
output = set(symbol2int[c] for c in output)
size_dict = {symbol2int[c]: v for c, v in size_dict.items() if c in symbol2int}
size_dict = [size_dict[j] for j in range(len(size_dict))]
naive_cost = len(inputs) * functools.reduce(operator.mul, size_dict)

inputs, inputs_done, inputs_contractions = _dp_parse_out_single_term_ops(inputs, all_inds, ind_counts)

Expand Down Expand Up @@ -1033,6 +1035,9 @@ def __call__(self, inputs, output, size_dict, memory_limit=None):
xn, g, all_tensors, inputs, i1_cut_i2_wo_output,
memory_limit, cntrct1, cntrct2)

if (cost_cap > naive_cost) and (len(x[-1]) == 0):
raise RuntimeError("No contraction found for given `memory_limit`.")

# increase cost cap for next iteration:
cost_cap = cost_increment * cost_cap

Expand Down
16 changes: 16 additions & 0 deletions opt_einsum/tests/test_paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,22 @@ def test_custom_dp_can_set_cost_cap():
assert info1.opt_cost == info2.opt_cost == info3.opt_cost


def test_dp_errors_when_no_contractions_found():
eq, shapes, size_dict = oe.helpers.rand_equation(10, 3, seed=42, return_size_dict=True)

# first get the actual minimum cost
opt = oe.DynamicProgramming(minimize='size')
path, info = oe.contract_path(eq, *shapes, shapes=True, optimize=opt)
mincost = info.largest_intermediate

# check we can still find it without minimizing size explicitly
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost, optimize='dp')

# but check just below this threshold raises
with pytest.raises(RuntimeError):
oe.contract_path(eq, *shapes, shapes=True, memory_limit=mincost - 1, optimize='dp')


@pytest.mark.parametrize("optimize", ['greedy', 'branch-2', 'branch-all', 'optimal', 'dp'])
def test_can_optimize_outer_products(optimize):
a, b, c = [np.random.randn(10, 10) for _ in range(3)]
Expand Down