Skip to content

Commit

Permalink
Major fixes regarding name changes etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Dec 3, 2024
1 parent 08cb50c commit ac90c86
Show file tree
Hide file tree
Showing 11 changed files with 144 additions and 120 deletions.
13 changes: 8 additions & 5 deletions dace/codegen/targets/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,6 @@ def declare_array(self,

name = node.root_data
ptrname = cpp.ptr(name, nodedesc, sdfg, self._frame)
print("D2", name, nodedesc)

if nodedesc.transient is False:
return
Expand All @@ -344,7 +343,7 @@ def declare_array(self,
if nodedesc.transient and nodedesc.storage == dtypes.StorageType.CPU_Heap:
size_desc_name = sdfg.arrays[name].size_desc_name
if size_desc_name is not None:
size_desc = sdfg.size_arrays[size_desc_name]
size_desc = sdfg.arrays[size_desc_name]
size_ctypedef = dtypes.pointer(size_desc.dtype).ctype
self._dispatcher.declared_arrays.add(size_desc_name, DefinedType.Pointer, size_ctypedef)
return
Expand Down Expand Up @@ -513,9 +512,13 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV
declaration_stream.write(f'{nodedesc.dtype.ctype} *{name};\n', cfg, state_id, node)
# Initialize size array
size_str = ",".join(["0" if cpp.sym2cpp(dim).startswith("__dace_defer") else cpp.sym2cpp(dim) for dim in nodedesc.shape])
size_desc_name = nodedesc.size_desc_name
size_nodedesc = sdfg.size_arrays[size_desc_name]
declaration_stream.write(f'{size_nodedesc.dtype.ctype} {size_desc_name}[{size_nodedesc.shape[0]}]{{{size_str}}};\n', cfg, state_id, node)
if (nodedesc.transient and (
nodedesc.storage == dtypes.StorageType.CPU_Heap or
nodedesc.storage == dtypes.StorageType.GPU_Global)
):
size_desc_name = nodedesc.size_desc_name
size_nodedesc = sdfg.arrays[size_desc_name]
declaration_stream.write(f'{size_nodedesc.dtype.ctype} {size_desc_name}[{size_nodedesc.shape[0]}]{{{size_str}}};\n', cfg, state_id, node)
if deferred_allocation:
allocation_stream.write(
"%s = nullptr; // Deferred Allocation" %
Expand Down
4 changes: 2 additions & 2 deletions dace/codegen/targets/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ def allocate_array(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgraphV
if nodedesc.transient:
size_desc_name = nodedesc.size_desc_name
if size_desc_name is not None:
size_nodedesc = sdfg.size_arrays[size_desc_name]
size_nodedesc = sdfg.arrays[size_desc_name]
result_decl.write(f'{size_nodedesc.dtype.ctype} {size_desc_name}[{size_nodedesc.shape[0]}]{{{size_str}}};\n')
self._dispatcher.defined_vars.add(size_desc_name, DefinedType.Pointer, size_nodedesc.dtype.ctype)
self._dispatcher.defined_vars.add(dataname, DefinedType.Pointer, ctypedef)
Expand Down Expand Up @@ -1586,7 +1586,7 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub
if aname in sdfg.arrays:
size_arr_name = data_desc.size_desc_name
if size_arr_name is not None:
size_arr = sdfg.size_arrays[data_desc.size_desc_name]
size_arr = sdfg.arrays[data_desc.size_desc_name]
host_size_args[size_arr_name] = size_arr

kernel_args_typed = [('const ' if k in const_params else '') + v.as_arg(name=k)
Expand Down
5 changes: 2 additions & 3 deletions dace/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def _transient_setter(self, value):
default=dtypes.AllocationLifetime.Scope)
location = DictProperty(key_type=str, value_type=str, desc='Full storage location identifier (e.g., rank, GPU ID)')
debuginfo = DebugInfoProperty(allow_none=True)
size_desc_name = Property(dtype=str, default=None, allow_none=True)

def __init__(self, dtype, shape, transient, storage, location, lifetime, debuginfo):
self.dtype = dtype
Expand All @@ -192,6 +193,7 @@ def __init__(self, dtype, shape, transient, storage, location, lifetime, debugin
self.location = location if location is not None else {}
self.lifetime = lifetime
self.debuginfo = debuginfo
self.size_desc_name = None
self._validate()

def __call__(self):
Expand Down Expand Up @@ -1385,9 +1387,6 @@ class Array(Data):
'it is inferred by other properties and the OptionalArrayInference pass.')
pool = Property(dtype=bool, default=False, desc='Hint to the allocator that using a memory pool is preferred')

size_desc_name = Property(dtype=str, default=None, allow_none=True, desc='The name of the size array (1D, length is the shape of thte current array)'
'Of the array (usually <name>_size)')

def __init__(self,
dtype,
shape,
Expand Down
40 changes: 20 additions & 20 deletions dace/memlet.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ def __init__(self,
debuginfo: Optional[dtypes.DebugInfo] = None,
wcr_nonatomic: bool = False,
allow_oob: bool = False):
"""
"""
Constructs a Memlet.
:param expr: A string expression of the this memlet, given as an ease
of use API. Must follow one of the following forms:
1. ``ARRAY``,
Expand All @@ -82,7 +82,7 @@ def __init__(self,
:param subset: The subset to take from the data attached to the edge,
represented either as a string or a Subset object.
:param other_subset: The subset to offset into the other side of the
memlet, represented either as a string or a Subset
memlet, represented either as a string or a Subset
object.
:param volume: The exact number of elements moved using this
memlet, or the maximum number of elements if
Expand All @@ -91,14 +91,14 @@ def __init__(self,
is runtime-defined and unbounded.
:param dynamic: If True, the number of elements moved in this memlet
is defined dynamically at runtime.
:param wcr: A lambda function (represented as a string or Python AST)
:param wcr: A lambda function (represented as a string or Python AST)
specifying how write-conflicts are resolved. The syntax
of the lambda function receives two elements: ``current``
value and `new` value, and returns the value after
of the lambda function receives two elements: ``current``
value and `new` value, and returns the value after
resolution. For example, summation is represented by
``'lambda cur, new: cur + new'``.
:param debuginfo: Line information from the generating source code.
:param wcr_nonatomic: If True, overrides the automatic code generator
:param wcr_nonatomic: If True, overrides the automatic code generator
decision and treat all write-conflict resolution
operations as non-atomic, which might cause race
conditions in the general case.
Expand Down Expand Up @@ -225,16 +225,16 @@ def __deepcopy__(self, memo):
return node

def is_empty(self) -> bool:
"""
"""
Returns True if this memlet carries no data. Memlets without data are
primarily used for connecting nodes to scopes without transferring
data to them.
primarily used for connecting nodes to scopes without transferring
data to them.
"""
return (self.data is None and self.subset is None and self.other_subset is None)

@property
def num_accesses(self):
"""
"""
Returns the total memory movement volume (in elements) of this memlet.
"""
return self.volume
Expand All @@ -255,7 +255,7 @@ def simple(data,
"""
DEPRECATED: Constructs a Memlet from string-based expressions.
:param data: The data object or name to access.
:param data: The data object or name to access.
:param subset_str: The subset of `data` that is going to
be accessed in string format. Example: '0:N'.
:param wcr_str: A lambda function (as a string) specifying
Expand Down Expand Up @@ -335,7 +335,7 @@ def _parse_from_subexpr(self, expr: str):
# [subset] syntax
if expr.startswith('['):
return None, SubsetProperty.from_string(expr[1:-1])

# array[subset] syntax
arrname, subset_str = expr[:-1].split('[')
if not dtypes.validate_name(arrname):
Expand Down Expand Up @@ -385,8 +385,8 @@ def _parse_memlet_from_str(self, expr: str):

def try_initialize(self, sdfg: 'dace.sdfg.SDFG', state: 'dace.sdfg.SDFGState',
edge: 'dace.sdfg.graph.MultiConnectorEdge'):
"""
Tries to initialize the internal fields of the memlet (e.g., src/dst
"""
Tries to initialize the internal fields of the memlet (e.g., src/dst
subset) once it is added to an SDFG as an edge.
"""
from dace.sdfg.nodes import AccessNode, CodeNode # Avoid import loops
Expand Down Expand Up @@ -435,7 +435,7 @@ def get_dst_subset(self, edge: 'dace.sdfg.graph.MultiConnectorEdge', state: 'dac

@staticmethod
def from_array(dataname, datadesc, wcr=None):
"""
"""
Constructs a Memlet that transfers an entire array's contents.
:param dataname: The name of the data descriptor in the SDFG.
Expand All @@ -456,7 +456,7 @@ def __eq__(self, other):
def replace(self, repl_dict):
"""
Substitute a given set of symbols with a different set of symbols.
:param repl_dict: A dict of string symbol names to symbols with
which to replace them.
"""
Expand Down Expand Up @@ -538,8 +538,8 @@ def validate(self, sdfg, state):

def used_symbols(self, all_symbols: bool, edge=None) -> Set[str]:
"""
Returns a set of symbols used in this edge's properties.
Returns a set of symbols used in this edge's properties.
:param all_symbols: If False, only returns the set of symbols that will be used
in the generated code and are needed as arguments.
:param edge: If given, provides richer context-based tests for the case
Expand Down Expand Up @@ -606,7 +606,7 @@ def get_free_symbols_by_indices(self, indices_src: List[int], indices_dst: List[

def get_stride(self, sdfg: 'dace.sdfg.SDFG', map: 'dace.sdfg.nodes.Map', dim: int = -1) -> 'dace.symbolic.SymExpr':
""" Returns the stride of the underlying memory when traversing a Map.
:param sdfg: The SDFG in which the memlet resides.
:param map: The map in which the memlet resides.
:param dim: The dimension that is incremented. By default it is the innermost.
Expand Down
12 changes: 6 additions & 6 deletions dace/sdfg/infer_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def infer_out_connector_type(sdfg: SDFG, state: SDFGState, node: nodes.CodeNode,


def infer_connector_types(sdfg: SDFG):
"""
"""
Infers connector types throughout an SDFG and its nested SDFGs in-place.
:param sdfg: The SDFG to infer.
"""
# Loop over states, and in a topological sort over each state's nodes
Expand Down Expand Up @@ -125,13 +125,13 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E
use_parent_schedule: bool = False,
state: SDFGState = None,
child_nodes: Dict[nodes.Node, List[nodes.Node]] = None):
"""
"""
Sets default storage and schedule types throughout SDFG in-place.
Replaces ``ScheduleType.Default`` and ``StorageType.Default``
with the corresponding types according to the parent scope's schedule.
with the corresponding types according to the parent scope's schedule.
The defaults for storage types are determined by the
``dtypes.SCOPEDEFAULT_STORAGE`` dictionary (for example, a GPU device
``dtypes.SCOPEDEFAULT_STORAGE`` dictionary (for example, a GPU device
schedule, by default, will allocate containers on the shared memory).
Following storage type inference for a scope, nested scopes (e.g., map entry, nested SDFG)
are evaluated using the ``dtypes.STORAGEDEFAULT_SCHEDULE`` dictionary (for example, a
Expand Down
4 changes: 2 additions & 2 deletions dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _replsym(symlist, symrepl):
def replace_dict(subgraph: 'StateSubgraphView',
repl: Dict[str, str],
symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None):
"""
"""
Finds and replaces all occurrences of a set of symbols/arrays in the given subgraph.
:param subgraph: The given graph or subgraph to replace in.
Expand Down Expand Up @@ -86,7 +86,7 @@ def replace_dict(subgraph: 'StateSubgraphView',
def replace(subgraph: 'StateSubgraphView', name: str, new_name: str):
"""
Finds and replaces all occurrences of a symbol or array in the given subgraph.
:param subgraph: The given graph or subgraph to replace in.
:param name: Name to find.
:param new_name: Name to replace.
Expand Down
Loading

0 comments on commit ac90c86

Please sign in to comment.