Skip to content

Commit

Permalink
Add tensor name to constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Jul 24, 2024
1 parent f4b1a9f commit c3aa2e8
Show file tree
Hide file tree
Showing 23 changed files with 60 additions and 116 deletions.
3 changes: 2 additions & 1 deletion teaal/trans/header.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ def make_output(self) -> Statement:
self.program.get_loop_order().apply(tensor)

arg0 = TransUtils.build_rank_ids(tensor)
args = self.__make_shape([arg0])
arg1 = AParam("name", EString(tensor.root_name()))
args = self.__make_shape([arg0, arg1])
constr = EFunc("Tensor", args)
return SAssign(AVar(tensor.tensor_name()), constr)

Expand Down
71 changes: 7 additions & 64 deletions tests/integration/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,66 +1,9 @@
einsum:
declaration:
SOB: [UA, UB]
T: [UA, UB, K]
I: [K]
OB: []
declaration: # Ranks are listed alphabetically in this section
TS: [T, P1, P0, E]
ND: [P1, P0, E]
expressions:
- SOB[ua, ub] = T[ua, ub, k] * I[k]
- OB[] = SOB[ua, ub]
# einsum:
# declaration:
# A: [S]
# Z: [T]
# expressions:
# - Z[t] = A[2 * t]
# format:
# A:
# default:
# rank-order: [S]
# S:
# format: C
# pbits: 32
# Z:
# default:
# rank-order: [T]
# T:
# format: C
# pbits: 32
# architecture:
# accel:
# - name: System
# local:
# - name: MainMemory
# class: DRAM
# subtree:
# - name: Chip
# local:
# - name: LLB
# class: Cache
# attributes:
# width: 32
# depth: 1024
# bindings:
# Z:
# - config: accel
# prefix: tmp/demo
# - component: MainMemory
# bindings:
# - tensor: A
# rank: S
# type: payload
# format: default
# - tensor: Z
# rank: T
# type: payload
# format: default
# - component: LLB
# bindings:
# - tensor: A
# rank: S
# type: payload
# format: default
# - tensor: Z
# rank: T
# type: payload
# format: default
- ND[p1, p0, e] = TS[t, p1, p0, e]
mapping:
loop-order:
ND: [T, P1, P0, E]
2 changes: 1 addition & 1 deletion tests/integration/dotprod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Z_ = Tensor(rank_ids=[])
Z_ = Tensor(rank_ids=[], name="Z")
z_ref = Z_.getRoot()
a_k = A_K.getRoot()
b_k = B_K.getRoot()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
T1_IJ = Tensor(rank_ids=["I", "J"])
T1_IJ = Tensor(rank_ids=["I", "J"], name="T1")
t1_i = T1_IJ.getRoot()
a_i = A_IJK.getRoot()
b_k = B_KL.getRoot()
Expand All @@ -7,7 +7,7 @@
for k, (a_val, b_l) in a_k & b_k:
for l, b_val in b_l:
t1_ref += a_val * b_val
D_I = Tensor(rank_ids=["I"])
D_I = Tensor(rank_ids=["I"], name="D")
d_i = D_I.getRoot()
c_i = C_IJ.getRoot()
t1_i = T1_IJ.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/example2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
D_ = Tensor(rank_ids=[])
D_ = Tensor(rank_ids=[], name="D")
B_IJ = B_JI.swizzleRanks(rank_ids=["I", "J"])
d_ref = D_.getRoot()
a_i = A_IJ.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/example3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
D_ = Tensor(rank_ids=[])
D_ = Tensor(rank_ids=[], name="D")
B_IJK = B_JKI.swizzleRanks(rank_ids=["I", "J", "K"])
d_ref = D_.getRoot()
a_i = A_IJK.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/example4.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
D_ = Tensor(rank_ids=[])
D_ = Tensor(rank_ids=[], name="D")
B_IJK = B_JKI.swizzleRanks(rank_ids=["I", "J", "K"])
C_IJK = C_JKI.swizzleRanks(rank_ids=["I", "J", "K"])
d_ref = D_.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/example5.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Z_IK = Tensor(rank_ids=["I", "K"])
Z_IK = Tensor(rank_ids=["I", "K"], name="Z")
B_KJ = B_JK.swizzleRanks(rank_ids=["K", "J"])
z_i = Z_IK.getRoot()
a_i = A_IJ.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/example6.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
D_I = Tensor(rank_ids=["I"])
D_I = Tensor(rank_ids=["I"], name="D")
d_i = D_I.getRoot()
c_i = C_IJ.getRoot()
for i, (d_ref, c_j) in d_i << c_i:
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/example7.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
Z_MN = Tensor(rank_ids=["M", "N"])
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
z_m = Z_MN.getRoot()
a_m = A_MN.getRoot()
b_m = B_MN.getRoot()
for m, (z_n, (_, a_n, b_n)) in z_m << (a_m | b_m):
for n, (z_ref, (_, a_val, b_val)) in z_n << (a_n | b_n):
z_ref += a_val + b_val
Z_MN = Tensor(rank_ids=["M", "N"])
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
z_m = Z_MN.getRoot()
a_m = A_MN.getRoot()
b_m = B_MN.getRoot()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/gemm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
T1_MN = Tensor(rank_ids=["M", "N"])
T1_MN = Tensor(rank_ids=["M", "N"], name="T1")
A_MK = A_KM.swizzleRanks(rank_ids=["M", "K"])
B_NK = B_KN.swizzleRanks(rank_ids=["N", "K"])
t1_m = T1_MN.getRoot()
Expand All @@ -8,7 +8,7 @@
for n, (t1_ref, b_k) in t1_n << b_n:
for k, (a_val, b_val) in a_k & b_k:
t1_ref += a_val * b_val
Z_MN = Tensor(rank_ids=["M", "N"])
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
z_m = Z_MN.getRoot()
t1_m = T1_MN.getRoot()
c_m = C_MN.getRoot()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/gemv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
T1_M = Tensor(rank_ids=["M"])
T1_M = Tensor(rank_ids=["M"], name="T1")
B_MK = B_KM.swizzleRanks(rank_ids=["M", "K"])
t1_m = T1_M.getRoot()
a_k = A_K.getRoot()
b_m = B_MK.getRoot()
for m, (t1_ref, b_k) in t1_m << b_m:
for k, (a_val, b_val) in a_k & b_k:
t1_ref += a_val * b_val
Z_M = Tensor(rank_ids=["M"])
Z_M = Tensor(rank_ids=["M"], name="Z")
z_m = Z_M.getRoot()
t1_m = T1_M.getRoot()
c_m = C_M.getRoot()
Expand Down
6 changes: 3 additions & 3 deletions tests/integration/gram.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
G0_II_ = Tensor(rank_ids=["I", "I_"])
G0_II_ = Tensor(rank_ids=["I", "I_"], name="G0")
g0_i = G0_II_.getRoot()
a_i = A_IJK.getRoot()
b0_i_ = B0_I_JK.getRoot()
Expand All @@ -7,7 +7,7 @@
for j, (a_k, b0_k) in a_j & b0_j:
for k, (a_val, b0_val) in a_k & b0_k:
g0_ref += a_val * b0_val
G1_JJ_ = Tensor(rank_ids=["J", "J_"])
G1_JJ_ = Tensor(rank_ids=["J", "J_"], name="G1")
A_JIK = A_IJK.swizzleRanks(rank_ids=["J", "I", "K"])
B1_J_IK = B1_IJ_K.swizzleRanks(rank_ids=["J_", "I", "K"])
g1_j = G1_JJ_.getRoot()
Expand All @@ -18,7 +18,7 @@
for i, (a_k, b1_k) in a_i & b1_i:
for k, (a_val, b1_val) in a_k & b1_k:
g1_ref += a_val * b1_val
G2_KK_ = Tensor(rank_ids=["K", "K_"])
G2_KK_ = Tensor(rank_ids=["K", "K_"], name="G2")
A_KIJ = A_IJK.swizzleRanks(rank_ids=["K", "I", "J"])
B2_K_IJ = B2_IJK_.swizzleRanks(rank_ids=["K_", "I", "J"])
g2_k = G2_KK_.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/mttkrp.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Z_MN = Tensor(rank_ids=["M", "N"])
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
B_NK = B_KN.swizzleRanks(rank_ids=["N", "K"])
C_NJ = C_JN.swizzleRanks(rank_ids=["N", "J"])
z_m = Z_MN.getRoot()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/nrm_sq.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
T_ABIJ = Tensor(rank_ids=["A", "B", "I", "J"])
T_ABIJ = Tensor(rank_ids=["A", "B", "I", "J"], name="T")
t_a = T_ABIJ.getRoot()
v_a = V_ABIJ.getRoot()
for a, (t_b, v_b) in t_a << v_a:
for b, (t_i, v_i) in t_b << v_b:
for i, (t_j, v_j) in t_i << v_i:
for j, (t_ref, v_val) in t_j << v_j:
t_ref += v_val
Q_ = Tensor(rank_ids=[])
Q_ = Tensor(rank_ids=[], name="Q")
q_ref = Q_.getRoot()
v_a = V_ABIJ.getRoot()
t_a = T_ABIJ.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/outerprod.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Z_MN = Tensor(rank_ids=["M", "N"])
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
z_m = Z_MN.getRoot()
a_m = A_M.getRoot()
b_n = B_N.getRoot()
Expand Down
4 changes: 2 additions & 2 deletions tests/integration/sddmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
T1_MN = Tensor(rank_ids=["M", "N"])
T1_MN = Tensor(rank_ids=["M", "N"], name="T1")
B_NK = B_KN.swizzleRanks(rank_ids=["N", "K"])
t1_m = T1_MN.getRoot()
a_m = A_MK.getRoot()
Expand All @@ -7,7 +7,7 @@
for n, (t1_ref, b_k) in t1_n << b_n:
for k, (a_val, b_val) in a_k & b_k:
t1_ref += a_val * b_val
Z_MN = Tensor(rank_ids=["M", "N"])
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
z_m = Z_MN.getRoot()
c_m = C_MN.getRoot()
t1_m = T1_MN.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/spmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Z_MN = Tensor(rank_ids=["M", "N"])
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
A_MK = A_KM.swizzleRanks(rank_ids=["M", "K"])
B_NK = B_KN.swizzleRanks(rank_ids=["N", "K"])
z_m = Z_MN.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/spmv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Z_M = Tensor(rank_ids=["M"])
Z_M = Tensor(rank_ids=["M"], name="Z")
z_m = Z_M.getRoot()
a_k = A_K.getRoot()
b_m = B_MK.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/ttm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Z_MNO = Tensor(rank_ids=["M", "N", "O"])
Z_MNO = Tensor(rank_ids=["M", "N", "O"], name="Z")
B_OK = B_KO.swizzleRanks(rank_ids=["O", "K"])
z_m = Z_MNO.getRoot()
a_m = A_MNK.getRoot()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration/ttv.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Z_MN = Tensor(rank_ids=["M", "N"])
Z_MN = Tensor(rank_ids=["M", "N"], name="Z")
z_m = Z_MN.getRoot()
a_m = A_MNK.getRoot()
b_k = B_K.getRoot()
Expand Down
12 changes: 6 additions & 6 deletions tests/trans/test_header.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def test_make_output():
Z: [M1, K, N, M0]
"""

hifiber = "Z_M1NM0 = Tensor(rank_ids=[\"M1\", \"N\", \"M0\"])"
hifiber = "Z_M1NM0 = Tensor(rank_ids=[\"M1\", \"N\", \"M0\"], name=\"Z\")"

header = build_matmul_header(mapping)
assert header.make_output().gen(depth=0) == hifiber
Expand All @@ -135,7 +135,7 @@ def test_make_output_shape():
- Z[m, n] = A[k, m]
"""

