Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU code generation: User-specified block/thread/warp location #1358

Merged
merged 10 commits into from
Nov 3, 2023
85 changes: 79 additions & 6 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def node_dispatch_predicate(self, sdfg, state, node):
if hasattr(node, 'schedule'): # NOTE: Works on nodes and scopes
if node.schedule in dtypes.GPU_SCHEDULES:
return True
if isinstance(node, nodes.NestedSDFG) and CUDACodeGen._in_device_code:
if CUDACodeGen._in_device_code:
return True
return False

Expand Down Expand Up @@ -1324,11 +1324,11 @@ def generate_devicelevel_state(self, sdfg, state, function_stream, callsite_stre

if write_scope == 'grid':
callsite_stream.write("if (blockIdx.x == 0 "
"&& threadIdx.x == 0) "
"{ // sub-graph begin", sdfg, state.node_id)
"&& threadIdx.x == 0) "
"{ // sub-graph begin", sdfg, state.node_id)
elif write_scope == 'block':
callsite_stream.write("if (threadIdx.x == 0) "
"{ // sub-graph begin", sdfg, state.node_id)
"{ // sub-graph begin", sdfg, state.node_id)
else:
callsite_stream.write("{ // subgraph begin", sdfg, state.node_id)
else:
Expand Down Expand Up @@ -2519,8 +2519,9 @@ def generate_devicelevel_scope(self, sdfg, dfg_scope, state_id, function_stream,
def generate_node(self, sdfg, dfg, state_id, node, function_stream, callsite_stream):
if self.node_dispatch_predicate(sdfg, dfg, node):
# Dynamically obtain node generator according to class name
gen = getattr(self, '_generate_' + type(node).__name__)
gen(sdfg, dfg, state_id, node, function_stream, callsite_stream)
gen = getattr(self, '_generate_' + type(node).__name__, False)
if gen is not False: # Not every node type has a code generator here
gen(sdfg, dfg, state_id, node, function_stream, callsite_stream)
return

if not CUDACodeGen._in_device_code:
Expand Down Expand Up @@ -2591,6 +2592,78 @@ def _generate_MapExit(self, sdfg, dfg, state_id, node, function_stream, callsite

self._cpu_codegen._generate_MapExit(sdfg, dfg, state_id, node, function_stream, callsite_stream)

def _get_thread_id(self) -> str:
result = 'threadIdx.x'
if self._block_dims[1] != 1:
result += f' + ({sym2cpp(self._block_dims[0])}) * threadIdx.y'
if self._block_dims[2] != 1:
result += f' + ({sym2cpp(self._block_dims[0] * self._block_dims[1])}) * threadIdx.z'
return result

def _get_warp_id(self) -> str:
return f'(({self._get_thread_id()}) / warpSize)'

def _get_block_id(self) -> str:
result = 'blockIdx.x'
if self._block_dims[1] != 1:
result += f' + gridDim.x * blockIdx.y'
if self._block_dims[2] != 1:
result += f' + gridDim.x * gridDim.y * blockIdx.z'
return result

def _generate_condition_from_location(self, name: str, index_expr: str, node: nodes.Tasklet,
callsite_stream: CodeIOStream) -> str:
if name not in node.location:
return 0

location: Union[int, str, subsets.Range] = node.location[name]
if isinstance(location, str) and ':' in location:
location = subsets.Range.from_string(location)
elif symbolic.issymbolic(location):
location = sym2cpp(location)

if isinstance(location, subsets.Range):
# Range of indices
if len(location) != 1:
raise ValueError(f'Only one-dimensional ranges are allowed for {name} specialization, {location} given')
begin, end, stride = location[0]
rb, re, rs = sym2cpp(begin), sym2cpp(end), sym2cpp(stride)
cond = ''
cond += f'(({index_expr}) >= {rb}) && (({index_expr}) <= {re})'
if stride != 1:
cond += f' && ((({index_expr}) - {rb}) % {rs} == 0)'

callsite_stream.write(f'if ({cond}) {{')
else:
# Single-element
callsite_stream.write(f'if (({index_expr}) == {location}) {{')

return 1

def _generate_Tasklet(self, sdfg: SDFG, dfg, state_id: int, node: nodes.Tasklet, function_stream: CodeIOStream,
callsite_stream: CodeIOStream):
generated_preamble_scopes = 0
if self._in_device_code:
# If location dictionary prescribes that the code should run on a certain group of threads/blocks,
# add condition
generated_preamble_scopes += self._generate_condition_from_location('gpu_thread', self._get_thread_id(),
node, callsite_stream)
generated_preamble_scopes += self._generate_condition_from_location('gpu_warp', self._get_warp_id(), node,
callsite_stream)
generated_preamble_scopes += self._generate_condition_from_location('gpu_block', self._get_block_id(), node,
callsite_stream)

# Call standard tasklet generation
old_codegen = self._cpu_codegen.calling_codegen
self._cpu_codegen.calling_codegen = self
self._cpu_codegen._generate_Tasklet(sdfg, dfg, state_id, node, function_stream, callsite_stream)
self._cpu_codegen.calling_codegen = old_codegen

if generated_preamble_scopes > 0:
# Generate appropriate postamble
for i in range(generated_preamble_scopes):
callsite_stream.write('}', sdfg, state_id, node)

def make_ptr_vector_cast(self, *args, **kwargs):
return cpp.make_ptr_vector_cast(*args, **kwargs)

Expand Down
38 changes: 38 additions & 0 deletions tests/cuda_block_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

@dace.program(dace.float64[N], dace.float64[N])
def cudahello(V, Vout):

@dace.mapscope(_[0:N:32])
def multiplication(i):

@dace.map(_[0:32])
def mult_block(bi):
in_V << V[i + bi]
Expand Down Expand Up @@ -55,6 +57,7 @@ def test_gpu():

@pytest.mark.gpu
def test_different_block_sizes_nesting():

@dace.program
def nested(V: dace.float64[34], v1: dace.float64[1]):
with dace.tasklet:
Expand Down Expand Up @@ -105,6 +108,7 @@ def diffblocks(V: dace.float64[130], v1: dace.float64[4], v2: dace.float64[128])

@pytest.mark.gpu
def test_custom_block_size_onemap():

@dace.program
def tester(A: dace.float64[400, 300]):
for i, j in dace.map[0:400, 0:300]:
Expand Down Expand Up @@ -132,6 +136,7 @@ def tester(A: dace.float64[400, 300]):

@pytest.mark.gpu
def test_custom_block_size_twomaps():

@dace.program
def tester(A: dace.float64[400, 300, 2, 32]):
for i, j in dace.map[0:400, 0:300]:
Expand All @@ -154,9 +159,42 @@ def tester(A: dace.float64[400, 300, 2, 32]):
sdfg.compile()


@pytest.mark.gpu
def test_block_thread_specialization():

@dace.program
def tester(A: dace.float64[200]):
for i in dace.map[0:200:32]:
for bi in dace.map[0:32]:
with dace.tasklet:
a >> A[i + bi]
a = 1
with dace.tasklet: # Tasklet to be specialized
a >> A[i + bi]
a = 2

sdfg = tester.to_sdfg()
sdfg.apply_gpu_transformations(sequential_innermaps=False)
tasklet = next(n for n, _ in sdfg.all_nodes_recursive()
if isinstance(n, dace.nodes.Tasklet) and '2' in n.code.as_string)
tasklet.location['gpu_thread'] = dace.subsets.Range.from_string('2:9:3')
tasklet.location['gpu_block'] = 1

code = sdfg.generate_code()[1].clean_code # Get GPU code (second file)
assert '>= 2' in code and '<= 8' in code
assert ' == 1' in code

a = np.random.rand(200)
ref = np.ones_like(a)
ref[32:64][2:9:3] = 2
sdfg(a)
assert np.allclose(a, ref)


if __name__ == "__main__":
test_cpu()
test_gpu()
test_different_block_sizes_nesting()
test_custom_block_size_onemap()
test_custom_block_size_twomaps()
test_block_thread_specialization()
Loading