diff --git a/sycl/source/queue.cpp b/sycl/source/queue.cpp index 579782950d4e1..3fee25b7236a2 100644 --- a/sycl/source/queue.cpp +++ b/sycl/source/queue.cpp @@ -238,7 +238,11 @@ event queue::ext_oneapi_submit_barrier(const detail::code_location &CodeLoc) { /// group is being enqueued on. event queue::ext_oneapi_submit_barrier(const std::vector &WaitList, const detail::code_location &CodeLoc) { - if (is_in_order() && WaitList.empty()) { + bool AllEventsEmpty = std::all_of( + begin(WaitList), end(WaitList), [&](const event &Event) -> bool { + return !detail::getSyclObjImpl(Event)->isContextInitialized(); + }); + if (is_in_order() && AllEventsEmpty) { // The last command recorded in the graph is not tracked by the queue but by // the graph itself. We must therefore search for the last node/event in the // graph. diff --git a/sycl/test-e2e/Regression/ext_oneapi_barrier_opt.cpp b/sycl/test-e2e/Regression/ext_oneapi_barrier_opt.cpp new file mode 100644 index 0000000000000..bc06e85a00f94 --- /dev/null +++ b/sycl/test-e2e/Regression/ext_oneapi_barrier_opt.cpp @@ -0,0 +1,45 @@ +// RUN: %{build} %threads_lib -o %t.out +// RUN: %{run} %t.out + +// Check that ext_oneapi_submit_barrier works fine in the scenarios +// when provided waitlist consists of only empty events. + +#include +#include +#include +#include + +static constexpr int niter = 1024; +static constexpr int nthreads = 2; + +std::array mutexes; +std::array, nthreads> events; + +void threadFunction(int tid) { + sycl::device dev; + std::cout << dev.get_info() << std::endl; + sycl::context ctx{dev}; + sycl::queue q1{ctx, dev, {sycl::property::queue::in_order()}}; + sycl::queue q2{ctx, dev, {sycl::property::queue::in_order()}}; + for (int i = 0; i < niter; i++) { + sycl::event ev1 = q1.ext_oneapi_submit_barrier(); + q2.ext_oneapi_submit_barrier({ev1}); + sycl::event ev2 = q2.ext_oneapi_submit_barrier(); + q1.ext_oneapi_submit_barrier({ev2}); + } +} + +int main() { + std::array threads; + + for (int i = 0; i < nthreads; i++) { + threads[i] = std::thread{threadFunction, i}; + } + + for (int i = 0; i < nthreads; i++) { + threads[i].join(); + } + std::cout << "All threads have finished." << std::endl; + + return 0; +}