Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable torch.backends.opt_einsum to avoid duplicate work #205

Closed

Conversation

janeyx99
Copy link
Contributor

@janeyx99 janeyx99 commented Nov 7, 2022

Description

Recently, torch.einsum has improved to automatically optimize for multi-contractions if the opt_einsum library is installed. This way, torch users can reap benefits easily. However, this change may inadvertently cause regressions for those who use opt_einsum.contract with torch tensors. While these use cases can be improved by just changing opt_einsum.contract calls to torch.einsum, it is not right to cause silent regressions.

Consider the following example:

import torch
import opt_einsum

A = torch.rand(4, 5)
B = torch.rand(5, 6)
C = torch.rand(6, 7) 
D = opt_einsum.contract('ij,jk,kl->il', A, B, C)

Looking through the code, opt_einsum.contract will figure out an ideal path + return that path in the form of contraction tuples (commonly pairs). Then, for each of the contraction tuples, opt_einsum may call BACK into the torch backend and call torch.einsum.

Since the contractions are commonly pairs, calling back into torch.einsum will do what it used to--just directly do the contraction.

HOWEVER. There are cases where the contractions are not necessarily pairs (@dgasmith had mentioned that hadamard products are faster when chained vs staggered in pairs) and in this case, torch.einsum will do unnecessary work in recomputing an ideal path, which is a regression from the previous state.

This PR simply asks "hey, does the opt_einsum backend exist in torch?" If so, this means torch.einsum will do the unnecessary work in recomputing an ideal path, so let's just turn it off.

Todos

Notable points that this PR has either accomplished or will accomplish.

  • Add the code to disable torch's opt_einsum optimization
  • Test cases?

Questions

  • Unsure how to best test this/what use case would ensure that, without this code, the path would be computed more than once, and after this code, the path would only be computed at the top level. We could just test that torch.backends.opt_einsum.enabled is globally False, I guess.

Status

  • Ready to go

@codecov
Copy link

codecov bot commented Nov 7, 2022

Codecov Report

Merging #205 (9a836cf) into master (1a984b7) will increase coverage by 3.06%.
The diff coverage is 50.00%.

❗ Current head 9a836cf differs from pull request most recent head fbd3f9c. Consider uploading reports for the commit fbd3f9c to get more accurate results

@janeyx99
Copy link
Contributor Author

janeyx99 commented Nov 9, 2022

tagging @dgasmith to get it on the radar

@dgasmith
Copy link
Owner

dgasmith commented Nov 9, 2022

@janeyx99 Apologies, I'm traveling for the next week and will be slow to respond.

@jcmgray Can you get eyes on this?

@jcmgray
Copy link
Collaborator

jcmgray commented Nov 9, 2022

Hi! I'm guessing torch does not have a way to turn off the path optimization at call time (e.g. torch.einsum(..., optimize=False))? That would obviously be ideal.

My thoughts would be the following:

  • it's rare to come across paths where not contracting pairwise is practically beneficial (theoretically ops-wise pairwise is always optimal). I think currently only imposing a size limit makes opt_einsum produce these?
  • In the case that say a 3 tensor contraction is encountered, the path optimization would only run on the those 3, not the whole expression - i.e. unlikely to have much overhead

So in general my opinion is that is unlikely to make a difference for most cases. Other considerations:

@jcmgray
Copy link
Collaborator

jcmgray commented Nov 9, 2022

A bit hacky, but could we try and call _VF.einsum directly and skip the torch side python processing?

@dgasmith
Copy link
Owner

@janeyx99 Thoughts here?

@janeyx99
Copy link
Contributor Author

Ah, @jcmgray thank you for the thoughts. I too believe that this change should not make a huge difference given the same reasons you provided, and I do agree there are edge cases where it may be confusing to set the global state for the user.

We could just try directly calling _VF.einsum :p. We do also have a local context manager where I could do something like

if torch has the attributes and whatnot:
    with torch.backends.opt_einsum.flags(enabled=False):
         return torch.einsum(eq, ops)
return torch.einsum(eq, ops)

Would this be less hacky?

@dgasmith
Copy link
Owner

dgasmith commented Jan 4, 2023

@jcmgray / @janeyx99 Gentle pings here on solutions. I don't have a strong opinion, but it would be nice to keep this moving.

