Skip to content

Commit

Permalink
Add bounds checking for enqueue operations to the validation layer.
Browse files Browse the repository at this point in the history
This is accomplished with the various size queries for buffers, images
and USM allocations. Since not all adapters have these queries
implemented the bounds checking isn't entirely comprehensive on all
platforms just yet.
  • Loading branch information
aarongreig committed Nov 17, 2023
1 parent 534071e commit 8a548cd
Show file tree
Hide file tree
Showing 15 changed files with 433 additions and 164 deletions.
6 changes: 3 additions & 3 deletions include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -6510,7 +6510,7 @@ urEnqueueMemUnmap(
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hQueue`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == ptr`
/// + `NULL == pMem`
/// + `NULL == pPattern`
/// - ::UR_RESULT_ERROR_INVALID_QUEUE
/// - ::UR_RESULT_ERROR_INVALID_EVENT
Expand All @@ -6530,7 +6530,7 @@ urEnqueueMemUnmap(
UR_APIEXPORT ur_result_t UR_APICALL
urEnqueueUSMFill(
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
void *ptr, ///< [in] pointer to USM memory object
void *pMem, ///< [in] pointer to USM memory object
size_t patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less
///< than or equal to width.
const void *pPattern, ///< [in] pointer with the bytes of the pattern to set.
Expand Down Expand Up @@ -9246,7 +9246,7 @@ typedef struct ur_enqueue_mem_unmap_params_t {
/// allowing the callback the ability to modify the parameter's value
typedef struct ur_enqueue_usm_fill_params_t {
ur_queue_handle_t *phQueue;
void **pptr;
void **ppMem;
size_t *ppatternSize;
const void **ppPattern;
size_t *psize;
Expand Down
2 changes: 1 addition & 1 deletion scripts/core/enqueue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -985,7 +985,7 @@ params:
name: hQueue
desc: "[in] handle of the queue object"
- type: void*
name: ptr
name: pMem
desc: "[in] pointer to USM memory object"
- type: size_t
name: patternSize
Expand Down
37 changes: 37 additions & 0 deletions scripts/templates/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1300,3 +1300,40 @@ def get_create_retain_release_functions(specs, namespace, tags):
)

return {"create": create_funcs, "retain": retain_funcs, "release": release_funcs}

"""
Public:
returns a dictionary with enqueue entry-points and calls to thir bounds
checking helper functions
"""
def get_bounds_checks(specs, namespace, tags):
single_buffer_check = "boundsCheckBuffer(hQueue, hBuffer, size, offset)"
double_buffer_check = "boundsCheckBuffer(hQueue, hBufferSrc, hBufferDst, size, srcOffset, dstOffset)"
single_rect_check = "boundsCheckBufferRect(hQueue, hBuffer, region, bufferOrigin)"
double_rect_check = "boundsCheckBufferRect(hQueue, hBufferSrc, hBufferDst, region, srcOrigin, dstOrigin)"
single_image_check = "boundsCheckImage(hQueue, hImage, region, origin)"
double_image_check = "boundsCheckImage(hQueue, hImageSrc, hImageDst, region, srcOrigin, dstOrigin)"
single_usm_check = "boundsCheckUSMAllocation(hQueue, pMem, size)"
double_usm_check = "boundsCheckUSMAllocation(hQueue, pSrc, pDst, size, size)"
single_usm_2d_check = "boundsCheckUSMAllocation(hQueue, pMem, pitch * height)"
double_usm_2d_check = "boundsCheckUSMAllocation(hQueue, pSrc, pDst, srcPitch * height, dstPitch * height)"
bounds_check_dict = {
"urEnqueueMemBufferRead" : single_buffer_check,
"urEnqueueMemBufferWrite" : single_buffer_check,
"urEnqueueMemBufferFill" : single_buffer_check,
"urEnqueueMemBufferMap" : single_buffer_check,
"urEnqueueMemBufferCopy" : double_buffer_check,
"urEnqueueMemBufferReadRect" : single_rect_check,
"urEnqueueMemBufferWriteRect" : single_rect_check,
"urEnqueueMemBufferCopyRect" : double_rect_check,
"urEnqueueMemImageRead" : single_image_check,
"urEnqueueMemImageWrite" : single_image_check,
"urEnqueueMemImageCopy" : double_image_check,
"urEnqueueUSMFill" : single_usm_check,
"urEnqueueUSMPrefetch" : single_usm_check,
"urEnqueueUSMAdvise" : single_usm_check,
"urEnqueueUSMMemcpy" : double_usm_check,
"urEnqueueUSMFill2D" : single_usm_2d_check,
"urEnqueueUSMMemcpy2D" : double_usm_2d_check
}
return bounds_check_dict
8 changes: 8 additions & 0 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ from templates import helper as th
x=tags['$x']
X=x.upper()
create_retain_release_funcs=th.get_create_retain_release_functions(specs, n, tags)
bounds_check_dict=th.get_bounds_checks(specs, n, tags)
%>/*
*
* Copyright (C) 2023 Intel Corporation
Expand Down Expand Up @@ -60,6 +61,13 @@ namespace ur_validation_layer

%endfor
%endfor

<% bounds_check = bounds_check_dict.get(func_name) %>
%if bounds_check:
if(auto boundsErr = ${bounds_check}; boundsErr != ${X}_RESULT_SUCCESS) {
return boundsErr;
}
%endif
}

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
Expand Down
6 changes: 3 additions & 3 deletions source/adapters/hip/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,6 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
#endif
return ReturnValue(UR_USM_TYPE_UNKNOWN);
}
case UR_USM_ALLOC_INFO_BASE_PTR:
case UR_USM_ALLOC_INFO_SIZE:
return UR_RESULT_ERROR_INVALID_VALUE;
case UR_USM_ALLOC_INFO_DEVICE: {
// get device index associated with this pointer
UR_CHECK_ERROR(hipPointerGetAttributes(&hipPointerAttributeType, pMem));
Expand Down Expand Up @@ -222,6 +219,9 @@ urUSMGetMemAllocInfo(ur_context_handle_t hContext, const void *pMem,
}
return ReturnValue(Pool);
}
case UR_USM_ALLOC_INFO_BASE_PTR:
case UR_USM_ALLOC_INFO_SIZE:
return UR_RESULT_ERROR_UNSUPPORTED_ENUMERATION;
default:
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
Expand Down
4 changes: 2 additions & 2 deletions source/adapters/null/ur_nullddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3488,7 +3488,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueMemUnmap(
/// @brief Intercept function for urEnqueueUSMFill
__urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill(
ur_queue_handle_t hQueue, ///< [in] handle of the queue object
void *ptr, ///< [in] pointer to USM memory object
void *pMem, ///< [in] pointer to USM memory object
size_t
patternSize, ///< [in] the size in bytes of the pattern. Must be a power of 2 and less
///< than or equal to width.
Expand All @@ -3511,7 +3511,7 @@ __urdlllocal ur_result_t UR_APICALL urEnqueueUSMFill(
// if the driver has created a custom function, then call it instead of using the generic path
auto pfnUSMFill = d_context.urDdiTable.Enqueue.pfnUSMFill;
if (nullptr != pfnUSMFill) {
result = pfnUSMFill(hQueue, ptr, patternSize, pPattern, size,
result = pfnUSMFill(hQueue, pMem, patternSize, pPattern, size,
numEventsInWaitList, phEventWaitList, phEvent);
} else {
// generic implementation
Expand Down
Loading

0 comments on commit 8a548cd

Please sign in to comment.