Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Make queue fill use native functions #12702

Merged
merged 12 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sycl/doc/design/CommandGraph.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,8 @@ The types of commands which are unsupported, and lead to this exception are:
This corresponds to a memory buffer write command.
* `handler::copy(src, dest)` or `handler::memcpy(dest, src)` - Where both `src` and
`dest` are USM pointers. This corresponds to a USM copy command.
* `handler::fill(ptr, pattern, count)` - This corresponds to a USM memory
fill command.
* `handler::memset(ptr, value, numBytes)` - This corresponds to a USM memory
fill command.
* `handler::prefetch()`.
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/cg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ class CGFillUSM : public CG {
MPattern(std::move(Pattern)), MDst(DstPtr), MLength(Length) {}
void *getDst() { return MDst; }
size_t getLength() { return MLength; }
int getFill() { return MPattern[0]; }
const std::vector<char> &getPattern() { return MPattern; }
};

/// "Prefetch USM" command group class.
Expand Down
2 changes: 1 addition & 1 deletion sycl/include/sycl/detail/pi.def
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ _PI_API(piextUSMHostAlloc)
_PI_API(piextUSMDeviceAlloc)
_PI_API(piextUSMSharedAlloc)
_PI_API(piextUSMFree)
_PI_API(piextUSMEnqueueMemset)
_PI_API(piextUSMEnqueueFill)
_PI_API(piextUSMEnqueueMemcpy)
_PI_API(piextUSMEnqueuePrefetch)
_PI_API(piextUSMEnqueueMemAdvise)
Expand Down
27 changes: 14 additions & 13 deletions sycl/include/sycl/detail/pi.h
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,10 @@
// - PI_EXT_ONEAPI_DEVICE_INFO_BINDLESS_SAMPLED_IMAGE_FETCH_2D
// - PI_EXT_ONEAPI_DEVICE_INFO_BINDLESS_SAMPLED_IMAGE_FETCH_3D_USM
// - PI_EXT_ONEAPI_DEVICE_INFO_BINDLESS_SAMPLED_IMAGE_FETCH_3D
// 16.51 Replaced piextUSMEnqueueMemset with piextUSMEnqueueFill

#define _PI_H_VERSION_MAJOR 15
#define _PI_H_VERSION_MINOR 50
#define _PI_H_VERSION_MAJOR 16
#define _PI_H_VERSION_MINOR 51

#define _PI_STRING_HELPER(a) #a
#define _PI_CONCAT(a, b) _PI_STRING_HELPER(a.b)
Expand Down Expand Up @@ -2060,22 +2061,22 @@ __SYCL_EXPORT pi_result piextUSMPitchedAlloc(
/// \param ptr is the memory to be freed
__SYCL_EXPORT pi_result piextUSMFree(pi_context context, void *ptr);

/// USM Memset API
/// USM Fill API
///
/// \param queue is the queue to submit to
/// \param ptr is the ptr to memset
/// \param value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// \param count is the size in bytes to memset
/// \param ptr is the ptr to fill
/// \param pattern is the ptr with the bytes of the pattern to set
/// \param patternSize is the size in bytes of the pattern to set
/// \param count is the size in bytes to fill
/// \param num_events_in_waitlist is the number of events to wait on
/// \param events_waitlist is an array of events to wait on
/// \param event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueMemset(pi_queue queue, void *ptr,
pi_int32 value, size_t count,
pi_uint32 num_events_in_waitlist,
const pi_event *events_waitlist,
pi_event *event);
__SYCL_EXPORT pi_result piextUSMEnqueueFill(pi_queue queue, void *ptr,
const void *pattern,
size_t patternSize, size_t count,
pi_uint32 num_events_in_waitlist,
const pi_event *events_waitlist,
pi_event *event);

/// USM Memcpy API
///
Expand Down
10 changes: 4 additions & 6 deletions sycl/include/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2867,14 +2867,9 @@ class __SYCL_EXPORT handler {
/// device copyable.
/// \param Count is the number of times to fill Pattern into Ptr.
template <typename T> void fill(void *Ptr, const T &Pattern, size_t Count) {
throwIfActionIsCreated();
setUserFacingNodeType(ext::oneapi::experimental::node_type::memfill);
static_assert(is_device_copyable<T>::value,
"Pattern must be device copyable");
parallel_for<__usmfill<T>>(range<1>(Count), [=](id<1> Index) {
T *CastedPtr = static_cast<T *>(Ptr);
CastedPtr[Index] = Pattern;
});
this->fill_impl(Ptr, &Pattern, sizeof(T), Count);
EwanC marked this conversation as resolved.
Show resolved Hide resolved
}

/// Prevents any commands submitted afterward to this queue from executing
Expand Down Expand Up @@ -3574,6 +3569,9 @@ class __SYCL_EXPORT handler {
});
}

// Implementation of USM fill using command for native fill.
void fill_impl(void *Dest, const void *Value, size_t ValueSize, size_t Count);