@janeyx99
Copy link
Contributor Author

My view on this is that maybe it’s not a big deal and so the complexity/additional check may not be worth it. That said, I would be happy to try calling _VF.einsum directly if that is preferable to the context manager.

@jcmgray
Copy link
Collaborator

jcmgray commented Jan 19, 2023

I think it's a rare and small enough effect that I'd be happy to ignore. Having said that:

  1. I suppose to respect size_limit, calling the raw einsum might be important.
  2. The time difference is small but not zero:
import torch
from torch import _VF

x = torch.rand((2,))
y = torch.rand((2,))
z = torch.rand((2,))
eq = "a,a,a->a"

%timeit torch.einsum(eq, x, y, z)
# 37.3 µs ± 70.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%%timeit
with torch.backends.opt_einsum.flags(enabled=False):
    torch.einsum(eq, x, y, z)
# 8.94 µs ± 95.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

%timeit _VF.einsum(eq, (x, y, z))
# 3.86 µs ± 36 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

we could do something like:

from functools import lru_cache

@lru_cache(None)
def get_einsum():
    try:
        from torch import _VF
        return _VF.einsum
    except ImportError:  # maybe other errors, AttributeError?
        import torch

        def einsum_no_opt(*args, **kwargs):
            with torch.backends.opt_einsum.flags(enabled=False):
                return torch.einsum(*args, **kwargs)

        return einsum_no_opt

that way if relying on _VF.einsum breaks somehow, we have a sensible fallback.

@janeyx99 janeyx99 force-pushed the torch-backend-disable-opt_einsum branch from d2800f9 to a63aedd Compare January 24, 2023 16:32
@janeyx99 janeyx99 force-pushed the torch-backend-disable-opt_einsum branch from a63aedd to 0a8579c Compare January 24, 2023 16:45
@janeyx99
Copy link
Contributor Author

@dgasmith looks like there are conda package conflicts that aren't related to the PR

def get_einsum():
torch, _ = _get_torch_and_device()
if hasattr(torch, "_VF") and hasattr(torch._VF, "einsum"):
print("returning VF")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the prints, I think we're good to go1

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the current test cases sufficient to verify this doesn't break anything?

@dgasmith
Copy link
Owner

dgasmith commented May 5, 2024

@janeyx99 Given PyTorch has likely moved on since this PR was made and deprecated older versions, what is optimal to do here?

@janeyx99
Copy link
Contributor Author

I'm willing to clean this PR up if it is still advantageous from the opt_einsum perspective, but it has been sufficiently long enough where maybe this doesn't matter so much.

@dgasmith
Copy link
Owner

My understanding was this PR would be optimal from a Torch perspective. If it's no longer needed or if Torch's minimum supported version no longer requires this patch, we can safely close.

@janeyx99
Copy link
Contributor Author

janeyx99 commented May 22, 2024

Ah, my understanding was that people calling into opt_einsum.contract will get slight silent regressions because the code path will look like

  • opt_einsum.contract
  • calls into torch.einsum
  • which calls contract_path (duplicating the work) to get the path
  • and then lastly calls the C++ einsum with the path

But if this is not a big deal, then I'm fine with leaving this unmerged!

@dgasmith
Copy link
Owner

@janeyx99 Ah- I would recommend having PyTorch skip contract_path when len(operands) <= 2 as there is nothing the code can do. While we do optimize for this fast path within opt_einsum there is still some parsing overhead. By this small optimization PyTorch can lower torch.einsum's latency in general.

@janeyx99
Copy link
Contributor Author

@dgasmith we already do that! Though I’m not sure how that’s relevant to the code path I had in mind above. I think it’s also sufficient to document that torch has already upstreamed opt_einsum though

@dgasmith
Copy link
Owner

dgasmith commented May 27, 2024

@janeyx99 Got it, then we only deal with the edge case where it's optimal to perform a three-product Hadamard or similar. The performance hit is likely quite minimal. Would you mind closing this PR and adding a note to the torch documentation that you should use the torch implementations instead, unless you have a complex case?

@janeyx99
Copy link
Contributor Author

janeyx99 commented Jun 7, 2024

okay! closing this PR, and planning to work on PyTorch issue 127109 instead.

@janeyx99 janeyx99 closed this Jun 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants