Skip to content

Commit

Permalink
Bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Dec 3, 2024
1 parent c4eef0c commit 82cdfde
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
14 changes: 8 additions & 6 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,13 @@ def declare_array(self,
declaration_stream.write(f'{nodedesc.dtype.ctype} *{name} = nullptr;\n', cfg, state_id, node)
self._dispatcher.declared_arrays.add(name, DefinedType.Pointer, ctypedef)

size_arr_name = sdfg.arrays[name].size_desc_name
size_arr_desc = sdfg.arrays[size_arr_name]
size_ctypedef = dtypes.pointer(size_arr_desc.dtype).ctype

self._dispatcher.declared_arrays.add(size_arr_name, DefinedType.Pointer, size_ctypedef)
# Size desc is defined only for transient arrays
if nodedesc.transient and nodedesc.storage == dtypes.StorageType.CPU_Heap:
size_desc_name = sdfg.arrays[name].size_desc_name
if size_desc_name is not None:
size_desc = sdfg.arrays[size_desc_name]
size_ctypedef = dtypes.pointer(size_desc.dtype).ctype
self._dispatcher.declared_arrays.add(size_desc_name, DefinedType.Pointer, size_ctypedef)
return
elif nodedesc.storage is dtypes.StorageType.CPU_ThreadLocal:
# Define pointer once
Expand Down Expand Up @@ -1070,7 +1072,7 @@ def check_dace_defer(elements):
deferred_size_names.append(f"__{memlet.data}_dim{i}_size" if desc.storage == dtypes.StorageType.GPU_Global else f"{desc.size_desc_name}[{i}]")
else:
deferred_size_names.append(elem)
return deferred_size_names if len(deferred_size_names) > 0 else None
return deferred_size_names if deferred_size_names is not None and len(deferred_size_names) > 0 else None

def process_out_memlets(self,
sdfg: SDFG,
Expand Down
18 changes: 9 additions & 9 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,12 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV
if not declared:
result_decl.write('%s %s;\n' % (ctypedef, dataname))
size_str = ",".join(["0" if cpp.sym2cpp(dim).startswith("__dace_defer") else cpp.sym2cpp(dim) for dim in nodedesc.shape])
size_desc_name = nodedesc.size_desc_name
size_nodedesc = sdfg.arrays[size_desc_name]
result_decl.write(f'{size_nodedesc.dtype.ctype} {size_desc_name}[{size_nodedesc.shape[0]}]{{{size_str}}};\n')
self._dispatcher.defined_vars.add(size_desc_name, DefinedType.Pointer, size_nodedesc.dtype.ctype)
if nodedesc.transient:
size_desc_name = nodedesc.size_desc_name
if size_desc_name is not None:
size_nodedesc = sdfg.arrays[size_desc_name]
result_decl.write(f'{size_nodedesc.dtype.ctype} {size_desc_name}[{size_nodedesc.shape[0]}]{{{size_str}}};\n')
self._dispatcher.defined_vars.add(size_desc_name, DefinedType.Pointer, size_nodedesc.dtype.ctype)
self._dispatcher.defined_vars.add(dataname, DefinedType.Pointer, ctypedef)


Expand Down Expand Up @@ -1481,7 +1483,6 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub
if isinstance(node, nodes.AccessNode):
nsdfg: SDFG = parent.parent
desc = node.desc(nsdfg)
sizedesc = nsdfg.arrays[desc.size_desc_name]
if (nsdfg, node.data) in visited:
continue
visited.add((nsdfg, node.data))
Expand Down Expand Up @@ -1584,10 +1585,9 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub

if aname in sdfg.arrays:
size_arr_name = data_desc.size_desc_name
size_arr = sdfg.arrays[data_desc.size_desc_name]
size_arr_len = size_arr.shape[0]
size_arr_dtype = size_arr.dtype.ctype
host_size_args[size_arr_name] = size_arr
if size_arr_name is not None:
size_arr = sdfg.arrays[data_desc.size_desc_name]
host_size_args[size_arr_name] = size_arr

kernel_args_typed = [('const ' if k in const_params else '') + v.as_arg(name=k)
for k, v in prototype_kernel_args.items()]
Expand Down
16 changes: 8 additions & 8 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,13 +189,13 @@ def parse_dace_program(name: str,
Parses a ``@dace.program`` function into an SDFG.
:param src_ast: The AST of the Python program to parse.
:param visitor: A ProgramVisitor object returned from
:param visitor: A ProgramVisitor object returned from
``preprocess_dace_program``.
:param closure: An object that contains the @dace.program closure.
:param simplify: If True, simplification pass will be performed.
:param save: If True, saves source mapping data for this SDFG.
:param progress: If True, prints a progress bar of the parsing process.
If None (default), prints after 5 seconds of parsing.
:param progress: If True, prints a progress bar of the parsing process.
If None (default), prints after 5 seconds of parsing.
If False, never prints progress.
:return: A 2-tuple of SDFG and its reduced (used) closure.
"""
Expand Down Expand Up @@ -1466,8 +1466,8 @@ def _parse_subprogram(self, name, node, is_tasklet=False, extra_symbols=None, ex
def _symbols_from_params(self, params: List[Tuple[str, Union[str, dtypes.typeclass]]],
memlet_inputs: Dict[str, Memlet]) -> Dict[str, symbolic.symbol]:
"""
Returns a mapping between symbol names to their type, as a symbol
object to maintain compatibility with global symbols. Used to maintain
Returns a mapping between symbol names to their type, as a symbol
object to maintain compatibility with global symbols. Used to maintain
typed symbols in SDFG scopes (e.g., map, consume).
"""
from dace.codegen.tools.type_inference import infer_expr_type
Expand Down Expand Up @@ -1900,7 +1900,7 @@ def _parse_map_inputs(self, name: str, params: List[Tuple[str, str]],

def _parse_consume_inputs(self, node: ast.FunctionDef) -> Tuple[str, str, Tuple[str, str], str, str]:
""" Parse consume parameters from AST.
:return: A 5-tuple of Stream name, internal stream name,
(PE index, number of PEs), condition, chunk size.
"""
Expand Down Expand Up @@ -2179,7 +2179,7 @@ def _add_dependencies(self,
state.add_nedge(internal_node, exit_node, dace.Memlet())

def _add_nested_symbols(self, nsdfg_node: nodes.NestedSDFG):
"""
"""
Adds symbols from nested SDFG mapping values (if appear as globals)
to current SDFG.
"""
Expand Down Expand Up @@ -4769,7 +4769,7 @@ def visit_With(self, node: ast.With, is_async=False):
evald = astutils.evalnode(node.items[0].context_expr, self.globals)
if hasattr(evald, "name"):
named_region_name: str = evald.name
else:
else:
named_region_name = f"Named Region {node.lineno}"
named_region = NamedRegion(named_region_name, debuginfo=self.current_lineinfo)
self.cfg_target.add_node(named_region)
Expand Down

0 comments on commit 82cdfde

Please sign in to comment.