diff --git a/dace/codegen/targets/intel_fpga.py b/dace/codegen/targets/intel_fpga.py index a9a05604f5..9437dccbe3 100644 --- a/dace/codegen/targets/intel_fpga.py +++ b/dace/codegen/targets/intel_fpga.py @@ -571,8 +571,9 @@ def generate_module(self, sdfg, cfg, state, kernel_name, module_name, subgraph, arg = self.make_kernel_argument(p, pname, is_output, True) if arg is not None: - #change c type long long to opencl type long - arg = arg.replace("long long", "long") + #change c type to opencl type + if arg in dtypes._CTYPES_TO_OCLTYPES: + arg = dtypes._CTYPES_TO_OCLTYPES[arg] kernel_args_opencl.append(arg) kernel_args_host.append(p.as_arg(True, name=pname)) @@ -770,8 +771,9 @@ def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node): ptrname = cpp.ptr(in_memlet.data, desc, sdfg, self._frame) defined_type, defined_ctype = self._dispatcher.defined_vars.get(ptrname, 1) - #change c type long long to opencl type long - defined_ctype = defined_ctype.replace("long long", "long") + #change c type to opencl type + if defined_ctype in dtypes._CTYPES_TO_OCLTYPES: + defined_ctype = dtypes._CTYPES_TO_OCLTYPES[defined_ctype] if isinstance(desc, dace.data.Array) and (desc.storage == dtypes.StorageType.FPGA_Global or desc.storage == dtypes.StorageType.FPGA_Local): @@ -823,9 +825,9 @@ def generate_nsdfg_arguments(self, sdfg, cfg, dfg, state, node): ptrname = cpp.ptr(out_memlet.data, desc, sdfg, self._frame) defined_type, defined_ctype = self._dispatcher.defined_vars.get(ptrname, 1) - #change c type long long to opencl type long - if defined_ctype.__contains__("long long"): - defined_ctype = defined_ctype.replace("long long", "long") + #change c type to opencl type + if defined_ctype in dtypes._CTYPES_TO_OCLTYPES: + defined_ctype = dtypes._CTYPES_TO_OCLTYPES[defined_ctype] if isinstance(desc, dace.data.Array) and (desc.storage == dtypes.StorageType.FPGA_Global or desc.storage == dtypes.StorageType.FPGA_Local): diff --git a/dace/dtypes.py b/dace/dtypes.py index 69497974e7..8be7c78b8d 100644 --- a/dace/dtypes.py +++ b/dace/dtypes.py @@ -286,6 +286,26 @@ class TilingType(aenum.AutoNumberEnum): numpy.complex128: "complex double", } +_CTYPES_TO_OCLTYPES = { + "void": "void", + "int": "int", + "float": "float", + "double": "double", + "dace::complex64": "complex float", + "dace::complex128": "complex double", + "bool": "bool", + "char": "char", + "short": "short", + "int": "int", + "int64_t": "long", + "uint8_t": "uchar", + "uint16_t": "ushort", + "uint32_t": "uint", + "dace::uint": "uint", + "uint64_t": "ulong", + "dace::float16": "half", +} + # Translation of types to OpenCL vector types _OCL_VECTOR_TYPES = { numpy.int8: "char",