Skip to content

Commit

Permalink
Compute overall execution time automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
nandeeka committed Oct 26, 2023
1 parent 24817ad commit cb89745
Show file tree
Hide file tree
Showing 12 changed files with 355 additions and 79 deletions.
43 changes: 31 additions & 12 deletions teaal/ir/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,26 +27,31 @@
from typing import List, Optional, Set

from teaal.ir.component import *
from teaal.ir.metrics import Metrics
from teaal.ir.hardware import Hardware
from teaal.ir.program import Program


class Fusion:
"""
Representation of the fusion schedule of the accelerator
"""

def __init__(self) -> None:
def __init__(self, hardware: Hardware) -> None:
"""
Construct a new fusion object
"""
self.hardware = hardware

self.blocks: List[List[str]] = []
self.curr_block: List[str] = []
self.fused_ranks: List[str] = []

self.curr_config: Optional[str] = None
self.components_used: Set[str] = set()

def add_einsum(self, program: Program, metrics: Metrics) -> None:
self.component_dict: Dict[str, List[str]] = {}

def add_einsum(self, program: Program) -> None:
"""
Add the information corresponding to this Einsum
"""
Expand All @@ -55,11 +60,12 @@ def add_einsum(self, program: Program, metrics: Metrics) -> None:

spacetime = program.get_spacetime()
if not spacetime:
raise ValueError("Undefined spacetime for einsum " + einsum)
raise ValueError("Undefined spacetime for Einsum " + einsum)

space_ranks = spacetime.get_space()

# Get the temporal ranks in all loop orders before the first spatial rank
# Get the temporal ranks in all loop orders before the first spatial
# rank
fused_ranks: List[str]
if space_ranks:
fused_ranks = loop_ranks[:loop_ranks.index(space_ranks[0])]
Expand All @@ -68,19 +74,17 @@ def add_einsum(self, program: Program, metrics: Metrics) -> None:

# Get the components used for this Einsum
components_used = set()
for component in metrics.get_hardware().get_components(einsum, FunctionalComponent):
if einsum not in component.get_bindings().keys():
continue

for component in self.hardware.get_components(
einsum, FunctionalComponent):
if component.get_bindings()[einsum]:
components_used.add(component.get_name())

# Get the config
config = metrics.get_hardware().get_config(einsum)

config = self.hardware.get_config(einsum)

# Check if the fusion conditions are met
if config == self.curr_config and fused_ranks == self.fused_ranks and not self.components_used.intersection(components_used):
if config == self.curr_config and fused_ranks == self.fused_ranks and not self.components_used.intersection(
components_used):
self.curr_block.append(einsum)
self.components_used = self.components_used.union(components_used)

Expand All @@ -91,8 +95,23 @@ def add_einsum(self, program: Program, metrics: Metrics) -> None:
self.fused_ranks = fused_ranks
self.curr_config = config

# Prepare to record the components contributing to the exectuion time
self.component_dict[einsum] = []

def add_component(self, einsum: str, component: str) -> None:
"""
Add a component whose time is being tracked
"""
self.component_dict[einsum].append(component)

def get_blocks(self) -> List[List[str]]:
"""
Get the Einsums organized by their fusion blocks
"""
return self.blocks

def get_components(self, einsum: str) -> List[str]:
"""
Get the names of the components used for this Einsum
"""
return self.component_dict[einsum]
8 changes: 6 additions & 2 deletions teaal/ir/hardware.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,14 @@ def get_frequency(self, einsum: str) -> int:
freq = top_level.get_attr("clock_frequency")

if freq is None:
raise ValueError("Unspecified clock frequency for config " + self.configs[einsum])
raise ValueError(
"Unspecified clock frequency for config " +
self.configs[einsum])

if isinstance(freq, str):
raise ValueError("Bad clock frequency for config " + self.configs[einsum])
raise ValueError(
"Bad clock frequency for config " +
self.configs[einsum])

return freq

Expand Down
1 change: 0 additions & 1 deletion teaal/ir/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,6 @@ def __build_traffic_paths(self) -> None:
# Build the set of specs to collect
einsum = self.program.get_equation().get_output().root_name()

