From 6ba7b52318f78ee040055f548e465901da8e7ac0 Mon Sep 17 00:00:00 2001 From: Alexey Sachkov Date: Fri, 27 Oct 2023 19:49:33 +0200 Subject: [PATCH] [SYCL] Fix SYCL kernel lambda argument type detection (#11679) We have a helper which is used to extract a type of the first SYCL kernel lambda argument to do some error-checking and special handling based on that. That check, however, was missing a case when a kernel lambda is also accepting `kernel_handler` argument, always falling back to a suggested type in that case. This led to a situations where we couldn't compile code like: ```c++ sycl::queue q; q.parallel_for(sycl::range{1}, [=](sycl::item<1, false>, kernel_handler) {}); ``` This patch adds extra specializations of some internal helpers to fix the error. This is a follow-up from intel/llvm#11625 --- sycl/include/sycl/handler.hpp | 16 ++++++++++------ .../basic_tests/handler/parallel_for_args.cpp | 16 +++++++++++++++- 2 files changed, 25 insertions(+), 7 deletions(-) diff --git a/sycl/include/sycl/handler.hpp b/sycl/include/sycl/handler.hpp index 18dc90cf995e3..1315e3931bbf8 100644 --- a/sycl/include/sycl/handler.hpp +++ b/sycl/include/sycl/handler.hpp @@ -189,11 +189,15 @@ static Arg member_ptr_helper(RetType (Func::*)(Arg) const); template static Arg member_ptr_helper(RetType (Func::*)(Arg)); -// template -// static void member_ptr_helper(RetType (Func::*)() const); +// Version with two arguments to handle the case when kernel_handler is passed +// to a lambda +template +static Arg1 member_ptr_helper(RetType (Func::*)(Arg1, Arg2) const); -// template -// static void member_ptr_helper(RetType (Func::*)()); +// Non-const version of the above template to match functors whose 'operator()' +// is declared w/o the 'const' qualifier. +template +static Arg1 member_ptr_helper(RetType (Func::*)(Arg1, Arg2)); template decltype(member_ptr_helper(&F::operator())) argument_helper(int); @@ -1280,7 +1284,7 @@ class __SYCL_EXPORT handler { using KName = std::conditional_t::value, decltype(Wrapper), NameWT>; - kernel_parallel_for_wrapper, decltype(Wrapper), + kernel_parallel_for_wrapper(Wrapper); #ifndef __SYCL_DEVICE_ONLY__ // We are executing over the rounded range, but there are still @@ -1290,7 +1294,7 @@ class __SYCL_EXPORT handler { // of the user range, instead of the rounded range. detail::checkValueRange(UserRange); MNDRDesc.set(*RoundedRange); - StoreLambda>( + StoreLambda( std::move(Wrapper)); setType(detail::CG::Kernel); #endif diff --git a/sycl/test/basic_tests/handler/parallel_for_args.cpp b/sycl/test/basic_tests/handler/parallel_for_args.cpp index 11f6d4ac791db..fab3e0923c8bd 100644 --- a/sycl/test/basic_tests/handler/parallel_for_args.cpp +++ b/sycl/test/basic_tests/handler/parallel_for_args.cpp @@ -36,6 +36,14 @@ int main() { q.parallel_for(r2, [=](sycl::item<2> it) {}); q.parallel_for(r3, [=](sycl::item<3> it) {}); + q.parallel_for(r1, [=](sycl::item<1, false> it) {}); + q.parallel_for(r2, [=](sycl::item<2, false> it) {}); + q.parallel_for(r3, [=](sycl::item<3, false> it) {}); + + // int, size_t -> sycl::item + q.parallel_for(r1, [=](int it) {}); + q.parallel_for(r1, [=](size_t it) {}); + // sycl::item -> sycl::id q.parallel_for(r1, [=](sycl::id<1> it) {}); q.parallel_for(r2, [=](sycl::id<2> it) {}); @@ -51,6 +59,13 @@ int main() { q.parallel_for(r2, [=](sycl::item<2> it, sycl::kernel_handler kh) {}); q.parallel_for(r3, [=](sycl::item<3> it, sycl::kernel_handler kh) {}); + q.parallel_for(r1, [=](int it, sycl::kernel_handler kh) {}); + q.parallel_for(r1, [=](size_t it, sycl::kernel_handler kh) {}); + + q.parallel_for(r1, [=](sycl::item<1, false> it, sycl::kernel_handler kh) {}); + q.parallel_for(r2, [=](sycl::item<2, false> it, sycl::kernel_handler kh) {}); + q.parallel_for(r3, [=](sycl::item<3, false> it, sycl::kernel_handler kh) {}); + q.parallel_for(r1, [=](sycl::id<1> it, sycl::kernel_handler kh) {}); q.parallel_for(r2, [=](sycl::id<2> it, sycl::kernel_handler kh) {}); q.parallel_for(r3, [=](sycl::id<3> it, sycl::kernel_handler kh) {}); @@ -90,5 +105,4 @@ int main() { [=](ConvertibleFromNDItem<3> it, sycl::kernel_handler kh) {}); // TODO: consider adding test cases for hierarchical parallelism - // TODO: consider adding cases for sycl::item with offset }