Skip to content

Commit

Permalink
[OpenCL] Add bounds checking to the Enqueue memory operations.
Browse files Browse the repository at this point in the history
This allows us to return UR_ERROR_INVALID_SIZE when we should. Extra
checks are only performed on a non-success error code.
  • Loading branch information
aarongreig committed Nov 7, 2023
1 parent 39eec0c commit 27349dc
Showing 1 changed file with 143 additions and 37 deletions.
180 changes: 143 additions & 37 deletions source/adapters/opencl/enqueue.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,77 @@ cl_map_flags convertURMapFlagsToCL(ur_map_flags_t URFlags) {
return CLFlags;
}

ur_result_t ValidateBufferSize(ur_mem_handle_t Buffer, size_t Size,
size_t Origin) {
size_t BufferSize = 0;
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast<cl_mem>(Buffer),
CL_MEM_SIZE, sizeof(BufferSize),
&BufferSize, nullptr));
if (Size + Origin > BufferSize)
return UR_RESULT_ERROR_INVALID_SIZE;
return UR_RESULT_SUCCESS;
}

ur_result_t ValidateBufferRectSize(ur_mem_handle_t Buffer,
ur_rect_region_t Region,
ur_rect_offset_t Offset) {
size_t BufferSize = 0;
CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast<cl_mem>(Buffer),
CL_MEM_SIZE, sizeof(BufferSize),
&BufferSize, nullptr));
if (Offset.x >= BufferSize || Offset.y >= BufferSize ||
Offset.z >= BufferSize) {
return UR_RESULT_ERROR_INVALID_SIZE;
}

if ((Region.width + Offset.x) * (Region.height + Offset.y) *
(Region.depth + Offset.z) >
BufferSize) {
return UR_RESULT_ERROR_INVALID_SIZE;
}

return UR_RESULT_SUCCESS;
}

ur_result_t ValidateImageSize(ur_mem_handle_t Image, ur_rect_region_t Region,
ur_rect_offset_t Origin) {
size_t Width = 0;
CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast<cl_mem>(Image),
CL_IMAGE_WIDTH, sizeof(Width), &Width,
nullptr));
if (Region.width + Origin.x > Width) {
return UR_RESULT_ERROR_INVALID_SIZE;
}

size_t Height = 0;
CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast<cl_mem>(Image),
CL_IMAGE_HEIGHT, sizeof(Height), &Height,
nullptr));

// CL returns a height and depth of 0 for images that don't have those
// dimensions, but regions for enqueue operations must set these to 1, so we
// need to make this adjustment to validate.
if (Height == 0)
Height = 1;

if (Region.height + Origin.y > Height) {
return UR_RESULT_ERROR_INVALID_SIZE;
}

size_t Depth = 0;
CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast<cl_mem>(Image),
CL_IMAGE_DEPTH, sizeof(Depth), &Depth,
nullptr));
if (Depth == 0)
Depth = 1;

if (Region.depth + Origin.z > Depth) {
return UR_RESULT_ERROR_INVALID_SIZE;
}

