Skip to content

Commit

Permalink
Fix flattened and repartitioned canvas
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Sep 10, 2024
1 parent 5e4f221 commit 9a317b5
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 25 deletions.
6 changes: 5 additions & 1 deletion teaal/trans/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ def __build_access(self, rank: str) -> Expression:
# as a tuple of the constituent ranks
if part_ir.is_flattened(rank.upper()):
flat_ranks = self.program.get_loop_order().get_iter_ranks(rank.upper())
return ETuple([EVar(frank.lower()) for frank in flat_ranks])

if len(flat_ranks) == 1:
return EVar(flat_ranks[0].lower())
else:
return ETuple([EVar(frank.lower()) for frank in flat_ranks])

sexpr = self.program.get_coord_math().get_trans(root.lower())

Expand Down
41 changes: 23 additions & 18 deletions tests/integration/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
einsum:
declaration:
I: [B, C, H, W]
F: [C, M, R, S]
O: [B, M, P, Q]
A: [K, M]
B: [K, N]
T: [K, M, N]
Z: [M, N]
expressions:
- O[b, m, p, q] = I[b, c, 4*p+r, 4*q+s]*F[c, m, r, s]
- T[k, m, n] = A[k, m] * B[k, n]
- Z[m, n] = T[k, m, n]
mapping:
rank-order:
I: [B, C, H, W]
F: [M, C, R, S]
O: [B, M, P, Q]
A: [K, M]
B: [K, N]
T: [M, K, N]
Z: [M, N]
partitioning:
O:
M:
- uniform_shape(32)
- uniform_shape(16)
P:
- uniform_shape(P0)
H: [follow(P)]
T:
(K, M): [flatten()]
KM: [uniform_occupancy(A.4), uniform_occupancy(A.2)]
Z:
M: [uniform_occupancy(T.4), uniform_occupancy(T.2)]
loop-order:
O: [C, M2, B, M1, P1, P0, R, Q, S, M0]
T: [KM2, KM1, KM0, N]
Z: [M2, M1, M0, N, K]
spacetime:
O:
space: [P0]
time: [C, M2, B, M1, P1, R, Q, S, M0]
T:
space: [KM1, KM0]
time: [KM2, N]
Z:
space: [M1, M0]
time: [M2, N, K]
17 changes: 11 additions & 6 deletions tests/trans/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,24 +274,29 @@ def test_add_activity_flatten():
partitioning:
Z:
(M, N): [flatten()]
MN: [uniform_occupancy(A.5)]
loop-order:
Z: [MN]
Z: [MN1, MN0]
spacetime:
Z:
space: []
time: [MN]
space: [MN0]
time: [MN1]
"""
program = Program(Einsum.from_str(yaml), Mapping.from_str(yaml))
program.add_einsum(0)
part_ir = program.get_partitioning()

for tensor in program.get_equation().get_tensors():
program.apply_all_partitioning(tensor)
# Static partitioning
program.apply_all_partitioning(program.get_equation().get_output())
program.apply_partitioning(program.get_equation().get_tensor("A"), ("M", "N"))

canvas = Canvas(program)
canvas.create_canvas()

hifiber = "canvas.addActivity(((m, n),), ((m, n),), spacetime=((), (mn_pos,)))"
# Dynamic partitioning
program.apply_partitioning(program.get_equation().get_tensor("A"), ("MN",))

hifiber = "canvas.addActivity(((m, n),), (mn1, (m, n)), spacetime=((mn0_pos,), (mn1_pos,)))"
assert canvas.add_activity().gen(0) == hifiber


Expand Down

0 comments on commit 9a317b5

Please sign in to comment.