-
Notifications
You must be signed in to change notification settings - Fork 68
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Disable torch.backends.opt_einsum to avoid duplicate work #205
Disable torch.backends.opt_einsum to avoid duplicate work #205
Conversation
tagging @dgasmith to get it on the radar |
Hi! I'm guessing My thoughts would be the following:
So in general my opinion is that is unlikely to make a difference for most cases. Other considerations:
|
A bit hacky, but could we try and call |
@janeyx99 Thoughts here? |
Ah, @jcmgray thank you for the thoughts. I too believe that this change should not make a huge difference given the same reasons you provided, and I do agree there are edge cases where it may be confusing to set the global state for the user. We could just try directly calling _VF.einsum :p. We do also have a local context manager where I could do something like
Would this be less hacky? |
My view on this is that maybe it’s not a big deal and so the complexity/additional check may not be worth it. That said, I would be happy to try calling _VF.einsum directly if that is preferable to the context manager. |
I think it's a rare and small enough effect that I'd be happy to ignore. Having said that:
import torch
from torch import _VF
x = torch.rand((2,))
y = torch.rand((2,))
z = torch.rand((2,))
eq = "a,a,a->a"
%timeit torch.einsum(eq, x, y, z)
# 37.3 µs ± 70.1 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
%%timeit
with torch.backends.opt_einsum.flags(enabled=False):
torch.einsum(eq, x, y, z)
# 8.94 µs ± 95.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
%timeit _VF.einsum(eq, (x, y, z))
# 3.86 µs ± 36 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each) we could do something like: from functools import lru_cache
@lru_cache(None)
def get_einsum():
try:
from torch import _VF
return _VF.einsum
except ImportError: # maybe other errors, AttributeError?
import torch
def einsum_no_opt(*args, **kwargs):
with torch.backends.opt_einsum.flags(enabled=False):
return torch.einsum(*args, **kwargs)
return einsum_no_opt that way if relying on |
d2800f9
to
a63aedd
Compare
a63aedd
to
0a8579c
Compare
@dgasmith looks like there are conda package conflicts that aren't related to the PR |
def get_einsum(): | ||
torch, _ = _get_torch_and_device() | ||
if hasattr(torch, "_VF") and hasattr(torch._VF, "einsum"): | ||
print("returning VF") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we remove the prints, I think we're good to go1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are the current test cases sufficient to verify this doesn't break anything?
@janeyx99 Given PyTorch has likely moved on since this PR was made and deprecated older versions, what is optimal to do here? |
I'm willing to clean this PR up if it is still advantageous from the opt_einsum perspective, but it has been sufficiently long enough where maybe this doesn't matter so much. |
My understanding was this PR would be optimal from a Torch perspective. If it's no longer needed or if Torch's minimum supported version no longer requires this patch, we can safely close. |
Ah, my understanding was that people calling into opt_einsum.contract will get slight silent regressions because the code path will look like
But if this is not a big deal, then I'm fine with leaving this unmerged! |
@janeyx99 Ah- I would recommend having |
@dgasmith we already do that! Though I’m not sure how that’s relevant to the code path I had in mind above. I think it’s also sufficient to document that torch has already upstreamed opt_einsum though |
@janeyx99 Got it, then we only deal with the edge case where it's optimal to perform a three-product Hadamard or similar. The performance hit is likely quite minimal. Would you mind closing this PR and adding a note to the torch documentation that you should use the torch implementations instead, unless you have a complex case? |
okay! closing this PR, and planning to work on PyTorch issue 127109 instead. |
Description
Recently,
torch.einsum
has improved to automatically optimize for multi-contractions if the opt_einsum library is installed. This way, torch users can reap benefits easily. However, this change may inadvertently cause regressions for those who useopt_einsum.contract
with torch tensors. While these use cases can be improved by just changingopt_einsum.contract
calls totorch.einsum
, it is not right to cause silent regressions.Consider the following example:
Looking through the code,
opt_einsum.contract
will figure out an ideal path + return that path in the form of contraction tuples (commonly pairs). Then, for each of the contraction tuples,opt_einsum
may call BACK into the torch backend and calltorch.einsum
.Since the contractions are commonly pairs, calling back into torch.einsum will do what it used to--just directly do the contraction.
HOWEVER. There are cases where the contractions are not necessarily pairs (@dgasmith had mentioned that hadamard products are faster when chained vs staggered in pairs) and in this case,
torch.einsum
will do unnecessary work in recomputing an ideal path, which is a regression from the previous state.This PR simply asks "hey, does the opt_einsum backend exist in torch?" If so, this means
torch.einsum
will do the unnecessary work in recomputing an ideal path, so let's just turn it off.Todos
Notable points that this PR has either accomplished or will accomplish.
Questions
Status