Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sub-optimal contraction path when using broadcasting #220

Open
pimdh opened this issue Sep 23, 2023 · 0 comments
Open

Sub-optimal contraction path when using broadcasting #220

pimdh opened this issue Sep 23, 2023 · 0 comments

Comments

@pimdh
Copy link

pimdh commented Sep 23, 2023

Hi,
I'm not sure if there's an immediate solution possible, but it seems like opt_einsum first considers broadcasting, then optimizes the contraction path. This leads to sub-optimal results:

import opt_einsum
print(opt_einsum.__version__)
print(opt_einsum.contract_path("ijk,bj,bk->bi", (32, 32, 32), (10000, 32), (1, 32), optimize="optimal", shapes=True))
print(opt_einsum.contract_path("ijk,bj,k->bi", (32, 32, 32), (10000, 32), (32,), optimize="optimal", shapes=True))

Gives

v3.3.0+24.g1a984b7
([(1, 2), (0, 1)],   Complete contraction:  ijk,bj,bk->bi
         Naive scaling:  4
     Optimized scaling:  4
      Naive FLOP count:  9.830e+8
  Optimized FLOP count:  6.656e+8
   Theoretical speedup:  1.477e+0
  Largest intermediate:  1.024e+7 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3              0             bk,bj->bkj                           ijk,bkj->bi
   4           TDOT            bkj,ijk->bi                                bi->bi)
([(0, 2), (0, 1)],   Complete contraction:  ijk,bj,k->bi
         Naive scaling:  4
     Optimized scaling:  3
      Naive FLOP count:  9.830e+8
  Optimized FLOP count:  2.055e+7
   Theoretical speedup:  4.785e+1
  Largest intermediate:  3.200e+5 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3           GEMM              k,ijk->ij                             bj,ij->bi
   3           GEMM              ij,bj->bi                                bi->bi)

We see that in the first case, the third tensor is broadcasted to (b, 32) and then the optimizer decides it's best to contract the latter two tensors. Ideally, we'd strip off the to-be-broadcasted dim from the third tensor, which allows for a much faster computation, as shown in the second case.

Any ideas on how this could be addressed? I understand that this doesn't involve just choosing a contraction path, so might not be solvable by this library.
Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant