Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable torch.backends.opt_einsum to avoid duplicate work #205

Closed
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions opt_einsum/backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ..parser import convert_to_valid_einsum_chars
from ..sharing import to_backend_cache_wrap
from functools import lru_cache

__all__ = [
"transpose",
Expand Down Expand Up @@ -41,14 +42,29 @@ def transpose(a, axes):
return a.permute(*axes)


@lru_cache(None)
def get_einsum():
torch, _ = _get_torch_and_device()
if hasattr(torch, "_VF") and hasattr(torch._VF, "einsum"):
print("called")
return torch._VF.einsum
print("did not get called")

def einsum_no_opt(*args, **kwargs):
if hasattr(torch, "backends") and hasattr(torch.backends, "opt_einsum"):
janeyx99 marked this conversation as resolved.
Show resolved Hide resolved
with torch.backends.opt_einsum.flags(enabled=False):
return torch.einsum(*args, **kwargs)
return torch.einsum(*args, **kwargs)

return einsum_no_opt


def einsum(equation, *operands):
"""Variadic version of torch.einsum to match numpy api."""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
equation = convert_to_valid_einsum_chars(equation)

torch, _ = _get_torch_and_device()
return torch.einsum(equation, operands)
return get_einsum()(equation, operands)


def tensordot(x, y, axes=2):
Expand Down
13 changes: 12 additions & 1 deletion opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
Tests a series of opt_einsum contraction paths to ensure the results are the same for different paths
"""

import numpy as np
Expand Down Expand Up @@ -198,6 +198,17 @@ def test_contract_expression_interleaved_input():
assert np.allclose(out, expected)


def test_torch_contract_expression_interleaved_input():
import torch

x, y, z = (torch.randn(2, 2) for _ in "xyz")
expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0])
xshp, yshp, zshp = ((2, 2) for _ in "xyz")
expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0])
out = expr(x, y, z)
assert np.allclose(out, expected)


@pytest.mark.parametrize(
"string,constants",
[
Expand Down