Skip to content

Commit

Permalink
Fix temporary transient counter during Python parsing of nested calls (
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun authored Nov 13, 2024
1 parent 02c9c37 commit 17e4a88
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 15 deletions.
3 changes: 3 additions & 0 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -3940,6 +3940,9 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no
for arg in args_to_remove:
args.remove(arg)

# Refresh temporary transient counter of the nested SDFG
sdfg.refresh_temp_transients()

# Change connector names
updated_args = []
arrays_before = list(sdfg.arrays.items())
Expand Down
14 changes: 14 additions & 0 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1952,6 +1952,20 @@ def temp_data_name(self):
self._temp_transients += 1
return name

def refresh_temp_transients(self):
"""
Updates the temporary transient counter of this SDFG by querying the maximum number among the
``__tmp###`` data descriptors.
"""
temp_transients = [k[5:] for k in self.arrays.keys() if k.startswith('__tmp')]
max_temp_transient = 0
for arr_suffix in temp_transients:
try:
max_temp_transient = max(max_temp_transient, int(arr_suffix))
except ValueError: # Not of the form __tmp###
continue
self._temp_transients = max_temp_transient + 1

def add_temp_transient(self,
shape,
dtype,
Expand Down
2 changes: 1 addition & 1 deletion tests/npbench/misc/mandelbrot1_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def run_mandelbrot1(device_type: dace.dtypes.DeviceType):
Z, N = sdfg(xmin, xmax, ymin, ymax, maxiter, horizon)

# Compute ground truth and validate
Z_ref, N_ref = ground_truth(xmin, xmax, ymin, ymax, xn, yn, maxiter)
Z_ref, N_ref = ground_truth(xmin, xmax, ymin, ymax, XN, YN, maxiter)
assert np.allclose(Z, Z_ref)
assert np.allclose(N, N_ref)
return sdfg
Expand Down
2 changes: 1 addition & 1 deletion tests/npbench/misc/mandelbrot2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def run_mandelbrot2(device_type: dace.dtypes.DeviceType):
Z, N = sdfg(xmin, xmax, ymin, ymax, maxiter, horizon)

# Compute ground truth and validate
Z_ref, N_ref = ground_truth(xmin, xmax, ymin, ymax, xn, yn, maxiter)
Z_ref, N_ref = ground_truth(xmin, xmax, ymin, ymax, XN, YN, maxiter)
assert np.allclose(Z, Z_ref)
assert np.allclose(N, N_ref)
return sdfg
Expand Down
70 changes: 57 additions & 13 deletions tests/python_frontend/nested_name_accesses_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved.
import dace as dc
import numpy as np
import os
Expand Down Expand Up @@ -30,6 +30,7 @@ def test_nested_name_accesses():


def test_nested_offset_access():

@dc.program
def nested_offset_access(inp: dc.float64[6, 5, 5]):
out = np.zeros((5, 5, 5), np.float64)
Expand All @@ -46,6 +47,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):


def test_nested_offset_access_dappy():

@dc.program
def nested_offset_access(inp: dc.float64[6, 5, 5]):
out = np.zeros((5, 5, 5), np.float64)
Expand All @@ -66,6 +68,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):


def test_nested_multi_offset_access():

@dc.program
def nested_offset_access(inp: dc.float64[6, 5, 10]):
out = np.zeros((5, 5, 10), np.float64)
Expand All @@ -83,6 +86,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 10]):


def test_nested_multi_offset_access_dappy():

@dc.program
def nested_offset_access(inp: dc.float64[6, 5, 10]):
out = np.zeros((5, 5, 10), np.float64)
Expand All @@ -104,6 +108,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 10]):


def test_nested_dec_offset_access():

@dc.program
def nested_offset_access(inp: dc.float64[6, 5, 5]):
out = np.zeros((5, 5, 5), np.float64)
Expand All @@ -120,6 +125,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):


def test_nested_dec_offset_access_dappy():

@dc.program
def nested_offset_access(inp: dc.float64[6, 5, 5]):
out = np.zeros((5, 5, 5), np.float64)
Expand All @@ -140,6 +146,7 @@ def nested_offset_access(inp: dc.float64[6, 5, 5]):


def test_nested_offset_access_nested_dependency():

@dc.program
def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 5]):
out = np.zeros((5, 5, 5), np.float64)
Expand All @@ -161,6 +168,7 @@ def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 5]):


def test_nested_offset_access_nested_dependency_dappy():

@dc.program
def nested_offset_access_nested_dep(inp: dc.float64[6, 5, 10]):
out = np.zeros((5, 5, 10), np.float64)
Expand Down Expand Up @@ -188,19 +196,19 @@ def test_access_to_nested_transient():
NBLOCKS = 5

