From fc7ceff93748033858747e295d3441ad1b395389 Mon Sep 17 00:00:00 2001 From: Nandeeka Nayak Date: Thu, 26 Oct 2023 15:24:18 -0500 Subject: [PATCH] Fix write tracing --- teaal/trans/collector.py | 16 ++++++++++++++-- tests/trans/test_collector.py | 6 +++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/teaal/trans/collector.py b/teaal/trans/collector.py index 6202c98..40fca04 100644 --- a/teaal/trans/collector.py +++ b/teaal/trans/collector.py @@ -293,6 +293,20 @@ def trace_tree( """ Trace a subtree under the fiber specified """ + loop_ranks = self.program.get_loop_order().get_ranks() + i = loop_ranks.index(rank) + output = self.program.get_equation().get_output() + + # If this is a write trace and this is the current outer-most loop, + # then just use the outermost loop of the tensor + if not is_read_trace and i == 0: + rank = output.get_ranks()[0] + + # Otherwise, get the rank of the current fiber + elif not is_read_trace and i > 0: + output_ranks = output.get_ranks() + rank = output_ranks[output_ranks.index(loop_ranks[i - 1]) + 1] + fiber = tensor.lower() + "_" + rank.lower() trace = "eager_" + fiber @@ -304,7 +318,6 @@ def trace_tree( args: List[Argument] = [AJust(EString(trace))] if not is_read_trace: # We want to use the iteration number for the last loop rank - output = self.program.get_equation().get_tensor(tensor) final_tensor = Tensor(output.root_name(), output.get_init_ranks()) self.program.apply_all_partitioning(final_tensor) self.program.get_loop_order().apply(final_tensor) @@ -317,7 +330,6 @@ def trace_tree( return trace_stmt # If read, only read the first time - loop_ranks = self.program.get_loop_order().get_ranks() tensor_ir = self.program.get_equation().get_tensor(tensor) get_final = self.program.get_partitioning().get_final_rank_id diff --git a/tests/trans/test_collector.py b/tests/trans/test_collector.py index 42608c2..929fa95 100644 --- a/tests/trans/test_collector.py +++ b/tests/trans/test_collector.py @@ -787,7 +787,7 @@ def test_make_loop_footer(): assert collector.make_loop_footer("K0").gen(0) == hifiber hifiber = "K1Intersect_K1.addTraces(Metrics.consumeTrace(\"K1\", \"intersect_0\"), Metrics.consumeTrace(\"K1\", \"intersect_1\"))\n" + \ - "z_k1.trace(\"eager_z_k1_write\", iteration_num=n0_iter_num)" + "z_m0.trace(\"eager_z_m0_write\", iteration_num=n0_iter_num)" assert collector.make_loop_footer("K1").gen(0) == hifiber @@ -896,7 +896,7 @@ def test_make_loop_header_eager_root(): collector = build_collector(yaml, 0) collector.start() - hifiber = "z_k.trace(\"eager_z_k_write\", iteration_num=m_iter_num)" + hifiber = "z_m.trace(\"eager_z_m_write\", iteration_num=m_iter_num)" assert collector.make_loop_footer("K").gen(0) == hifiber @@ -1085,4 +1085,4 @@ def test_trace_tree(): assert collector.trace_tree("A", "M0", True).gen(0) == hifiber hifiber = "z_m0.trace(\"eager_z_m0_write\", iteration_num=n0_iter_num)" - assert collector.trace_tree("Z", "M0", False).gen(0) == hifiber + assert collector.trace_tree("Z", "K1", False).gen(0) == hifiber