-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Tobias Aeschbacher <72465460+tobiasae@users.noreply.github.com> Co-authored-by: Tal Ben-Nun <tbennun@users.noreply.github.com>
- Loading branch information
1 parent
10e3e2b
commit cbe45d0
Showing
17 changed files
with
1,372 additions
and
742 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,3 +5,4 @@ node_modules | |
*.code-workspace | ||
.vscode | ||
.env | ||
__pycache__/ |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
# Copyright 2020-2021 ETH Zurich and the DaCe-VSCode authors. | ||
# All rights reserved. | ||
|
||
import ast | ||
from astunparse import unparse | ||
from dace.symbolic import pystr_to_symbolic | ||
from dace.sdfg import propagation | ||
from dace.libraries.blas import MatMul, Transpose | ||
from dace.libraries.standard import Reduce | ||
from dace import nodes, dtypes | ||
from sympy import function as spf | ||
|
||
from dace_vscode.utils import get_uuid, load_sdfg_from_json | ||
|
||
|
||
def symeval(val, symbols): | ||
first_replacement = { | ||
pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) | ||
for k in symbols.keys() | ||
} | ||
second_replacement = { | ||
pystr_to_symbolic('__REPLSYM_' + k): v | ||
for k, v in symbols.items() | ||
} | ||
return val.subs(first_replacement).subs(second_replacement) | ||
|
||
|
||
def evaluate_symbols(base, new): | ||
result = {} | ||
for k, v in new.items(): | ||
result[k] = symeval(v, base) | ||
return result | ||
|
||
|
||
def count_matmul(node, symbols, state): | ||
A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') | ||
B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') | ||
C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') | ||
result = 2 # Multiply, add | ||
# Batch | ||
if len(C_memlet.data.subset) == 3: | ||
result *= symeval(C_memlet.data.subset.size()[0], symbols) | ||
# M*N | ||
result *= symeval(C_memlet.data.subset.size()[-2], symbols) | ||
result *= symeval(C_memlet.data.subset.size()[-1], symbols) | ||
# K | ||
result *= symeval(A_memlet.data.subset.size()[-1], symbols) | ||
return result | ||
|
||
|
||
def count_reduce(node, symbols, state): | ||
result = 0 | ||
if node.wcr is not None: | ||
result += count_arithmetic_ops_code(node.wcr) | ||
in_memlet = None | ||
in_edges = state.in_edges(node) | ||
if in_edges is not None and len(in_edges) == 1: | ||
in_memlet = in_edges[0] | ||
if in_memlet is not None and in_memlet.data.volume is not None: | ||
result *= in_memlet.data.volume | ||
else: | ||
result = 0 | ||
return result | ||
|
||
|
||
bigo = spf.Function('bigo') | ||
PYFUNC_TO_ARITHMETICS = { | ||
'float': 0, | ||
'math.exp': 1, | ||
'math.tanh': 1, | ||
'math.sqrt': 1, | ||
'min': 0, | ||
'max': 0, | ||
'ceiling': 0, | ||
'floor': 0, | ||
} | ||
LIBNODES_TO_ARITHMETICS = { | ||
MatMul: count_matmul, | ||
Transpose: lambda *args: 0, | ||
Reduce: count_reduce, | ||
} | ||
|
||
|
||
class ArithmeticCounter(ast.NodeVisitor): | ||
|
||
def __init__(self): | ||
self.count = 0 | ||
|
||
def visit_BinOp(self, node): | ||
if isinstance(node.op, ast.MatMult): | ||
raise NotImplementedError('MatMult op count requires shape ' | ||
'inference') | ||
self.count += 1 | ||
return self.generic_visit(node) | ||
|
||
def visit_UnaryOp(self, node): | ||
self.count += 1 | ||
return self.generic_visit(node) | ||
|
||
def visit_Call(self, node): | ||
fname = unparse(node.func)[:-1] | ||
if fname not in PYFUNC_TO_ARITHMETICS: | ||
print('WARNING: Unrecognized python function "%s"' % fname) | ||
return self.generic_visit(node) | ||
self.count += PYFUNC_TO_ARITHMETICS[fname] | ||
return self.generic_visit(node) | ||
|
||
def visit_AugAssign(self, node): | ||
return self.visit_BinOp(node) | ||
|
||
def visit_For(self, node): | ||
raise NotImplementedError | ||
|
||
def visit_While(self, node): | ||
raise NotImplementedError | ||
|
||
|
||
def count_arithmetic_ops_code(code): | ||
ctr = ArithmeticCounter() | ||
if isinstance(code, (tuple, list)): | ||
for stmt in code: | ||
ctr.visit(stmt) | ||
elif isinstance(code, str): | ||
ctr.visit(ast.parse(code)) | ||
else: | ||
ctr.visit(code) | ||
return ctr.count | ||
|
||
|
||
def create_arith_ops_map_state(state, arith_map, symbols): | ||
scope_tree_root = state.scope_tree()[None] | ||
scope_dict = state.scope_children() | ||
|
||
def traverse(scope): | ||
repetitions = 1 | ||
traversal_result = 0 | ||
if scope.entry is not None: | ||
repetitions = scope.entry.map.range.num_elements() | ||
for node in scope_dict[scope.entry]: | ||
node_result = 0 | ||
if isinstance(node, nodes.NestedSDFG): | ||
nested_syms = {} | ||
nested_syms.update(symbols) | ||
nested_syms.update( | ||
evaluate_symbols(symbols, node.symbol_mapping)) | ||
node_result += create_arith_ops_map(node.sdfg, arith_map, | ||
nested_syms) | ||
elif isinstance(node, nodes.LibraryNode): | ||
node_result += LIBNODES_TO_ARITHMETICS[type(node)](node, | ||
symbols, | ||
state) | ||
elif isinstance(node, nodes.Tasklet): | ||
if node.code.language == dtypes.Language.CPP: | ||
for oedge in state.out_edges(node): | ||
node_result += bigo(oedge.data.num_accesses) | ||
else: | ||
node_result += count_arithmetic_ops_code(node.code.code) | ||
elif isinstance(node, nodes.MapEntry): | ||
map_scope = None | ||
for child_scope in scope.children: | ||
if child_scope.entry == node: | ||
map_scope = child_scope | ||
break | ||
map_result = 0 | ||
if map_scope is not None: | ||
map_result = traverse(map_scope) | ||
node_result += map_result | ||
elif isinstance(node, nodes.MapExit): | ||
# Don't do anything for map exists. | ||
pass | ||
elif isinstance(node, | ||
(nodes.CodeNode, nodes.AccessNode)): | ||
for oedge in state.out_edges(node): | ||
if oedge.data.wcr is not None: | ||
node_result += count_arithmetic_ops_code( | ||
oedge.data.wcr) | ||
|
||
arith_map[get_uuid(node, state)] = str(node_result) | ||
traversal_result += node_result | ||
return repetitions * traversal_result | ||
|
||
state_result = traverse(scope_tree_root) | ||
|
||
if state.executions is not None: | ||
if (state.dynamic_executions is not None and state.dynamic_executions | ||
and state.executions == 0): | ||
state_result = 0 | ||
else: | ||
state_result *= state.executions | ||
|
||
arith_map[get_uuid(state)] = str(state_result) | ||
return state_result | ||
|
||
|
||
def create_arith_ops_map(sdfg, arith_map, symbols): | ||
sdfg_ops = 0 | ||
for state in sdfg.nodes(): | ||
sdfg_ops += create_arith_ops_map_state(state, arith_map, symbols) | ||
arith_map[get_uuid(sdfg)] = str(sdfg_ops) | ||
|
||
# Replace any operations that math.js does not understand. | ||
for uuid in arith_map: | ||
arith_map[uuid] = arith_map[uuid].replace('**', '^') | ||
|
||
return sdfg_ops | ||
|
||
|
||
def get_arith_ops(sdfg_json): | ||
loaded = load_sdfg_from_json(sdfg_json) | ||
if loaded['error'] is not None: | ||
return loaded['error'] | ||
sdfg = loaded['sdfg'] | ||
|
||
propagation.propagate_memlets_sdfg(sdfg) | ||
|
||
arith_map = {} | ||
create_arith_ops_map(sdfg, arith_map, {}) | ||
return { | ||
'arith_ops_map': arith_map, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
# Copyright 2020-2021 ETH Zurich and the DaCe-VSCode authors. | ||
# All rights reserved. | ||
|
||
from dace import ( | ||
serialize, nodes, SDFG, SDFGState, InterstateEdge, Memlet, dtypes | ||
) | ||
from dace_vscode.utils import ( | ||
load_sdfg_from_json, | ||
find_graph_element_by_uuid, | ||
get_uuid, | ||
) | ||
from pydoc import locate | ||
|
||
|
||
def remove_sdfg_elements(sdfg_json, uuids): | ||
from dace.sdfg.graph import Edge | ||
|
||
old_meta = serialize.JSON_STORE_METADATA | ||
serialize.JSON_STORE_METADATA = False | ||
|
||
loaded = load_sdfg_from_json(sdfg_json) | ||
if loaded['error'] is not None: | ||
return loaded['error'] | ||
sdfg = loaded['sdfg'] | ||
|
||
elements = [] | ||
for uuid in uuids: | ||
elements.append(find_graph_element_by_uuid(sdfg, uuid)) | ||
|
||
for element_ret in elements: | ||
element = element_ret['element'] | ||
parent = element_ret['parent'] | ||
|
||
if parent is not None and element is not None: | ||
if isinstance(element, Edge): | ||
parent.remove_edge(element) | ||
else: | ||
parent.remove_node(element) | ||
else: | ||
return { | ||
'error': { | ||
'message': 'Failed to delete element', | ||
'details': 'Element or parent not found', | ||
}, | ||
} | ||
|
||
new_sdfg = sdfg.to_json() | ||
serialize.JSON_STORE_METADATA = old_meta | ||
|
||
return { | ||
'sdfg': new_sdfg, | ||
} | ||
|
||
|
||
def insert_sdfg_element(sdfg_str, type, parent_uuid, edge_a_uuid): | ||
sdfg_answer = load_sdfg_from_json(sdfg_str) | ||
sdfg = sdfg_answer['sdfg'] | ||
uuid = 'error' | ||
ret = find_graph_element_by_uuid(sdfg, parent_uuid) | ||
parent = ret['element'] | ||
|
||
libname = None | ||
if type is not None and isinstance(type, str): | ||
split_type = type.split('|') | ||
if len(split_type) == 2: | ||
type = split_type[0] | ||
libname = split_type[1] | ||
|
||
if type == 'SDFGState': | ||
if parent is None: | ||
parent = sdfg | ||
elif isinstance(parent, nodes.NestedSDFG): | ||
parent = parent.sdfg | ||
state = parent.add_state() | ||
uuid = [get_uuid(state)] | ||
elif type == 'AccessNode': | ||
arrays = list(parent.parent.arrays.keys()) | ||
if len(arrays) == 0: | ||
parent.parent.add_array('tmp', [1], dtype=dtypes.float64) | ||
arrays = list(parent.parent.arrays.keys()) | ||
node = parent.add_access(arrays[0]) | ||
uuid = [get_uuid(node, parent)] | ||
elif type == 'Map': | ||
map_entry, map_exit = parent.add_map('map', dict(i='0:1')) | ||
uuid = [get_uuid(map_entry, parent), get_uuid(map_exit, parent)] | ||
elif type == 'Consume': | ||
consume_entry, consume_exit = parent.add_consume('consume', ('i', '1')) | ||
uuid = [get_uuid(consume_entry, parent), get_uuid(consume_exit, parent)] | ||
elif type == 'Tasklet': | ||
tasklet = parent.add_tasklet( | ||
name='placeholder', | ||
inputs={'in'}, | ||
outputs={'out'}, | ||
code='') | ||
uuid = [get_uuid(tasklet, parent)] | ||
elif type == 'NestedSDFG': | ||
sub_sdfg = SDFG('nested_sdfg') | ||
sub_sdfg.add_array('in', [1], dtypes.float32) | ||
sub_sdfg.add_array('out', [1], dtypes.float32) | ||
|
||
nsdfg = parent.add_nested_sdfg(sub_sdfg, sdfg, {'in'}, {'out'}) | ||
uuid = [get_uuid(nsdfg, parent)] | ||
elif type == 'LibraryNode': | ||
if libname is None: | ||
return { | ||
'error': { | ||
'message': 'Failed to add library node', | ||
'details': 'Must provide a valid library node type', | ||
}, | ||
} | ||
libnode_class = locate(libname) | ||
libnode = libnode_class() | ||
parent.add_node(libnode) | ||
uuid = [get_uuid(libnode, parent)] | ||
elif type == 'Edge': | ||
edge_start_ret = find_graph_element_by_uuid(sdfg, edge_a_uuid) | ||
edge_start = edge_start_ret['element'] | ||
edge_parent = edge_start_ret['parent'] | ||
if edge_start is not None: | ||
if edge_parent is None: | ||
edge_parent = sdfg | ||
|
||
if isinstance(edge_parent, SDFGState): | ||
if not (isinstance(edge_start, nodes.Node) and | ||
isinstance(parent, nodes.Node)): | ||
return { | ||
'error': { | ||
'message': 'Failed to add edge', | ||
'details': 'Must connect two nodes or two states', | ||
}, | ||
} | ||
memlet = Memlet() | ||
edge_parent.add_edge(edge_start, None, parent, None, memlet) | ||
elif isinstance(edge_parent, SDFG): | ||
if not (isinstance(edge_start, SDFGState) and | ||
isinstance(parent, SDFGState)): | ||
return { | ||
'error': { | ||
'message': 'Failed to add edge', | ||
'details': 'Must connect two nodes or two states', | ||
}, | ||
} | ||
isedge = InterstateEdge() | ||
edge_parent.add_edge(edge_start, parent, isedge) | ||
uuid = ['NONE'] | ||
else: | ||
raise ValueError('No edge starting point provided') | ||
|
||
old_meta = serialize.JSON_STORE_METADATA | ||
serialize.JSON_STORE_METADATA = False | ||
new_sdfg_str = sdfg.to_json() | ||
serialize.JSON_STORE_METADATA = old_meta | ||
|
||
return { | ||
'sdfg': new_sdfg_str, | ||
'uuid': uuid, | ||
} |
Oops, something went wrong.