@dc.program
def small_wip(inp: dc.float64[KLEV+1, KLON, NBLOCKS], out: dc.float64[KLEV, KLON, NBLOCKS]):
def small_wip(inp: dc.float64[KLEV + 1, KLON, NBLOCKS], out: dc.float64[KLEV, KLON, NBLOCKS]):
for jn in dc.map[0:NBLOCKS]:
tmp = np.zeros([KLEV+1, KLON])
tmp = np.zeros([KLEV + 1, KLON])
for jl in range(KLON):
for jk in range(KLEV):
tmp[jk, jl] = inp[jk, jl, jn] + inp[jk+1, jl, jn]
tmp[jk, jl] = inp[jk, jl, jn] + inp[jk + 1, jl, jn]

for jl in range(KLON):
for jk in range(KLEV):
out[jk, jl, jn] = tmp[jk, jl] + tmp[jk+1, jl]
out[jk, jl, jn] = tmp[jk, jl] + tmp[jk + 1, jl]

rng = np.random.default_rng(42)
inp = rng.random((KLEV+1, KLON, NBLOCKS))
inp = rng.random((KLEV + 1, KLON, NBLOCKS))
ref = np.zeros((KLEV, KLON, NBLOCKS))
val = np.zeros((KLEV, KLON, NBLOCKS))

Expand All @@ -217,27 +225,27 @@ def test_access_to_nested_transient_dappy():
NBLOCKS = 5

@dc.program
def small_wip_dappy(inp: dc.float64[KLEV+1, KLON, NBLOCKS], out: dc.float64[KLEV, KLON, NBLOCKS]):
def small_wip_dappy(inp: dc.float64[KLEV + 1, KLON, NBLOCKS], out: dc.float64[KLEV, KLON, NBLOCKS]):
for jn in dc.map[0:NBLOCKS]:
tmp = np.zeros([KLEV+1, KLON])
tmp = np.zeros([KLEV + 1, KLON])
for jl in range(KLON):
for jk in range(KLEV):
with dc.tasklet():
in1 << inp[jk, jl, jn]
in2 << inp[jk+1, jl, jn]
in2 << inp[jk + 1, jl, jn]
out1 >> tmp[jk, jl]
out1 = in1 + in2

for jl in range(KLON):
for jk in range(KLEV):
with dc.tasklet():
in1 << tmp[jk, jl]
in2 << tmp[jk+1, jl]
in2 << tmp[jk + 1, jl]
out1 >> out[jk, jl, jn]
out1 = in1 + in2

rng = np.random.default_rng(42)
inp = rng.random((KLEV+1, KLON, NBLOCKS))
inp = rng.random((KLEV + 1, KLON, NBLOCKS))
ref = np.zeros((KLEV, KLON, NBLOCKS))
val = np.zeros((KLEV, KLON, NBLOCKS))

Expand All @@ -247,6 +255,41 @@ def small_wip_dappy(inp: dc.float64[KLEV+1, KLON, NBLOCKS], out: dc.float64[KLEV
assert np.allclose(val, ref)


def test_issue_1139():
"""
Regression test generated from issue #1139.
The origin of the bug was in the Python frontend: An SDFG parsed by the frontend kept
a number called ``_temp_transients`` that specifies how many ``__tmp*`` arrays have been created.
This number is used to avoid name clashes when inlining SDFGs (although unnecessary).
However, if a nested SDFG had already been simplified, where transformations may change the number
of transients (or add new ones via inlining, which is happening in this bug), the ``_temp_transients``
field becomes out of date and renaming the fields during inlining removes data descriptors.
"""
XN = dc.symbol('XN')
YN = dc.symbol('YN')
N = dc.symbol('N')

@dc.program
def nester(start: dc.float64, stop: dc.float64, X: dc.float64[N]):
dist = (stop - start) / (N - 1)
for i in dc.map[0:N]:
X[i] = start + i * dist

@dc.program
def tester(xmin: dc.float64, xmax: dc.float64):
a = np.ndarray((XN, YN), dtype=np.int64)
b = np.ndarray((XN, YN), dtype=np.int64)
c = np.ndarray((XN, ), dtype=np.float64)
nester(xmin, xmax, c)
return c

xmin = 0.123
xmax = 4.567
c = tester(xmin, xmax, XN=30, YN=40)
assert np.allclose(c, np.linspace(xmin, xmax, 30))


if __name__ == "__main__":
test_nested_name_accesses()
test_nested_offset_access()
Expand All @@ -259,3 +302,4 @@ def small_wip_dappy(inp: dc.float64[KLEV+1, KLON, NBLOCKS], out: dc.float64[KLEV
test_nested_offset_access_nested_dependency_dappy()
test_access_to_nested_transient()
test_access_to_nested_transient_dappy()
test_issue_1139()

0 comments on commit 17e4a88

Please sign in to comment.