# TODO: What about eager binding?
for format_ in self.format_options[tensor]:
for rank in spec[format_]:
if rank == "rank-order":
Expand Down
113 changes: 100 additions & 13 deletions teaal/trans/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from teaal.hifiber import *
from teaal.ir.component import *
from teaal.ir.fusion import Fusion
from teaal.ir.metrics import Metrics
from teaal.ir.program import Program
from teaal.ir.tensor import Tensor
Expand All @@ -37,12 +38,17 @@ class Collector:
Translate the metrics collection
"""

def __init__(self, program: Program, metrics: Metrics) -> None:
def __init__(
self,
program: Program,
metrics: Metrics,
fusion: Fusion) -> None:
"""
Construct a collector object
"""
self.program = program
self.metrics = metrics
self.fusion = fusion

# tree_traces: Optional[Dict[rank, Dict[is_read, Set[tensor]]]]
self.tree_traces: Optional[Dict[str, Dict[bool, Set[str]]]] = None
Expand Down Expand Up @@ -122,6 +128,11 @@ def dump(self) -> Statement:
# Track the sequences
block.add(self.__build_sequencers())

# Add the final execution time modeling
num_einsums = len(self.program.get_all_einsums())
if self.program.get_einsum_ind() + 1 == num_einsums:
block.add(self.__build_time())

return block

@staticmethod
Expand Down Expand Up @@ -435,12 +446,13 @@ def __build_compute(self) -> Statement:
assert len(ops) == 1

# op_freq = cycles / s * ops / cycle
op_freq = self.metrics.get_hardware().get_frequency(einsum) * fu.get_num_instances()
op_freq = self.metrics.get_hardware().get_frequency(einsum) * \
fu.get_num_instances()
time = EBinOp(EAccess(metrics_fu, ops[0]), ODiv(), EInt(op_freq))

metrics_time = AAccess(metrics_fu, EString("time"))
block.add(SAssign(metrics_time, time))

self.fusion.add_component(einsum, fu.get_name())

return block

Expand Down Expand Up @@ -535,12 +547,14 @@ def __build_intersections(self) -> Statement:
block.add(SIAssign(metrics_isect, OAdd(), isects))

# op_freq = cycles / s * ops / cycle
op_freq = self.metrics.get_hardware().get_frequency(einsum) * intersector.get_num_instances()
op_freq = self.metrics.get_hardware().get_frequency(einsum) * \
intersector.get_num_instances()
metrics_isect_expr = EAccess(metrics_einsum, EString(isect_name))
time = EBinOp(metrics_isect_expr, ODiv(), EInt(op_freq))

metrics_time = AAccess(metrics_isect_expr, EString("time"))
block.add(SAssign(metrics_time, time))
self.fusion.add_component(einsum, intersector.get_name())

return block

Expand Down Expand Up @@ -633,11 +647,18 @@ def __build_merges(self) -> Statement:
assert len(tensors) == 1

# op_freq = cycles / s * ops / cycle
op_freq = self.metrics.get_hardware().get_frequency(einsum) * merger.get_num_instances()
time = EBinOp(EAccess(metrics_merger, tensors[0]), ODiv(), EInt(op_freq))
op_freq = self.metrics.get_hardware().get_frequency(einsum) * \
merger.get_num_instances()
time = EBinOp(
EAccess(
metrics_merger,
tensors[0]),
ODiv(),
EInt(op_freq))

metrics_time = AAccess(metrics_merger, EString("time"))
block.add(SAssign(metrics_time, time))
self.fusion.add_component(einsum, merger.get_name())

return block

Expand Down Expand Up @@ -678,15 +699,73 @@ def __build_sequencers(self) -> Statement:

assert steps is not None

op_freq = self.metrics.get_hardware().get_frequency(einsum) * seq.get_num_instances()
op_freq = self.metrics.get_hardware().get_frequency(einsum) * \
seq.get_num_instances()
time = EBinOp(EParens(steps), ODiv(), EInt(op_freq))

metrics_time = AAccess(seq_expr, EString("time"))
block.add(SAssign(metrics_time, time))
print(SAssign(metrics_time, time).gen(0))
self.fusion.add_component(einsum, seq.get_name())

return block

def __build_time(self) -> Statement:
"""
Add the code necessary to compute the final execution time
"""
sblock = SBlock([])

# Save the Einsum blocks
metrics = EVar("metrics")
blocks = TransUtils.build_expr(self.fusion.get_blocks())
sblock.add(SAssign(AAccess(metrics, EString("blocks")), blocks))

# Compute the execution time
time: Optional[Expression] = None
for block in self.fusion.get_blocks():

# Collect up the statistics for the block
component_time: Dict[str, Expression] = {}
for einsum in block:
metrics_einsum = EAccess(metrics, EString(einsum))
for comp in self.fusion.get_components(einsum):
new_time = EAccess(
EAccess(
metrics_einsum,
EString(comp)),
EString("time"))

if comp in component_time:
component_time[comp] = EBinOp(
component_time[comp], OAdd(), new_time)
else:
component_time[comp] = new_time

# Sort components to enable testing
comps = sorted(component_time.keys())

# Compute block time by taking the max
block_time: Expression
if len(comps) == 0:
block_time = EInt(0)
elif len(comps) == 1:
block_time = component_time[comp]
else:
comp_args = [AJust(component_time[comp]) for comp in comps]
block_time = EFunc("max", comp_args)

# The execution time is the sum of all of the blocks
if time:
time = EBinOp(time, OAdd(), block_time)
else:
time = block_time

assert time is not None

sblock.add(SAssign(AAccess(metrics, EString("time")), time))

return sblock

def __build_trace_ranks(self) -> Tuple[Statement, bool]:
"""
Add code to trace all necessary ranks
Expand Down Expand Up @@ -981,16 +1060,18 @@ def __build_traffic(self) -> Statement:
bits: Optional[Expression] = None
metrics_src = EAccess(metrics_einsum, EString(src))

