From 9a317b59ad1407577254db63ec3cda6da4bcd4e1 Mon Sep 17 00:00:00 2001 From: Nandeeka Nayak Date: Tue, 10 Sep 2024 14:20:58 -0500 Subject: [PATCH] Fix flattened and repartitioned canvas --- teaal/trans/canvas.py | 6 +++++- tests/integration/demo.yaml | 41 +++++++++++++++++++++---------------- tests/trans/test_canvas.py | 17 +++++++++------ 3 files changed, 39 insertions(+), 25 deletions(-) diff --git a/teaal/trans/canvas.py b/teaal/trans/canvas.py index 8aa6edf..bb48d6e 100644 --- a/teaal/trans/canvas.py +++ b/teaal/trans/canvas.py @@ -152,7 +152,11 @@ def __build_access(self, rank: str) -> Expression: # as a tuple of the constituent ranks if part_ir.is_flattened(rank.upper()): flat_ranks = self.program.get_loop_order().get_iter_ranks(rank.upper()) - return ETuple([EVar(frank.lower()) for frank in flat_ranks]) + + if len(flat_ranks) == 1: + return EVar(flat_ranks[0].lower()) + else: + return ETuple([EVar(frank.lower()) for frank in flat_ranks]) sexpr = self.program.get_coord_math().get_trans(root.lower()) diff --git a/tests/integration/demo.yaml b/tests/integration/demo.yaml index 3989890..10d6669 100644 --- a/tests/integration/demo.yaml +++ b/tests/integration/demo.yaml @@ -1,26 +1,31 @@ einsum: declaration: - I: [B, C, H, W] - F: [C, M, R, S] - O: [B, M, P, Q] + A: [K, M] + B: [K, N] + T: [K, M, N] + Z: [M, N] expressions: - - O[b, m, p, q] = I[b, c, 4*p+r, 4*q+s]*F[c, m, r, s] + - T[k, m, n] = A[k, m] * B[k, n] + - Z[m, n] = T[k, m, n] mapping: rank-order: - I: [B, C, H, W] - F: [M, C, R, S] - O: [B, M, P, Q] + A: [K, M] + B: [K, N] + T: [M, K, N] + Z: [M, N] partitioning: - O: - M: - - uniform_shape(32) - - uniform_shape(16) - P: - - uniform_shape(P0) - H: [follow(P)] + T: + (K, M): [flatten()] + KM: [uniform_occupancy(A.4), uniform_occupancy(A.2)] + Z: + M: [uniform_occupancy(T.4), uniform_occupancy(T.2)] loop-order: - O: [C, M2, B, M1, P1, P0, R, Q, S, M0] + T: [KM2, KM1, KM0, N] + Z: [M2, M1, M0, N, K] spacetime: - O: - space: [P0] - time: [C, M2, B, M1, P1, R, Q, S, M0] + T: + space: [KM1, KM0] + time: [KM2, N] + Z: + space: [M1, M0] + time: [M2, N, K] diff --git a/tests/trans/test_canvas.py b/tests/trans/test_canvas.py index 6fe0383..8b1be95 100644 --- a/tests/trans/test_canvas.py +++ b/tests/trans/test_canvas.py @@ -274,24 +274,29 @@ def test_add_activity_flatten(): partitioning: Z: (M, N): [flatten()] + MN: [uniform_occupancy(A.5)] loop-order: - Z: [MN] + Z: [MN1, MN0] spacetime: Z: - space: [] - time: [MN] + space: [MN0] + time: [MN1] """ program = Program(Einsum.from_str(yaml), Mapping.from_str(yaml)) program.add_einsum(0) part_ir = program.get_partitioning() - for tensor in program.get_equation().get_tensors(): - program.apply_all_partitioning(tensor) + # Static partitioning + program.apply_all_partitioning(program.get_equation().get_output()) + program.apply_partitioning(program.get_equation().get_tensor("A"), ("M", "N")) canvas = Canvas(program) canvas.create_canvas() - hifiber = "canvas.addActivity(((m, n),), ((m, n),), spacetime=((), (mn_pos,)))" + # Dynamic partitioning + program.apply_partitioning(program.get_equation().get_tensor("A"), ("MN",)) + + hifiber = "canvas.addActivity(((m, n),), (mn1, (m, n)), spacetime=((mn0_pos,), (mn1_pos,)))" assert canvas.add_activity().gen(0) == hifiber