Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Op] Enhanced Einsum #283

Closed
wants to merge 20 commits into from
Closed

[Op] Enhanced Einsum #283

wants to merge 20 commits into from

Conversation

LeshengJin
Copy link
Contributor

@LeshengJin LeshengJin commented Aug 13, 2023

Typically, Einsum only performs element-wise multiplication and summation across indices. This pr expands Einsum's capabilities:

  1. Customize element-wise computation and index combination.
  2. Einsum can now produce Tuple(tensor) outputs.

This enhanced Einsum could represents more complex computations within a few lines of code. For example,

  • sum(x), sum(x^2)
einsum("ij -> i, i", x, fcompute=lambda x_ij: (x_ij, x_ij * x_ij))
  • sum(x), prod(x)
einsum(
    "ij -> i, i",
    x,
    fcombine=lambda x, y: (x[0] + y[0], x[1] * y[1]),
    fidentity=lambda dtype1, dtype2: (tvm.tir.const(0, dtype1), tvm.tir.const(1, dtype2)),
)
  • Online Softmax
def fcombine(tensor1, tensor2):
    mi = tensor1[0]
    di = tensor1[1]
    mj = tensor2[0]
    dj = tensor2[1]
    r0 = tvm.tir.max(mi, mj)
    r1 = di * tvm.tir.exp(mi - r0) + dj * tvm.tir.exp(mj - r0)
    return r0, r1

def fidentity(dtype1, dtype2):
    return tvm.te.min_value(dtype1), tvm.tir.const(0, dtype2)

mv, dv = einsum(
    "ij -> i, i",
    x,
    fcompute=lambda x_ij: (x_ij, 1.0),
    fcombine=fcombine,
    fidentity=fidentity,
)

softmax_x = einsum(
    "ij, i, i -> ij",
    (x, mv, dv),
    fcompute=lambda x_ij, mv_i, dv_i: (tvm.tir.exp(x_ij) - mv_i) / dv_i,
)

@LeshengJin LeshengJin changed the title Einsum [Op] Einsum with customized compute function and combine function Aug 13, 2023
@LeshengJin LeshengJin changed the title [Op] Einsum with customized compute function and combine function [Op] Enhanced Einsum Aug 13, 2023
@junrushao
Copy link
Member

You may take a look at this: https://einops.rocks/

@MasterJH5574 MasterJH5574 force-pushed the mlc branch 2 times, most recently from f8b2ff1 to 59c3556 Compare February 20, 2024 14:55
@MasterJH5574 MasterJH5574 force-pushed the mlc branch 2 times, most recently from f178458 to 4b79ceb Compare March 4, 2024 22:24
@MasterJH5574 MasterJH5574 force-pushed the mlc branch 4 times, most recently from d20be8b to c06ec1f Compare March 17, 2024 02:08
@MasterJH5574 MasterJH5574 force-pushed the mlc branch 3 times, most recently from 1ce4a34 to dcdd541 Compare March 24, 2024 00:36
@MasterJH5574 MasterJH5574 force-pushed the mlc branch 2 times, most recently from 0c81069 to ae057a2 Compare April 12, 2024 21:11
@MasterJH5574 MasterJH5574 force-pushed the mlc branch 3 times, most recently from 597664a to ce58d63 Compare May 15, 2024 05:49
@tqchen tqchen closed this May 24, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants