Skip to content

Commit

Permalink
[SYCL] Make queue fill use native functions (#12702)
Browse files Browse the repository at this point in the history
This PR changes the `queue.fill()` implementation to make use of the
native functions for a specific backend. It also unifies that
implementation with the one for memset, since it is just an 8-bit subset
operation of fill.

In the CUDA case, both memset and fill are currently calling
`urEnqueueUSMFill` which depending on the size of the filling pattern
calls either `cuMemsetD8Async`, `cuMemsetD16Async`, `cuMemsetD32Async`
or `commonMemSetLargePattern`. Before this patch memset was using the
same thing, just beforehand setting patternSize always to 1 byte which
resulted in calling `cuMemsetD8Async`. In other backends, the behaviour
is analogous.

The fill method was just invoking a `parallel_for` to fill the memory
with the pattern which was making this operation quite slow.

This PR depends on:
- oneapi-src/unified-runtime#1395
- oneapi-src/unified-runtime#1412
  • Loading branch information
konradkusiak97 authored May 2, 2024
1 parent 1a5595f commit 46e49ec
Show file tree
Hide file tree
Showing 41 changed files with 211 additions and 156 deletions.
2 changes: 2 additions & 0 deletions sycl/doc/design/CommandGraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ The types of commands which are unsupported, and lead to this exception are:
This corresponds to a memory buffer write command.
* `handler::copy(src, dest)` or `handler::memcpy(dest, src)` - Where both `src` and
`dest` are USM pointers. This corresponds to a USM copy command.
* `handler::fill(ptr, pattern, count)` - This corresponds to a USM memory
fill command.
* `handler::memset(ptr, value, numBytes)` - This corresponds to a USM memory
fill command.
* `handler::prefetch()`.
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class CGFillUSM : public CG {
MPattern(std::move(Pattern)), MDst(DstPtr), MLength(Length) {}
void *getDst() { return MDst; }
size_t getLength() { return MLength; }
int getFill() { return MPattern[0]; }
const std::vector<char> &getPattern() { return MPattern; }
};

/// "Prefetch USM" command group class.
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/pi.def
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ _PI_API(piextUSMHostAlloc)
_PI_API(piextUSMDeviceAlloc)
_PI_API(piextUSMSharedAlloc)
_PI_API(piextUSMFree)
_PI_API(piextUSMEnqueueMemset)
_PI_API(piextUSMEnqueueFill)
_PI_API(piextUSMEnqueueMemcpy)
_PI_API(piextUSMEnqueuePrefetch)
_PI_API(piextUSMEnqueueMemAdvise)
Expand Down
27 changes: 14 additions & 13 deletions sycl/include/sycl/detail/pi.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,10 @@
// - PI_EXT_ONEAPI_DEVICE_INFO_BINDLESS_SAMPLED_IMAGE_FETCH_2D
// - PI_EXT_ONEAPI_DEVICE_INFO_BINDLESS_SAMPLED_IMAGE_FETCH_3D_USM
// - PI_EXT_ONEAPI_DEVICE_INFO_BINDLESS_SAMPLED_IMAGE_FETCH_3D
// 16.51 Replaced piextUSMEnqueueMemset with piextUSMEnqueueFill

#define _PI_H_VERSION_MAJOR 15
#define _PI_H_VERSION_MINOR 50
#define _PI_H_VERSION_MAJOR 16
#define _PI_H_VERSION_MINOR 51

#define _PI_STRING_HELPER(a) #a
#define _PI_CONCAT(a, b) _PI_STRING_HELPER(a.b)
Expand Down Expand Up @@ -2060,22 +2061,22 @@ __SYCL_EXPORT pi_result piextUSMPitchedAlloc(
/// \param ptr is the memory to be freed
__SYCL_EXPORT pi_result piextUSMFree(pi_context context, void *ptr);

/// USM Memset API
/// USM Fill API
///
/// \param queue is the queue to submit to
/// \param ptr is the ptr to memset
/// \param value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// \param count is the size in bytes to memset
/// \param ptr is the ptr to fill
/// \param pattern is the ptr with the bytes of the pattern to set
/// \param patternSize is the size in bytes of the pattern to set
/// \param count is the size in bytes to fill
/// \param num_events_in_waitlist is the number of events to wait on
/// \param events_waitlist is an array of events to wait on
/// \param event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueMemset(pi_queue queue, void *ptr,
pi_int32 value, size_t count,
pi_uint32 num_events_in_waitlist,
const pi_event *events_waitlist,
pi_event *event);
__SYCL_EXPORT pi_result piextUSMEnqueueFill(pi_queue queue, void *ptr,
const void *pattern,
size_t patternSize, size_t count,
pi_uint32 num_events_in_waitlist,
const pi_event *events_waitlist,
pi_event *event);

/// USM Memcpy API
///
Expand Down
10 changes: 4 additions & 6 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2870,14 +2870,9 @@ class __SYCL_EXPORT handler {
/// device copyable.
/// \param Count is the number of times to fill Pattern into Ptr.
template <typename T> void fill(void *Ptr, const T &Pattern, size_t Count) {
throwIfActionIsCreated();
setUserFacingNodeType(ext::oneapi::experimental::node_type::memfill);
static_assert(is_device_copyable<T>::value,
"Pattern must be device copyable");
parallel_for<__usmfill<T>>(range<1>(Count), [=](id<1> Index) {
T *CastedPtr = static_cast<T *>(Ptr);
CastedPtr[Index] = Pattern;
});
this->fill_impl(Ptr, &Pattern, sizeof(T), Count);
}

/// Prevents any commands submitted afterward to this queue from executing
Expand Down Expand Up @@ -3577,6 +3572,9 @@ class __SYCL_EXPORT handler {
});
}

