Skip to content

Commit

Permalink
Merge branch 'spcl:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
hodelcl authored Dec 20, 2023
2 parents f8a2a9d + 7c06755 commit 0bd0927
Show file tree
Hide file tree
Showing 12 changed files with 305 additions and 31 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/fpga-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ on:
branches: [ master, ci-fix ]
pull_request:
branches: [ master, ci-fix ]
merge_group:
branches: [ master, ci-fix ]

jobs:
test-fpga:
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/general-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ on:
branches: [ master, ci-fix ]
pull_request:
branches: [ master, ci-fix ]
merge_group:
branches: [ master, ci-fix ]

jobs:
test:
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/gpu-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ on:
branches: [ master, ci-fix ]
pull_request:
branches: [ master, ci-fix ]
merge_group:
branches: [ master, ci-fix ]

jobs:
test-gpu:
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/heterogeneous-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ on:
branches: [ master, ci-fix ]
pull_request:
branches: [ master, ci-fix ]
merge_group:
branches: [ master, ci-fix ]

jobs:
test-heterogeneous:
Expand Down
16 changes: 14 additions & 2 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -1132,10 +1132,22 @@ def _emit_copy(self, state_id, src_node, src_storage, dst_node, dst_storage, dst
func=funcname,
type=dst_node.desc(sdfg).dtype.ctype,
bdims=', '.join(_topy(self._block_dims)),
is_async='true' if state_dfg.out_degree(dst_node) > 0 else 'true',
is_async='true' if state_dfg.out_degree(dst_node) == 0 else 'false',
accum=accum,
args=', '.join([src_expr] + _topy(src_strides) + [dst_expr] + custom_reduction +
_topy(dst_strides) + _topy(copy_shape))), sdfg, state_id, [src_node, dst_node])
elif funcname == 'dace::SharedToGlobal1D':
# special case: use a new template struct that provides functions for copy and reduction
callsite_stream.write(
(' {func}<{type}, {bdims}, {copysize}, {is_async}>{accum}({args});').format(
func=funcname,
type=dst_node.desc(sdfg).dtype.ctype,
bdims=', '.join(_topy(self._block_dims)),
copysize=', '.join(_topy(copy_shape)),
is_async='true' if state_dfg.out_degree(dst_node) == 0 else 'false',
accum=accum or '::Copy',
args=', '.join([src_expr] + _topy(src_strides) + [dst_expr] + _topy(dst_strides) + custom_reduction)), sdfg,
state_id, [src_node, dst_node])
else:
callsite_stream.write(
(' {func}<{type}, {bdims}, {copysize}, ' +
Expand All @@ -1145,7 +1157,7 @@ def _emit_copy(self, state_id, src_node, src_storage, dst_node, dst_storage, dst
bdims=', '.join(_topy(self._block_dims)),
copysize=', '.join(_topy(copy_shape)),
dststrides=', '.join(_topy(dst_strides)),
is_async='true' if state_dfg.out_degree(dst_node) > 0 else 'true',
is_async='true' if state_dfg.out_degree(dst_node) == 0 else 'false',
accum=accum,
args=', '.join([src_expr] + _topy(src_strides) + [dst_expr] + custom_reduction)), sdfg,
state_id, [src_node, dst_node])
Expand Down
11 changes: 4 additions & 7 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,7 @@ def to_json(self):
attrs['is_data_src'] = self._is_data_src

# Fill in legacy (DEPRECATED) values for backwards compatibility
attrs['num_accesses'] = \
str(self.volume) if not self.dynamic else -1
attrs['num_accesses'] = str(self.volume) if not self.dynamic else -1

return {"type": "Memlet", "attributes": attrs}

Expand Down Expand Up @@ -421,13 +420,11 @@ def from_array(dataname, datadesc, wcr=None):
return Memlet.simple(dataname, rng, wcr_str=wcr)

def __hash__(self):
return hash((self.volume, self.src_subset, self.dst_subset, str(self.wcr)))
return hash((self.data, self.volume, self.src_subset, self.dst_subset, str(self.wcr)))

def __eq__(self, other):
return all([
self.volume == other.volume, self.src_subset == other.src_subset, self.dst_subset == other.dst_subset,
self.wcr == other.wcr
])
return all((self.data == other.data, self.volume == other.volume, self.src_subset == other.src_subset,
self.dst_subset == other.dst_subset, self.wcr == other.wcr))

def replace(self, repl_dict):
"""
Expand Down
53 changes: 35 additions & 18 deletions dace/runtime/include/dace/cuda/copy.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -736,60 +736,77 @@ namespace dace
int COPY_XLEN, bool ASYNC>
struct SharedToGlobal1D
{
template <typename WCR>
static DACE_DFI void Accum(const T *smem, int src_xstride, T *ptr, int DST_XSTRIDE, WCR wcr)
static constexpr int BLOCK_SIZE = BLOCK_WIDTH * BLOCK_HEIGHT * BLOCK_DEPTH;
static constexpr int TOTAL = COPY_XLEN;
static constexpr int WRITES = TOTAL / BLOCK_SIZE;
static constexpr int REM_WRITES = TOTAL % BLOCK_SIZE;

static DACE_DFI void Copy(const T *smem, int src_xstride, T *ptr, int dst_xstride)
{
// Linear thread ID
int ltid = GetLinearTID<BLOCK_WIDTH, BLOCK_HEIGHT, BLOCK_DEPTH>();

#pragma unroll
for (int i = 0; i < WRITES; ++i) {
*(ptr + (ltid + i * BLOCK_SIZE) * dst_xstride) =
*(smem + (ltid + i * BLOCK_SIZE) * src_xstride);
}

if (REM_WRITES != 0 && ltid < REM_WRITES) {
*(ptr + (ltid + WRITES*BLOCK_SIZE)* dst_xstride) =
*(smem + (ltid + WRITES * BLOCK_SIZE) * src_xstride);
}

if (!ASYNC)
__syncthreads();
}

template <typename WCR>
static DACE_DFI void Accum(const T *smem, int src_xstride, T *ptr, int dst_xstride, WCR wcr)
{
// Linear thread ID
int ltid = GetLinearTID<BLOCK_WIDTH, BLOCK_HEIGHT, BLOCK_DEPTH>();
constexpr int BLOCK_SIZE = BLOCK_WIDTH * BLOCK_HEIGHT * BLOCK_DEPTH;
constexpr int TOTAL = COPY_XLEN;
constexpr int WRITES = TOTAL / BLOCK_SIZE;
constexpr int REM_WRITES = TOTAL % BLOCK_SIZE;

#pragma unroll
for (int i = 0; i < WRITES; ++i) {
wcr_custom<T>::template reduce(
wcr, ptr + (ltid + i * BLOCK_SIZE) * DST_XSTRIDE,
wcr, ptr + (ltid + i * BLOCK_SIZE) * dst_xstride,
*(smem + (ltid + i * BLOCK_SIZE) * src_xstride));
}

if (REM_WRITES != 0) {
if (ltid < REM_WRITES)
wcr_custom<T>::template reduce(
ptr + (ltid + WRITES * BLOCK_SIZE)* DST_XSTRIDE,
ptr + (ltid + WRITES * BLOCK_SIZE)* dst_xstride,
*(smem + (ltid + WRITES * BLOCK_SIZE) * src_xstride));
}

if (!ASYNC)
__syncthreads();
}

template <ReductionType REDTYPE>
static DACE_DFI void Accum(const T *smem, int src_xstride, T *ptr, int DST_XSTRIDE)
static DACE_DFI void Accum(const T *smem, int src_xstride, T *ptr, int dst_xstride)
{
if (!ASYNC)
__syncthreads();

// Linear thread ID
int ltid = GetLinearTID<BLOCK_WIDTH, BLOCK_HEIGHT, BLOCK_DEPTH>();
constexpr int BLOCK_SIZE = BLOCK_WIDTH * BLOCK_HEIGHT * BLOCK_DEPTH;
constexpr int TOTAL = COPY_XLEN;
constexpr int WRITES = TOTAL / BLOCK_SIZE;
constexpr int REM_WRITES = TOTAL % BLOCK_SIZE;

#pragma unroll
for (int i = 0; i < WRITES; ++i) {
wcr_fixed<REDTYPE, T>::template reduce_atomic(
ptr + (ltid + i * BLOCK_SIZE) * DST_XSTRIDE,
ptr + (ltid + i * BLOCK_SIZE) * dst_xstride,
*(smem + (ltid + i * BLOCK_SIZE) * src_xstride));
}

if (REM_WRITES != 0) {
if (ltid < REM_WRITES)
wcr_fixed<REDTYPE, T>::template reduce_atomic(
ptr + (ltid + WRITES*BLOCK_SIZE)* DST_XSTRIDE,
ptr + (ltid + WRITES*BLOCK_SIZE)* dst_xstride,
*(smem + (ltid + WRITES * BLOCK_SIZE) * src_xstride));
}

if (!ASYNC)
__syncthreads();
}
};

Expand Down
80 changes: 79 additions & 1 deletion dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from collections import defaultdict
from dace.transformation import pass_pipeline as ppl
from dace import SDFG, SDFGState, properties, InterstateEdge
from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt
from dace.sdfg.graph import Edge
from dace.sdfg import nodes as nd
from dace.sdfg.analysis import cfg
Expand Down Expand Up @@ -505,3 +505,81 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i
del result[desc][write]
top_result[sdfg.sdfg_id] = result
return top_result


@properties.make_properties
class AccessRanges(ppl.Pass):
"""
For each data descriptor, finds all memlets used to access it (read/write ranges).
"""

CATEGORY: str = 'Analysis'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Nothing

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.Memlets

def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]:
"""
:return: A dictionary mapping each data descriptor name to a set of memlets.
"""
top_result: Dict[int, Dict[str, Set[Memlet]]] = dict()

for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Set[Memlet]] = defaultdict(set)
for state in sdfg.states():
for anode in state.data_nodes():
for e in state.all_edges(anode):
if e.dst is anode and e.dst_conn == 'set': # Skip reference sets
continue
if e.data.is_empty(): # Skip empty memlets
continue
# Find (hopefully propagated) root memlet
e = state.memlet_tree(e).root().edge
result[anode.data].add(e.data)
top_result[sdfg.sdfg_id] = result
return top_result


@properties.make_properties
class FindReferenceSources(ppl.Pass):
"""
For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used
to set the reference, the Tasklet is given as a source.
"""

CATEGORY: str = 'Analysis'

def modifies(self) -> ppl.Modifies:
return ppl.Modifies.Nothing

def should_reapply(self, modified: ppl.Modifies) -> bool:
return modified & ppl.Modifies.Memlets

def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]]:
"""
:return: A dictionary mapping each data descriptor name to a set of memlets.
"""
top_result: Dict[int, Dict[str, Set[Union[Memlet, nd.CodeNode]]]] = dict()

for sdfg in top_sdfg.all_sdfgs_recursive():
result: Dict[str, Set[Memlet]] = defaultdict(set)
reference_descs = set(k for k, v in sdfg.arrays.items() if isinstance(v, dt.Reference))
for state in sdfg.states():
for anode in state.data_nodes():
if anode.data not in reference_descs:
continue
for e in state.in_edges(anode):
if e.dst_conn != 'set':
continue
true_src = state.memlet_path(e)[0].src
if isinstance(true_src, nd.CodeNode):
# Code -> Reference
result[anode.data].add(true_src)
else:
# Array -> Reference
result[anode.data].add(e.data)
top_result[sdfg.sdfg_id] = result
return top_result
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@
},
include_package_data=True,
install_requires=[
'numpy', 'networkx >= 2.5', 'astunparse', 'sympy<=1.9', 'pyyaml', 'ply', 'websockets', 'jinja2',
'numpy', 'networkx >= 2.5', 'astunparse', 'sympy >= 1.9', 'pyyaml', 'ply', 'websockets', 'jinja2',
'fparser >= 0.1.3', 'aenum >= 3.1', 'dataclasses; python_version < "3.7"', 'dill',
'pyreadline;platform_system=="Windows"', 'typing-compat; python_version < "3.8"'
] + cmake_requires,
Expand Down
84 changes: 84 additions & 0 deletions tests/codegen/cuda_memcopy_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
""" Tests code generation for array copy on GPU target. """
import dace
from dace.transformation.auto import auto_optimize

import pytest
import re

# this test requires cupy module
cp = pytest.importorskip("cupy")

# initialize random number generator
rng = cp.random.default_rng(42)


@pytest.mark.gpu
def test_gpu_shared_to_global_1D():
M = 32
N = dace.symbol('N')

@dace.program
def transpose_shared_to_global(A: dace.float64[M, N], B: dace.float64[N, M]):
for i in dace.map[0:N]:
local_gather = dace.define_local([M], A.dtype, storage=dace.StorageType.GPU_Shared)
for j in dace.map[0:M]:
local_gather[j] = A[j, i]
B[i, :] = local_gather


sdfg = transpose_shared_to_global.to_sdfg()
auto_optimize.apply_gpu_storage(sdfg)

size_M = M
size_N = 128

A = rng.random((size_M, size_N,))
B = rng.random((size_N, size_M,))

ref = A.transpose()

sdfg(A, B, N=size_N)
cp.allclose(ref, B)

code = sdfg.generate_code()[1].clean_code # Get GPU code (second file)
m = re.search('dace::SharedToGlobal1D<.+>::Copy', code)
assert m is not None


@pytest.mark.gpu
def test_gpu_shared_to_global_1D_accumulate():
M = 32
N = dace.symbol('N')

@dace.program
def transpose_and_add_shared_to_global(A: dace.float64[M, N], B: dace.float64[N, M]):
for i in dace.map[0:N]:
local_gather = dace.define_local([M], A.dtype, storage=dace.StorageType.GPU_Shared)
for j in dace.map[0:M]:
local_gather[j] = A[j, i]
local_gather[:] >> B(M, lambda x, y: x + y)[i, :]


sdfg = transpose_and_add_shared_to_global.to_sdfg()
auto_optimize.apply_gpu_storage(sdfg)

size_M = M
size_N = 128

A = rng.random((size_M, size_N,))
B = rng.random((size_N, size_M,))

ref = A.transpose() + B

sdfg(A, B, N=size_N)
cp.allclose(ref, B)

code = sdfg.generate_code()[1].clean_code # Get GPU code (second file)
m = re.search('dace::SharedToGlobal1D<.+>::template Accum', code)
assert m is not None


if __name__ == '__main__':
test_gpu_shared_to_global_1D()
test_gpu_shared_to_global_1D_accumulate()

Loading

0 comments on commit 0bd0927

Please sign in to comment.