Skip to content

Commit

Permalink
Merge branch 'spcl:master' into additions
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose authored Oct 15, 2024
2 parents 40de8f6 + 073b613 commit a7ab6c3
Show file tree
Hide file tree
Showing 92 changed files with 4,790 additions and 1,583 deletions.
1 change: 1 addition & 0 deletions dace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .frontend.operations import reduce, elementwise

from . import data, hooks, subsets
from .codegen.compiled_sdfg import CompiledSDFG
from .config import Config
from .sdfg import SDFG, SDFGState, InterstateEdge, nodes
from .sdfg.propagation import propagate_memlets_sdfg, propagate_memlet
Expand Down
220 changes: 220 additions & 0 deletions dace/cli/sdfg_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
""" SDFG diff tool. """

import argparse
from hashlib import sha256
import json
import os
import platform
import tempfile
from typing import Dict, Literal, Set, Tuple, Union

import jinja2
import dace
from dace import memlet as mlt
from dace.sdfg import nodes as nd
from dace.sdfg.graph import Edge, MultiConnectorEdge
from dace.sdfg.sdfg import InterstateEdge
from dace.sdfg.state import ControlFlowBlock
import dace.serialize


DiffableT = Union[ControlFlowBlock, nd.Node, MultiConnectorEdge[mlt.Memlet], Edge[InterstateEdge]]
DiffSetsT = Tuple[Set[str], Set[str], Set[str]]


def _print_diff(sdfg_A: dace.SDFG, sdfg_B: dace.SDFG, diff_sets: DiffSetsT) -> None:
all_id_elements_A: Dict[str, DiffableT] = dict()
all_id_elements_B: Dict[str, DiffableT] = dict()

all_id_elements_A[sdfg_A.guid] = sdfg_A
for n, _ in sdfg_A.all_nodes_recursive():
all_id_elements_A[n.guid] = n
for e, _ in sdfg_A.all_edges_recursive():
all_id_elements_A[e.data.guid] = e

all_id_elements_B[sdfg_B.guid] = sdfg_B
for n, _ in sdfg_B.all_nodes_recursive():
all_id_elements_B[n.guid] = n
for e, _ in sdfg_B.all_edges_recursive():
all_id_elements_B[e.data.guid] = e

no_removed = True
no_added = True
no_changed = True
if len(diff_sets[0]) > 0:
print('Removed elements:')
for k in diff_sets[0]:
print(all_id_elements_A[k])
no_removed = False
if len(diff_sets[1]) > 0:
if not no_removed:
print('')
print('Added elements:')
for k in diff_sets[1]:
print(all_id_elements_B[k])
no_added = False
if len(diff_sets[2]) > 0:
if not no_removed or not no_added:
print('')
print('Changed elements:')
for k in diff_sets[2]:
print(all_id_elements_B[k])
no_changed = False

if no_removed and no_added and no_changed:
print('SDFGs are identical')


def _sdfg_diff(sdfg_A: dace.SDFG, sdfg_B: dace.SDFG, eq_strategy = Union[Literal['hash', '==']]) -> DiffSetsT:
all_id_elements_A: Dict[str, DiffableT] = dict()
all_id_elements_B: Dict[str, DiffableT] = dict()

all_id_elements_A[sdfg_A.guid] = sdfg_A
for n, _ in sdfg_A.all_nodes_recursive():
all_id_elements_A[n.guid] = n
for e, _ in sdfg_A.all_edges_recursive():
all_id_elements_A[e.data.guid] = e

all_id_elements_B[sdfg_B.guid] = sdfg_B
for n, _ in sdfg_B.all_nodes_recursive():
all_id_elements_B[n.guid] = n
for e, _ in sdfg_B.all_edges_recursive():
all_id_elements_B[e.data.guid] = e

a_keys = set(all_id_elements_A.keys())
b_keys = set(all_id_elements_B.keys())

added_keys = b_keys - a_keys
removed_keys = a_keys - b_keys
changed_keys = set()

remaining_keys = a_keys - removed_keys
if remaining_keys != b_keys - added_keys:
raise RuntimeError(
'The sets of remaining keys between graphs A and B after accounting for added and removed keys do not match'
)
for k in remaining_keys:
el_a = all_id_elements_A[k]
el_b = all_id_elements_B[k]

if eq_strategy == 'hash':
try:
if isinstance(el_a, Edge):
attr_a = dace.serialize.all_properties_to_json(el_a.data)
else:
attr_a = dace.serialize.all_properties_to_json(el_a)
hash_a = sha256(json.dumps(attr_a).encode('utf-8')).hexdigest()
except KeyError:
hash_a = None
try:
if isinstance(el_b, Edge):
attr_b = dace.serialize.all_properties_to_json(el_b.data)
else:
attr_b = dace.serialize.all_properties_to_json(el_b)
hash_b = sha256(json.dumps(attr_b).encode('utf-8')).hexdigest()
except KeyError:
hash_b = None