// Implementation of USM fill using command for native fill.
void fill_impl(void *Dest, const void *Value, size_t ValueSize, size_t Count);

// Implementation of ext_oneapi_memcpy2d using command for native 2D memcpy.
void ext_oneapi_memcpy2d_impl(void *Dest, size_t DestPitch, const void *Src,
size_t SrcPitch, size_t Width, size_t Height);
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/cuda/pi_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,12 +902,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/hip/pi_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
23 changes: 11 additions & 12 deletions sycl/plugins/level_zero/pi_level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -929,23 +929,22 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

/// USM Memset API
/// USM Fill API
///
/// @param Queue is the queue to submit to
/// @param Ptr is the ptr to memset
/// @param Value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// @param Count is the size in bytes to memset
/// @param Ptr is the ptr to fill
/// \param Pattern is the ptr with the bytes of the pattern to set
/// \param PatternSize is the size in bytes of the pattern to set
/// @param Count is the size in bytes to fill
/// @param NumEventsInWaitlist is the number of events to wait on
/// @param EventsWaitlist is an array of events to wait on
/// @param Event is the event that represents this operation
pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/native_cpu/pi_native_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -864,12 +864,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
14 changes: 7 additions & 7 deletions sycl/plugins/unified_runtime/pi2ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3889,11 +3889,12 @@ inline pi_result piEnqueueMemBufferFill(pi_queue Queue, pi_mem Buffer,
return PI_SUCCESS;
}