// Implementation of ext_oneapi_memcpy2d using command for native 2D memcpy.
void ext_oneapi_memcpy2d_impl(void *Dest, size_t DestPitch, const void *Src,
size_t SrcPitch, size_t Width, size_t Height);
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/cuda/pi_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -902,12 +902,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/hip/pi_hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
23 changes: 11 additions & 12 deletions sycl/plugins/level_zero/pi_level_zero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -929,23 +929,22 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

/// USM Memset API
/// USM Fill API
///
/// @param Queue is the queue to submit to
/// @param Ptr is the ptr to memset
/// @param Value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// @param Count is the size in bytes to memset
/// @param Ptr is the ptr to fill
/// \param Pattern is the ptr with the bytes of the pattern to set
/// \param PatternSize is the size in bytes of the pattern to set
/// @param Count is the size in bytes to fill
/// @param NumEventsInWaitlist is the number of events to wait on
/// @param EventsWaitlist is an array of events to wait on
/// @param Event is the event that represents this operation
pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/native_cpu/pi_native_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -905,12 +905,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
12 changes: 6 additions & 6 deletions sycl/plugins/opencl/pi_opencl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -864,12 +864,12 @@ pi_result piextKernelSetArgPointer(pi_kernel Kernel, pi_uint32 ArgIndex,
return pi2ur::piextKernelSetArgPointer(Kernel, ArgIndex, ArgSize, ArgValue);
}

pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr, pi_int32 Value,
size_t Count, pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr, const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist, pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

