From 08ec5ea4f75ffa19b66c906ba12a27188b546e5a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 8 Nov 2024 00:44:42 -0800 Subject: [PATCH] SDFG API additions for version 1.0 (#1740) This PR adds additional API calls and fields as requested by DaCe users. This includes: * `SDFG.auto_optimize` * `SDFG.regenerate_code` * `SDFG.as_schedule_tree` --- dace/frontend/python/parser.py | 4 +-- dace/sdfg/sdfg.py | 61 +++++++++++++++++++++++++++++++++- 2 files changed, 62 insertions(+), 3 deletions(-) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index d03759fa8e..20018effd0 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -760,7 +760,7 @@ def _load_sdfg(self, path: str, *args, **kwargs): if sdfg is not None: # Set regenerate and recompile flags - sdfg._regenerate_code = self.regenerate_code + sdfg.regenerate_code = self.regenerate_code sdfg._recompile = self.recompile return sdfg, self._cache.make_key(argtypes, given_args, self.closure_array_keys, self.closure_constant_keys, @@ -928,7 +928,7 @@ def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], # TODO: Add to parsed SDFG cache # Set regenerate and recompile flags - sdfg._regenerate_code = self.regenerate_code + sdfg.regenerate_code = self.regenerate_code sdfg._recompile = self.recompile return sdfg, cached diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 3be268e44d..716bb9accc 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -38,6 +38,7 @@ from dace.codegen.instrumentation.report import InstrumentationReport from dace.codegen.instrumentation.data.data_report import InstrumentedDataReport from dace.codegen.compiled_sdfg import CompiledSDFG + from dace.sdfg.analysis.schedule_tree.treenodes import ScheduleTreeScope class NestedDict(dict): @@ -802,6 +803,14 @@ def start_state(self): def start_state(self, state_id): self.start_block = state_id + @property + def regenerate_code(self): + return self._regenerate_code + + @regenerate_code.setter + def regenerate_code(self, value): + self._regenerate_code = value + def set_global_code(self, cpp_code: str, location: str = 'frame'): """ Sets C++ code that will be generated in a global scope on @@ -1070,6 +1079,24 @@ def call_with_instrumented_data(self, dreport: 'InstrumentedDataReport', *args, ########################################## + def as_schedule_tree(self, in_place: bool = False) -> 'ScheduleTreeScope': + """ + Creates a schedule tree from this SDFG and all nested SDFGs. The schedule tree is a tree of nodes that represent + the execution order of the SDFG. + Each node in the tree can either represent a single statement (symbol assignment, tasklet, copy, library node, + etc.) or a ``ScheduleTreeScope`` block (map, for-loop, pipeline, etc.) that contains other nodes. + + It can be used to generate code from an SDFG, or to perform schedule transformations on the SDFG. For example, + erasing an empty if branch, or merging two consecutive for-loops. + + :param in_place: If True, the SDFG is modified in-place. Otherwise, a copy is made. Note that the SDFG might + not be usable after the conversion if ``in_place`` is True! + :return: A schedule tree representing the given SDFG. + """ + # Avoid import loop + from dace.sdfg.analysis.schedule_tree import sdfg_to_tree as s2t + return s2t.as_schedule_tree(self, in_place=in_place) + @property def build_folder(self) -> str: """ Returns a relative path to the build cache folder for this SDFG. """ @@ -2293,7 +2320,7 @@ def compile(self, output_file=None, validate=True, ############################ # DaCe Compilation Process # - if self._regenerate_code or not os.path.isdir(build_folder): + if self.regenerate_code or not os.path.isdir(build_folder): # Clone SDFG as the other modules may modify its contents sdfg = copy.deepcopy(self) # Fix the build folder name on the copied SDFG to avoid it changing @@ -2463,6 +2490,38 @@ def simplify(self, validate=True, validate_all=False, verbose=False): from dace.transformation.passes.simplify import SimplifyPass return SimplifyPass(validate=validate, validate_all=validate_all, verbose=verbose).apply_pass(self, {}) + def auto_optimize(self, + device: dtypes.DeviceType, + validate: bool = True, + validate_all: bool = False, + symbols: Dict[str, int] = None, + use_gpu_storage: bool = False): + """ + Runs a basic sequence of transformations to optimize a given SDFG to decent + performance. In particular, performs the following: + + * Simplify + * Auto-parallelization (loop-to-map) + * Greedy application of SubgraphFusion + * Tiled write-conflict resolution (MapTiling -> AccumulateTransient) + * Tiled stream accumulation (MapTiling -> AccumulateTransient) + * Collapse all maps to parallelize across all dimensions + * Set all library nodes to expand to ``fast`` expansion, which calls + the fastest library on the target device + + :param device: the device to optimize for. + :param validate: If True, validates the SDFG after all transformations + have been applied. + :param validate_all: If True, validates the SDFG after every step. + :param symbols: Optional dict that maps symbols (str/symbolic) to int/float + :param use_gpu_storage: If True, changes the storage of non-transient data to GPU global memory. + :note: Operates in-place on the given SDFG. + :note: This function is still experimental and may harm correctness in + certain cases. Please report an issue if it does. + """ + from dace.transformation.auto.auto_optimize import auto_optimize + auto_optimize(device, validate, validate_all, symbols, use_gpu_storage) + def _initialize_transformations_from_type( self, xforms: Union[Type, List[Type], 'dace.transformation.PatternTransformation'],