diff --git a/source/adapters/opencl/enqueue.cpp b/source/adapters/opencl/enqueue.cpp index 5dff7066ae..ad6eaec88f 100644 --- a/source/adapters/opencl/enqueue.cpp +++ b/source/adapters/opencl/enqueue.cpp @@ -25,6 +25,77 @@ cl_map_flags convertURMapFlagsToCL(ur_map_flags_t URFlags) { return CLFlags; } +ur_result_t ValidateBufferSize(ur_mem_handle_t Buffer, size_t Size, + size_t Origin) { + size_t BufferSize = 0; + CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast(Buffer), + CL_MEM_SIZE, sizeof(BufferSize), + &BufferSize, nullptr)); + if (Size + Origin > BufferSize) + return UR_RESULT_ERROR_INVALID_SIZE; + return UR_RESULT_SUCCESS; +} + +ur_result_t ValidateBufferRectSize(ur_mem_handle_t Buffer, + ur_rect_region_t Region, + ur_rect_offset_t Offset) { + size_t BufferSize = 0; + CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast(Buffer), + CL_MEM_SIZE, sizeof(BufferSize), + &BufferSize, nullptr)); + if (Offset.x >= BufferSize || Offset.y >= BufferSize || + Offset.z >= BufferSize) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + + if ((Region.width + Offset.x) * (Region.height + Offset.y) * + (Region.depth + Offset.z) > + BufferSize) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + + return UR_RESULT_SUCCESS; +} + +ur_result_t ValidateImageSize(ur_mem_handle_t Image, ur_rect_region_t Region, + ur_rect_offset_t Origin) { + size_t Width = 0; + CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast(Image), + CL_IMAGE_WIDTH, sizeof(Width), &Width, + nullptr)); + if (Region.width + Origin.x > Width) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + + size_t Height = 0; + CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast(Image), + CL_IMAGE_HEIGHT, sizeof(Height), &Height, + nullptr)); + + // CL returns a height and depth of 0 for images that don't have those + // dimensions, but regions for enqueue operations must set these to 1, so we + // need to make this adjustment to validate. + if (Height == 0) + Height = 1; + + if (Region.height + Origin.y > Height) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + + size_t Depth = 0; + CL_RETURN_ON_FAILURE(clGetImageInfo(cl_adapter::cast(Image), + CL_IMAGE_DEPTH, sizeof(Depth), &Depth, + nullptr)); + if (Depth == 0) + Depth = 1; + + if (Region.depth + Origin.z > Depth) { + return UR_RESULT_ERROR_INVALID_SIZE; + } + + return UR_RESULT_SUCCESS; +} + UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch( ur_queue_handle_t hQueue, ur_kernel_handle_t hKernel, uint32_t workDim, const size_t *pGlobalWorkOffset, const size_t *pGlobalWorkSize, @@ -70,13 +141,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferRead( size_t offset, size_t size, void *pDst, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueReadBuffer( + auto ClErr = clEnqueueReadBuffer( cl_adapter::cast(hQueue), cl_adapter::cast(hBuffer), blockingRead, offset, size, pDst, numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( @@ -84,13 +158,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWrite( size_t offset, size_t size, const void *pSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueWriteBuffer( + auto ClErr = clEnqueueWriteBuffer( cl_adapter::cast(hQueue), cl_adapter::cast(hBuffer), blockingWrite, offset, size, pSrc, numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect( @@ -101,7 +178,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueReadBufferRect( + auto ClErr = clEnqueueReadBufferRect( cl_adapter::cast(hQueue), cl_adapter::cast(hBuffer), blockingRead, cl_adapter::cast(&bufferOrigin), @@ -109,9 +186,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferReadRect( cl_adapter::cast(®ion), bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, pDst, numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBuffer, region, bufferOrigin)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( @@ -122,7 +202,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueWriteBufferRect( + auto ClErr = clEnqueueWriteBufferRect( cl_adapter::cast(hQueue), cl_adapter::cast(hBuffer), blockingWrite, cl_adapter::cast(&bufferOrigin), @@ -130,9 +210,12 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferWriteRect( cl_adapter::cast(®ion), bufferRowPitch, bufferSlicePitch, hostRowPitch, hostSlicePitch, pSrc, numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBuffer, region, bufferOrigin)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( @@ -141,14 +224,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopy( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueCopyBuffer( + auto ClErr = clEnqueueCopyBuffer( cl_adapter::cast(hQueue), cl_adapter::cast(hBufferSrc), cl_adapter::cast(hBufferDst), srcOffset, dstOffset, size, numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateBufferSize(hBufferSrc, size, srcOffset)); + UR_RETURN_ON_FAILURE(ValidateBufferSize(hBufferDst, size, dstOffset)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( @@ -159,7 +246,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueCopyBufferRect( + auto ClErr = clEnqueueCopyBufferRect( cl_adapter::cast(hQueue), cl_adapter::cast(hBufferSrc), cl_adapter::cast(hBufferDst), @@ -168,9 +255,13 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect( cl_adapter::cast(®ion), srcRowPitch, srcSlicePitch, dstRowPitch, dstSlicePitch, numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBufferSrc, region, srcOrigin)); + UR_RETURN_ON_FAILURE(ValidateBufferRectSize(hBufferDst, region, dstOrigin)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill( @@ -181,13 +272,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill( // CL FillBuffer only allows pattern sizes up to the largest CL type: // long16/double16 if (patternSize <= 128) { - CL_RETURN_ON_FAILURE( - clEnqueueFillBuffer(cl_adapter::cast(hQueue), - cl_adapter::cast(hBuffer), pPattern, - patternSize, offset, size, numEventsInWaitList, - cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); - return UR_RESULT_SUCCESS; + auto ClErr = (clEnqueueFillBuffer( + cl_adapter::cast(hQueue), + cl_adapter::cast(hBuffer), pPattern, patternSize, offset, size, + numEventsInWaitList, + cl_adapter::cast(phEventWaitList), + cl_adapter::cast(phEvent))); + if (ClErr != CL_SUCCESS) { + UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset)); + } + return mapCLErrorToUR(ClErr); } auto NumValues = size / sizeof(uint64_t); @@ -205,6 +299,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferFill( &WriteEvent); if (ClErr != CL_SUCCESS) { delete[] HostBuffer; + UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, offset, size)); CL_RETURN_ON_FAILURE(ClErr); } @@ -237,15 +332,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageRead( size_t slicePitch, void *pDst, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueReadImage( + auto ClErr = clEnqueueReadImage( cl_adapter::cast(hQueue), cl_adapter::cast(hImage), blockingRead, cl_adapter::cast(&origin), cl_adapter::cast(®ion), rowPitch, slicePitch, pDst, numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateImageSize(hImage, region, origin)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite( @@ -254,15 +352,18 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageWrite( size_t slicePitch, void *pSrc, uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueWriteImage( + auto ClErr = clEnqueueWriteImage( cl_adapter::cast(hQueue), cl_adapter::cast(hImage), blockingWrite, cl_adapter::cast(&origin), cl_adapter::cast(®ion), rowPitch, slicePitch, pSrc, numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateImageSize(hImage, region, origin)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy( @@ -272,16 +373,20 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemImageCopy( uint32_t numEventsInWaitList, const ur_event_handle_t *phEventWaitList, ur_event_handle_t *phEvent) { - CL_RETURN_ON_FAILURE(clEnqueueCopyImage( + auto ClErr = clEnqueueCopyImage( cl_adapter::cast(hQueue), cl_adapter::cast(hImageSrc), cl_adapter::cast(hImageDst), cl_adapter::cast(&srcOrigin), cl_adapter::cast(&dstOrigin), cl_adapter::cast(®ion), numEventsInWaitList, cl_adapter::cast(phEventWaitList), - cl_adapter::cast(phEvent))); + cl_adapter::cast(phEvent)); - return UR_RESULT_SUCCESS; + if (ClErr == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateImageSize(hImageSrc, region, srcOrigin)); + UR_RETURN_ON_FAILURE(ValidateImageSize(hImageDst, region, dstOrigin)); + } + return mapCLErrorToUR(ClErr); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap( @@ -298,9 +403,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferMap( cl_adapter::cast(phEventWaitList), cl_adapter::cast(phEvent), &Err); - CL_RETURN_ON_FAILURE(Err); - - return UR_RESULT_SUCCESS; + if (Err == CL_INVALID_VALUE) { + UR_RETURN_ON_FAILURE(ValidateBufferSize(hBuffer, size, offset)); + } + return mapCLErrorToUR(Err); } UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemUnmap( diff --git a/source/adapters/opencl/memory.cpp b/source/adapters/opencl/memory.cpp index be9b266f3d..1a77754c57 100644 --- a/source/adapters/opencl/memory.cpp +++ b/source/adapters/opencl/memory.cpp @@ -319,9 +319,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urMemBufferPartition( *phMem = reinterpret_cast(clCreateSubBuffer( cl_adapter::cast(hBuffer), static_cast(flags), BufferCreateType, &BufferRegion, cl_adapter::cast(&RetErr))); - CL_RETURN_ON_FAILURE(RetErr); - return UR_RESULT_SUCCESS; + if (RetErr == CL_INVALID_VALUE) { + size_t BufferSize = 0; + CL_RETURN_ON_FAILURE(clGetMemObjectInfo(cl_adapter::cast(hBuffer), + CL_MEM_SIZE, sizeof(BufferSize), + &BufferSize, nullptr)); + if (BufferRegion.size + BufferRegion.origin > BufferSize) + return UR_RESULT_ERROR_INVALID_BUFFER_SIZE; + } + return mapCLErrorToUR(RetErr); } UR_APIEXPORT ur_result_t UR_APICALL