# Note: not technically necessary, just to make the testing deterministic
sorted_tensors = list(tensors)
sorted_tensors.sort()
# Note: not technically necessary, just to make the testing
# deterministic
sorted_tensors = sorted(tensors)

for tensor in sorted_tensors:
metrics_tensor = EAccess(metrics_src, EString(tensor))
new_bits: Expression = EAccess(metrics_tensor, EString("read"))

if tensor == einsum:
new_bits = EBinOp(new_bits, OAdd(), EAccess(metrics_tensor, EString("write")))
new_bits = EBinOp(
new_bits, OAdd(), EAccess(
metrics_tensor, EString("write")))

if bits:
bits = EBinOp(bits, OAdd(), new_bits)
Expand All @@ -1006,9 +1087,15 @@ def __build_traffic(self) -> Statement:

metrics_time = AAccess(metrics_src, EString("time"))
# Note: the current model assumes perfect load balance
time = EBinOp(bits, ODiv(), EInt(component.get_bandwidth() * component.get_num_instances()))
time = EBinOp(
bits,
ODiv(),
EInt(
component.get_bandwidth() *
component.get_num_instances()))

block.add(SAssign(metrics_time, time))
self.fusion.add_component(einsum, src)

return block

Expand Down
6 changes: 3 additions & 3 deletions teaal/trans/hifiber.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
self.format = format_
if arch and bindings and arch.get_spec():
self.hardware = Hardware(arch, bindings, self.program)
self.fusion = Fusion()
self.fusion = Fusion(self.hardware)

self.trans_utils = TransUtils(self.program)

Expand All @@ -85,7 +85,7 @@ def __translate(self, i: int) -> Statement:
self.metrics: Optional[Metrics] = None
if self.hardware and self.format:
self.metrics = Metrics(self.program, self.hardware, self.format)
self.fusion.add_einsum(self.program, self.metrics)
self.fusion.add_einsum(self.program)

# Create the flow graph and get the relevant nodes
flow_graph = FlowGraph(self.program, self.metrics, ["hoist"])
Expand All @@ -99,7 +99,7 @@ def __translate(self, i: int) -> Statement:
self.eqn = Equation(self.program, self.metrics)

if self.metrics:
self.collector = Collector(self.program, self.metrics)
self.collector = Collector(self.program, self.metrics, self.fusion)

stmt = self.__trans_nodes(nodes)[1]

Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,6 @@ def test_integration():
output = str(HiFiber(einsum, mapping))

hifiber = read_hifiber(filename + ".py")
if output != hifiber:
print(output)
assert output == hifiber, test_name + " integration test failed!"
6 changes: 4 additions & 2 deletions tests/ir/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ def test_component_get_name():
component = Component("Test", 1, {}, {})
assert component.get_name() == "Test"


def test_component_get_num_instances():
component = Component("Test", 5, {}, {})
assert component.get_num_instances() == 5



def test_component_eq():
component0 = Component("Test", 1, {"attr0": 5}, {})
component1 = Component("Test", 1, {"attr0": 5}, {})
Expand Down Expand Up @@ -319,7 +319,9 @@ def test_compute_component():
def test_dram_component():
bindings = {"Z": [{"tensor": "A", "rank": "M",
"type": "payload", "format": "default"}]}
dram = DRAMComponent("DRAM", 1, {"datawidth": 8, "bandwidth": 128}, bindings)
dram = DRAMComponent(
"DRAM", 1, {
"datawidth": 8, "bandwidth": 128}, bindings)


def test_intersector_component_binding_errs():
Expand Down
Loading

0 comments on commit cb89745

Please sign in to comment.