inline pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,
pi_int32 Value, size_t Count,
pi_uint32 NumEventsInWaitList,
const pi_event *EventsWaitList,
pi_event *OutEvent) {
inline pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr,
const void *Pattern, size_t PatternSize,
size_t Count,
pi_uint32 NumEventsInWaitList,
const pi_event *EventsWaitList,
pi_event *OutEvent) {
PI_ASSERT(Queue, PI_ERROR_INVALID_QUEUE);
if (!Ptr) {
return PI_ERROR_INVALID_VALUE;
Expand All @@ -3905,8 +3906,7 @@ inline pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,

ur_event_handle_t *UREvent = reinterpret_cast<ur_event_handle_t *>(OutEvent);

size_t PatternSize = 1;
HANDLE_ERRORS(urEnqueueUSMFill(UrQueue, Ptr, PatternSize, &Value, Count,
HANDLE_ERRORS(urEnqueueUSMFill(UrQueue, Ptr, PatternSize, Pattern, Count,
NumEventsInWaitList, UrEventsWaitList,
UREvent));

Expand Down
36 changes: 18 additions & 18 deletions sycl/plugins/unified_runtime/pi_unified_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,24 +437,24 @@ __SYCL_EXPORT pi_result piQueueGetInfo(pi_queue Queue, pi_queue_info ParamName,
ParamValueSizeRet);
}

/// USM Memset API
/// USM Fill API
///
/// @param Queue is the queue to submit to
/// @param Ptr is the ptr to memset
/// @param Value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// @param Count is the size in bytes to memset
/// @param NumEventsInWaitlist is the number of events to wait on
/// @param EventsWaitlist is an array of events to wait on
/// @param Event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,
pi_int32 Value, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
/// \param queue is the queue to submit to
/// \param ptr is the ptr to fill
/// \param pattern is the ptr with the bytes of the pattern to set
/// \param patternSize is the size in bytes of the pattern to set
/// \param count is the size in bytes to fill
/// \param num_events_in_waitlist is the number of events to wait on
/// \param events_waitlist is an array of events to wait on
/// \param event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr,
const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

__SYCL_EXPORT pi_result piEnqueueMemBufferCopyRect(
Expand Down Expand Up @@ -1490,7 +1490,7 @@ __SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
_PI_API(piEnqueueMemBufferMap)
_PI_API(piEnqueueMemUnmap)
_PI_API(piEnqueueMemBufferFill)
_PI_API(piextUSMEnqueueMemset)
_PI_API(piextUSMEnqueueFill)
_PI_API(piEnqueueMemBufferCopyRect)
_PI_API(piEnqueueMemBufferCopy)
_PI_API(piextUSMEnqueueMemcpy)
Expand Down
6 changes: 4 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,10 @@ class node_impl {
sycl::detail::CGFillUSM *FillUSM =
static_cast<sycl::detail::CGFillUSM *>(MCommandGroup.get());
Stream << "Dst: " << FillUSM->getDst()
<< " Length: " << FillUSM->getLength()
<< " Pattern: " << FillUSM->getFill() << "\\n";
<< " Length: " << FillUSM->getLength() << " Pattern: ";
for (auto byte : FillUSM->getPattern())
Stream << byte;
Stream << "\\n";
}
break;
case sycl::detail::CG::CGTYPE::PrefetchUSM:
Expand Down
34 changes: 24 additions & 10 deletions sycl/source/detail/memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,7 +947,7 @@ void MemoryManager::copy_usm(const void *SrcMem, QueueImplPtr SrcQueue,
sycl::detail::pi::PiEvent *OutEvent,
const detail::EventImplPtr &OutEventImpl) {
assert(!SrcQueue->getContextImplPtr()->is_host() &&
"Host queue not supported in fill_usm.");
"Host queue not supported in copy_usm.");

if (!Len) { // no-op, but ensure DepEvents will still be waited on
if (!DepEvents.empty()) {
Expand Down Expand Up @@ -983,7 +983,7 @@ void MemoryManager::copy_usm(const void *SrcMem, QueueImplPtr SrcQueue,
}

void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
int Pattern,
const std::vector<char> &Pattern,
std::vector<sycl::detail::pi::PiEvent> DepEvents,
sycl::detail::pi::PiEvent *OutEvent,
const detail::EventImplPtr &OutEventImpl) {
Expand All @@ -1006,17 +1006,31 @@ void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
if (OutEventImpl != nullptr)
OutEventImpl->setHostEnqueueTime();
const PluginPtr &Plugin = Queue->getPlugin();
Plugin->call<PiApiKind::piextUSMEnqueueMemset>(
Queue->getHandleRef(), Mem, Pattern, Length, DepEvents.size(),
DepEvents.data(), OutEvent);
Plugin->call<PiApiKind::piextUSMEnqueueFill>(
Queue->getHandleRef(), Mem, Pattern.data(), Pattern.size(), Length,
DepEvents.size(), DepEvents.data(), OutEvent);
}

// TODO: This function will remain until ABI-breaking change
void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
int Pattern,
std::vector<sycl::detail::pi::PiEvent> DepEvents,
sycl::detail::pi::PiEvent *OutEvent,
const detail::EventImplPtr &OutEventImpl) {
std::vector<char> vecPattern(sizeof(Pattern));
std::memcpy(vecPattern.data(), &Pattern, sizeof(Pattern));
MemoryManager::fill_usm(Mem, Queue, Length, vecPattern, DepEvents, OutEvent,
OutEventImpl);
}

// TODO: This function will remain until ABI-breaking change
void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
int Pattern,
std::vector<sycl::detail::pi::PiEvent> DepEvents,
sycl::detail::pi::PiEvent *OutEvent) {
MemoryManager::fill_usm(Mem, Queue, Length, Pattern, DepEvents, OutEvent,
std::vector<char> vecPattern(sizeof(Pattern));
std::memcpy(vecPattern.data(), &Pattern, sizeof(Pattern));
MemoryManager::fill_usm(Mem, Queue, Length, vecPattern, DepEvents, OutEvent,
nullptr); // OutEventImpl);
}

Expand Down Expand Up @@ -1682,18 +1696,18 @@ void MemoryManager::ext_oneapi_copy_usm_cmd_buffer(
void MemoryManager::ext_oneapi_fill_usm_cmd_buffer(
sycl::detail::ContextImplPtr Context,
sycl::detail::pi::PiExtCommandBuffer CommandBuffer, void *DstMem,
size_t Len, int Pattern, std::vector<sycl::detail::pi::PiExtSyncPoint> Deps,
size_t Len, const std::vector<char> &Pattern,
std::vector<sycl::detail::pi::PiExtSyncPoint> Deps,
sycl::detail::pi::PiExtSyncPoint *OutSyncPoint) {

if (!DstMem)
throw runtime_error("NULL pointer argument in memory fill operation.",
PI_ERROR_INVALID_VALUE);

const PluginPtr &Plugin = Context->getPlugin();
// Pattern is interpreted as an unsigned char so pattern size is always 1.
size_t PatternSize = 1;

Plugin->call<PiApiKind::piextCommandBufferFillUSM>(
CommandBuffer, DstMem, &Pattern, PatternSize, Len, Deps.size(),
CommandBuffer, DstMem, Pattern.data(), Pattern.size(), Len, Deps.size(),
Deps.data(), OutSyncPoint);
}

Expand Down
Loading

0 comments on commit 46e49ec

Please sign in to comment.