Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable torch.backends.opt_einsum to avoid duplicate work #205

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions opt_einsum/backends/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from ..parser import convert_to_valid_einsum_chars
from ..sharing import to_backend_cache_wrap
from functools import lru_cache

__all__ = [
"transpose",
Expand Down Expand Up @@ -41,14 +42,31 @@ def transpose(a, axes):
return a.permute(*axes)


@lru_cache(None)
def get_einsum():
torch, _ = _get_torch_and_device()
if hasattr(torch, "_VF") and hasattr(torch._VF, "einsum"):
print("returning VF")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the prints, I think we're good to go1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the current test cases sufficient to verify this doesn't break anything?

return torch._VF.einsum

if not hasattr(torch, "backends") or not hasattr(torch.backends, "opt_einsum"):
print("returning normal torch")
return torch.einsum

def einsum_no_opt(*args, **kwargs):
with torch.backends.opt_einsum.flags(enabled=False):
return torch.einsum(*args, **kwargs)

print("returning normal torch with opt off")
return einsum_no_opt


def einsum(equation, *operands):
"""Variadic version of torch.einsum to match numpy api."""
# rename symbols to support PyTorch 0.4.1 and earlier,
# which allow only symbols a-z.
equation = convert_to_valid_einsum_chars(equation)

torch, _ = _get_torch_and_device()
return torch.einsum(equation, operands)
return get_einsum()(equation, operands)


def tensordot(x, y, axes=2):
Expand Down
13 changes: 12 additions & 1 deletion opt_einsum/tests/test_contract.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Tets a series of opt_einsum contraction paths to ensure the results are the same for different paths
Tests a series of opt_einsum contraction paths to ensure the results are the same for different paths
"""

import numpy as np
Expand Down Expand Up @@ -198,6 +198,17 @@ def test_contract_expression_interleaved_input():
assert np.allclose(out, expected)


def test_torch_contract_expression_interleaved_input():
import torch

x, y, z = (torch.randn(2, 2) for _ in "xyz")
expected = np.einsum(x, [0, 1], y, [1, 2], z, [2, 3], [3, 0])
xshp, yshp, zshp = ((2, 2) for _ in "xyz")
expr = contract_expression(xshp, [0, 1], yshp, [1, 2], zshp, [2, 3], [3, 0])
out = expr(x, y, z)
assert np.allclose(out, expected)


@pytest.mark.parametrize(
"string,constants",
[
Expand Down