Skip to content

Commit

Permalink
Fix eager evict in the presence of flattening
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Oct 26, 2023
1 parent b784690 commit f4b1a9f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 19 deletions.
8 changes: 6 additions & 2 deletions teaal/ir/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,10 @@ def get_eager_evict_on(self, tensor: str, rank: str) -> List[str]:
if (tensor, rank) in evicts:
ranks.append(loop_rank)

ranks.sort(key=self.program.get_loop_order().get_ranks().index)
ranks.sort(
key=(
["root"] +
self.program.get_loop_order().get_ranks()).index)
return ranks

def get_eager_evicts(self, rank: str) -> List[Tuple[str, str]]:
Expand Down Expand Up @@ -248,7 +251,8 @@ def get_source_memory(
component +
" not a memory")

inds = [i for i, (comp, _) in enumerate(path) if comp.get_name() == component]
inds = [i for i, (comp, _) in enumerate(
path) if comp.get_name() == component]
if not inds:
return None

Expand Down
24 changes: 15 additions & 9 deletions teaal/trans/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,19 +295,27 @@ def trace_tree(
"""
loop_ranks = self.program.get_loop_order().get_ranks()
i = loop_ranks.index(rank)
output = self.program.get_equation().get_output()
tensor_ir = self.program.get_equation().get_tensor(tensor)

# If this is a write trace and this is the current outer-most loop,
# then just use the outermost loop of the tensor
if not is_read_trace and i == 0:
rank = output.get_ranks()[0]
tensor_rank = tensor_ir.get_ranks()[0]

# Otherwise, get the rank of the current fiber
elif not is_read_trace and i > 0:
output_ranks = output.get_ranks()
rank = output_ranks[output_ranks.index(loop_ranks[i - 1]) + 1]
output_ranks = tensor_ir.get_ranks()
tensor_rank = output_ranks[output_ranks.index(
loop_ranks[i - 1]) + 1]

else:
avail = self.program.get_partitioning().get_available(rank)
poss_ranks = [
rank for rank in tensor_ir.get_ranks() if rank in avail]
assert len(poss_ranks) == 1
tensor_rank = poss_ranks[0]

fiber = tensor.lower() + "_" + rank.lower()
fiber = tensor.lower() + "_" + tensor_rank.lower()

trace = "eager_" + fiber
if is_read_trace:
Expand All @@ -318,7 +326,7 @@ def trace_tree(
args: List[Argument] = [AJust(EString(trace))]
if not is_read_trace:
# We want to use the iteration number for the last loop rank
final_tensor = Tensor(output.root_name(), output.get_init_ranks())
final_tensor = Tensor(tensor, tensor_ir.get_init_ranks())
self.program.apply_all_partitioning(final_tensor)
self.program.get_loop_order().apply(final_tensor)

Expand All @@ -330,10 +338,8 @@ def trace_tree(
return trace_stmt

# If read, only read the first time
tensor_ir = self.program.get_equation().get_tensor(tensor)

evict_rank = self.metrics.get_eager_evict_on(tensor, tensor_rank)[-1]
get_final = self.program.get_partitioning().get_final_rank_id
evict_rank = self.metrics.get_eager_evict_on(tensor, rank)[-1]
er_ind = loop_ranks.index(get_final([evict_rank], evict_rank))
tree_ind = loop_ranks.index(get_final([rank], rank))

Expand Down
27 changes: 19 additions & 8 deletions tests/trans/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,7 @@ def test_dump_gamma_Z():

assert collector.dump().gen(0) == hifiber


def test_dump_outerspace_T0():
yaml = build_outerspace_yaml()
collector = build_collector(yaml, 0)
Expand Down Expand Up @@ -847,6 +848,16 @@ def test_make_loop_header_unconfigured():
def test_make_loop_header():
yaml = build_extensor_yaml()
collector = build_collector(yaml, 0)
program = collector.program
part_ir = program.get_partitioning()
for tensor in program.get_equation().get_tensors():
tensor.update_ranks(
part_ir.partition_ranks(
tensor.get_ranks(),
part_ir.get_all_parts(),
True,
True))
program.get_loop_order().apply(tensor)
collector.start()

assert collector.make_loop_header("N2").gen(0) == ""
Expand All @@ -856,18 +867,18 @@ def test_make_loop_header():

assert collector.make_loop_header("M1").gen(0) == hifiber

hifiber_option1 = "if () not in eager_z_m0_read:\n" + \
" eager_z_m0_read.add(())\n" + \
hifiber_option1 = "if (m1, n1) not in eager_z_m0_read:\n" + \
" eager_z_m0_read.add((m1, n1))\n" + \
" z_m0.trace(\"eager_z_m0_read\")\n" + \
"if () not in eager_a_m0_read:\n" + \
" eager_a_m0_read.add(())\n" + \
"if (m1, k1) not in eager_a_m0_read:\n" + \
" eager_a_m0_read.add((m1, k1))\n" + \
" a_m0.trace(\"eager_a_m0_read\")"

hifiber_option2 = "if () not in eager_a_m0_read:\n" + \
" eager_a_m0_read.add(())\n" + \
hifiber_option2 = "if (m1, k1) not in eager_a_m0_read:\n" + \
" eager_a_m0_read.add((m1, k1))\n" + \
" a_m0.trace(\"eager_a_m0_read\")\n" + \
"if () not in eager_z_m0_read:\n" + \
" eager_z_m0_read.add(())\n" + \
"if (m1, n1) not in eager_z_m0_read:\n" + \
" eager_z_m0_read.add((m1, n1))\n" + \
" z_m0.trace(\"eager_z_m0_read\")"

assert collector.make_loop_header("M0").gen(
Expand Down

0 comments on commit f4b1a9f

Please sign in to comment.