diff --git a/.gitignore b/.gitignore index 89b3e8d..c0beee2 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ node_modules *.code-workspace .vscode .env +__pycache__/ diff --git a/backend/dace_vscode/__init__.py b/backend/dace_vscode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/dace_vscode/arith_ops.py b/backend/dace_vscode/arith_ops.py new file mode 100644 index 0000000..3364237 --- /dev/null +++ b/backend/dace_vscode/arith_ops.py @@ -0,0 +1,220 @@ +# Copyright 2020-2021 ETH Zurich and the DaCe-VSCode authors. +# All rights reserved. + +import ast +from astunparse import unparse +from dace.symbolic import pystr_to_symbolic +from dace.sdfg import propagation +from dace.libraries.blas import MatMul, Transpose +from dace.libraries.standard import Reduce +from dace import nodes, dtypes +from sympy import function as spf + +from dace_vscode.utils import get_uuid, load_sdfg_from_json + + +def symeval(val, symbols): + first_replacement = { + pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) + for k in symbols.keys() + } + second_replacement = { + pystr_to_symbolic('__REPLSYM_' + k): v + for k, v in symbols.items() + } + return val.subs(first_replacement).subs(second_replacement) + + +def evaluate_symbols(base, new): + result = {} + for k, v in new.items(): + result[k] = symeval(v, base) + return result + + +def count_matmul(node, symbols, state): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') + C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') + result = 2 # Multiply, add + # Batch + if len(C_memlet.data.subset) == 3: + result *= symeval(C_memlet.data.subset.size()[0], symbols) + # M*N + result *= symeval(C_memlet.data.subset.size()[-2], symbols) + result *= symeval(C_memlet.data.subset.size()[-1], symbols) + # K + result *= symeval(A_memlet.data.subset.size()[-1], symbols) + return result + + +def count_reduce(node, symbols, state): + result = 0 + if node.wcr is not None: + result += count_arithmetic_ops_code(node.wcr) + in_memlet = None + in_edges = state.in_edges(node) + if in_edges is not None and len(in_edges) == 1: + in_memlet = in_edges[0] + if in_memlet is not None and in_memlet.data.volume is not None: + result *= in_memlet.data.volume + else: + result = 0 + return result + + +bigo = spf.Function('bigo') +PYFUNC_TO_ARITHMETICS = { + 'float': 0, + 'math.exp': 1, + 'math.tanh': 1, + 'math.sqrt': 1, + 'min': 0, + 'max': 0, + 'ceiling': 0, + 'floor': 0, +} +LIBNODES_TO_ARITHMETICS = { + MatMul: count_matmul, + Transpose: lambda *args: 0, + Reduce: count_reduce, +} + + +class ArithmeticCounter(ast.NodeVisitor): + + def __init__(self): + self.count = 0 + + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print('WARNING: Unrecognized python function "%s"' % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + + +def count_arithmetic_ops_code(code): + ctr = ArithmeticCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + + +def create_arith_ops_map_state(state, arith_map, symbols): + scope_tree_root = state.scope_tree()[None] + scope_dict = state.scope_children() + + def traverse(scope): + repetitions = 1 + traversal_result = 0 + if scope.entry is not None: + repetitions = scope.entry.map.range.num_elements() + for node in scope_dict[scope.entry]: + node_result = 0 + if isinstance(node, nodes.NestedSDFG): + nested_syms = {} + nested_syms.update(symbols) + nested_syms.update( + evaluate_symbols(symbols, node.symbol_mapping)) + node_result += create_arith_ops_map(node.sdfg, arith_map, + nested_syms) + elif isinstance(node, nodes.LibraryNode): + node_result += LIBNODES_TO_ARITHMETICS[type(node)](node, + symbols, + state) + elif isinstance(node, nodes.Tasklet): + if node.code.language == dtypes.Language.CPP: + for oedge in state.out_edges(node): + node_result += bigo(oedge.data.num_accesses) + else: + node_result += count_arithmetic_ops_code(node.code.code) + elif isinstance(node, nodes.MapEntry): + map_scope = None + for child_scope in scope.children: + if child_scope.entry == node: + map_scope = child_scope + break + map_result = 0 + if map_scope is not None: + map_result = traverse(map_scope) + node_result += map_result + elif isinstance(node, nodes.MapExit): + # Don't do anything for map exists. + pass + elif isinstance(node, + (nodes.CodeNode, nodes.AccessNode)): + for oedge in state.out_edges(node): + if oedge.data.wcr is not None: + node_result += count_arithmetic_ops_code( + oedge.data.wcr) + + arith_map[get_uuid(node, state)] = str(node_result) + traversal_result += node_result + return repetitions * traversal_result + + state_result = traverse(scope_tree_root) + + if state.executions is not None: + if (state.dynamic_executions is not None and state.dynamic_executions + and state.executions == 0): + state_result = 0 + else: + state_result *= state.executions + + arith_map[get_uuid(state)] = str(state_result) + return state_result + + +def create_arith_ops_map(sdfg, arith_map, symbols): + sdfg_ops = 0 + for state in sdfg.nodes(): + sdfg_ops += create_arith_ops_map_state(state, arith_map, symbols) + arith_map[get_uuid(sdfg)] = str(sdfg_ops) + + # Replace any operations that math.js does not understand. + for uuid in arith_map: + arith_map[uuid] = arith_map[uuid].replace('**', '^') + + return sdfg_ops + + +def get_arith_ops(sdfg_json): + loaded = load_sdfg_from_json(sdfg_json) + if loaded['error'] is not None: + return loaded['error'] + sdfg = loaded['sdfg'] + + propagation.propagate_memlets_sdfg(sdfg) + + arith_map = {} + create_arith_ops_map(sdfg, arith_map, {}) + return { + 'arith_ops_map': arith_map, + } diff --git a/backend/dace_vscode/editing.py b/backend/dace_vscode/editing.py new file mode 100644 index 0000000..4f318c8 --- /dev/null +++ b/backend/dace_vscode/editing.py @@ -0,0 +1,157 @@ +# Copyright 2020-2021 ETH Zurich and the DaCe-VSCode authors. +# All rights reserved. + +from dace import ( + serialize, nodes, SDFG, SDFGState, InterstateEdge, Memlet, dtypes +) +from dace_vscode.utils import ( + load_sdfg_from_json, + find_graph_element_by_uuid, + get_uuid, +) +from pydoc import locate + + +def remove_sdfg_elements(sdfg_json, uuids): + from dace.sdfg.graph import Edge + + old_meta = serialize.JSON_STORE_METADATA + serialize.JSON_STORE_METADATA = False + + loaded = load_sdfg_from_json(sdfg_json) + if loaded['error'] is not None: + return loaded['error'] + sdfg = loaded['sdfg'] + + elements = [] + for uuid in uuids: + elements.append(find_graph_element_by_uuid(sdfg, uuid)) + + for element_ret in elements: + element = element_ret['element'] + parent = element_ret['parent'] + + if parent is not None and element is not None: + if isinstance(element, Edge): + parent.remove_edge(element) + else: + parent.remove_node(element) + else: + return { + 'error': { + 'message': 'Failed to delete element', + 'details': 'Element or parent not found', + }, + } + + new_sdfg = sdfg.to_json() + serialize.JSON_STORE_METADATA = old_meta + + return { + 'sdfg': new_sdfg, + } + + +def insert_sdfg_element(sdfg_str, type, parent_uuid, edge_a_uuid): + sdfg_answer = load_sdfg_from_json(sdfg_str) + sdfg = sdfg_answer['sdfg'] + uuid = 'error' + ret = find_graph_element_by_uuid(sdfg, parent_uuid) + parent = ret['element'] + + libname = None + if type is not None and isinstance(type, str): + split_type = type.split('|') + if len(split_type) == 2: + type = split_type[0] + libname = split_type[1] + + if type == 'SDFGState': + if parent is None: + parent = sdfg + elif isinstance(parent, nodes.NestedSDFG): + parent = parent.sdfg + state = parent.add_state() + uuid = [get_uuid(state)] + elif type == 'AccessNode': + arrays = list(parent.parent.arrays.keys()) + if len(arrays) == 0: + parent.parent.add_array('tmp', [1], dtype=dtypes.float64) + arrays = list(parent.parent.arrays.keys()) + node = parent.add_access(arrays[0]) + uuid = [get_uuid(node, parent)] + elif type == 'Map': + map_entry, map_exit = parent.add_map('map', dict(i='0:1')) + uuid = [get_uuid(map_entry, parent), get_uuid(map_exit, parent)] + elif type == 'Consume': + consume_entry, consume_exit = parent.add_consume('consume', ('i', '1')) + uuid = [get_uuid(consume_entry, parent), get_uuid(consume_exit, parent)] + elif type == 'Tasklet': + tasklet = parent.add_tasklet( + name='placeholder', + inputs={'in'}, + outputs={'out'}, + code='') + uuid = [get_uuid(tasklet, parent)] + elif type == 'NestedSDFG': + sub_sdfg = SDFG('nested_sdfg') + sub_sdfg.add_array('in', [1], dtypes.float32) + sub_sdfg.add_array('out', [1], dtypes.float32) + + nsdfg = parent.add_nested_sdfg(sub_sdfg, sdfg, {'in'}, {'out'}) + uuid = [get_uuid(nsdfg, parent)] + elif type == 'LibraryNode': + if libname is None: + return { + 'error': { + 'message': 'Failed to add library node', + 'details': 'Must provide a valid library node type', + }, + } + libnode_class = locate(libname) + libnode = libnode_class() + parent.add_node(libnode) + uuid = [get_uuid(libnode, parent)] + elif type == 'Edge': + edge_start_ret = find_graph_element_by_uuid(sdfg, edge_a_uuid) + edge_start = edge_start_ret['element'] + edge_parent = edge_start_ret['parent'] + if edge_start is not None: + if edge_parent is None: + edge_parent = sdfg + + if isinstance(edge_parent, SDFGState): + if not (isinstance(edge_start, nodes.Node) and + isinstance(parent, nodes.Node)): + return { + 'error': { + 'message': 'Failed to add edge', + 'details': 'Must connect two nodes or two states', + }, + } + memlet = Memlet() + edge_parent.add_edge(edge_start, None, parent, None, memlet) + elif isinstance(edge_parent, SDFG): + if not (isinstance(edge_start, SDFGState) and + isinstance(parent, SDFGState)): + return { + 'error': { + 'message': 'Failed to add edge', + 'details': 'Must connect two nodes or two states', + }, + } + isedge = InterstateEdge() + edge_parent.add_edge(edge_start, parent, isedge) + uuid = ['NONE'] + else: + raise ValueError('No edge starting point provided') + + old_meta = serialize.JSON_STORE_METADATA + serialize.JSON_STORE_METADATA = False + new_sdfg_str = sdfg.to_json() + serialize.JSON_STORE_METADATA = old_meta + + return { + 'sdfg': new_sdfg_str, + 'uuid': uuid, + } diff --git a/backend/dace_vscode/transformations.py b/backend/dace_vscode/transformations.py new file mode 100644 index 0000000..e2acef4 --- /dev/null +++ b/backend/dace_vscode/transformations.py @@ -0,0 +1,214 @@ +# Copyright 2020-2021 ETH Zurich and the DaCe-VSCode authors. +# All rights reserved. + +from dace import nodes, serialize +from dace.transformation.transformation import SubgraphTransformation +from dace_vscode import utils +import sys +import traceback + +def expand_library_node(json_in): + """ + Expand a specific library node in a given SDFG. If no specific library node + is provided, expand all library nodes in the given SDFG. + :param json_in: The entire provided request JSON. + """ + old_meta = serialize.JSON_STORE_METADATA + serialize.JSON_STORE_METADATA = False + + sdfg = None + try: + loaded = utils.load_sdfg_from_json(json_in['sdfg']) + if loaded['error'] is not None: + return loaded['error'] + sdfg = loaded['sdfg'] + except KeyError: + return { + 'error': { + 'message': 'Failed to expand library node', + 'details': 'No SDFG provided', + }, + } + + try: + sdfg_id, state_id, node_id = json_in['nodeid'] + except KeyError: + sdfg_id, state_id, node_id = None, None, None + + if sdfg_id is None: + sdfg.expand_library_nodes() + else: + context_sdfg = sdfg.sdfg_list[sdfg_id] + state = context_sdfg.node(state_id) + node = state.node(node_id) + if isinstance(node, nodes.LibraryNode): + node.expand(context_sdfg, state) + else: + return { + 'error': { + 'message': 'Failed to expand library node', + 'details': 'The provided node is not a valid library node', + }, + } + + new_sdfg = sdfg.to_json() + serialize.JSON_STORE_METADATA = old_meta + return { + 'sdfg': new_sdfg, + } + + +def reapply_history_until(sdfg_json, index): + """ + Rewind a given SDFG back to a specific point in its history by reapplying + all transformations until a given index in its history to its original + state. + :param sdfg_json: The SDFG to rewind. + :param index: Index of the last history item to apply. + """ + old_meta = serialize.JSON_STORE_METADATA + serialize.JSON_STORE_METADATA = False + + loaded = utils.load_sdfg_from_json(sdfg_json) + if loaded['error'] is not None: + return loaded['error'] + sdfg = loaded['sdfg'] + + original_sdfg = sdfg.orig_sdfg + history = sdfg.transformation_hist + + for i in range(index + 1): + transformation = history[i] + try: + if isinstance(transformation, SubgraphTransformation): + transformation.apply( + original_sdfg.sdfg_list[transformation.sdfg_id]) + else: + transformation.apply_pattern( + original_sdfg.sdfg_list[transformation.sdfg_id]) + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + sys.stderr.flush() + return { + 'error': { + 'message': + 'Failed to play back the transformation history', + 'details': utils.get_exception_message(e), + }, + } + + new_sdfg = original_sdfg.to_json() + serialize.JSON_STORE_METADATA = old_meta + return { + 'sdfg': new_sdfg, + } + + +def apply_transformation(sdfg_json, transformation_json): + old_meta = serialize.JSON_STORE_METADATA + serialize.JSON_STORE_METADATA = False + + loaded = utils.load_sdfg_from_json(sdfg_json) + if loaded['error'] is not None: + return loaded['error'] + sdfg = loaded['sdfg'] + + try: + transformation = serialize.from_json(transformation_json) + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + sys.stderr.flush() + return { + 'error': { + 'message': 'Failed to parse the applied transformation', + 'details': utils.get_exception_message(e), + }, + } + try: + target_sdfg = sdfg.sdfg_list[transformation.sdfg_id] + if isinstance(transformation, SubgraphTransformation): + sdfg.append_transformation(transformation) + transformation.apply(target_sdfg) + else: + transformation.apply_pattern(target_sdfg) + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + sys.stderr.flush() + return { + 'error': { + 'message': 'Failed to apply the transformation to the SDFG', + 'details': utils.get_exception_message(e), + }, + } + + new_sdfg = sdfg.to_json() + serialize.JSON_STORE_METADATA = old_meta + return { + 'sdfg': new_sdfg, + } + + +def get_transformations(sdfg_json, selected_elements): + # We lazy import DaCe, not to break cyclic imports, but to avoid any large + # delays when booting in daemon mode. + from dace.transformation.optimizer import SDFGOptimizer + from dace.sdfg.graph import SubgraphView + + old_meta = serialize.JSON_STORE_METADATA + serialize.JSON_STORE_METADATA = False + + loaded = utils.load_sdfg_from_json(sdfg_json) + if loaded['error'] is not None: + return loaded['error'] + sdfg = loaded['sdfg'] + + optimizer = SDFGOptimizer(sdfg) + matches = optimizer.get_pattern_matches() + + transformations = [] + docstrings = {} + for transformation in matches: + transformations.append(transformation.to_json()) + docstrings[type(transformation).__name__] = transformation.__doc__ + + selected_states = [ + utils.sdfg_find_state_from_element(sdfg, n) for n in selected_elements + if n['type'] == 'state' + ] + selected_nodes = [ + utils.sdfg_find_node_from_element(sdfg, n) for n in selected_elements + if n['type'] == 'node' + ] + subgraph = None + if len(selected_states) > 0: + subgraph = SubgraphView(sdfg, selected_states) + else: + violated = False + state = None + for node in selected_nodes: + if state is None: + state = node.state + elif state != node.state: + violated = True + break + if not violated and state is not None: + subgraph = SubgraphView(state, selected_nodes) + + if subgraph is not None: + extensions = SubgraphTransformation.extensions() + for xform in extensions: + xform_data = extensions[xform] + if ('singlestate' in xform_data and + xform_data['singlestate'] and + len(selected_states) > 0): + continue + xform_obj = xform(subgraph) + if xform_obj.can_be_applied(sdfg, subgraph): + transformations.append(xform_obj.to_json()) + docstrings[xform.__name__] = xform_obj.__doc__ + + serialize.JSON_STORE_METADATA = old_meta + return { + 'transformations': transformations, + 'docstrings': docstrings, + } diff --git a/backend/dace_vscode/utils.py b/backend/dace_vscode/utils.py new file mode 100644 index 0000000..a12575b --- /dev/null +++ b/backend/dace_vscode/utils.py @@ -0,0 +1,177 @@ +# Copyright 2020-2021 ETH Zurich and the DaCe-VSCode authors. +# All rights reserved. + +from dace import SDFG, SDFGState, nodes +import sys +import traceback + + +UUID_SEPARATOR = '/' + + +def get_exception_message(exception): + return '%s: %s' % (type(exception).__name__, exception) + + +def ids_to_string(sdfg_id, state_id=-1, node_id=-1, edge_id=-1): + return (str(sdfg_id) + UUID_SEPARATOR + str(state_id) + UUID_SEPARATOR + + str(node_id) + UUID_SEPARATOR + str(edge_id)) + + +def get_uuid(element, state=None): + if isinstance(element, SDFG): + return ids_to_string(element.sdfg_id) + elif isinstance(element, SDFGState): + return ids_to_string(element.parent.sdfg_id, + element.parent.node_id(element)) + elif isinstance(element, nodes.Node): + return ids_to_string(state.parent.sdfg_id, state.parent.node_id(state), + state.node_id(element)) + else: + return ids_to_string(-1) + + +def recursively_find_graph(graph, graph_id, ns_node = None): + if graph.sdfg_id == graph_id: + return { + 'graph': graph, + 'node': ns_node, + } + + res = { + 'graph': None, + 'node': None, + } + + for state in graph.nodes(): + for node in state.nodes(): + if isinstance(node, nodes.NestedSDFG): + graph_result = recursively_find_graph( + node.sdfg, graph_id, node + ) + if graph_result != None: + return graph_result + + return res + + +def find_graph_element_by_uuid(sdfg, uuid): + uuid_split = uuid.split(UUID_SEPARATOR) + + graph_id = int(uuid_split[0]) + state_id = int(uuid_split[1]) + node_id = int(uuid_split[2]) + edge_id = int(uuid_split[3]) + + ret = { + 'parent': None, + 'element': None, + } + + graph = sdfg + if graph_id > 0: + found_graph = recursively_find_graph(graph, graph_id) + graph = found_graph['graph'] + ret = { + 'parent': graph, + 'element': found_graph['node'], + } + + state = None + if state_id != -1 and graph is not None: + state = graph.node(state_id) + ret = { + 'parent': graph, + 'element': state, + } + + if node_id != -1 and state is not None: + ret = { + 'parent': state, + 'element': state.node(node_id), + } + elif edge_id != -1 and state is not None: + ret = { + 'parent': state, + 'element': state.edges()[edge_id], + } + elif edge_id != -1 and state is None: + ret = { + 'parent': graph, + 'element': graph.edges()[edge_id], + } + + return ret + + +def sdfg_find_state_from_element(sdfg, element): + graph = sdfg.sdfg_list[element['sdfg_id']] + if element['id'] >= 0: + return graph.nodes()[element['id']] + else: + return None + + +def sdfg_find_node_from_element(sdfg, element): + graph = sdfg.sdfg_list[element['sdfg_id']] + if element['state_id'] >= 0: + state = graph.nodes()[element['state_id']] + node = state.nodes()[element['id']] + node.state = state + return node + else: + node = graph.nodes()[element['id']] + node.state = None + return node + + +def load_sdfg_from_file(path): + try: + sdfg = SDFG.from_file(path) + error = None + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + sys.stderr.flush() + error = { + 'error': { + 'message': 'Failed to load the provided SDFG file path', + 'details': get_exception_message(e), + }, + } + sdfg = None + return { + 'error': error, + 'sdfg': sdfg, + } + + +def load_sdfg_from_json(json): + if 'error' in json: + message = '' + if ('message' in json['error']): + message = json['error']['message'] + error = { + 'error': { + 'message': 'Invalid SDFG provided', + 'details': message, + } + } + sdfg = None + else: + try: + sdfg = SDFG.from_json(json) + error = None + except Exception as e: + print(traceback.format_exc(), file=sys.stderr) + sys.stderr.flush() + error = { + 'error': { + 'message': 'Failed to parse the provided SDFG', + 'details': get_exception_message(e), + }, + } + sdfg = None + return { + 'error': error, + 'sdfg': sdfg, + } diff --git a/backend/run_dace.py b/backend/run_dace.py index 87ad66f..627c7cd 100644 --- a/backend/run_dace.py +++ b/backend/run_dace.py @@ -41,558 +41,18 @@ # Then, load the rest of the modules import aenum from argparse import ArgumentParser -import ast, astunparse import dace -from dace.sdfg import propagation -from dace.symbolic import pystr_to_symbolic -from dace.libraries.blas import MatMul, Transpose -from dace.libraries.standard import Reduce import inspect -import sympy +from os import path import sys -import traceback - -# Prepare a whitelist of DaCe enumeration types -enum_list = [ - typename - for typename, dtype in inspect.getmembers(dace.dtypes, inspect.isclass) - if issubclass(dtype, aenum.Enum) -] - - -def count_matmul(node, symbols, state): - A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') - B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') - C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') - result = 2 # Multiply, add - # Batch - if len(C_memlet.data.subset) == 3: - result *= symeval(C_memlet.data.subset.size()[0], symbols) - # M*N - result *= symeval(C_memlet.data.subset.size()[-2], symbols) - result *= symeval(C_memlet.data.subset.size()[-1], symbols) - # K - result *= symeval(A_memlet.data.subset.size()[-1], symbols) - return result - - -def count_reduce(node, symbols, state): - result = 0 - if node.wcr is not None: - result += count_arithmetic_ops_code(node.wcr) - in_memlet = None - in_edges = state.in_edges(node) - if in_edges is not None and len(in_edges) == 1: - in_memlet = in_edges[0] - if in_memlet is not None and in_memlet.data.volume is not None: - result *= in_memlet.data.volume - else: - result = 0 - return result - - -bigo = sympy.Function('bigo') -UUID_SEPARATOR = '/' -PYFUNC_TO_ARITHMETICS = { - 'float': 0, - 'math.exp': 1, - 'math.tanh': 1, - 'math.sqrt': 1, - 'min': 0, - 'max': 0, - 'ceiling': 0, - 'floor': 0, -} -LIBNODES_TO_ARITHMETICS = { - MatMul: count_matmul, - Transpose: lambda *args: 0, - Reduce: count_reduce, -} - - -class ArithmeticCounter(ast.NodeVisitor): - def __init__(self): - self.count = 0 - - def visit_BinOp(self, node): - if isinstance(node.op, ast.MatMult): - raise NotImplementedError('MatMult op count requires shape ' - 'inference') - self.count += 1 - return self.generic_visit(node) - - def visit_UnaryOp(self, node): - self.count += 1 - return self.generic_visit(node) - - def visit_Call(self, node): - fname = astunparse.unparse(node.func)[:-1] - if fname not in PYFUNC_TO_ARITHMETICS: - print('WARNING: Unrecognized python function "%s"' % fname) - return self.generic_visit(node) - self.count += PYFUNC_TO_ARITHMETICS[fname] - return self.generic_visit(node) - - def visit_AugAssign(self, node): - return self.visit_BinOp(node) - - def visit_For(self, node): - raise NotImplementedError - - def visit_While(self, node): - raise NotImplementedError - - -def ids_to_string(sdfg_id, state_id=-1, node_id=-1, edge_id=-1): - return (str(sdfg_id) + UUID_SEPARATOR + str(state_id) + UUID_SEPARATOR + - str(node_id) + UUID_SEPARATOR + str(edge_id)) - - -def get_uuid(element, state=None): - if isinstance(element, dace.SDFG): - return ids_to_string(element.sdfg_id) - elif isinstance(element, dace.SDFGState): - return ids_to_string(element.parent.sdfg_id, - element.parent.node_id(element)) - elif isinstance(element, dace.nodes.Node): - return ids_to_string(state.parent.sdfg_id, state.parent.node_id(state), - state.node_id(element)) - else: - return ids_to_string(-1) - - -def symeval(val, symbols): - first_replacement = { - pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) - for k in symbols.keys() - } - second_replacement = { - pystr_to_symbolic('__REPLSYM_' + k): v - for k, v in symbols.items() - } - return val.subs(first_replacement).subs(second_replacement) - - -def evaluate_symbols(base, new): - result = {} - for k, v in new.items(): - result[k] = symeval(v, base) - return result - - -def count_arithmetic_ops_code(code): - ctr = ArithmeticCounter() - if isinstance(code, (tuple, list)): - for stmt in code: - ctr.visit(stmt) - elif isinstance(code, str): - ctr.visit(ast.parse(code)) - else: - ctr.visit(code) - return ctr.count - - -def create_arith_ops_map_state(state, arith_map, symbols): - scope_tree_root = state.scope_tree()[None] - scope_dict = state.scope_children() - - def traverse(scope): - repetitions = 1 - traversal_result = 0 - if scope.entry is not None: - repetitions = scope.entry.map.range.num_elements() - for node in scope_dict[scope.entry]: - node_result = 0 - if isinstance(node, dace.nodes.NestedSDFG): - nested_syms = {} - nested_syms.update(symbols) - nested_syms.update( - evaluate_symbols(symbols, node.symbol_mapping)) - node_result += create_arith_ops_map(node.sdfg, arith_map, - nested_syms) - elif isinstance(node, dace.nodes.LibraryNode): - node_result += LIBNODES_TO_ARITHMETICS[type(node)](node, - symbols, - state) - elif isinstance(node, dace.nodes.Tasklet): - if node.code.language == dace.dtypes.Language.CPP: - for oedge in state.out_edges(node): - node_result += bigo(oedge.data.num_accesses) - else: - node_result += count_arithmetic_ops_code(node.code.code) - elif isinstance(node, dace.nodes.MapEntry): - map_scope = None - for child_scope in scope.children: - if child_scope.entry == node: - map_scope = child_scope - break - map_result = 0 - if map_scope is not None: - map_result = traverse(map_scope) - node_result += map_result - elif isinstance(node, dace.nodes.MapExit): - # Don't do anything for map exists. - pass - elif isinstance(node, - (dace.nodes.CodeNode, dace.nodes.AccessNode)): - for oedge in state.out_edges(node): - if oedge.data.wcr is not None: - node_result += count_arithmetic_ops_code( - oedge.data.wcr) - - arith_map[get_uuid(node, state)] = str(node_result) - traversal_result += node_result - return repetitions * traversal_result - - state_result = traverse(scope_tree_root) - - if state.executions is not None: - if (state.dynamic_executions is not None and state.dynamic_executions - and state.executions == 0): - state_result = 0 - else: - state_result *= state.executions - - arith_map[get_uuid(state)] = str(state_result) - return state_result - - -def create_arith_ops_map(sdfg, arith_map, symbols): - sdfg_ops = 0 - for state in sdfg.nodes(): - sdfg_ops += create_arith_ops_map_state(state, arith_map, symbols) - arith_map[get_uuid(sdfg)] = str(sdfg_ops) - - # Replace any operations that math.js does not understand. - for uuid in arith_map: - arith_map[uuid] = arith_map[uuid].replace('**', '^') - - return sdfg_ops - - -def get_exception_message(exception): - return '%s: %s' % (type(exception).__name__, exception) - - -def load_sdfg_from_file(path): - # We lazy import SDFGs, not to break cyclic imports, but to avoid any large - # delays when booting in daemon mode. - from dace.sdfg import SDFG - - try: - sdfg = SDFG.from_file(path) - error = None - except Exception as e: - print(traceback.format_exc(), file=sys.stderr) - sys.stderr.flush() - error = { - 'error': { - 'message': 'Failed to load the provided SDFG file path', - 'details': get_exception_message(e), - }, - } - sdfg = None - return { - 'error': error, - 'sdfg': sdfg, - } - - -def load_sdfg_from_json(json): - # We lazy import SDFGs, not to break cyclic imports, but to avoid any large - # delays when booting in daemon mode. - from dace.sdfg import SDFG - - if 'error' in json: - message = '' - if ('message' in json['error']): - message = json['error']['message'] - error = { - 'error': { - 'message': 'Invalid SDFG provided', - 'details': message, - } - } - sdfg = None - else: - try: - sdfg = SDFG.from_json(json) - error = None - except Exception as e: - print(traceback.format_exc(), file=sys.stderr) - sys.stderr.flush() - error = { - 'error': { - 'message': 'Failed to parse the provided SDFG', - 'details': get_exception_message(e), - }, - } - sdfg = None - return { - 'error': error, - 'sdfg': sdfg, - } - - -def expand_library_node(json_in): - """ - Expand a specific library node in a given SDFG. If no specific library node - is provided, expand all library nodes in the given SDFG. - :param json_in: The entire provided request JSON. - """ - from dace import serialize - old_meta = serialize.JSON_STORE_METADATA - serialize.JSON_STORE_METADATA = False - sdfg = None - try: - loaded = load_sdfg_from_json(json_in['sdfg']) - if loaded['error'] is not None: - return loaded['error'] - sdfg = loaded['sdfg'] - except KeyError: - return { - 'error': { - 'message': 'Failed to expand library node', - 'details': 'No SDFG provided', - }, - } +sys.path.append(path.abspath(path.dirname(__file__))) - try: - sdfg_id, state_id, node_id = json_in['nodeid'] - except KeyError: - sdfg_id, state_id, node_id = None, None, None - - if sdfg_id is None: - sdfg.expand_library_nodes() - else: - context_sdfg = sdfg.sdfg_list[sdfg_id] - state = context_sdfg.node(state_id) - node = state.node(node_id) - if isinstance(node, dace.nodes.LibraryNode): - node.expand(context_sdfg, state) - else: - return { - 'error': { - 'message': 'Failed to expand library node', - 'details': 'The provided node is not a valid library node', - }, - } - - new_sdfg = sdfg.to_json() - serialize.JSON_STORE_METADATA = old_meta - return { - 'sdfg': new_sdfg, - } - - -def reapply_history_until(sdfg_json, index): - """ - Rewind a given SDFG back to a specific point in its history by reapplying - all transformations until a given index in its history to its original - state. - :param sdfg_json: The SDFG to rewind. - :param index: Index of the last history item to apply. - """ - from dace import serialize - old_meta = serialize.JSON_STORE_METADATA - serialize.JSON_STORE_METADATA = False - - loaded = load_sdfg_from_json(sdfg_json) - if loaded['error'] is not None: - return loaded['error'] - sdfg = loaded['sdfg'] - - original_sdfg = sdfg.orig_sdfg - history = sdfg.transformation_hist - - for i in range(index + 1): - transformation = history[i] - try: - if isinstance( - transformation, - dace.transformation.transformation.SubgraphTransformation): - transformation.apply( - original_sdfg.sdfg_list[transformation.sdfg_id]) - else: - transformation.apply_pattern( - original_sdfg.sdfg_list[transformation.sdfg_id]) - except Exception as e: - print(traceback.format_exc(), file=sys.stderr) - sys.stderr.flush() - return { - 'error': { - 'message': - 'Failed to play back the transformation history', - 'details': get_exception_message(e), - }, - } - - new_sdfg = original_sdfg.to_json() - serialize.JSON_STORE_METADATA = old_meta - return { - 'sdfg': new_sdfg, - } - - -def apply_transformation(sdfg_json, transformation_json): - # We lazy import DaCe, not to break cyclic imports, but to avoid any large - # delays when booting in daemon mode. - from dace import serialize - old_meta = serialize.JSON_STORE_METADATA - serialize.JSON_STORE_METADATA = False - - loaded = load_sdfg_from_json(sdfg_json) - if loaded['error'] is not None: - return loaded['error'] - sdfg = loaded['sdfg'] - - try: - transformation = serialize.from_json(transformation_json) - except Exception as e: - print(traceback.format_exc(), file=sys.stderr) - sys.stderr.flush() - return { - 'error': { - 'message': 'Failed to parse the applied transformation', - 'details': get_exception_message(e), - }, - } - try: - target_sdfg = sdfg.sdfg_list[transformation.sdfg_id] - if isinstance( - transformation, - dace.transformation.transformation.SubgraphTransformation): - sdfg.append_transformation(transformation) - transformation.apply(target_sdfg) - else: - transformation.apply_pattern(target_sdfg) - except Exception as e: - print(traceback.format_exc(), file=sys.stderr) - sys.stderr.flush() - return { - 'error': { - 'message': 'Failed to apply the transformation to the SDFG', - 'details': get_exception_message(e), - }, - } - - new_sdfg = sdfg.to_json() - serialize.JSON_STORE_METADATA = old_meta - return { - 'sdfg': new_sdfg, - } - - -def sdfg_find_state(sdfg, element): - graph = sdfg.sdfg_list[element['sdfg_id']] - if element['id'] >= 0: - return graph.nodes()[element['id']] - else: - return None - - -def sdfg_find_node(sdfg, element): - graph = sdfg.sdfg_list[element['sdfg_id']] - if element['state_id'] >= 0: - state = graph.nodes()[element['state_id']] - node = state.nodes()[element['id']] - node.state = state - return node - else: - node = graph.nodes()[element['id']] - node.state = None - return node - - -def get_arith_ops(sdfg_json): - loaded = load_sdfg_from_json(sdfg_json) - if loaded['error'] is not None: - return loaded['error'] - sdfg = loaded['sdfg'] - - propagation.propagate_memlets_sdfg(sdfg) - - arith_map = {} - create_arith_ops_map(sdfg, arith_map, {}) - return { - 'arith_ops_map': arith_map, - } - - -def get_transformations(sdfg_json, selected_elements): - # We lazy import DaCe, not to break cyclic imports, but to avoid any large - # delays when booting in daemon mode. - from dace.transformation.optimizer import SDFGOptimizer - from dace import serialize - - old_meta = serialize.JSON_STORE_METADATA - serialize.JSON_STORE_METADATA = False - from dace.sdfg.graph import SubgraphView - from dace.transformation.transformation import SubgraphTransformation - - loaded = load_sdfg_from_json(sdfg_json) - if loaded['error'] is not None: - return loaded['error'] - sdfg = loaded['sdfg'] - - optimizer = SDFGOptimizer(sdfg) - matches = optimizer.get_pattern_matches() - - transformations = [] - docstrings = {} - for transformation in matches: - transformations.append(transformation.to_json()) - docstrings[type(transformation).__name__] = transformation.__doc__ - - selected_states = [ - sdfg_find_state(sdfg, n) for n in selected_elements - if n['type'] == 'state' - ] - selected_nodes = [ - sdfg_find_node(sdfg, n) for n in selected_elements - if n['type'] == 'node' - ] - subgraph = None - if len(selected_states) > 0: - subgraph = SubgraphView(sdfg, selected_states) - else: - violated = False - state = None - for node in selected_nodes: - if state is None: - state = node.state - elif state != node.state: - violated = True - break - if not violated and state is not None: - subgraph = SubgraphView(state, selected_nodes) - - if subgraph is not None: - for xform in SubgraphTransformation.extensions(): - xform_obj = xform(subgraph) - if xform_obj.can_be_applied(sdfg, subgraph): - transformations.append(xform_obj.to_json()) - docstrings[xform.__name__] = xform_obj.__doc__ - - serialize.JSON_STORE_METADATA = old_meta - return { - 'transformations': transformations, - 'docstrings': docstrings, - } - - -def get_enum(name): - if name not in enum_list: - return { - 'error': { - 'message': 'Failed to get Enum', - 'details': 'Enum type "' + str(name) + '" is not in whitelist', - }, - } - return { - 'enum': [str(e).split('.')[-1] for e in getattr(dace.dtypes, name)] - } +from dace_vscode.utils import ( + load_sdfg_from_file, +) +from dace_vscode.arith_ops import get_arith_ops +from dace_vscode import transformations, editing def get_property_metdata(): @@ -617,6 +77,7 @@ def get_property_metdata(): meta_dict = {} meta_dict['__reverse_type_lookup__'] = {} + meta_dict['__libs__'] = {} for typename in dace.serialize._DACE_SERIALIZE_TYPES: t = dace.serialize._DACE_SERIALIZE_TYPES[typename] if hasattr(t, '__properties__'): @@ -669,7 +130,18 @@ def get_property_metdata(): meta_dict['__reverse_type_lookup__'][ meta_type] = meta_dict[meta_key][propname] + # For library nodes we want to make sure they are all easily + # accessible under '__libs__', to be able to list them all out. + if (issubclass(t, dace.sdfg.nodes.LibraryNode) + and not t == dace.sdfg.nodes.LibraryNode): + meta_dict['__libs__'][typename] = meta_key + # Save a lookup for enum values not present yet. + enum_list = [ + typename + for typename, dtype in inspect.getmembers(dace.dtypes, inspect.isclass) + if issubclass(dtype, aenum.Enum) + ] for enum_name in enum_list: if not enum_name in meta_dict['__reverse_type_lookup__']: choices = [] @@ -760,41 +232,55 @@ def _root(): @daemon.route('/transformations', methods=['POST']) def _get_transformations(): request_json = request.get_json() - return get_transformations(request_json['sdfg'], - request_json['selected_elements']) + return transformations.get_transformations( + request_json['sdfg'], request_json['selected_elements'] + ) @daemon.route('/apply_transformation', methods=['POST']) def _apply_transformation(): request_json = request.get_json() - return apply_transformation(request_json['sdfg'], - request_json['transformation']) + return transformations.apply_transformation( + request_json['sdfg'], request_json['transformation'] + ) @daemon.route('/expand_library_node', methods=['POST']) def _expand_library_node(): request_json = request.get_json() - return expand_library_node(request_json) + return transformations.expand_library_node(request_json) @daemon.route('/reapply_history_until', methods=['POST']) def _reapply_history_until(): request_json = request.get_json() - return reapply_history_until(request_json['sdfg'], - request_json['index']) + return transformations.reapply_history_until( + request_json['sdfg'], request_json['index'] + ) @daemon.route('/get_arith_ops', methods=['POST']) def _get_arith_ops(): request_json = request.get_json() return get_arith_ops(request_json['sdfg']) - @daemon.route('/get_enum/', methods=['GET']) - def _get_enum(name): - return get_enum(name) - @daemon.route('/compile_sdfg_from_file', methods=['POST']) def _compile_sdfg_from_file(): request_json = request.get_json() return compile_sdfg(request_json['path'], request_json['suppress_instrumentation']) + @daemon.route('/insert_sdfg_element', methods=['POST']) + def _insert_sdfg_element(): + request_json = request.get_json() + return editing.insert_sdfg_element( + request_json['sdfg'], request_json['type'], request_json['parent'], + request_json['edge_a'] + ) + + @daemon.route('/remove_sdfg_elements', methods=['POST']) + def _remove_sdfg_elements(): + request_json = request.get_json() + return editing.remove_sdfg_elements( + request_json['sdfg'], request_json['uuids'] + ) + @daemon.route('/get_metadata', methods=['GET']) def _get_metadata(): return get_property_metdata() @@ -826,6 +312,6 @@ def _get_metadata(): args = parser.parse_args() if (args.transformations): - get_transformations(None) + transformations.get_transformations(None) else: run_daemon(args.port) diff --git a/media/components/sdfv/breakpoints/breakpoints.js b/media/components/sdfv/breakpoints/breakpoints.js index 02430ec..9048b0e 100644 --- a/media/components/sdfv/breakpoints/breakpoints.js +++ b/media/components/sdfv/breakpoints/breakpoints.js @@ -25,13 +25,15 @@ class BreakpointIndicator extends daceGenericSDFGOverlay { daceRenderer; constructor(daceRenderer) { - - super(daceRenderer.overlay_manager, daceRenderer, daceGenericSDFGOverlay.OVERLAY_TYPE.BREAKPOINTS); + super( + daceRenderer.overlay_manager, daceRenderer, + daceGenericSDFGOverlay.OVERLAY_TYPE.BREAKPOINTS + ); this.daceRenderer = daceRenderer; this.breakpoints = new Map(); vscode.postMessage({ type: 'bp_handler.get_saved_nodes', - sdfg_name: this.daceRenderer.sdfg.attributes.name + sdfg_name: this.daceRenderer.sdfg.attributes.name, }); this.refresh(); } @@ -44,12 +46,10 @@ class BreakpointIndicator extends daceGenericSDFGOverlay { if (element instanceof NestedSDFG) { sdfg_id = element.data.node.attributes.sdfg.sdfg_list_id; - } - else if (element instanceof State) { + } else if (element instanceof State) { sdfg_id = element.sdfg.sdfg_list_id; state_id = element.id; - } - else if (element instanceof SDFGNode) { + } else if (element instanceof SDFGNode) { sdfg_id = element.sdfg.sdfg_list_id; state_id = element.parent_id; node_id = element.id; @@ -59,7 +59,7 @@ class BreakpointIndicator extends daceGenericSDFGOverlay { return { sdfg_id: sdfg_id, state_id: state_id, - node_id: node_id + node_id: node_id, }; else return ( @@ -151,7 +151,9 @@ class BreakpointIndicator extends daceGenericSDFGOverlay { erase_breakpoint(node, ctx) { // Draw on top of the Breakpoint - let background = node.getCssProperty(daceRenderer, '--state-background-color'); + let background = node.getCssProperty( + daceRenderer, '--state-background-color' + ); this.draw_breakpoint_circle(node, ctx, background, background); } @@ -164,14 +166,16 @@ class BreakpointIndicator extends daceGenericSDFGOverlay { ctx.beginPath(); (node instanceof State) ? ctx.arc(topleft.x + 10, topleft.y + 20, 4, 0, 2 * Math.PI) : - ctx.arc(topleft.x - 10, topleft.y + node.height / 2.0, 4, 0, 2 * Math.PI); + ctx.arc(topleft.x - 10, topleft.y + node.height / 2.0, 4, 0, + 2 * Math.PI); ctx.stroke(); ctx.fill(); } draw_tooltip(node, msg) { if (this.daceRenderer.mousepos && - node.intersect(this.daceRenderer.mousepos.x, this.daceRenderer.mousepos.y)) { + node.intersect(this.daceRenderer.mousepos.x, + this.daceRenderer.mousepos.y)) { this.daceRenderer.tooltip = () => { this.daceRenderer.tooltip_container.innerText = (msg); this.daceRenderer.tooltip_container.className = 'sdfvtooltip'; @@ -236,9 +240,11 @@ class BreakpointIndicator extends daceGenericSDFGOverlay { node.state_id + '/' + node.node_id ); + if (this.breakpoints.has(elem_uuid)) { this.breakpoints.set(elem_uuid, BreakpointEnum.UNBOUND); } + this.draw(); this.daceRenderer.draw_async(); } diff --git a/media/components/sdfv/index.html b/media/components/sdfv/index.html index 9740308..ea2c37c 100644 --- a/media/components/sdfv/index.html +++ b/media/components/sdfv/index.html @@ -34,30 +34,137 @@
-