hifiber = "Z_MN = Tensor(rank_ids=[\"M\", \"N\"], shape=[M, N])"
hifiber = "Z_MN = Tensor(rank_ids=[\"M\", \"N\"], name=\"Z\", shape=[M, N])"

header = build_header(exprs, "")
assert header.make_output().gen(depth=0) == hifiber
Expand All @@ -152,27 +152,27 @@ def test_make_output_no_shape_flattening():
(M, N): [flatten()]
"""

hifiber = "Z_MN = Tensor(rank_ids=[\"MN\"])"
hifiber = "Z_MN = Tensor(rank_ids=[\"MN\"], name=\"Z\")"
header = build_header(exprs, mapping)
assert header.make_output().gen(depth=0) == hifiber


def test_make_output_conv_no_shape():
hifiber = "O_Q = Tensor(rank_ids=[\"Q\"])"
hifiber = "O_Q = Tensor(rank_ids=[\"Q\"], name=\"O\")"
header = build_header_conv("[S, Q]")

assert header.make_output().gen(0) == hifiber


def test_make_output_conv_shape():
hifiber = "O_Q = Tensor(rank_ids=[\"Q\"], shape=[Q])"
hifiber = "O_Q = Tensor(rank_ids=[\"Q\"], name=\"O\", shape=[Q])"
header = build_header_conv("[Q, S]")

assert header.make_output().gen(0) == hifiber


def test_make_output_metrics_shape():
hifiber = "T_MKN = Tensor(rank_ids=[\"M\", \"K\", \"N\"], shape=[M, K, N])"
hifiber = "T_MKN = Tensor(rank_ids=[\"M\", \"K\", \"N\"], name=\"T\", shape=[M, K, N])"
header = build_header_gamma()

assert header.make_output().gen(0) == hifiber
Expand Down
Loading

0 comments on commit c3aa2e8

Please sign in to comment.