Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

opt_einsum.contract is slower than raw numpy.einsum #234

Open
lin-cp opened this issue May 27, 2024 · 3 comments
Open

opt_einsum.contract is slower than raw numpy.einsum #234

lin-cp opened this issue May 27, 2024 · 3 comments

Comments

@lin-cp
Copy link

lin-cp commented May 27, 2024

Dear developer,

I am consdering to substitute every numpy.einsum in my code withopt_einsum.contract. However, I found in some cases, contract is much slower than the raw numpy.einsum without any path optimization.

I have three tensors in complex number with the shape as follows:
A: (1166, 8000, 3, 3)
B: (8000, 6, 1166, 3)
C: (8000, 6, 1166, 3)

Then I used:
D = np.einsum("ijkl,jmik,jmil->jm", A, B, C)
and
from opt_einsum import contract
D = contract("ijkl,jmik,jmil->jm", A, B, C)

Using contract is about 3 to 4 times slower compared to np.einsum (e.g. ~58s and ~16s, respectively). According to the documentation of opt_einsum, if I don't specify the parameter optimize, the default is optimize=auto which keeps the path finding time below around 1 ms.

I am trying to understand why in this case opt_einsum will perform worse than the raw einsum of numpy. Any suggestion or comment?

Best,
Changpeng

@lin-cp lin-cp changed the title opt_einsum.contract is slow than raw numpy.einsum opt_einsum.contract is slower than raw numpy.einsum May 27, 2024
@jcmgray
Copy link
Collaborator

jcmgray commented May 27, 2024

Hi @lin-cp, thanks for the interesting example. For what its worth I get a much more minor performance difference (4.5 vs 6.0 seconds), but worse nonetheless. If you check the optimal path:

  Complete contraction:  ijkl,jmik,jmil->jm
         Naive scaling:  5
     Optimized scaling:  5
      Naive FLOP count:  1.511e+9
  Optimized FLOP count:  1.343e+9
   Theoretical speedup:  1.125e+0
  Largest intermediate:  1.679e+8 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   5              0        jmik,ijkl->jmil                         jmil,jmil->jm
   4              0          jmil,jmil->jm                                jm->jm

you see that there is actually no scaling advantage to the optimized path, and only a very small absolute cost decrease. This cost decrease is probably majorly offset by the time to write the memory of the intermediate - thus leading to the worse performance.

In general for these very low arithmetic intensity, small contractions (ie. 3 terms) where most indices appear on all the tensors, there is little to be gained from optimizing the order of contractions (which is what opt_einsum does) and it mostly comes down to a variety of things which are not purely FLOPs.

@lin-cp
Copy link
Author

lin-cp commented May 28, 2024

@jcmgray Thanks for your inspiring explanation. Could you explain more what are the intermediate steps that require possibly intensive memory that worses the performance of opt_einsum?

I have two further questions:

  1. Is there any empirical rule to judge if using opt_einsum will help to save time?
  2. Does the order of indices matter? Will it help to save time? For example, will there be a large difference between "ij,jk->ik" and "ji,jk->ik" (i.e. first transpose the first tensor)?

@jcmgray
Copy link
Collaborator

jcmgray commented May 28, 2024

The intermediate step is as listed above:

  • "jmik,ijkl->jmil"

Is there any empirical rule to judge if using opt_einsum will help to save time?

  1. Generally the speedup listed by contract_path should be quite a bit greater than 1. The speedups from ordering are much greater as the number of terms increases (exponentially so). The speedups from using matrix multiplication are much greater when the arithemetic intensity is high (~FLOPs / MOPs), which in practice means more 'contracted' indices vs 'batch' indices.

Does the order of indices matter?

  1. It does, but its hard to predict and opt_einsum does not take it into account - just the order of contractions. Essentially if the indices/memory are already lined up to do the contraction via matrix multiplication that is best.

If you have a 3 term contraction and need to eke the best performance out, you probably need to benchmark things explicitly, the optimization opt_einsum is doing here is really just choosing from the 3 possibilities ((AB)C), ((AC)B) and (A(BC)) then dispatching these to matrix multiplication when possible (which is not the case in your example).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants