Skip to content

Commit

Permalink
SDFG Editing (#29)
Browse files Browse the repository at this point in the history
Co-authored-by: Tobias Aeschbacher <72465460+tobiasae@users.noreply.github.com>
Co-authored-by: Tal Ben-Nun <tbennun@users.noreply.github.com>
  • Loading branch information
3 people authored Jul 6, 2021
1 parent 10e3e2b commit cbe45d0
Show file tree
Hide file tree
Showing 17 changed files with 1,372 additions and 742 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ node_modules
*.code-workspace
.vscode
.env
__pycache__/
Empty file added backend/dace_vscode/__init__.py
Empty file.
220 changes: 220 additions & 0 deletions backend/dace_vscode/arith_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2020-2021 ETH Zurich and the DaCe-VSCode authors.
# All rights reserved.

import ast
from astunparse import unparse
from dace.symbolic import pystr_to_symbolic
from dace.sdfg import propagation
from dace.libraries.blas import MatMul, Transpose
from dace.libraries.standard import Reduce
from dace import nodes, dtypes
from sympy import function as spf

from dace_vscode.utils import get_uuid, load_sdfg_from_json


def symeval(val, symbols):
first_replacement = {
pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k)
for k in symbols.keys()
}
second_replacement = {
pystr_to_symbolic('__REPLSYM_' + k): v
for k, v in symbols.items()
}
return val.subs(first_replacement).subs(second_replacement)


def evaluate_symbols(base, new):
result = {}
for k, v in new.items():
result[k] = symeval(v, base)
return result


def count_matmul(node, symbols, state):
A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a')
B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b')
C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c')
result = 2 # Multiply, add
# Batch
if len(C_memlet.data.subset) == 3:
result *= symeval(C_memlet.data.subset.size()[0], symbols)
# M*N
result *= symeval(C_memlet.data.subset.size()[-2], symbols)
result *= symeval(C_memlet.data.subset.size()[-1], symbols)
# K
result *= symeval(A_memlet.data.subset.size()[-1], symbols)
return result


def count_reduce(node, symbols, state):
result = 0
if node.wcr is not None:
result += count_arithmetic_ops_code(node.wcr)
in_memlet = None
in_edges = state.in_edges(node)
if in_edges is not None and len(in_edges) == 1:
in_memlet = in_edges[0]
if in_memlet is not None and in_memlet.data.volume is not None:
result *= in_memlet.data.volume
else:
result = 0
return result


bigo = spf.Function('bigo')
PYFUNC_TO_ARITHMETICS = {
'float': 0,
'math.exp': 1,
'math.tanh': 1,
'math.sqrt': 1,
'min': 0,
'max': 0,
'ceiling': 0,
'floor': 0,
}
LIBNODES_TO_ARITHMETICS = {
MatMul: count_matmul,
Transpose: lambda *args: 0,
Reduce: count_reduce,
}


class ArithmeticCounter(ast.NodeVisitor):

def __init__(self):
self.count = 0

def visit_BinOp(self, node):
if isinstance(node.op, ast.MatMult):
raise NotImplementedError('MatMult op count requires shape '
'inference')
self.count += 1
return self.generic_visit(node)

def visit_UnaryOp(self, node):
self.count += 1
return self.generic_visit(node)

def visit_Call(self, node):
fname = unparse(node.func)[:-1]
if fname not in PYFUNC_TO_ARITHMETICS:
print('WARNING: Unrecognized python function "%s"' % fname)
return self.generic_visit(node)
self.count += PYFUNC_TO_ARITHMETICS[fname]
return self.generic_visit(node)

def visit_AugAssign(self, node):
return self.visit_BinOp(node)

def visit_For(self, node):
raise NotImplementedError

def visit_While(self, node):
raise NotImplementedError


def count_arithmetic_ops_code(code):
ctr = ArithmeticCounter()
if isinstance(code, (tuple, list)):
for stmt in code:
ctr.visit(stmt)
elif isinstance(code, str):
ctr.visit(ast.parse(code))
else:
ctr.visit(code)
return ctr.count


def create_arith_ops_map_state(state, arith_map, symbols):
scope_tree_root = state.scope_tree()[None]
scope_dict = state.scope_children()

def traverse(scope):
repetitions = 1
traversal_result = 0
if scope.entry is not None:
repetitions = scope.entry.map.range.num_elements()
for node in scope_dict[scope.entry]:
node_result = 0
if isinstance(node, nodes.NestedSDFG):
nested_syms = {}
nested_syms.update(symbols)
nested_syms.update(
evaluate_symbols(symbols, node.symbol_mapping))
node_result += create_arith_ops_map(node.sdfg, arith_map,
nested_syms)
elif isinstance(node, nodes.LibraryNode):
node_result += LIBNODES_TO_ARITHMETICS[type(node)](node,
symbols,
state)
elif isinstance(node, nodes.Tasklet):
if node.code.language == dtypes.Language.CPP:
for oedge in state.out_edges(node):
node_result += bigo(oedge.data.num_accesses)
else:
node_result += count_arithmetic_ops_code(node.code.code)
elif isinstance(node, nodes.MapEntry):
map_scope = None
for child_scope in scope.children:
if child_scope.entry == node:
map_scope = child_scope
break
map_result = 0
if map_scope is not None:
map_result = traverse(map_scope)
node_result += map_result
elif isinstance(node, nodes.MapExit):
# Don't do anything for map exists.
pass
elif isinstance(node,
(nodes.CodeNode, nodes.AccessNode)):
for oedge in state.out_edges(node):
if oedge.data.wcr is not None:
node_result += count_arithmetic_ops_code(
oedge.data.wcr)

arith_map[get_uuid(node, state)] = str(node_result)
traversal_result += node_result
return repetitions * traversal_result

state_result = traverse(scope_tree_root)

if state.executions is not None:
if (state.dynamic_executions is not None and state.dynamic_executions
and state.executions == 0):
state_result = 0
else:
state_result *= state.executions

arith_map[get_uuid(state)] = str(state_result)
return state_result


def create_arith_ops_map(sdfg, arith_map, symbols):
sdfg_ops = 0
for state in sdfg.nodes():
sdfg_ops += create_arith_ops_map_state(state, arith_map, symbols)
arith_map[get_uuid(sdfg)] = str(sdfg_ops)

# Replace any operations that math.js does not understand.
for uuid in arith_map:
arith_map[uuid] = arith_map[uuid].replace('**', '^')

return sdfg_ops


def get_arith_ops(sdfg_json):
loaded = load_sdfg_from_json(sdfg_json)
if loaded['error'] is not None:
return loaded['error']
sdfg = loaded['sdfg']

propagation.propagate_memlets_sdfg(sdfg)

arith_map = {}
create_arith_ops_map(sdfg, arith_map, {})
return {
'arith_ops_map': arith_map,
}
157 changes: 157 additions & 0 deletions backend/dace_vscode/editing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2020-2021 ETH Zurich and the DaCe-VSCode authors.
# All rights reserved.

from dace import (
serialize, nodes, SDFG, SDFGState, InterstateEdge, Memlet, dtypes
)
from dace_vscode.utils import (
load_sdfg_from_json,
find_graph_element_by_uuid,
get_uuid,
)
from pydoc import locate


def remove_sdfg_elements(sdfg_json, uuids):
from dace.sdfg.graph import Edge

old_meta = serialize.JSON_STORE_METADATA
serialize.JSON_STORE_METADATA = False

loaded = load_sdfg_from_json(sdfg_json)
if loaded['error'] is not None:
return loaded['error']
sdfg = loaded['sdfg']

elements = []
for uuid in uuids:
elements.append(find_graph_element_by_uuid(sdfg, uuid))

for element_ret in elements:
element = element_ret['element']
parent = element_ret['parent']

if parent is not None and element is not None:
if isinstance(element, Edge):
parent.remove_edge(element)
else:
parent.remove_node(element)
else:
return {
'error': {
'message': 'Failed to delete element',
'details': 'Element or parent not found',
},
}

