diff --git a/dace/codegen/compiled_sdfg.py b/dace/codegen/compiled_sdfg.py index 7de385cead..ff7bc6084e 100644 --- a/dace/codegen/compiled_sdfg.py +++ b/dace/codegen/compiled_sdfg.py @@ -33,7 +33,7 @@ def __init__(self, library_filename, program_name): """ self._stub_filename = os.path.join( os.path.dirname(os.path.realpath(library_filename)), - 'libdacestub_%s.%s' % (program_name, Config.get('compiler', 'library_extension'))) + f'libdacestub_{program_name}.{Config.get("compiler", "library_extension")}') self._library_filename = os.path.realpath(library_filename) self._stub = None self._lib = None @@ -47,7 +47,7 @@ def get_symbol(self, name, restype=ctypes.c_int): func = self._stub.get_symbol(self._lib, ctypes.c_char_p(name.encode())) if func is None: - raise KeyError('Function %s not found in library %s' % (name, os.path.basename(self._library_filename))) + raise KeyError(f'Function {name} not found in library {os.path.basename(self._library_filename)}') return ctypes.CFUNCTYPE(restype)(func) @@ -105,15 +105,14 @@ def load(self): is_loaded = self._stub.is_library_loaded(lib_cfilename) if is_loaded == 1: - warnings.warn('Library %s already loaded, renaming file' % self._library_filename) + warnings.warn(f'Library {self._library_filename} already loaded, renaming file') try: shutil.copyfile(self._library_filename, self._library_filename + '_') self._library_filename += '_' except shutil.Error: - raise cgx.DuplicateDLLError('Library %s is already loaded somewhere else ' % - os.path.basename(self._library_filename) + - 'and cannot be unloaded. Please use a different name ' + - 'for the SDFG/program.') + raise cgx.DuplicateDLLError(f'Library {os.path.basename(self._library_filename)}' + 'is already loaded somewhere else and cannot be unloaded. ' + 'Please use a different name for the SDFG/program.') # Actually load the library self._lib = ctypes.c_void_p(self._stub.load_library(lib_cfilename)) @@ -126,7 +125,7 @@ def load(self): result = subprocess.run(['ld', self._library_filename], stdout=subprocess.PIPE, stderr=subprocess.PIPE) stderr = result.stderr.decode('utf-8') reason = 'Reason:\n' + '\n'.join([l for l in stderr.split('\n') if '_start' not in l]) - raise RuntimeError('Could not load library %s. %s' % (os.path.basename(self._library_filename), reason)) + raise RuntimeError(f'Could not load library {os.path.basename(self._library_filename)}. {reason}') def unload(self): """ Unloads the internal library using the stub. """ @@ -160,8 +159,16 @@ def _array_interface_ptr(array: Any, storage: dtypes.StorageType) -> int: """ if hasattr(array, 'data_ptr'): return array.data_ptr() + if storage == dtypes.StorageType.GPU_Global: - return array.__cuda_array_interface__['data'][0] + try: + return array.__cuda_array_interface__['data'][0] + except AttributeError: + # Special case for CuPy with HIP + if hasattr(array, 'data') and hasattr(array.data, 'ptr'): + return array.data.ptr + raise + return array.__array_interface__['data'][0] @@ -212,6 +219,7 @@ def __init__(self, sdfg, lib: ReloadableDLL, argnames: List[str] = None): self.has_gpu_code = True break + def get_exported_function(self, name: str, restype=None) -> Optional[Callable[..., Any]]: """ Tries to find a symbol by name in the compiled SDFG, and convert it to a callable function @@ -225,6 +233,7 @@ def get_exported_function(self, name: str, restype=None) -> Optional[Callable[.. except KeyError: # Function not found return None + def get_state_struct(self) -> ctypes.Structure: """ Attempt to parse the SDFG source code and extract the state struct. This method will parse the first consecutive entries in the struct that are pointers. As soon as a non-pointer or other unparseable field is @@ -238,6 +247,7 @@ def get_state_struct(self) -> ctypes.Structure: return ctypes.cast(self._libhandle, ctypes.POINTER(self._try_parse_state_struct())).contents + def _try_parse_state_struct(self) -> Optional[Type[ctypes.Structure]]: from dace.codegen.targets.cpp import mangle_dace_state_struct_name # Avoid import cycle # the path of the main sdfg file containing the state struct @@ -365,25 +375,71 @@ def _get_error_text(self, result: Union[str, int]) -> str: else: return result + def __call__(self, *args, **kwargs): - # Update arguments from ordered list - if len(args) > 0 and self.argnames is not None: - kwargs.update({aname: arg for aname, arg in zip(self.argnames, args)}) + """ + Forwards the Python call to the compiled ``SDFG``. + + The order of the positional arguments is expected to be the same as in + the ``argnames`` member. The function will roughly perform the + following tasks: + - Change the order of the Python arguments into the one required by + the binary. + - Performing some basic sanity checks. + - Transforming the Python arguments into their ``C`` equivalents. + - Allocate the memory for the return values. + - Call the ``C` function. + + :note: The memory for the return values is only allocated the first + time this function is called. Thus, this function will always + return the same objects. To force the allocation of new memory + you can call ``clear_return_values()`` in advance. + """ + if self.argnames is None and len(args) != 0: + raise KeyError(f"Passed positional arguments to an SDFG that does not accept them.") + elif len(args) > 0 and self.argnames is not None: + kwargs.update( + # `_construct_args` will handle all of its arguments as kwargs. + {aname: arg for aname, arg in zip(self.argnames, args)} + ) + argtuple, initargtuple = self._construct_args(kwargs) # Missing arguments will be detected here. + # Return values are cached in `self._lastargs`. + return self.fast_call(argtuple, initargtuple, do_gpu_check=True) + + + def fast_call( + self, + callargs: Tuple[Any, ...], + initargs: Tuple[Any, ...], + do_gpu_check: bool = False, + ) -> Union[Tuple[Any, ...], Any]: + """ + Calls the underlying binary functions directly and bypassing + argument sanitation. - try: - argtuple, initargtuple = self._construct_args(kwargs) + This is a faster, but less user friendly version of ``__call__()``. + While ``__call__()`` will transforms its Python arguments such that + they can be forwarded, this function assumes that this processing + was already done by the user. + :param callargs: Arguments passed to the actual computation. + :param initargs: Arguments passed to the initialization function. + :param do_gpu_check: Check if errors happened on the GPU. + + :note: You may use `_construct_args()` to generate the processed arguments. + """ + try: # Call initializer function if necessary, then SDFG if self._initialized is False: self._lib.load() - self._initialize(initargtuple) + self._initialize(initargs) - with hooks.invoke_compiled_sdfg_call_hooks(self, argtuple): + with hooks.invoke_compiled_sdfg_call_hooks(self, callargs): if self.do_not_execute is False: - self._cfunc(self._libhandle, *argtuple) + self._cfunc(self._libhandle, *callargs) - if self.has_gpu_code: - # Optionally get errors from call + # Optionally get errors from call + if do_gpu_check and self.has_gpu_code: try: lasterror = common.get_gpu_runtime().get_last_error_string() except RuntimeError as ex: @@ -399,6 +455,7 @@ def __call__(self, *args, **kwargs): self._lib.unload() raise + def __del__(self): if self._initialized is True: self.finalize() @@ -406,24 +463,30 @@ def __del__(self): self._libhandle = ctypes.c_void_p(0) self._lib.unload() + def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: - """ Main function that controls argument construction for calling - the C prototype of the SDFG. + """ + Main function that controls argument construction for calling + the C prototype of the SDFG. - Organizes arguments first by `sdfg.arglist`, then data descriptors - by alphabetical order, then symbols by alphabetical order. + Organizes arguments first by ``sdfg.arglist``, then data descriptors + by alphabetical order, then symbols by alphabetical order. + + :note: If not initialized this function will initialize the memory for + the return values, however, it might also reallocate said memory. + :note: This function will also update the internal argument cache. """ - # Return value initialization (for values that have not been given) self._initialize_return_values(kwargs) + + # Add the return values to the arguments, since they are part of the C signature. for desc, arr in zip(self._retarray_shapes, self._return_arrays): kwargs[desc[0]] = arr - # Argument construction sig = self._sig typedict = self._typedict if len(kwargs) > 0: # Construct mapping from arguments to signature - arglist = [] + arglist = [] argtypes = [] argnames = [] for a in sig: @@ -433,38 +496,58 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: argnames.append(a) except KeyError: raise KeyError("Missing program argument \"{}\"".format(a)) + else: arglist = [] argtypes = [] argnames = [] sig = [] + # Type checking + no_view_arguments = not Config.get_bool('compiler', 'allow_view_arguments') for i, (a, arg, atype) in enumerate(zip(argnames, arglist, argtypes)): - if not dtypes.is_array(arg) and isinstance(atype, dt.Array): + is_array = dtypes.is_array(arg) + is_ndarray = isinstance(arg, np.ndarray) + is_dtArray = isinstance(atype, dt.Array) + if not is_array and is_dtArray: if isinstance(arg, list): - print('WARNING: Casting list argument "%s" to ndarray' % a) + print(f'WARNING: Casting list argument "{a}" to ndarray') elif arg is None: if atype.optional is False: # If array cannot be None raise TypeError(f'Passing a None value to a non-optional array in argument "{a}"') # Otherwise, None values are passed as null pointers below else: - raise TypeError('Passing an object (type %s) to an array in argument "%s"' % - (type(arg).__name__, a)) - elif dtypes.is_array(arg) and not isinstance(atype, dt.Array): + raise TypeError(f'Passing an object (type {type(arg).__name__}) to an array in argument "{a}"') + elif is_array and not is_dtArray: # GPU scalars and return values are pointers, so this is fine if atype.storage != dtypes.StorageType.GPU_Global and not a.startswith('__return'): - raise TypeError('Passing an array to a scalar (type %s) in argument "%s"' % (atype.dtype.ctype, a)) + raise TypeError(f'Passing an array to a scalar (type {atype.dtype.ctype}) in argument "{a}"') + elif (is_dtArray and is_ndarray and not isinstance(atype, dt.StructArray) + and atype.dtype.as_numpy_dtype() != arg.dtype): + # Make exception for vector types + if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype): + pass + else: + print(f'WARNING: Passing {arg.dtype} array argument "{a}" to a {atype.dtype.type.__name__} array') + elif is_dtArray and is_ndarray and arg.base is not None and not '__return' in a and no_view_arguments: + raise TypeError(f'Passing a numpy view (e.g., sub-array or "A.T") "{a}" to DaCe ' + 'programs is not allowed in order to retain analyzability. ' + 'Please make a copy with "numpy.copy(...)". If you know what ' + 'you are doing, you can override this error in the ' + 'configuration by setting compiler.allow_view_arguments ' + 'to True.') elif (not isinstance(atype, (dt.Array, dt.Structure)) and not isinstance(atype.dtype, dtypes.callback) and not isinstance(arg, (atype.dtype.type, sp.Basic)) and not (isinstance(arg, symbolic.symbol) and arg.dtype == atype.dtype)): - if isinstance(arg, int) and atype.dtype.type == np.int64: + is_int = isinstance(arg, int) + if is_int and atype.dtype.type == np.int64: pass - elif isinstance(arg, float) and atype.dtype.type == np.float64: + elif (is_int and atype.dtype.type == np.int32 and abs(arg) <= (1 << 31) - 1): pass - elif (isinstance(arg, int) and atype.dtype.type == np.int32 and abs(arg) <= (1 << 31) - 1): + elif (is_int and atype.dtype.type == np.uint32 and arg >= 0 and arg <= (1 << 32) - 1): pass - elif (isinstance(arg, int) and atype.dtype.type == np.uint32 and arg >= 0 and arg <= (1 << 32) - 1): + elif isinstance(arg, float) and atype.dtype.type == np.float64: pass elif (isinstance(arg, str) or arg is None) and atype.dtype == dtypes.string: if arg is None: @@ -475,24 +558,7 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: else: warnings.warn(f'Casting scalar argument "{a}" from {type(arg).__name__} to {atype.dtype.type}') arglist[i] = atype.dtype.type(arg) - elif (isinstance(atype, dt.Array) and isinstance(arg, np.ndarray) and not isinstance(atype, dt.StructArray) - and atype.dtype.as_numpy_dtype() != arg.dtype): - # Make exception for vector types - if (isinstance(atype.dtype, dtypes.vector) and atype.dtype.vtype.as_numpy_dtype() == arg.dtype): - pass - else: - print('WARNING: Passing %s array argument "%s" to a %s array' % - (arg.dtype, a, atype.dtype.type.__name__)) - elif (isinstance(atype, dt.Array) and isinstance(arg, np.ndarray) and arg.base is not None - and not '__return' in a and not Config.get_bool('compiler', 'allow_view_arguments')): - raise TypeError(f'Passing a numpy view (e.g., sub-array or "A.T") "{a}" to DaCe ' - 'programs is not allowed in order to retain analyzability. ' - 'Please make a copy with "numpy.copy(...)". If you know what ' - 'you are doing, you can override this error in the ' - 'configuration by setting compiler.allow_view_arguments ' - 'to True.') - # Explicit casting for index, (arg, argtype) in enumerate(zip(arglist, argtypes)): # Call a wrapper function to make NumPy arrays from pointers. if isinstance(argtype.dtype, dtypes.callback): @@ -505,52 +571,47 @@ def _construct_args(self, kwargs) -> Tuple[Tuple[Any], Tuple[Any]]: arglist[index] = ctypes.c_void_p(0) # Retain only the element datatype for upcoming checks and casts - arg_ctypes = [t.dtype.as_ctypes() for t in argtypes] - - sdfg = self._sdfg + arg_ctypes = tuple(at.dtype.as_ctypes() for at in argtypes) + + constants = self.sdfg.constants + callparams = tuple( + (actype(arg.get()) + if isinstance(arg, symbolic.symbol) + else arg, actype, atype, aname + ) + for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames) + if not (symbolic.issymbolic(arg) and (hasattr(arg, 'name') and arg.name in constants)) + ) - # Obtain SDFG constants - constants = sdfg.constants - - # Remove symbolic constants from arguments - callparams = tuple((arg, actype, atype, aname) - for arg, actype, atype, aname in zip(arglist, arg_ctypes, argtypes, argnames) - if not symbolic.issymbolic(arg) or (hasattr(arg, 'name') and arg.name not in constants)) - - # Replace symbols with their values - callparams = tuple((actype(arg.get()) if isinstance(arg, symbolic.symbol) else arg, actype, atype, aname) - for arg, actype, atype, aname in callparams) - - # Construct init args, which only consist of the symbols symbols = self._free_symbols initargs = tuple( actype(arg) if not isinstance(arg, ctypes._SimpleCData) else arg - for arg, actype, atype, aname in callparams if aname in symbols) - - # Replace arrays with their base host/device pointers - newargs = tuple((ctypes.c_void_p(_array_interface_ptr(arg, atype.storage)), actype, - atype) if dtypes.is_array(arg) else (arg, actype, atype) - for arg, actype, atype, _ in callparams) + for arg, actype, atype, aname in callparams + if aname in symbols + ) try: - newargs = tuple( - actype(arg) if not isinstance(arg, (ctypes._SimpleCData)) else arg - for arg, actype, atype in newargs) - except TypeError: - # Pinpoint bad argument - for i, (arg, actype, _) in enumerate(newargs): - try: - if not isinstance(arg, ctypes._SimpleCData): - actype(arg) - except TypeError as ex: - raise TypeError(f'Invalid type for scalar argument "{callparams[i][3]}": {ex}') + # Replace arrays with their base host/device pointers + newargs = [None] * len(callparams) + for i, (arg, actype, atype, _) in enumerate(callparams): + if dtypes.is_array(arg): + newargs[i] = ctypes.c_void_p(_array_interface_ptr(arg, atype.storage)) # `c_void_p` is subclass of `ctypes._SimpleCData`. + elif not isinstance(arg, (ctypes._SimpleCData)): + newargs[i] = actype(arg) + else: + newargs[i] = arg + + except TypeError as ex: + raise TypeError(f'Invalid type for scalar argument "{callparams[i][3]}": {ex}') self._lastargs = newargs, initargs return self._lastargs + def clear_return_values(self): self._create_new_arrays = True + def _create_array(self, _: str, dtype: np.dtype, storage: dtypes.StorageType, shape: Tuple[int], strides: Tuple[int], total_size: int): ndarray = np.ndarray @@ -575,10 +636,12 @@ def ndarray(*args, buffer=None, **kwargs): # Create an array with the properties of the SDFG array return ndarray(shape, dtype, buffer=zeros(total_size, dtype), strides=strides) + def _initialize_return_values(self, kwargs): # Obtain symbol values from arguments and constants syms = dict() - syms.update({k: v for k, v in kwargs.items() if k not in self.sdfg.arrays}) + sdfg_arrays = self.sdfg.arrays + syms.update({k: v for k, v in kwargs.items() if k not in sdfg_arrays}) syms.update(self.sdfg.constants) # Clear references from last call (allow garbage collection) @@ -624,6 +687,7 @@ def _initialize_return_values(self, kwargs): arr = self._create_array(*shape_desc) self._return_arrays.append(arr) + def _convert_return_values(self): # Return the values as they would be from a Python function if self._return_arrays is None or len(self._return_arrays) == 0: diff --git a/dace/codegen/cppunparse.py b/dace/codegen/cppunparse.py index e4456e3e18..18ee00721b 100644 --- a/dace/codegen/cppunparse.py +++ b/dace/codegen/cppunparse.py @@ -746,9 +746,9 @@ def _Repr(self, t): def _Num(self, t): t_n = t.value if sys.version_info >= (3, 8) else t.n repr_n = repr(t_n) - # For complex values, use DTYPE_TO_TYPECLASS dictionary + # For complex values, use ``dtype_to_typeclass`` if isinstance(t_n, complex): - dtype = dtypes.DTYPE_TO_TYPECLASS[complex] + dtype = dtypes.dtype_to_typeclass(complex) # Handle large integer values if isinstance(t_n, int): diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 0db4062976..7b6df55132 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -887,8 +887,9 @@ def generate_code(self, # NOTE: NestedSDFGs frequently contain tautologies in their symbol mapping, e.g., `'i': i`. Do not # redefine the symbols in such cases. - if (not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping - and str(sdfg.parent_nsdfg_node.symbol_mapping[isvarName]) == str(isvarName)): + # Additionally, do not redefine a symbol with its type if it was already defined + # as part of the function's arguments + if not is_top_level and isvarName in sdfg.parent_nsdfg_node.symbol_mapping: continue isvar = data.Scalar(isvarType) callsite_stream.write('%s;\n' % (isvar.as_arg(with_types=True, name=isvarName)), sdfg) diff --git a/dace/data.py b/dace/data.py index 199e7dabd4..cceaa4139c 100644 --- a/dace/data.py +++ b/dace/data.py @@ -73,9 +73,15 @@ def create_datadescriptor(obj, no_custom_desc=False): else: dtype = dtypes.typeclass(obj.dtype.type) return Array(dtype=dtype, strides=tuple(s // obj.itemsize for s in obj.strides), shape=obj.shape) - # special case for torch tensors. Maybe __array__ could be used here for a more - # general solution, but torch doesn't support __array__ for cuda tensors. + elif type(obj).__module__ == "cupy" and type(obj).__name__ == "ndarray": + # special case for CuPy and HIP, which does not support __cuda_array_interface__ + storage = dtypes.StorageType.GPU_Global + dtype = dtypes.typeclass(obj.dtype.type) + itemsize = obj.itemsize + return Array(dtype=dtype, shape=obj.shape, strides=tuple(s // itemsize for s in obj.strides), storage=storage) elif type(obj).__module__ == "torch" and type(obj).__name__ == "Tensor": + # special case for torch tensors. Maybe __array__ could be used here for a more + # general solution, but torch doesn't support __array__ for cuda tensors. try: # If torch is importable, define translations between typeclasses and torch types. These are reused by daceml. # conversion happens here in pytorch: diff --git a/dace/dtypes.py b/dace/dtypes.py index f0bac23958..a890668595 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -360,6 +360,7 @@ class typeclass(object): 2. Enabling declaration syntax: `dace.float32[M,N]` 3. Enabling extensions such as `dace.struct` and `dace.vector` """ + def __init__(self, wrapped_type, typename=None): # Convert python basic types if isinstance(wrapped_type, str): @@ -600,6 +601,7 @@ def result_type_of(lhs, *rhs): class opaque(typeclass): """ A data type for an opaque object, useful for C bindings/libnodes, i.e., MPI_Request. """ + def __init__(self, typename): self.type = typename self.ctype = typename @@ -635,6 +637,7 @@ class pointer(typeclass): Example use: `dace.pointer(dace.struct(x=dace.float32, y=dace.float32))`. """ + def __init__(self, wrapped_typeclass): self._typeclass = wrapped_typeclass self.type = wrapped_typeclass.type @@ -680,6 +683,7 @@ class vector(typeclass): Example use: `dace.vector(dace.float32, 4)` becomes float4. """ + def __init__(self, dtype: typeclass, vector_length: int): self.vtype = dtype self.type = dtype.type @@ -737,6 +741,7 @@ class stringtype(pointer): Python/generated code marshalling. Used internally when `str` types are given """ + def __init__(self): super().__init__(int8) @@ -756,6 +761,7 @@ class struct(typeclass): Example use: `dace.struct(a=dace.int32, b=dace.float64)`. """ + def __init__(self, name, **fields_and_types): # self._data = fields_and_types self.type = ctypes.Structure @@ -859,6 +865,7 @@ class pyobject(opaque): It cannot be used inside a DaCe program, but can be passed back to other Python callbacks. Use with caution, and ensure the value is not removed by the garbage collector or the program will crash. """ + def __init__(self): super().__init__('pyobject') self.bytes = ctypes.sizeof(ctypes.c_void_p) @@ -892,6 +899,7 @@ def example(A: dace.float64[20], constant: dace.compiletime): In the above code, ``constant`` will be replaced with its value at call time during parsing. """ + @staticmethod def __descriptor__(): raise ValueError('All compile-time arguments must be provided in order to compile the SDFG ahead-of-time.') @@ -914,6 +922,7 @@ class callback(typeclass): """ Looks like ``dace.callback([None, ], *types)`` """ + def __init__(self, return_types, *variadic_args): from dace import data if return_types is None: @@ -1240,31 +1249,39 @@ class Typeclasses(aenum.AutoNumberEnum): complex128 = complex128 -DTYPE_TO_TYPECLASS = { - bool: typeclass(bool), - int: typeclass(int), - float: typeclass(float), - complex: typeclass(complex), - numpy.bool_: bool_, - numpy.int8: int8, - numpy.int16: int16, - numpy.int32: int32, - numpy.int64: int64, - numpy.intc: int32, - numpy.uint8: uint8, - numpy.uint16: uint16, - numpy.uint32: uint32, - numpy.uint64: uint64, - numpy.uintc: uint32, - numpy.float16: float16, - numpy.float32: float32, - numpy.float64: float64, - numpy.complex64: complex64, - numpy.complex128: complex128, - # FIXME - numpy.longlong: int64, - numpy.ulonglong: uint64 -} +_bool = bool + + +def dtype_to_typeclass(dtype=None): + DTYPE_TO_TYPECLASS = { + _bool: typeclass(_bool), + int: typeclass(int), + float: typeclass(float), + complex: typeclass(complex), + numpy.bool_: bool_, + numpy.int8: int8, + numpy.int16: int16, + numpy.int32: int32, + numpy.int64: int64, + numpy.intc: int32, + numpy.uint8: uint8, + numpy.uint16: uint16, + numpy.uint32: uint32, + numpy.uint64: uint64, + numpy.uintc: uint32, + numpy.float16: float16, + numpy.float32: float32, + numpy.float64: float64, + numpy.complex64: complex64, + numpy.complex128: complex128, + # FIXME + numpy.longlong: int64, + numpy.ulonglong: uint64 + } + if dtype is None: + return DTYPE_TO_TYPECLASS + return DTYPE_TO_TYPECLASS[dtype] + # Since this overrides the builtin bool, this should be after the # DTYPE_TO_TYPECLASS dictionary @@ -1354,6 +1371,7 @@ def isallowed(var, allow_recursive=False): class DebugInfo: """ Source code location identifier of a node/edge in an SDFG. Used for IDE and debugging purposes. """ + def __init__(self, start_line, start_column=0, end_line=-1, end_column=0, filename=None): self.start_line = start_line self.end_line = end_line if end_line >= 0 else start_line @@ -1397,6 +1415,7 @@ def json_to_typeclass(obj, context=None): def paramdec(dec): """ Parameterized decorator meta-decorator. Enables using `@decorator`, `@decorator()`, and `@decorator(...)` with the same function. """ + @wraps(dec) def layer(*args, **kwargs): from dace import data @@ -1478,20 +1497,22 @@ def can_allocate(storage: StorageType, schedule: ScheduleType): # Host-only allocation if storage in [StorageType.CPU_Heap, StorageType.CPU_Pinned, StorageType.CPU_ThreadLocal]: return schedule in [ - ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.GPU_Default + ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, + ScheduleType.GPU_Default ] # GPU-global memory if storage is StorageType.GPU_Global: return schedule in [ - ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.GPU_Default + ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, + ScheduleType.GPU_Default ] # FPGA-global memory if storage is StorageType.FPGA_Global: return schedule in [ - ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, ScheduleType.FPGA_Device, - ScheduleType.GPU_Default + ScheduleType.CPU_Multicore, ScheduleType.CPU_Persistent, ScheduleType.Sequential, ScheduleType.MPI, + ScheduleType.FPGA_Device, ScheduleType.GPU_Default ] # FPGA-local memory @@ -1536,6 +1557,8 @@ def is_array(obj: Any) -> bool: return hasattr(obj, 'shape') and len(obj.shape) > 0 except TypeError: # PyTorch scalar objects define an attribute called shape that cannot be used return False + if hasattr(obj, 'data') and hasattr(obj.data, 'ptr'): # CuPy special case with HIP + return True return False @@ -1556,4 +1579,9 @@ def is_gpu_array(obj: Any) -> bool: # In PyTorch, accessing this attribute throws a runtime error for # variables that require grad, or KeyError when a boolean array is used return False + + if hasattr(obj, 'data') and hasattr(obj.data, 'ptr'): # CuPy special case with HIP + if hasattr(obj, 'device') and getattr(obj.device, 'id', -1) >= 0: + return True + return False diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 733c3c7f62..2f77bd430d 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -3240,7 +3240,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): raise DaceSyntaxError(self, target, 'Variable "{}" used before definition'.format(name)) new_data, rng = None, None - dtype_keys = tuple(dtypes.DTYPE_TO_TYPECLASS.keys()) + dtype_keys = tuple(dtypes.dtype_to_typeclass().keys()) if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or (isinstance(result, str) and result in self.sdfg.arrays)): raise DaceSyntaxError( @@ -4653,14 +4653,14 @@ def visit_Num(self, node: NumConstant): if isinstance(node.n, bool): return dace.bool_(node.n) if isinstance(node.n, (int, float, complex)): - return dtypes.DTYPE_TO_TYPECLASS[type(node.n)](node.n) + return dtypes.dtype_to_typeclass(type(node.n))(node.n) return node.n def visit_Constant(self, node: ast.Constant): if isinstance(node.value, bool): return dace.bool_(node.value) if isinstance(node.value, (int, float, complex)): - return dtypes.DTYPE_TO_TYPECLASS[type(node.value)](node.value) + return dtypes.dtype_to_typeclass(type(node.value))(node.value) if isinstance(node.value, (str, bytes)): return StringLiteral(node.value) return node.value @@ -4745,7 +4745,7 @@ def _gettype(self, opnode: ast.AST) -> List[Tuple[str, str]]: result.append((operand, type(self.sdfg.arrays[operand]))) elif isinstance(operand, str) and operand in self.scope_arrays: result.append((operand, type(self.scope_arrays[operand]))) - elif isinstance(operand, tuple(dtypes.DTYPE_TO_TYPECLASS.keys())): + elif isinstance(operand, tuple(dtypes.dtype_to_typeclass().keys())): if isinstance(operand, (bool, numpy.bool_)): result.append((operand, 'BoolConstant')) else: diff --git a/dace/frontend/python/replacements.py b/dace/frontend/python/replacements.py index f55a65eabb..2e34b3077d 100644 --- a/dace/frontend/python/replacements.py +++ b/dace/frontend/python/replacements.py @@ -289,7 +289,7 @@ def _numpy_full(pv: ProgramVisitor, """ is_data = False if isinstance(fill_value, (Number, np.bool_)): - vtype = dtypes.DTYPE_TO_TYPECLASS[type(fill_value)] + vtype = dtypes.dtype_to_typeclass(type(fill_value)) elif isinstance(fill_value, sp.Expr): vtype = _sym_type(fill_value) else: @@ -546,10 +546,10 @@ def _arange(pv: ProgramVisitor, sdfg: SDFG, state: SDFGState, *args, **kwargs): if 'dtype' in kwargs and kwargs['dtype'] != None: dtype = kwargs['dtype'] if not isinstance(dtype, dtypes.typeclass): - dtype = dtypes.DTYPE_TO_TYPECLASS[dtype] + dtype = dtypes.dtype_to_typeclass(dtype) outname, outarr = sdfg.add_temp_transient(shape, dtype) else: - dtype = dtypes.DTYPE_TO_TYPECLASS[type(shape[0])] + dtype = dtypes.dtype_to_typeclass(type(shape[0])) outname, outarr = sdfg.add_temp_transient(shape, dtype) state.add_mapped_tasklet(name="_numpy_arange_", @@ -1076,8 +1076,8 @@ def _array_array_where(visitor: ProgramVisitor, left_arr = sdfg.arrays.get(left_operand, None) right_arr = sdfg.arrays.get(right_operand, None) - left_type = left_arr.dtype if left_arr else dtypes.DTYPE_TO_TYPECLASS[type(left_operand)] - right_type = right_arr.dtype if right_arr else dtypes.DTYPE_TO_TYPECLASS[type(right_operand)] + left_type = left_arr.dtype if left_arr else dtypes.dtype_to_typeclass(type(left_operand)) + right_type = right_arr.dtype if right_arr else dtypes.dtype_to_typeclass(type(right_operand)) # Implicit Python coversion implemented as casting arguments = [cond_arr, left_arr or left_type, right_arr or right_type] @@ -1356,11 +1356,11 @@ def _np_result_type(nptypes): # Fix for np.result_type returning platform-dependent types, # e.g. np.longlong restype = np.result_type(*nptypes) - if restype.type not in dtypes.DTYPE_TO_TYPECLASS.keys(): - for k in dtypes.DTYPE_TO_TYPECLASS.keys(): + if restype.type not in dtypes.dtype_to_typeclass().keys(): + for k in dtypes.dtype_to_typeclass().keys(): if k == restype.type: - return dtypes.DTYPE_TO_TYPECLASS[k] - return dtypes.DTYPE_TO_TYPECLASS[restype.type] + return dtypes.dtype_to_typeclass(k) + return dtypes.dtype_to_typeclass(restype.type) def _sym_type(expr: Union[symbolic.symbol, sp.Basic]) -> dtypes.typeclass: @@ -1393,7 +1393,7 @@ def _result_type(arguments: Sequence[Union[str, Number, symbolic.symbol, sp.Basi datatypes.append(arg.dtype) dtypes_for_result.append(_representative_num(arg.dtype)) elif isinstance(arg, (Number, np.bool_)): - datatypes.append(dtypes.DTYPE_TO_TYPECLASS[type(arg)]) + datatypes.append(dtypes.dtype_to_typeclass(type(arg))) dtypes_for_result.append(arg) elif symbolic.issymbolic(arg): datatypes.append(_sym_type(arg)) @@ -1668,13 +1668,13 @@ def _array_const_binop(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, le left_shape = left_arr.shape storage = left_arr.storage right_arr = None - right_type = dtypes.DTYPE_TO_TYPECLASS[type(right_operand)] + right_type = dtypes.dtype_to_typeclass(type(right_operand)) right_shape = [1] arguments = [left_arr, right_operand] tasklet_args = ['__in1', f'({str(right_operand)})'] else: left_arr = None - left_type = dtypes.DTYPE_TO_TYPECLASS[type(left_operand)] + left_type = dtypes.dtype_to_typeclass(type(left_operand)) left_shape = [1] right_arr = sdfg.arrays[right_operand] right_type = right_arr.dtype @@ -2229,7 +2229,7 @@ def _matmult(visitor: ProgramVisitor, sdfg: SDFG, state: SDFGState, op1: str, op type1 = arr1.dtype.type type2 = arr2.dtype.type - restype = dace.DTYPE_TO_TYPECLASS[np.result_type(type1, type2).type] + restype = dace.dtype_to_typeclass(np.result_type(type1, type2).type) op3, arr3 = sdfg.add_temp_transient(output_shape, restype, arr1.storage) @@ -3517,7 +3517,7 @@ def implement_ufunc(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SDFG, sta ufunc_impl['operator']) if 'dtype' in kwargs.keys(): dtype = kwargs['dtype'] - if dtype in dtypes.DTYPE_TO_TYPECLASS.keys(): + if dtype in dtypes.dtype_to_typeclass().keys(): result_type = dtype # Create output data (if needed) @@ -3709,7 +3709,7 @@ def implement_ufunc_reduce(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SD datadesc = sdfg.arrays[arg] result_type = datadesc.dtype elif isinstance(arg, (Number, np.bool_)): - result_type = dtypes.DTYPE_TO_TYPECLASS[type(arg)] + result_type = dtypes.dtype_to_typeclass(type(arg)) elif isinstance(arg, sp.Basic): result_type = _sym_type(arg) @@ -4018,7 +4018,7 @@ def implement_ufunc_outer(visitor: ProgramVisitor, ast_node: ast.Call, sdfg: SDF ufunc_impl['operator']) if 'dtype' in kwargs.keys(): dtype = kwargs['dtype'] - if dtype in dtypes.DTYPE_TO_TYPECLASS.keys(): + if dtype in dtypes.dtype_to_typeclass().keys(): result_type = dtype # Create output data (if needed) @@ -4412,9 +4412,9 @@ def _make_datatype_converter(typeclass: str): if typeclass == "bool": dtype = dace.bool elif typeclass in {"int", "float", "complex"}: - dtype = dtypes.DTYPE_TO_TYPECLASS[eval(typeclass)] + dtype = dtypes.dtype_to_typeclass(eval(typeclass)) else: - dtype = dtypes.DTYPE_TO_TYPECLASS[eval("np.{}".format(typeclass))] + dtype = dtypes.dtype_to_typeclass(eval("np.{}".format(typeclass))) @oprepo.replaces(typeclass) @oprepo.replaces("dace.{}".format(typeclass)) @@ -4711,7 +4711,7 @@ def _cupy_full(pv: ProgramVisitor, the fill value. """ if isinstance(fill_value, (Number, np.bool_)): - vtype = dtypes.DTYPE_TO_TYPECLASS[type(fill_value)] + vtype = dtypes.dtype_to_typeclass(type(fill_value)) elif isinstance(fill_value, sp.Expr): vtype = _sym_type(fill_value) else: diff --git a/dace/libraries/blas/nodes/gemm.py b/dace/libraries/blas/nodes/gemm.py index 83be99d78b..d78e54eb6e 100644 --- a/dace/libraries/blas/nodes/gemm.py +++ b/dace/libraries/blas/nodes/gemm.py @@ -30,12 +30,12 @@ def _cast_to_dtype_str(value, dtype: dace.dtypes.typeclass) -> str: cast_value = complex(value) return "dace.{type}({real}, {imag})".format( - type=dace.DTYPE_TO_TYPECLASS[dtype].to_string(), + type=dace.dtype_to_typeclass(dtype).to_string(), real=cast_value.real, imag=cast_value.imag, ) else: - return "dace.{}({})".format(dace.DTYPE_TO_TYPECLASS[dtype].to_string(), value) + return "dace.{}({})".format(dace.dtype_to_typeclass(dtype).to_string(), value) @dace.library.expansion @@ -52,7 +52,7 @@ def make_sdfg(node, parent_state, parent_sdfg): dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type - dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_b).type] + dtype_c = dace.dtype_to_typeclass(np.result_type(dtype_a, dtype_b).type) if node.transA: trans_shape_a = list(reversed(shape_a)) @@ -518,7 +518,7 @@ def expansion(node, parent_state, parent_sdfg, num_pes=32, tile_size_m=None): dtype_a = outer_array_a.dtype.type dtype_b = outer_array_b.dtype.type - dtype_c = dace.DTYPE_TO_TYPECLASS[np.result_type(dtype_a, dtype_b).type] + dtype_c = dace.dtype_to_typeclass(np.result_type(dtype_a, dtype_b).type) shape_c = (shape_a[0], shape_b[1]) if node.transA: raise NotImplementedError("GEMM FPGA expansion not implemented for transposed A.") diff --git a/dace/libraries/sparse/nodes/csrmm.py b/dace/libraries/sparse/nodes/csrmm.py index d5707b400d..b21867b0e9 100644 --- a/dace/libraries/sparse/nodes/csrmm.py +++ b/dace/libraries/sparse/nodes/csrmm.py @@ -28,12 +28,12 @@ def _cast_to_dtype_str(value, dtype: dace.dtypes.typeclass) -> str: cast_value = complex(value) return "dace.{type}({real}, {imag})".format( - type=dace.DTYPE_TO_TYPECLASS[dtype].to_string(), + type=dace.dtype_to_typeclass(dtype).to_string(), real=cast_value.real, imag=cast_value.imag, ) else: - return "dace.{}({})".format(dace.DTYPE_TO_TYPECLASS[dtype].to_string(), value) + return "dace.{}({})".format(dace.dtype_to_typeclass(dtype).to_string(), value) def _get_csrmm_operands(node, diff --git a/dace/libraries/sparse/nodes/csrmv.py b/dace/libraries/sparse/nodes/csrmv.py index 7b69a7af00..cc3e98eec4 100644 --- a/dace/libraries/sparse/nodes/csrmv.py +++ b/dace/libraries/sparse/nodes/csrmv.py @@ -27,12 +27,12 @@ def _cast_to_dtype_str(value, dtype: dace.dtypes.typeclass) -> str: cast_value = complex(value) return "dace.{type}({real}, {imag})".format( - type=dace.DTYPE_TO_TYPECLASS[dtype].to_string(), + type=dace.dtype_to_typeclass(dtype).to_string(), real=cast_value.real, imag=cast_value.imag, ) else: - return "dace.{}({})".format(dace.DTYPE_TO_TYPECLASS[dtype].to_string(), value) + return "dace.{}({})".format(dace.dtype_to_typeclass(dtype).to_string(), value) def _get_csrmv_operands(node: dace.sdfg.nodes.LibraryNode, diff --git a/dace/libraries/standard/nodes/reduce.py b/dace/libraries/standard/nodes/reduce.py index dd026ea62c..4e04a656fe 100644 --- a/dace/libraries/standard/nodes/reduce.py +++ b/dace/libraries/standard/nodes/reduce.py @@ -1088,7 +1088,7 @@ class ExpandReduceGPUAuto(pm.ExpandTransformation): """ GPU implementation of the reduce node. This expansion aims to map the reduction inputs to an optimal GPU schedule. """ - environments = [CUDA] + environments = [] @staticmethod def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): @@ -1099,6 +1099,8 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): :param state: the state in which the node is in :param sdfg: the SDFG in which the node is in """ + from dace.codegen import common + node.validate(sdfg, state) inedge: graph.MultiConnectorEdge = state.in_edges(node)[0] outedge: graph.MultiConnectorEdge = state.out_edges(node)[0] @@ -1106,6 +1108,7 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): isqdim = insubset.squeeze() raw_input_data = sdfg.arrays[inedge.data.data] raw_output_data = sdfg.arrays[outedge.data.data] + warp_size = 64 if common.get_gpu_backend() == 'hip' else 32 in_type = raw_input_data.dtype @@ -1132,7 +1135,7 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): axes = [axis for axis in axes if axis in isqdim] # call the planner script - schedule = red_planner.get_reduction_schedule(raw_input_data, axes) + schedule = red_planner.get_reduction_schedule(raw_input_data, axes, warp_size=warp_size) if schedule.error: # return pure expansion if error @@ -1340,25 +1343,25 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): real_state = nested_sdfg.add_state('real_state') nested_sdfg.add_edge(start_state, real_state, - dace.InterstateEdge(f'_b1 + 32 * _g < {schedule.in_shape[-1]}')) + dace.InterstateEdge(f'_b1 + {warp_size} * _g < {schedule.in_shape[-1]}')) reset_outm = dace.Memlet(f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape))])}]') if len(schedule.out_shape) > 1: outm = dace.Memlet( - f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape) - 1)])},_g * 32 + _b]', + f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape) - 1)])},_g * {warp_size} + _b]', dynamic=True) outm_wcr = dace.Memlet( - f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape) - 1)])},_g * 32 + _b]', + f'_out[{",".join(["_o%d" % i for i in range(len(schedule.out_shape) - 1)])},_g * {warp_size} + _b]', dynamic=True, wcr=node.wcr) else: - outm = dace.Memlet(f'_out[_g * 32 + _b]', dynamic=True) - outm_wcr = dace.Memlet(f'_out[_g * 32 + _b]', dynamic=True, wcr=node.wcr) + outm = dace.Memlet(f'_out[_g * {warp_size} + _b]', dynamic=True) + outm_wcr = dace.Memlet(f'_out[_g * {warp_size} + _b]', dynamic=True, wcr=node.wcr) input_subset = input_subset[:-2] input_subset.append(f'0:{schedule.sequential[0]}') - input_subset.append('_g * 32 + _b1') + input_subset.append(f'_g * {warp_size} + _b1') inmm = dace.Memlet(f'_in[{",".join(input_subset)}]', dynamic=True) if schedule.multi_axes: @@ -1401,13 +1404,13 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): schedule=dtypes.ScheduleType.GPU_ThreadBlock) else: - bme1, bmx1 = nstate.add_map('block', {'_b': f'0:32'}, schedule=dtypes.ScheduleType.GPU_ThreadBlock) + bme1, bmx1 = nstate.add_map('block', {'_b': f'0:{warp_size}'}, schedule=dtypes.ScheduleType.GPU_ThreadBlock) bme2, bmx2 = nstate.add_map('block', {f'_b{i}': f'0:{sz}' for i, sz in enumerate(schedule.block)}, schedule=dtypes.ScheduleType.GPU_ThreadBlock) - # add shared memory of size 32 to outer sdfg + # add shared memory of warp size to outer sdfg nsdfg.add_array('s_mem', [schedule.shared_mem_size], nsdfg.arrays['_in'].dtype, dtypes.StorageType.GPU_Shared, @@ -1482,11 +1485,11 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): if mini_warps: cond_tasklet = nstate.add_tasklet( 'cond_write', {'_input'}, {'_output'}, - f'if _b + 32 * _g < {schedule.out_shape[-1]} and _bb == 0 and _mwid == 0: _output = _input') + f'if _b + {warp_size} * _g < {schedule.out_shape[-1]} and _bb == 0 and _mwid == 0: _output = _input') else: cond_tasklet = nstate.add_tasklet( 'cond_write', {'_input'}, {'_output'}, - f'if _b + 32 * _g < {schedule.out_shape[-1]} and _bb == 0: _output = _input') + f'if _b + {warp_size} * _g < {schedule.out_shape[-1]} and _bb == 0: _output = _input') # connect accumulator to identity tasklet real_state.add_memlet_path(accread, ime, id, dst_conn='a', memlet=dace.Memlet('acc[0]')) @@ -1511,8 +1514,8 @@ def expansion(node: 'Reduce', state: SDFGState, sdfg: SDFG): nstate.add_memlet_path(s_mem3, bme3, cond_tasklet, dst_conn='_input', memlet=dace.Memlet('s_mem[_b]')) else: bme3, bmx3 = nstate.add_map('block', { - '_bb': '0:16', - '_b': f'0:32' + '_bb': f'0:{512//warp_size}', + '_b': f'0:{warp_size}' }, schedule=dtypes.ScheduleType.GPU_ThreadBlock) nstate.add_memlet_path(s_mem3, bme3, cond_tasklet, dst_conn='_input', memlet=dace.Memlet('s_mem[_b]')) diff --git a/dace/libraries/standard/nodes/ttranspose.py b/dace/libraries/standard/nodes/ttranspose.py index e11012e3ad..6d142db81f 100644 --- a/dace/libraries/standard/nodes/ttranspose.py +++ b/dace/libraries/standard/nodes/ttranspose.py @@ -38,7 +38,10 @@ def expansion(node, parent_state, parent_sdfg): out_mem = dace.Memlet(expr=f"_out_tensor[{','.join([map_params[i] for i in node.axes])}]") inputs = {"_inp": inp_mem} outputs = {"_out": out_mem} - code = f"_out = {node.alpha} * _inp" + if node.alpha == 1: + code = "_out = _inp" + else: + code = f"_out = decltype(_inp)({node.alpha}) * _inp" if node.beta != 0: inputs["_inout"] = out_mem code = f"_out = {node.alpha} * _inp + {node.beta} * _inout" diff --git a/dace/runtime/include/dace/reduction.h b/dace/runtime/include/dace/reduction.h index 9d8c59997c..927bf449de 100644 --- a/dace/runtime/include/dace/reduction.h +++ b/dace/runtime/include/dace/reduction.h @@ -592,7 +592,9 @@ namespace dace { cub::TransformInputIterator itr(counting_iterator, conversion_op); return itr; } +#endif +#if defined(__CUDACC__) template struct warpReduce { static DACE_DFI T reduce(T v) @@ -610,6 +612,24 @@ namespace dace { return v; } }; +#elif defined(__HIPCC__) + template + struct warpReduce { + static DACE_DFI T reduce(T v) + { + for (int i = 1; i < warpSize; i = i * 2) + v = _wcr_fixed()(v, __shfl_xor(v, i)); + return v; + } + + template + static DACE_DFI T mini_reduce(T v) + { + for (int i = 1; i < NUM_MW; i = i * 2) + v = _wcr_fixed()(v, __shfl_xor(v, i)); + return v; + } + }; #endif } // namespace dace diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 2e35218a3d..eb43a99a54 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -769,7 +769,7 @@ def add_symbol(self, name, stype): if name in self.symbols: raise FileExistsError('Symbol "%s" already exists in SDFG' % name) if not isinstance(stype, dtypes.typeclass): - stype = dtypes.DTYPE_TO_TYPECLASS[stype] + stype = dtypes.dtype_to_typeclass(stype) self.symbols[name] = stype def remove_symbol(self, name): diff --git a/dace/serialize.py b/dace/serialize.py index ef07530905..4afaef69ee 100644 --- a/dace/serialize.py +++ b/dace/serialize.py @@ -47,7 +47,7 @@ def to_json(obj): return None try: - dtype_json = dace.dtypes.DTYPE_TO_TYPECLASS[obj.dtype.type].to_json() + dtype_json = dace.dtypes.dtype_to_typeclass(obj.dtype.type).to_json() except KeyError: dtype_json = str(obj.dtype) @@ -69,12 +69,19 @@ def to_json(obj): # All classes annotated with the make_properties decorator will register # themselves here. } -# Also register each of the basic types -_DACE_SERIALIZE_TYPES.update({v.to_string(): v for v in dace.dtypes.DTYPE_TO_TYPECLASS.values()}) def get_serializer(type_name): - return _DACE_SERIALIZE_TYPES[type_name] + if type_name in _DACE_SERIALIZE_TYPES: + return _DACE_SERIALIZE_TYPES[type_name] + + # Also try each of the basic types + basic_dtypes = {v.to_string(): v for v in dace.dtypes.dtype_to_typeclass().values()} + if type_name in basic_dtypes: + return basic_dtypes[type_name] + + raise KeyError(f'Serializer for type "{type_name}" was not found. Object type does not support serialization. ' + 'Please implement serialization by decorating the class with ``@serializable``.') # Decorator for objects that should be serializable, but don't call @@ -144,7 +151,7 @@ def from_json(obj, context=None, known_type=None): if t: try: - deserialized = _DACE_SERIALIZE_TYPES[t].from_json(obj, context=context) + deserialized = get_serializer(t).from_json(obj, context=context) except Exception as ex: if config.Config.get_bool('testing', 'deserialize_exception'): raise diff --git a/dace/symbolic.py b/dace/symbolic.py index f3dfcfb36d..8342725349 100644 --- a/dace/symbolic.py +++ b/dace/symbolic.py @@ -42,7 +42,7 @@ def __new__(cls, name=None, dtype=DEFAULT_SYMBOL_TYPE, **assumptions): if not isinstance(dtype, dtypes.typeclass): raise TypeError('dtype must be a DaCe type, got %s' % str(dtype)) - dkeys = [k for k, v in dtypes.DTYPE_TO_TYPECLASS.items() if v == dtype] + dkeys = [k for k, v in dtypes.dtype_to_typeclass().items() if v == dtype] is_integer = [issubclass(k, int) or issubclass(k, numpy.integer) for k in dkeys] if 'integer' in assumptions or not numpy.any(is_integer): # Using __xnew__ as the regular __new__ is cached, which leads diff --git a/tests/codegen/symbol_arguments_test.py b/tests/codegen/symbol_arguments_test.py index 3ca89ddd06..557c42f8c1 100644 --- a/tests/codegen/symbol_arguments_test.py +++ b/tests/codegen/symbol_arguments_test.py @@ -48,7 +48,21 @@ def tester(A: dace.float64[N, N]): assert 'N' in sdfg.arglist() +def test_nested_sdfg_redefinition(): + sdfg = dace.SDFG('tester') + nsdfg = dace.SDFG('nester') + state = sdfg.add_state() + nnode = state.add_nested_sdfg(nsdfg, None, {}, {}, symbol_mapping=dict(sym=0)) + + nstate = nsdfg.add_state() + nstate.add_tasklet('nothing', {}, {}, 'a = sym') + nstate2 = nsdfg.add_state() + nsdfg.add_edge(nstate, nstate2, dace.InterstateEdge(assignments=dict(sym=1))) + sdfg.compile() + + if __name__ == '__main__': test_global_sizes() test_global_sizes_used() test_global_sizes_multidim() + test_nested_sdfg_redefinition()