if hash_a != hash_b:
changed_keys.add(k)
else:
if isinstance(el_a, Edge):
attr_a = dace.serialize.all_properties_to_json(el_a.data)
else:
attr_a = dace.serialize.all_properties_to_json(el_a)
if isinstance(el_b, Edge):
attr_b = dace.serialize.all_properties_to_json(el_b.data)
else:
attr_b = dace.serialize.all_properties_to_json(el_b)

if attr_a != attr_b:
changed_keys.add(k)

return removed_keys, added_keys, changed_keys


def main():
# Command line options parser
parser = argparse.ArgumentParser(description='SDFG diff tool.')

# Required argument for SDFG file path
parser.add_argument('sdfg_A_path', help='<PATH TO FIRST SDFG FILE>', type=str)
parser.add_argument('sdfg_B_path', help='<PATH TO SECOND SDFG FILE>', type=str)

parser.add_argument('-g',
'--graphical',
dest='graphical',
action='store_true',
help="If set, visualize the difference graphically",
default=False)
parser.add_argument('-o',
'--output',
dest='output',
help="The output filename to generate",
type=str)
parser.add_argument('-H',
'--hash',
dest='hash',
action='store_true',
help="If set, use the hash of JSON serialized properties for change checks instead of " +
"Python's dictionary equivalence checks. This makes changes order sensitive.",
default=False)

args = parser.parse_args()

if not os.path.isfile(args.sdfg_A_path):
print('SDFG file', args.sdfg_A_path, 'not found')
exit(1)

if not os.path.isfile(args.sdfg_B_path):
print('SDFG file', args.sdfg_B_path, 'not found')
exit(1)

sdfg_A = dace.SDFG.from_file(args.sdfg_A_path)
sdfg_B = dace.SDFG.from_file(args.sdfg_B_path)

eq_strategy = 'hash' if args.hash else '=='

diff_sets = _sdfg_diff(sdfg_A, sdfg_B, eq_strategy)

if args.graphical:
basepath = os.path.join(os.path.dirname(os.path.realpath(dace.__file__)), 'viewer')
template_loader = jinja2.FileSystemLoader(searchpath=os.path.join(basepath, 'templates'))
template_env = jinja2.Environment(loader=template_loader)
template = template_env.get_template('sdfv_diff_view.html')

# if we are serving, the base path should just be root
html = template.render(sdfgA=json.dumps(dace.serialize.dumps(sdfg_A.to_json())),
sdfgB=json.dumps(dace.serialize.dumps(sdfg_B.to_json())),
removedKeysList=json.dumps(list(diff_sets[0])),
addedKeysList=json.dumps(list(diff_sets[1])),
changedKeysList=json.dumps(list(diff_sets[2])),
dir=basepath + '/')

if args.output:
fd = None
html_filename = args.output
else:
fd, html_filename = tempfile.mkstemp(suffix=".sdfg.html")

with open(html_filename, 'w') as f:
f.write(html)

if fd is not None:
os.close(fd)

system = platform.system()

if system == 'Windows':
os.system(html_filename)
elif system == 'Darwin':
os.system('open %s' % html_filename)
else:
os.system('xdg-open %s' % html_filename)
else:
_print_diff(sdfg_A, sdfg_B, diff_sets)


if __name__ == '__main__':
main()
3 changes: 3 additions & 0 deletions dace/codegen/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.

from dace.codegen.compiled_sdfg import CompiledSDFG
4 changes: 3 additions & 1 deletion dace/codegen/compiled_sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sympy as sp

from dace import data as dt, dtypes, hooks, symbolic
from dace.codegen import exceptions as cgx, common
from dace.codegen import exceptions as cgx
from dace.config import Config
from dace.frontend import operations

Expand Down Expand Up @@ -369,6 +369,7 @@ def finalize(self):
f'An error was detected after running "{self._sdfg.name}": {self._get_error_text(res)}')

def _get_error_text(self, result: Union[str, int]) -> str:
from dace.codegen import common # Circular import
if self.has_gpu_code:
if isinstance(result, int):
result = common.get_gpu_runtime().get_error_string(result)
Expand Down Expand Up @@ -428,6 +429,7 @@ def fast_call(
:note: You may use `_construct_args()` to generate the processed arguments.
"""
from dace.codegen import common # Circular import
try:
# Call initializer function if necessary, then SDFG
if self._initialized is False:
Expand Down
Loading

0 comments on commit a7ab6c3

Please sign in to comment.