diff --git a/sycl/include/sycl/detail/cg_types.hpp b/sycl/include/sycl/detail/cg_types.hpp index cb7b066a256b1..3f571a5868497 100644 --- a/sycl/include/sycl/detail/cg_types.hpp +++ b/sycl/include/sycl/detail/cg_types.hpp @@ -192,10 +192,18 @@ class HostKernel : public HostKernelBase { std::is_same_v>) { constexpr bool HasOffset = std::is_same_v>; - KernelArgType Item = IDBuilder::createItem( - InitializedVal::template get<1>(), - InitializedVal::template get<0>()); - runKernelWithArg(MKernel, Item); + if constexpr (!HasOffset) { + KernelArgType Item = IDBuilder::createItem( + InitializedVal::template get<1>(), + InitializedVal::template get<0>()); + runKernelWithArg(MKernel, Item); + } else { + KernelArgType Item = IDBuilder::createItem( + InitializedVal::template get<1>(), + InitializedVal::template get<0>(), + InitializedVal::template get<0>()); + runKernelWithArg(MKernel, Item); + } } else if constexpr (std::is_same_v>) { sycl::range Range = InitializedVal::template get<1>(); sycl::id ID = InitializedVal::template get<0>(); diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index ad389ab898940..105f33ee68dff 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -767,130 +767,6 @@ class __SYCL_EXPORT handler { &DynamicParamBase, int ArgIndex); - /* The kernel passed to StoreLambda can take an id, an item or an nd_item as - * its argument. Since esimd adapter directly invokes the kernel (doesn’t use - * urKernelSetArg), the kernel argument type must be known to the adapter. - * However, passing kernel argument type to the adapter requires changing ABI - * in HostKernel class. To overcome this problem, helpers below wrap the - * “original” kernel with a functor that always takes an nd_item as argument. - * A functor is used instead of a lambda because extractArgsAndReqsFromLambda - * needs access to the “original” kernel and keeps references to its internal - * data, i.e. the kernel passed as argument cannot be local in scope. The - * functor itself is again encapsulated in a std::function since functor’s - * type is unknown to the adapter. - */ - - // For 'id, item w/wo offset, nd_item' kernel arguments - template - KernelType *ResetHostKernelHelper(const KernelType &KernelFunc) { - NormalizedKernelType NormalizedKernel(KernelFunc); - auto NormalizedKernelFunc = - std::function &)>(NormalizedKernel); - auto HostKernelPtr = new detail::HostKernel, Dims>( - std::move(NormalizedKernelFunc)); - MHostKernel.reset(HostKernelPtr); - return &HostKernelPtr->MKernel.template target() - ->MKernelFunc; - } - - // For 'sycl::id' kernel argument - template - std::enable_if_t>, KernelType *> - ResetHostKernel(const KernelType &KernelFunc) { - struct NormalizedKernelType { - KernelType MKernelFunc; - NormalizedKernelType(const KernelType &KernelFunc) - : MKernelFunc(KernelFunc) {} - void operator()(const nd_item &Arg) { - detail::runKernelWithArg(MKernelFunc, Arg.get_global_id()); - } - }; - return ResetHostKernelHelper( - KernelFunc); - } - - // For 'sycl::nd_item' kernel argument - template - std::enable_if_t>, KernelType *> - ResetHostKernel(const KernelType &KernelFunc) { - struct NormalizedKernelType { - KernelType MKernelFunc; - NormalizedKernelType(const KernelType &KernelFunc) - : MKernelFunc(KernelFunc) {} - void operator()(const nd_item &Arg) { - detail::runKernelWithArg(MKernelFunc, Arg); - } - }; - return ResetHostKernelHelper( - KernelFunc); - } - - // For 'sycl::item' kernel argument - template - std::enable_if_t>, KernelType *> - ResetHostKernel(const KernelType &KernelFunc) { - struct NormalizedKernelType { - KernelType MKernelFunc; - NormalizedKernelType(const KernelType &KernelFunc) - : MKernelFunc(KernelFunc) {} - void operator()(const nd_item &Arg) { - sycl::item Item = detail::Builder::createItem( - Arg.get_global_range(), Arg.get_global_id()); - detail::runKernelWithArg(MKernelFunc, Item); - } - }; - return ResetHostKernelHelper( - KernelFunc); - } - - // For 'sycl::item' kernel argument - template - std::enable_if_t>, KernelType *> - ResetHostKernel(const KernelType &KernelFunc) { - struct NormalizedKernelType { - KernelType MKernelFunc; - NormalizedKernelType(const KernelType &KernelFunc) - : MKernelFunc(KernelFunc) {} - void operator()(const nd_item &Arg) { - sycl::item Item = detail::Builder::createItem( - Arg.get_global_range(), Arg.get_global_id(), Arg.get_offset()); - detail::runKernelWithArg(MKernelFunc, Item); - } - }; - return ResetHostKernelHelper( - KernelFunc); - } - - // For 'void' kernel argument (single_task) - template - typename std::enable_if_t, KernelType *> - ResetHostKernel(const KernelType &KernelFunc) { - struct NormalizedKernelType { - KernelType MKernelFunc; - NormalizedKernelType(const KernelType &KernelFunc) - : MKernelFunc(KernelFunc) {} - void operator()(const nd_item &Arg) { - (void)Arg; - detail::runKernelWithoutArg(MKernelFunc); - } - }; - return ResetHostKernelHelper( - KernelFunc); - } - - // For 'sycl::group' kernel argument - // 'wrapper'-based approach using 'NormalizedKernelType' struct is not used - // for 'void(sycl::group)' since 'void(sycl::group)' is not - // supported in ESIMD. - template - std::enable_if_t>, KernelType *> - ResetHostKernel(const KernelType &KernelFunc) { - MHostKernel.reset( - new detail::HostKernel(KernelFunc)); - return (KernelType *)(MHostKernel->getPtr()); - } - /// Verifies the kernel bundle to be used if any is set. This throws a /// sycl::exception with error code errc::kernel_not_supported if the used /// kernel bundle does not contain a suitable device image with the requested @@ -918,8 +794,8 @@ class __SYCL_EXPORT handler { detail::KernelLambdaHasKernelHandlerArgT::value; - KernelType *KernelPtr = - ResetHostKernel(KernelFunc); + MHostKernel = std::make_unique< + detail::HostKernel>(KernelFunc); constexpr bool KernelHasName = detail::getKernelName() != nullptr && @@ -950,7 +826,7 @@ class __SYCL_EXPORT handler { if (KernelHasName) { // TODO support ESIMD in no-integration-header case too. clearArgs(); - extractArgsAndReqsFromLambda(reinterpret_cast(KernelPtr), + extractArgsAndReqsFromLambda(MHostKernel->getPtr(), detail::getKernelParamDescs(), detail::isKernelESIMD()); MKernelName = detail::getKernelName();