Skip to content

Commit

Permalink
Clean up footer merging
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Sep 12, 2024
1 parent 9a317b5 commit 600ef0a
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 35 deletions.
15 changes: 13 additions & 2 deletions teaal/ir/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,20 @@ def swizzle_for_flattening(self, tensor_ranks: List[str]) -> List[str]:
used_parts = self.__used_parts(new_ranks, self.all_parts, True)
for part in used_parts:
if len(part) > 1:
for rank in part:
# We do not want to re-order unaffected ranks

# Start and end of flattened ranks
min_pos = min(new_ranks.index(rank) for rank in part)
max_pos = max(new_ranks.index(rank) for rank in part)

# Extra ranks that need to be moved around
extras = max_pos - min_pos - len(part) + 1

for i, rank in enumerate(part):
del new_ranks[new_ranks.index(rank)]
new_ranks.append(rank)

for i, rank in enumerate(part):
new_ranks.insert(min_pos + extras + i, rank)

return new_ranks

Expand Down
51 changes: 22 additions & 29 deletions tests/integration/demo.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
einsum:
declaration:
A: [K, M]
B: [K, N]
T: [K, M, N]
Z: [M, N]
expressions:
- T[k, m, n] = A[k, m] * B[k, n]
- Z[m, n] = T[k, m, n]
declaration:
A: [K, M]
B: [K, N]
T: [K, M, N]
Z: [M, N]
expressions:
- T[k, m, n] = A[k, m] * B[k, n]
mapping:
rank-order:
A: [K, M]
B: [K, N]
T: [M, K, N]
Z: [M, N]
partitioning:
T:
(K, M): [flatten()]
KM: [uniform_occupancy(A.4), uniform_occupancy(A.2)]
Z:
M: [uniform_occupancy(T.4), uniform_occupancy(T.2)]
loop-order:
T: [KM2, KM1, KM0, N]
Z: [M2, M1, M0, N, K]
spacetime:
T:
space: [KM1, KM0]
time: [KM2, N]
Z:
space: [M1, M0]
time: [M2, N, K]
rank-order:
A: [K, M]
B: [K, N]
T: [M, K, N]
Z: [M, N]
partitioning:
T:
K: [ uniform_shape(16) ]
(K0, M): [ flatten() ]
K0M: [ uniform_occupancy(A.64) ]
Z:
K: [ uniform_occupancy(T.64) ]
loop-order:
T: [K1, K0M1, K0M0, N]
Z: [M, K1, K0, N]
17 changes: 15 additions & 2 deletions tests/ir/test_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def parse_partitioning(parts):

def build_part_dict(parts):
parsed = parse_partitioning(parts)
return {tuple(str(child) for child in key.children): val for key, val in parsed["Z"].items()}
return {tuple(str(child) for child in key.children)
: val for key, val in parsed["Z"].items()}


def build_partitioning(parts):
Expand Down Expand Up @@ -906,7 +907,19 @@ def test_swizzle_for_flattening():

assert partitioning.swizzle_for_flattening(["K", "M"]) == ["K", "M"]
assert partitioning.swizzle_for_flattening(["K1", "K0", "J", "M", "N"]) == [
"K1", "J", "N", "M", "K0"]
"K1", "J", "M", "K0", "N"]


def test_swizzle_for_flattening_flexagon():
all_parts = """
K: [ uniform_shape(16) ]
(K0, M): [ flatten() ]
K0M: [ uniform_occupancy(A.64) ]
"""
partitioning = build_partitioning(all_parts)

assert partitioning.swizzle_for_flattening(["M", "K1", "K0", "N"]) == [
"K1", "K0", "M", "N"]


def test_unpack():
Expand Down
2 changes: 1 addition & 1 deletion tests/ir/test_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_apply_partition_swizzling():

program.apply_partitioning(A, ("K",))
program.apply_partition_swizzling(A)
assert A.get_ranks() == ["J", "K1", "N", "M", "K0"]
assert A.get_ranks() == ["J", "K1", "M", "K0", "N"]


def test_get_all_einsums():
Expand Down
3 changes: 2 additions & 1 deletion tests/trans/test_canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,8 @@ def test_add_activity_flatten():

# Static partitioning
program.apply_all_partitioning(program.get_equation().get_output())
program.apply_partitioning(program.get_equation().get_tensor("A"), ("M", "N"))
program.apply_partitioning(
program.get_equation().get_tensor("A"), ("M", "N"))

canvas = Canvas(program)
canvas.create_canvas()
Expand Down

0 comments on commit 600ef0a

Please sign in to comment.