Skip to content

Commit

Permalink
Fix partitioner bug
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Sep 24, 2024
1 parent b5d5c53 commit ea99815
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion teaal/trans/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,12 @@ def unpartition(self, tensor: Tensor) -> Statement:
swizzled_ranks, part_ir.get_all_parts(), False)
for part in valid_parts:
trans.append((part, tensor.get_ranks()))

# If this is a partition, apply all partitioning
# If this is a flatten, just flatten
tensor.update_ranks(
part_ir.partition_ranks(
tensor.get_ranks(), {part}, True, False))
tensor.get_ranks(), {part}, len(part) == 1, False))

new_ranks = tensor.get_ranks()

Expand Down

0 comments on commit ea99815

Please sign in to comment.