Skip to content

Commit

Permalink
Add extra checks for optional parameters and invalid flags
Browse files Browse the repository at this point in the history
- Refactors the error code generation script
- Adds checks for optional parameters
- Adds checks for invalid combinations of flags in urQueueCreate

Closes #856
  • Loading branch information
fabiomestre committed Sep 13, 2023
1 parent a346a30 commit 2baa8db
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 58 deletions.
20 changes: 19 additions & 1 deletion include/ur_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1717,6 +1717,7 @@ typedef struct ur_device_partition_properties_t {
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == pProperties`
/// + `NULL == pProperties->pProperties`
/// - ::UR_RESULT_ERROR_DEVICE_PARTITION_FAILED
/// - ::UR_RESULT_ERROR_INVALID_DEVICE_PARTITION_COUNT
UR_APIEXPORT ur_result_t UR_APICALL
Expand Down Expand Up @@ -2029,6 +2030,8 @@ typedef struct ur_context_properties_t {
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phDevices`
/// + `NULL == phContext`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `NULL != pProperties && ::UR_CONTEXT_FLAGS_MASK & pProperties->flags`
/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY
/// - ::UR_RESULT_ERROR_OUT_OF_DEVICE_MEMORY
UR_APIEXPORT ur_result_t UR_APICALL
Expand Down Expand Up @@ -3273,6 +3276,8 @@ typedef struct ur_usm_pool_limits_desc_t {
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `NULL != pUSMDesc && ::UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == ppMem`
/// - ::UR_RESULT_ERROR_INVALID_CONTEXT
Expand Down Expand Up @@ -3317,6 +3322,8 @@ urUSMHostAlloc(
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `NULL != pUSMDesc && ::UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == ppMem`
/// - ::UR_RESULT_ERROR_INVALID_CONTEXT
Expand Down Expand Up @@ -3363,6 +3370,8 @@ urUSMDeviceAlloc(
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `NULL != pUSMDesc && ::UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == ppMem`
/// - ::UR_RESULT_ERROR_INVALID_CONTEXT
Expand Down Expand Up @@ -3805,6 +3814,8 @@ typedef struct ur_physical_mem_properties_t {
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `NULL != pProperties && ::UR_PHYSICAL_MEM_FLAGS_MASK & pProperties->flags`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phPhysicalMem`
/// - ::UR_RESULT_ERROR_INVALID_SIZE
Expand Down Expand Up @@ -4882,6 +4893,8 @@ typedef struct ur_kernel_arg_mem_obj_properties_t {
/// - ::UR_RESULT_ERROR_ADAPTER_SPECIFIC
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hKernel`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `NULL != pProperties && ::UR_MEM_FLAGS_MASK & pProperties->memoryAccess`
/// - ::UR_RESULT_ERROR_INVALID_KERNEL_ARGUMENT_INDEX
UR_APIEXPORT ur_result_t UR_APICALL
urKernelSetArgMemObj(
Expand Down Expand Up @@ -5135,12 +5148,15 @@ typedef struct ur_queue_index_properties_t {
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `NULL != pProperties && ::UR_QUEUE_FLAGS_MASK & pProperties->flags`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == phQueue`
/// - ::UR_RESULT_ERROR_INVALID_CONTEXT
/// - ::UR_RESULT_ERROR_INVALID_DEVICE
/// - ::UR_RESULT_ERROR_INVALID_VALUE
/// - ::UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES
/// + `pProperties != NULL && pProperties->flags & UR_QUEUE_FLAG_PRIORITY_HIGH && pProperties->flags & UR_QUEUE_FLAG_PRIORITY_LOW`
/// + `pProperties != NULL && pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_BATCHED && pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_IMMEDIATE`
/// - ::UR_RESULT_ERROR_OUT_OF_HOST_MEMORY
/// - ::UR_RESULT_ERROR_OUT_OF_RESOURCES
UR_APIEXPORT ur_result_t UR_APICALL
Expand Down Expand Up @@ -7069,6 +7085,8 @@ typedef struct ur_exp_layered_image_properties_t {
/// - ::UR_RESULT_ERROR_INVALID_NULL_HANDLE
/// + `NULL == hContext`
/// + `NULL == hDevice`
/// - ::UR_RESULT_ERROR_INVALID_ENUMERATION
/// + `NULL != pUSMDesc && ::UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints`
/// - ::UR_RESULT_ERROR_INVALID_NULL_POINTER
/// + `NULL == ppMem`
/// + `NULL == pResultPitch`
Expand Down
5 changes: 3 additions & 2 deletions scripts/core/queue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ params:
returns:
- $X_RESULT_ERROR_INVALID_CONTEXT
- $X_RESULT_ERROR_INVALID_DEVICE
- $X_RESULT_ERROR_INVALID_VALUE
- $X_RESULT_ERROR_INVALID_QUEUE_PROPERTIES
- $X_RESULT_ERROR_INVALID_QUEUE_PROPERTIES:
- "`pProperties != NULL && pProperties->flags & UR_QUEUE_FLAG_PRIORITY_HIGH && pProperties->flags & UR_QUEUE_FLAG_PRIORITY_LOW`"
- "`pProperties != NULL && pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_BATCHED && pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_IMMEDIATE`"
- $X_RESULT_ERROR_OUT_OF_HOST_MEMORY
- $X_RESULT_ERROR_OUT_OF_RESOURCES
--- #--------------------------------------------------------------------------
Expand Down
100 changes: 48 additions & 52 deletions scripts/parse_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __validate_ordinal(d):
ordinal = None

if ordinal != d['ordinal']:
raise Exception("'ordinal' invalid value: '%s'"%d['ordinal'])
raise Exception("'ordinal' invalid value: '%s'"%d['ordinal'])

def __validate_version(d, prefix="", base_version=default_version):
if 'version' in d:
Expand Down Expand Up @@ -333,7 +333,7 @@ def __validate_params(d, tags):

if item['type'].endswith("flag_t"):
raise Exception(prefix+"'type' must not be '*_flag_t': %s"%item['type'])

if type_traits.is_pointer(item['type']) and "_handle_t" in item['type'] and "[in]" in item['desc']:
if not param_traits.is_range(item):
raise Exception(prefix+"handle type must include a range(start, end) as part of 'desc'")
Expand All @@ -342,11 +342,11 @@ def __validate_params(d, tags):
if ver < max_ver:
raise Exception(prefix+"'version' must be increasing: %s"%item['version'])
max_ver = ver

def __validate_union_tag(d):
if d.get('tag') is None:
raise Exception(f"{d['name']} must include a 'tag' part of the union.")

try:
if 'type' not in d:
raise Exception("every document must have 'type'")
Expand Down Expand Up @@ -466,7 +466,7 @@ def __filter_desc(d):
return d

flt = []
type = d['type']
type = d['type']
if 'enum' == type:
for e in d['etors']:
ver = float(e.get('version', default_version))
Expand Down Expand Up @@ -706,58 +706,54 @@ def _append(lst, key, val):
if val and val not in rets[idx][key]:
rets[idx][key].append(val)

def append_nullchecks(param, accessor: str):
if type_traits.is_pointer(param['type']):
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_POINTER", "`NULL == %s`" % accessor)

elif type_traits.is_funcptr(param['type'], meta):
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_POINTER", "`NULL == %s`" % accessor)

elif type_traits.is_handle(param['type']) and not type_traits.is_ipc_handle(item['type']):
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_HANDLE", "`NULL == %s`" % accessor)

def append_enum_checks(param, accessor: str):
ptypename = type_traits.base(param['type'])

prefix = "`"
if param_traits.is_optional(item):
prefix = "`NULL != %s && " % item['name']

if re.match(r"stype", param['name']):
_append(rets, "$X_RESULT_ERROR_UNSUPPORTED_VERSION", prefix + "%s != %s`"%(re.sub(r"(\$\w)_(.*)_t.*", r"\1_STRUCTURE_TYPE_\2", typename).upper(), accessor))
else:
if type_traits.is_flags(param['type']) and 'bit_mask' in meta['enum'][ptypename].keys():
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", prefix + "%s & %s`"%(ptypename.upper()[:-2]+ "_MASK", accessor))
else:
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", prefix + "%s < %s`"%(meta['enum'][ptypename]['max'], accessor))

# generate results based on parameters
for item in obj['params']:
if param_traits.is_nocheck(item):
continue

if not param_traits.is_optional(item):
append_nullchecks(item, item['name'])

if type_traits.is_enum(item['type'], meta) and not type_traits.is_pointer(item['type']):
append_enum_checks(item, item['name'])

if type_traits.is_descriptor(item['type']) or type_traits.is_properties(item['type']):
typename = type_traits.base(item['type'])
# walk each entry in the desc for pointers and enums
for i, m in enumerate(meta['struct'][typename]['members']):
if param_traits.is_nocheck(m):
continue

if not param_traits.is_optional(m):
append_nullchecks(m, "%s->%s" % (item['name'], m['name']))

if type_traits.is_pointer(item['type']):
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_POINTER", "`NULL == %s`"%item['name'])

elif type_traits.is_funcptr(item['type'], meta):
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_POINTER", "`NULL == %s`"%item['name'])

elif type_traits.is_handle(item['type']) and not type_traits.is_ipc_handle(item['type']):
_append(rets, "$X_RESULT_ERROR_INVALID_NULL_HANDLE", "`NULL == %s`"%item['name'])

elif type_traits.is_enum(item['type'], meta):
if type_traits.is_flags(item['type']) and 'bit_mask' in meta['enum'][typename].keys():
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", "`%s & %s`"%(typename.upper()[:-2]+ "_MASK", item['name']))
else:
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", "`%s < %s`"%(meta['enum'][typename]['max'], item['name']))

if type_traits.is_descriptor(item['type']):
# walk each entry in the desc for pointers and enums
for i, m in enumerate(meta['struct'][typename]['members']):
if param_traits.is_nocheck(m):
continue
mtypename = type_traits.base(m['type'])

if type_traits.is_pointer(m['type']) and not param_traits.is_optional({'desc': m['desc']}):
_append(rets,
"$X_RESULT_ERROR_INVALID_NULL_POINTER",
"`NULL == %s->%s`"%(item['name'], m['name']))

elif type_traits.is_enum(m['type'], meta):
if re.match(r"stype", m['name']):
_append(rets, "$X_RESULT_ERROR_UNSUPPORTED_VERSION", "`%s != %s->stype`"%(re.sub(r"(\$\w)_(.*)_t.*", r"\1_STRUCTURE_TYPE_\2", typename).upper(), item['name']))
else:
if type_traits.is_flags(m['type']) and 'bit_mask' in meta['enum'][mtypename].keys():
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", "`%s & %s->%s`"%(mtypename.upper()[:-2]+ "_MASK", item['name'], m['name']))
else:
_append(rets, "$X_RESULT_ERROR_INVALID_ENUMERATION", "`%s < %s->%s`"%(meta['enum'][mtypename]['max'], item['name'], m['name']))

elif type_traits.is_properties(item['type']):
# walk each entry in the properties
for i, m in enumerate(meta['struct'][typename]['members']):
if param_traits.is_nocheck(m):
continue
if type_traits.is_enum(m['type'], meta):
if re.match(r"stype", m['name']):
_append(rets, "$X_RESULT_ERROR_UNSUPPORTED_VERSION", "`%s != %s->stype`"%(re.sub(r"(\$\w)_(.*)_t.*", r"\1_STRUCTURE_TYPE_\2", typename).upper(), item['name']))
if type_traits.is_enum(m['type'], meta) and not type_traits.is_pointer(m['type']):
append_enum_checks(m, "%s->%s" % (item['name'], m['name']))

# finally, append all user entries
for item in obj.get('returns', []):
Expand Down Expand Up @@ -823,7 +819,7 @@ def _refresh_enum_meta(obj, meta):
## remove the existing meta records
if obj.get('class'):
meta['class'][obj['class']]['enum'].remove(obj['name'])

if meta['enum'].get(obj['name']):
del meta['enum'][obj['name']]
## re-generate meta
Expand Down Expand Up @@ -851,13 +847,13 @@ def _extend_enums(enum_extensions, specs, meta):
if not _validate_ext_enum_range(extension, matching_enum):
raise Exception(f"Invalid enum values.")
matching_enum['etors'].extend(extension['etors'])

_refresh_enum_meta(matching_enum, meta)

## Sort the etors
value = -1
def sort_etors(x):
nonlocal value
nonlocal value
value = _get_etor_value(x.get('value'), value)
return value
matching_enum['etors'] = sorted(matching_enum['etors'], key=sort_etors)
Expand Down
50 changes: 50 additions & 0 deletions source/loader/layers/validation/ur_valddi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,10 @@ __urdlllocal ur_result_t UR_APICALL urDevicePartition(
if (NULL == pProperties) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (NULL == pProperties->pProperties) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}
}

ur_result_t result = pfnPartition(hDevice, pProperties, NumDevices,
Expand Down Expand Up @@ -739,6 +743,10 @@ __urdlllocal ur_result_t UR_APICALL urContextCreate(
if (NULL == phContext) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (NULL != pProperties && UR_CONTEXT_FLAGS_MASK & pProperties->flags) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
}

ur_result_t result =
Expand Down Expand Up @@ -1616,6 +1624,10 @@ __urdlllocal ur_result_t UR_APICALL urUSMHostAlloc(
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (NULL != pUSMDesc && UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

if (pUSMDesc && pUSMDesc->align != 0 &&
((pUSMDesc->align & (pUSMDesc->align - 1)) != 0)) {
return UR_RESULT_ERROR_INVALID_VALUE;
Expand Down Expand Up @@ -1663,6 +1675,10 @@ __urdlllocal ur_result_t UR_APICALL urUSMDeviceAlloc(
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (NULL != pUSMDesc && UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

if (pUSMDesc && pUSMDesc->align != 0 &&
((pUSMDesc->align & (pUSMDesc->align - 1)) != 0)) {
return UR_RESULT_ERROR_INVALID_VALUE;
Expand Down Expand Up @@ -1711,6 +1727,10 @@ __urdlllocal ur_result_t UR_APICALL urUSMSharedAlloc(
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (NULL != pUSMDesc && UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

if (pUSMDesc && pUSMDesc->align != 0 &&
((pUSMDesc->align & (pUSMDesc->align - 1)) != 0)) {
return UR_RESULT_ERROR_INVALID_VALUE;
Expand Down Expand Up @@ -2236,6 +2256,11 @@ __urdlllocal ur_result_t UR_APICALL urPhysicalMemCreate(
if (NULL == phPhysicalMem) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (NULL != pProperties &&
UR_PHYSICAL_MEM_FLAGS_MASK & pProperties->flags) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
}

ur_result_t result =
Expand Down Expand Up @@ -3208,6 +3233,11 @@ __urdlllocal ur_result_t UR_APICALL urKernelSetArgMemObj(
if (NULL == hKernel) {
return UR_RESULT_ERROR_INVALID_NULL_HANDLE;
}

if (NULL != pProperties &&
UR_MEM_FLAGS_MASK & pProperties->memoryAccess) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}
}

ur_result_t result =
Expand Down Expand Up @@ -3398,6 +3428,22 @@ __urdlllocal ur_result_t UR_APICALL urQueueCreate(
if (NULL == phQueue) {
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (NULL != pProperties && UR_QUEUE_FLAGS_MASK & pProperties->flags) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

if (pProperties != NULL &&
pProperties->flags & UR_QUEUE_FLAG_PRIORITY_HIGH &&
pProperties->flags & UR_QUEUE_FLAG_PRIORITY_LOW) {
return UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES;
}

if (pProperties != NULL &&
pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_BATCHED &&
pProperties->flags & UR_QUEUE_FLAG_SUBMISSION_IMMEDIATE) {
return UR_RESULT_ERROR_INVALID_QUEUE_PROPERTIES;
}
}

ur_result_t result = pfnCreate(hContext, hDevice, pProperties, phQueue);
Expand Down Expand Up @@ -5556,6 +5602,10 @@ __urdlllocal ur_result_t UR_APICALL urUSMPitchedAllocExp(
return UR_RESULT_ERROR_INVALID_NULL_POINTER;
}

if (NULL != pUSMDesc && UR_USM_ADVICE_FLAGS_MASK & pUSMDesc->hints) {
return UR_RESULT_ERROR_INVALID_ENUMERATION;
}

if (pUSMDesc && pUSMDesc->align != 0 &&
((pUSMDesc->align & (pUSMDesc->align - 1)) != 0)) {
return UR_RESULT_ERROR_INVALID_VALUE;
Expand Down
Loading

0 comments on commit 2baa8db

Please sign in to comment.