new_sdfg = sdfg.to_json()
serialize.JSON_STORE_METADATA = old_meta

return {
'sdfg': new_sdfg,
}


def insert_sdfg_element(sdfg_str, type, parent_uuid, edge_a_uuid):
sdfg_answer = load_sdfg_from_json(sdfg_str)
sdfg = sdfg_answer['sdfg']
uuid = 'error'
ret = find_graph_element_by_uuid(sdfg, parent_uuid)
parent = ret['element']

libname = None
if type is not None and isinstance(type, str):
split_type = type.split('|')
if len(split_type) == 2:
type = split_type[0]
libname = split_type[1]

if type == 'SDFGState':
if parent is None:
parent = sdfg
elif isinstance(parent, nodes.NestedSDFG):
parent = parent.sdfg
state = parent.add_state()
uuid = [get_uuid(state)]
elif type == 'AccessNode':
arrays = list(parent.parent.arrays.keys())
if len(arrays) == 0:
parent.parent.add_array('tmp', [1], dtype=dtypes.float64)
arrays = list(parent.parent.arrays.keys())
node = parent.add_access(arrays[0])
uuid = [get_uuid(node, parent)]
elif type == 'Map':
map_entry, map_exit = parent.add_map('map', dict(i='0:1'))
uuid = [get_uuid(map_entry, parent), get_uuid(map_exit, parent)]
elif type == 'Consume':
consume_entry, consume_exit = parent.add_consume('consume', ('i', '1'))
uuid = [get_uuid(consume_entry, parent), get_uuid(consume_exit, parent)]
elif type == 'Tasklet':
tasklet = parent.add_tasklet(
name='placeholder',
inputs={'in'},
outputs={'out'},
code='')
uuid = [get_uuid(tasklet, parent)]
elif type == 'NestedSDFG':
sub_sdfg = SDFG('nested_sdfg')
sub_sdfg.add_array('in', [1], dtypes.float32)
sub_sdfg.add_array('out', [1], dtypes.float32)

nsdfg = parent.add_nested_sdfg(sub_sdfg, sdfg, {'in'}, {'out'})
uuid = [get_uuid(nsdfg, parent)]
elif type == 'LibraryNode':
if libname is None:
return {
'error': {
'message': 'Failed to add library node',
'details': 'Must provide a valid library node type',
},
}
libnode_class = locate(libname)
libnode = libnode_class()
parent.add_node(libnode)
uuid = [get_uuid(libnode, parent)]
elif type == 'Edge':
edge_start_ret = find_graph_element_by_uuid(sdfg, edge_a_uuid)
edge_start = edge_start_ret['element']
edge_parent = edge_start_ret['parent']
if edge_start is not None:
if edge_parent is None:
edge_parent = sdfg

if isinstance(edge_parent, SDFGState):
if not (isinstance(edge_start, nodes.Node) and
isinstance(parent, nodes.Node)):
return {
'error': {
'message': 'Failed to add edge',
'details': 'Must connect two nodes or two states',
},
}
memlet = Memlet()
edge_parent.add_edge(edge_start, None, parent, None, memlet)
elif isinstance(edge_parent, SDFG):
if not (isinstance(edge_start, SDFGState) and
isinstance(parent, SDFGState)):
return {
'error': {
'message': 'Failed to add edge',
'details': 'Must connect two nodes or two states',
},
}
isedge = InterstateEdge()
edge_parent.add_edge(edge_start, parent, isedge)
uuid = ['NONE']
else:
raise ValueError('No edge starting point provided')

old_meta = serialize.JSON_STORE_METADATA
serialize.JSON_STORE_METADATA = False
new_sdfg_str = sdfg.to_json()
serialize.JSON_STORE_METADATA = old_meta

return {
'sdfg': new_sdfg_str,
'uuid': uuid,
}
Loading

0 comments on commit cbe45d0

Please sign in to comment.