return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim,
const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize,
Expand Down Expand Up @@ -70,27 +141,33 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead(
size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueReadBuffer(
auto ClErr = clEnqueueReadBuffer(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hBuffer), blockingRead, offset, size, pDst,
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite(
ur_queue_handle_t hQueue, ur_mem_handle_t hBuffer, bool blockingWrite,
size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueWriteBuffer(
auto ClErr = clEnqueueWriteBuffer(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hBuffer), blockingWrite, offset, size, pSrc,
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
Expand All @@ -101,17 +178,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect(
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueReadBufferRect(
auto ClErr = clEnqueueReadBufferRect(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hBuffer), blockingRead,
cl_adapter::cast<const size_t *>(&bufferOrigin),
cl_adapter::cast<const size_t *>(&hostOrigin),
cl_adapter::cast<const size_t *>(&region), bufferRowPitch,
bufferSlicePitch, hostRowPitch, hostSlicePitch, pDst, numEventsInWaitList,
cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBuffer, region, bufferOrigin));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
Expand All @@ -122,17 +202,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect(
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueWriteBufferRect(
auto ClErr = clEnqueueWriteBufferRect(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hBuffer), blockingWrite,
cl_adapter::cast<const size_t *>(&bufferOrigin),
cl_adapter::cast<const size_t *>(&hostOrigin),
cl_adapter::cast<const size_t *>(&region), bufferRowPitch,
bufferSlicePitch, hostRowPitch, hostSlicePitch, pSrc, numEventsInWaitList,
cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBuffer, region, bufferOrigin));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
Expand All @@ -141,14 +224,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy(
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueCopyBuffer(
auto ClErr = clEnqueueCopyBuffer(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hBufferSrc),
cl_adapter::cast<cl_mem>(hBufferDst), srcOffset, dstOffset, size,
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBufferSrc, size, srcOffset));
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBufferDst, size, dstOffset));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
Expand All @@ -159,7 +246,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueCopyBufferRect(
auto ClErr = clEnqueueCopyBufferRect(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hBufferSrc),
cl_adapter::cast<cl_mem>(hBufferDst),
Expand All @@ -168,9 +255,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
cl_adapter::cast<const size_t *>(&region), srcRowPitch, srcSlicePitch,
dstRowPitch, dstSlicePitch, numEventsInWaitList,
cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBufferSrc, region, srcOrigin));
UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBufferDst, region, dstOrigin));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
Expand All @@ -181,13 +272,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
// CL FillBuffer only allows pattern sizes up to the largest CL type:
// long16/double16
if (patternSize <= 128) {
CL_RETURN_ON_FAILURE(
clEnqueueFillBuffer(cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hBuffer), pPattern,
patternSize, offset, size, numEventsInWaitList,
cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
return UR_RESULT_SUCCESS;
auto ClErr = (clEnqueueFillBuffer(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hBuffer), pPattern, patternSize, offset, size,
numEventsInWaitList,
cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
if (ClErr != CL_SUCCESS) {
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset));
}
return mapCLErrorToUR(ClErr);
}

auto NumValues = size / sizeof(uint64_t);
Expand All @@ -205,6 +299,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill(
&WriteEvent);
if (ClErr != CL_SUCCESS) {
delete[] HostBuffer;
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, offset, size));
CL_RETURN_ON_FAILURE(ClErr);
}

Expand Down Expand Up @@ -237,15 +332,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead(
size_t slicePitch, void *pDst, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueReadImage(
auto ClErr = clEnqueueReadImage(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hImage), blockingRead,
cl_adapter::cast<const size_t *>(&origin),
cl_adapter::cast<const size_t *>(&region), rowPitch, slicePitch, pDst,
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateImageSize(hImage, region, origin));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
Expand All @@ -254,15 +352,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite(
size_t slicePitch, void *pSrc, uint32_t numEventsInWaitList,
const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueWriteImage(
auto ClErr = clEnqueueWriteImage(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hImage), blockingWrite,
cl_adapter::cast<const size_t *>(&origin),
cl_adapter::cast<const size_t *>(&region), rowPitch, slicePitch, pSrc,
numEventsInWaitList, cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateImageSize(hImage, region, origin));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
Expand All @@ -272,16 +373,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy(
uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList,
ur_event_handle_t *phEvent) {

CL_RETURN_ON_FAILURE(clEnqueueCopyImage(
auto ClErr = clEnqueueCopyImage(
cl_adapter::cast<cl_command_queue>(hQueue),
cl_adapter::cast<cl_mem>(hImageSrc), cl_adapter::cast<cl_mem>(hImageDst),
cl_adapter::cast<const size_t *>(&srcOrigin),
cl_adapter::cast<const size_t *>(&dstOrigin),
cl_adapter::cast<const size_t *>(&region), numEventsInWaitList,
cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent)));
cl_adapter::cast<cl_event *>(phEvent));

return UR_RESULT_SUCCESS;
if (ClErr == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateImageSize(hImageSrc, region, srcOrigin));
UR_RETURN_ON_FAILURE(ValidateImageSize(hImageDst, region, dstOrigin));
}
return mapCLErrorToUR(ClErr);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
Expand All @@ -298,9 +403,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap(
cl_adapter::cast<const cl_event *>(phEventWaitList),
cl_adapter::cast<cl_event *>(phEvent), &Err);

CL_RETURN_ON_FAILURE(Err);

return UR_RESULT_SUCCESS;
if (Err == CL_INVALID_VALUE) {
UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset));
}
return mapCLErrorToUR(Err);
}

UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap(
Expand Down

0 comments on commit 27349dc

Please sign in to comment.