diff --git a/teaal/ir/metrics.py b/teaal/ir/metrics.py index 31c43f3..bc7d5af 100644 --- a/teaal/ir/metrics.py +++ b/teaal/ir/metrics.py @@ -146,7 +146,10 @@ def get_eager_evict_on(self, tensor: str, rank: str) -> List[str]: if (tensor, rank) in evicts: ranks.append(loop_rank) - ranks.sort(key=self.program.get_loop_order().get_ranks().index) + ranks.sort( + key=( + ["root"] + + self.program.get_loop_order().get_ranks()).index) return ranks def get_eager_evicts(self, rank: str) -> List[Tuple[str, str]]: @@ -248,7 +251,8 @@ def get_source_memory( component + " not a memory") - inds = [i for i, (comp, _) in enumerate(path) if comp.get_name() == component] + inds = [i for i, (comp, _) in enumerate( + path) if comp.get_name() == component] if not inds: return None diff --git a/teaal/trans/collector.py b/teaal/trans/collector.py index ade160b..2824662 100644 --- a/teaal/trans/collector.py +++ b/teaal/trans/collector.py @@ -295,19 +295,27 @@ def trace_tree( """ loop_ranks = self.program.get_loop_order().get_ranks() i = loop_ranks.index(rank) - output = self.program.get_equation().get_output() + tensor_ir = self.program.get_equation().get_tensor(tensor) # If this is a write trace and this is the current outer-most loop, # then just use the outermost loop of the tensor if not is_read_trace and i == 0: - rank = output.get_ranks()[0] + tensor_rank = tensor_ir.get_ranks()[0] # Otherwise, get the rank of the current fiber elif not is_read_trace and i > 0: - output_ranks = output.get_ranks() - rank = output_ranks[output_ranks.index(loop_ranks[i - 1]) + 1] + output_ranks = tensor_ir.get_ranks() + tensor_rank = output_ranks[output_ranks.index( + loop_ranks[i - 1]) + 1] + + else: + avail = self.program.get_partitioning().get_available(rank) + poss_ranks = [ + rank for rank in tensor_ir.get_ranks() if rank in avail] + assert len(poss_ranks) == 1 + tensor_rank = poss_ranks[0] - fiber = tensor.lower() + "_" + rank.lower() + fiber = tensor.lower() + "_" + tensor_rank.lower() trace = "eager_" + fiber if is_read_trace: @@ -318,7 +326,7 @@ def trace_tree( args: List[Argument] = [AJust(EString(trace))] if not is_read_trace: # We want to use the iteration number for the last loop rank - final_tensor = Tensor(output.root_name(), output.get_init_ranks()) + final_tensor = Tensor(tensor, tensor_ir.get_init_ranks()) self.program.apply_all_partitioning(final_tensor) self.program.get_loop_order().apply(final_tensor) @@ -330,10 +338,8 @@ def trace_tree( return trace_stmt # If read, only read the first time - tensor_ir = self.program.get_equation().get_tensor(tensor) - + evict_rank = self.metrics.get_eager_evict_on(tensor, tensor_rank)[-1] get_final = self.program.get_partitioning().get_final_rank_id - evict_rank = self.metrics.get_eager_evict_on(tensor, rank)[-1] er_ind = loop_ranks.index(get_final([evict_rank], evict_rank)) tree_ind = loop_ranks.index(get_final([rank], rank)) diff --git a/tests/trans/test_collector.py b/tests/trans/test_collector.py index 02574fe..18f6163 100644 --- a/tests/trans/test_collector.py +++ b/tests/trans/test_collector.py @@ -357,6 +357,7 @@ def test_dump_gamma_Z(): assert collector.dump().gen(0) == hifiber + def test_dump_outerspace_T0(): yaml = build_outerspace_yaml() collector = build_collector(yaml, 0) @@ -847,6 +848,16 @@ def test_make_loop_header_unconfigured(): def test_make_loop_header(): yaml = build_extensor_yaml() collector = build_collector(yaml, 0) + program = collector.program + part_ir = program.get_partitioning() + for tensor in program.get_equation().get_tensors(): + tensor.update_ranks( + part_ir.partition_ranks( + tensor.get_ranks(), + part_ir.get_all_parts(), + True, + True)) + program.get_loop_order().apply(tensor) collector.start() assert collector.make_loop_header("N2").gen(0) == "" @@ -856,18 +867,18 @@ def test_make_loop_header(): assert collector.make_loop_header("M1").gen(0) == hifiber - hifiber_option1 = "if () not in eager_z_m0_read:\n" + \ - " eager_z_m0_read.add(())\n" + \ + hifiber_option1 = "if (m1, n1) not in eager_z_m0_read:\n" + \ + " eager_z_m0_read.add((m1, n1))\n" + \ " z_m0.trace(\"eager_z_m0_read\")\n" + \ - "if () not in eager_a_m0_read:\n" + \ - " eager_a_m0_read.add(())\n" + \ + "if (m1, k1) not in eager_a_m0_read:\n" + \ + " eager_a_m0_read.add((m1, k1))\n" + \ " a_m0.trace(\"eager_a_m0_read\")" - hifiber_option2 = "if () not in eager_a_m0_read:\n" + \ - " eager_a_m0_read.add(())\n" + \ + hifiber_option2 = "if (m1, k1) not in eager_a_m0_read:\n" + \ + " eager_a_m0_read.add((m1, k1))\n" + \ " a_m0.trace(\"eager_a_m0_read\")\n" + \ - "if () not in eager_z_m0_read:\n" + \ - " eager_z_m0_read.add(())\n" + \ + "if (m1, n1) not in eager_z_m0_read:\n" + \ + " eager_z_m0_read.add((m1, n1))\n" + \ " z_m0.trace(\"eager_z_m0_read\")" assert collector.make_loop_header("M0").gen(