Skip to content

Commit

Permalink
Index metrics dictionary with string tensor name
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Oct 26, 2023
1 parent ff8c108 commit 415f190
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 9 deletions.
17 changes: 13 additions & 4 deletions teaal/trans/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,11 @@ def __build_intersections(self) -> Statement:
for intersector in self.metrics.get_hardware().get_components(einsum,
IntersectorComponent):
isect_name = intersector.get_name()
block.add(SAssign(AAccess(metrics_einsum, EString(isect_name)), EDict({})))
block.add(
SAssign(
AAccess(
metrics_einsum, EString(isect_name)), EDict(
{})))
metrics_isect = EAccess(metrics_einsum, EString(isect_name))
metrics_isect_op = AAccess(metrics_isect, EString("intersect"))
block.add(SAssign(metrics_isect_op, EInt(0)))
Expand All @@ -563,7 +567,12 @@ def __build_intersections(self) -> Statement:
# op_freq = cycles / s * ops / cycle
op_freq = self.metrics.get_hardware().get_frequency(einsum) * \
intersector.get_num_instances()
time = EBinOp(EAccess(metrics_isect, EString("intersect")), ODiv(), EInt(op_freq))
time = EBinOp(
EAccess(
metrics_isect,
EString("intersect")),
ODiv(),
EInt(op_freq))

metrics_time = AAccess(metrics_isect, EString("time"))
block.add(SAssign(metrics_time, time))
Expand Down Expand Up @@ -624,8 +633,8 @@ def __build_merges(self) -> Statement:
final_ranks = binding["final-ranks"]

input_ = binding["tensor"] + "_" + "".join(init_ranks)
tensors.append(input_)
tensor_name = EVar(input_)
tensors.append(tensor_name)

# TODO: Way more complicated merges are possible than a single
# swap
Expand Down Expand Up @@ -665,7 +674,7 @@ def __build_merges(self) -> Statement:
time = EBinOp(
EAccess(
metrics_merger,
tensors[0]),
EString(tensors[0])),
ODiv(),
EInt(op_freq))

Expand Down
4 changes: 2 additions & 2 deletions tests/trans/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ def test_dump_gamma_Z():
"metrics[\"Z\"][\"MainMemory\"][\"time\"] = (metrics[\"Z\"][\"MainMemory\"][\"Z\"][\"read\"] + metrics[\"Z\"][\"MainMemory\"][\"Z\"][\"write\"]) / 1099511627776\n" + \
"metrics[\"Z\"][\"HighRadixMerger\"] = {}\n" + \
"metrics[\"Z\"][\"HighRadixMerger\"][\"T_MKN\"] = Compute.numSwaps(T_MKN, 1, 64, 1)\n" + \
"metrics[\"Z\"][\"HighRadixMerger\"][\"time\"] = metrics[\"Z\"][\"HighRadixMerger\"][T_MKN] / 32000000000\n" + \
"metrics[\"Z\"][\"HighRadixMerger\"][\"time\"] = metrics[\"Z\"][\"HighRadixMerger\"][\"T_MKN\"] / 32000000000\n" + \
"metrics[\"Z\"][\"FPMul\"] = {}\n" + \
"metrics[\"Z\"][\"FPMul\"][\"mul\"] = Metrics.dump()[\"Compute\"][\"payload_mul\"]\n" + \
"metrics[\"Z\"][\"FPMul\"][\"time\"] = metrics[\"Z\"][\"FPMul\"][\"mul\"] / 32000000000\n" + \
Expand Down Expand Up @@ -386,7 +386,7 @@ def test_dump_outerspace_Z():
"metrics[\"Z\"][\"MainMemory\"][\"time\"] = (metrics[\"Z\"][\"MainMemory\"][\"Z\"][\"read\"] + metrics[\"Z\"][\"MainMemory\"][\"Z\"][\"write\"]) / 1099511627776\n" + \
"metrics[\"Z\"][\"SortHW\"] = {}\n" + \
"metrics[\"Z\"][\"SortHW\"][\"T1_MKN\"] = Compute.numSwaps(T1_MKN, 1, float(\"inf\"), \"N\")\n" + \
"metrics[\"Z\"][\"SortHW\"][\"time\"] = metrics[\"Z\"][\"SortHW\"][T1_MKN] / 193500000000\n" + \
"metrics[\"Z\"][\"SortHW\"][\"time\"] = metrics[\"Z\"][\"SortHW\"][\"T1_MKN\"] / 193500000000\n" + \
"metrics[\"Z\"][\"FPAdd\"] = {}\n" + \
"metrics[\"Z\"][\"FPAdd\"][\"add\"] = Metrics.dump()[\"Compute\"][\"payload_add\"]\n" + \
"metrics[\"Z\"][\"FPAdd\"][\"time\"] = metrics[\"Z\"][\"FPAdd\"][\"add\"] / 193500000000\n" + \
Expand Down
7 changes: 4 additions & 3 deletions tests/trans/test_hifiber.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,10 @@ def test_hifiber_intersect():
"metrics = {}\n" + \
"metrics[\"Z\"] = {}\n" + \
"formats = {}\n" + \
"metrics[\"Z\"][\"TF\"] = 0\n" + \
"metrics[\"Z\"][\"TF\"] += TF_K.getNumIntersects()\n" + \
"metrics[\"Z\"][\"TF\"][\"time\"] = metrics[\"Z\"][\"TF\"] / 2048\n" + \
"metrics[\"Z\"][\"TF\"] = {}\n" + \
"metrics[\"Z\"][\"TF\"][\"intersect\"] = 0\n" + \
"metrics[\"Z\"][\"TF\"][\"intersect\"] += TF_K.getNumIntersects()\n" + \
"metrics[\"Z\"][\"TF\"][\"time\"] = metrics[\"Z\"][\"TF\"][\"intersect\"] / 2048\n" + \
"metrics[\"blocks\"] = [[\"Z\"]]\n" + \
"metrics[\"time\"] = metrics[\"Z\"][\"TF\"][\"time\"]"

Expand Down

0 comments on commit 415f190

Please sign in to comment.