pi_result piextUSMEnqueueMemcpy(pi_queue Queue, pi_bool Blocking, void *DstPtr,
Expand Down
14 changes: 7 additions & 7 deletions sycl/plugins/unified_runtime/pi2ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3889,11 +3889,12 @@ inline pi_result piEnqueueMemBufferFill(pi_queue Queue, pi_mem Buffer,
return PI_SUCCESS;
}

inline pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,
pi_int32 Value, size_t Count,
pi_uint32 NumEventsInWaitList,
const pi_event *EventsWaitList,
pi_event *OutEvent) {
inline pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr,
const void *Pattern, size_t PatternSize,
size_t Count,
pi_uint32 NumEventsInWaitList,
const pi_event *EventsWaitList,
pi_event *OutEvent) {
PI_ASSERT(Queue, PI_ERROR_INVALID_QUEUE);
if (!Ptr) {
return PI_ERROR_INVALID_VALUE;
Expand All @@ -3905,8 +3906,7 @@ inline pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,

ur_event_handle_t *UREvent = reinterpret_cast<ur_event_handle_t *>(OutEvent);

size_t PatternSize = 1;
HANDLE_ERRORS(urEnqueueUSMFill(UrQueue, Ptr, PatternSize, &Value, Count,
HANDLE_ERRORS(urEnqueueUSMFill(UrQueue, Ptr, PatternSize, Pattern, Count,
NumEventsInWaitList, UrEventsWaitList,
UREvent));

Expand Down
36 changes: 18 additions & 18 deletions sycl/plugins/unified_runtime/pi_unified_runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -437,24 +437,24 @@ __SYCL_EXPORT pi_result piQueueGetInfo(pi_queue Queue, pi_queue_info ParamName,
ParamValueSizeRet);
}

/// USM Memset API
/// USM Fill API
///
/// @param Queue is the queue to submit to
/// @param Ptr is the ptr to memset
/// @param Value is value to set. It is interpreted as an 8-bit value and the
/// upper
/// 24 bits are ignored
/// @param Count is the size in bytes to memset
/// @param NumEventsInWaitlist is the number of events to wait on
/// @param EventsWaitlist is an array of events to wait on
/// @param Event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueMemset(pi_queue Queue, void *Ptr,
pi_int32 Value, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueMemset(
Queue, Ptr, Value, Count, NumEventsInWaitlist, EventsWaitlist, Event);
/// \param queue is the queue to submit to
/// \param ptr is the ptr to fill
/// \param pattern is the ptr with the bytes of the pattern to set
/// \param patternSize is the size in bytes of the pattern to set
/// \param count is the size in bytes to fill
/// \param num_events_in_waitlist is the number of events to wait on
/// \param events_waitlist is an array of events to wait on
/// \param event is the event that represents this operation
__SYCL_EXPORT pi_result piextUSMEnqueueFill(pi_queue Queue, void *Ptr,
const void *Pattern,
size_t PatternSize, size_t Count,
pi_uint32 NumEventsInWaitlist,
const pi_event *EventsWaitlist,
pi_event *Event) {
return pi2ur::piextUSMEnqueueFill(Queue, Ptr, Pattern, PatternSize, Count,
NumEventsInWaitlist, EventsWaitlist, Event);
}

__SYCL_EXPORT pi_result piEnqueueMemBufferCopyRect(
Expand Down Expand Up @@ -1490,7 +1490,7 @@ __SYCL_EXPORT pi_result piPluginInit(pi_plugin *PluginInit) {
_PI_API(piEnqueueMemBufferMap)
_PI_API(piEnqueueMemUnmap)
_PI_API(piEnqueueMemBufferFill)
_PI_API(piextUSMEnqueueMemset)
_PI_API(piextUSMEnqueueFill)
_PI_API(piEnqueueMemBufferCopyRect)
_PI_API(piEnqueueMemBufferCopy)
_PI_API(piextUSMEnqueueMemcpy)
Expand Down
6 changes: 4 additions & 2 deletions sycl/source/detail/graph_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,10 @@ class node_impl {
sycl::detail::CGFillUSM *FillUSM =
static_cast<sycl::detail::CGFillUSM *>(MCommandGroup.get());
Stream << "Dst: " << FillUSM->getDst()
<< " Length: " << FillUSM->getLength()
<< " Pattern: " << FillUSM->getFill() << "\\n";
<< " Length: " << FillUSM->getLength() << " Pattern: ";
for (auto byte : FillUSM->getPattern())
Stream << byte;
Stream << "\\n";
}
break;
case sycl::detail::CG::CGTYPE::PrefetchUSM:
Expand Down
34 changes: 24 additions & 10 deletions sycl/source/detail/memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,7 @@ void MemoryManager::copy_usm(const void *SrcMem, QueueImplPtr SrcQueue,
sycl::detail::pi::PiEvent *OutEvent,
const detail::EventImplPtr &OutEventImpl) {
assert(!SrcQueue->getContextImplPtr()->is_host() &&
"Host queue not supported in fill_usm.");
"Host queue not supported in copy_usm.");

if (!Len) { // no-op, but ensure DepEvents will still be waited on
if (!DepEvents.empty()) {
Expand Down Expand Up @@ -981,7 +981,7 @@ void MemoryManager::copy_usm(const void *SrcMem, QueueImplPtr SrcQueue,
}

void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
int Pattern,
const std::vector<char> &Pattern,
std::vector<sycl::detail::pi::PiEvent> DepEvents,
sycl::detail::pi::PiEvent *OutEvent,
const detail::EventImplPtr &OutEventImpl) {
Expand All @@ -1004,17 +1004,31 @@ void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
if (OutEventImpl != nullptr)
OutEventImpl->setHostEnqueueTime();
const PluginPtr &Plugin = Queue->getPlugin();
Plugin->call<PiApiKind::piextUSMEnqueueMemset>(
Queue->getHandleRef(), Mem, Pattern, Length, DepEvents.size(),
DepEvents.data(), OutEvent);
Plugin->call<PiApiKind::piextUSMEnqueueFill>(
Queue->getHandleRef(), Mem, Pattern.data(), Pattern.size(), Length,
DepEvents.size(), DepEvents.data(), OutEvent);
}

// TODO: This function will remain until ABI-breaking change
void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
int Pattern,
std::vector<sycl::detail::pi::PiEvent> DepEvents,
sycl::detail::pi::PiEvent *OutEvent,
const detail::EventImplPtr &OutEventImpl) {
std::vector<char> vecPattern(sizeof(Pattern));
std::memcpy(vecPattern.data(), &Pattern, sizeof(Pattern));
MemoryManager::fill_usm(Mem, Queue, Length, vecPattern, DepEvents, OutEvent,
OutEventImpl);
}

// TODO: This function will remain until ABI-breaking change
void MemoryManager::fill_usm(void *Mem, QueueImplPtr Queue, size_t Length,
int Pattern,
std::vector<sycl::detail::pi::PiEvent> DepEvents,
sycl::detail::pi::PiEvent *OutEvent) {
MemoryManager::fill_usm(Mem, Queue, Length, Pattern, DepEvents, OutEvent,
std::vector<char> vecPattern(sizeof(Pattern));
std::memcpy(vecPattern.data(), &Pattern, sizeof(Pattern));
MemoryManager::fill_usm(Mem, Queue, Length, vecPattern, DepEvents, OutEvent,
nullptr); // OutEventImpl);
}

Expand Down Expand Up @@ -1680,18 +1694,18 @@ void MemoryManager::ext_oneapi_copy_usm_cmd_buffer(
void MemoryManager::ext_oneapi_fill_usm_cmd_buffer(
sycl::detail::ContextImplPtr Context,
sycl::detail::pi::PiExtCommandBuffer CommandBuffer, void *DstMem,
size_t Len, int Pattern, std::vector<sycl::detail::pi::PiExtSyncPoint> Deps,
size_t Len, const std::vector<char> &Pattern,
std::vector<sycl::detail::pi::PiExtSyncPoint> Deps,
sycl::detail::pi::PiExtSyncPoint *OutSyncPoint) {

if (!DstMem)
throw runtime_error("NULL pointer argument in memory fill operation.",
PI_ERROR_INVALID_VALUE);

const PluginPtr &Plugin = Context->getPlugin();
// Pattern is interpreted as an unsigned char so pattern size is always 1.
size_t PatternSize = 1;

Plugin->call<PiApiKind::piextCommandBufferFillUSM>(
CommandBuffer, DstMem, &Pattern, PatternSize, Len, Deps.size(),
CommandBuffer, DstMem, Pattern.data(), Pattern.size(), Len, Deps.size(),
Deps.data(), OutSyncPoint);
}

Expand Down
Loading
Loading