Skip to content

Commit

Permalink
Merge pull request #1121 from kswiecicki/val-use-after-free
Browse files Browse the repository at this point in the history
[UR] Add lifetime validation to validation layer
  • Loading branch information
pbalcer committed Feb 9, 2024
2 parents 32e2533 + c0f0a70 commit 186bfb9
Show file tree
Hide file tree
Showing 12 changed files with 1,312 additions and 122 deletions.
2 changes: 2 additions & 0 deletions scripts/core/INTRO.rst
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,8 @@ Layers currently included with the runtime are as follows:
- Enables non-adapter-specific parameter validation (e.g. checking for null values).
* - UR_LAYER_LEAK_CHECKING
- Performs some leak checking for API calls involving object creation/destruction.
* - UR_LAYER_LIFETIME_VALIDATION
- Performs lifetime validation on objects (check if it was used within the scope of its creation and destruction) used in API calls. Automatically enables UR_LAYER_LEAK_CHECKING.
* - UR_LAYER_FULL_VALIDATION
- Enables UR_LAYER_PARAMETER_VALIDATION and UR_LAYER_LEAK_CHECKING.
* - UR_LAYER_TRACING
Expand Down
84 changes: 59 additions & 25 deletions scripts/templates/helper.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright (C) 2022-2023 Intel Corporation
Copyright (C) 2022-2024 Intel Corporation
Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
See LICENSE.TXT
Expand Down Expand Up @@ -1486,45 +1486,79 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta):

return epilogue


def get_event_wait_list_functions(specs, namespace, tags):
funcs = []
for s in specs:
for obj in s['objects']:
if re.match(r"function", obj['type']):
if any(x['name'] == 'phEventWaitList' for x in obj['params']) and any(
x['name'] == 'numEventsInWaitList' for x in obj['params']):
funcs.append(make_func_name(namespace, tags, obj))
return funcs


"""
Public:
returns a dictionary with lists of create, retain and release functions
Private:
returns a dictionary with lists of create, get, retain and release functions
"""
def get_create_retain_release_functions(specs, namespace, tags):
def _get_create_get_retain_release_functions(specs, namespace, tags):
funcs = []
for s in specs:
for obj in s['objects']:
if re.match(r"function", obj['type']):
funcs.append(make_func_name(namespace, tags, obj))

create_suffixes = r"(Create[A-Za-z]*){1}"
retain_suffixes = r"(Retain){1}"
release_suffixes = r"(Release){1}"
create_suffixes = r"(Create[A-Za-z]*){1}$"
get_suffixes = r"(Get){1}$"
retain_suffixes = r"(Retain){1}$"
release_suffixes = r"(Release){1}$"
common_prefix = r"^" + namespace

create_exp = namespace + r"([A-Za-z]+)" + create_suffixes
retain_exp = namespace + r"([A-Za-z]+)" + retain_suffixes
release_exp = namespace + r"([A-Za-z]+)" + release_suffixes
create_exp = common_prefix + r"[A-Za-z]+" + create_suffixes
get_exp = common_prefix + r"[A-Za-z]+" + get_suffixes
retain_exp = common_prefix + r"[A-Za-z]+" + retain_suffixes
release_exp = common_prefix + r"[A-Za-z]+" + release_suffixes

create_funcs, retain_funcs, release_funcs = (
create_funcs, get_funcs, retain_funcs, release_funcs = (
list(filter(lambda f: re.match(create_exp, f), funcs)),
list(filter(lambda f: re.match(get_exp, f), funcs)),
list(filter(lambda f: re.match(retain_exp, f), funcs)),
list(filter(lambda f: re.match(release_exp, f), funcs)),
)

create_funcs, retain_funcs = (
list(filter(lambda f: re.sub(create_suffixes, "Release", f) in release_funcs, create_funcs)),
list(filter(lambda f: re.sub(retain_suffixes, "Release", f) in release_funcs, retain_funcs)),
)
return {"create": create_funcs, "get": get_funcs, "retain": retain_funcs, "release": release_funcs}

return {"create": create_funcs, "retain": retain_funcs, "release": release_funcs}

"""
Public:
returns a list of dictionaries containing handle types and the corresponding create, get, retain and release functions
"""
def get_handle_create_get_retain_release_functions(specs, namespace, tags):
# Handles without release function
excluded_handles = ["$x_platform_handle_t", "$x_native_handle_t"]
# Handles from experimental features
exp_prefix = "$x_exp"

funcs = _get_create_get_retain_release_functions(specs, namespace, tags)
records = []
for h in get_adapter_handles(specs):
if h['name'] in excluded_handles or h['name'].startswith(exp_prefix):
continue

def get_event_wait_list_functions(specs, namespace, tags):
funcs = []
for s in specs:
for obj in s['objects']:
if re.match(r"function", obj['type']):
if any(x['name'] == 'phEventWaitList' for x in obj['params']) and any(
x['name'] == 'numEventsInWaitList' for x in obj['params']):
funcs.append(make_func_name(namespace, tags, obj))
return funcs
class_type = subt(namespace, tags, h['class'])
create_funcs = list(filter(lambda f: class_type in f, funcs['create']))
get_funcs = list(filter(lambda f: class_type in f, funcs['get']))
retain_funcs = list(filter(lambda f: class_type in f, funcs['retain']))
release_funcs = list(filter(lambda f: class_type in f, funcs['release']))

record = {}
record['handle'] = subt(namespace, tags, h['name'])
record['create'] = create_funcs
record['get'] = get_funcs
record['retain'] = retain_funcs
record['release'] = release_funcs

records.append(record)

return records
66 changes: 42 additions & 24 deletions scripts/templates/valddi.cpp.mako
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@ from templates import helper as th
x=tags['$x']
X=x.upper()
create_retain_release_funcs=th.get_create_retain_release_functions(specs, n, tags)
handle_create_get_retain_release_funcs=th.get_handle_create_get_retain_release_functions(specs, n, tags)
%>/*
*
* Copyright (C) 2023 Intel Corporation
* Copyright (C) 2023-2024 Intel Corporation
*
* Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
* See LICENSE.TXT
Expand All @@ -27,11 +28,12 @@ namespace ur_validation_layer
%for obj in th.get_adapter_functions(specs):
<%
func_name=th.make_func_name(n, tags, obj)
object_param=th.make_param_lines(n, tags, obj, format=["name"])[-1]
object_param_type=th.make_param_lines(n, tags, obj, format=["type"])[-1]
param_checks=th.make_param_checks(n, tags, obj, meta=meta).items()
first_errors = [X + "_RESULT_ERROR_INVALID_NULL_POINTER", X + "_RESULT_ERROR_INVALID_NULL_HANDLE"]
sorted_param_checks = sorted(param_checks, key=lambda pair: False if pair[0] in first_errors else True)
tracked_params = list(filter(lambda p: any(th.subt(n, tags, p['type']) in [hf['handle'], hf['handle'] + "*"] for hf in handle_create_get_retain_release_funcs), obj['params']))
%>
///////////////////////////////////////////////////////////////////////////////
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
Expand Down Expand Up @@ -72,39 +74,49 @@ namespace ur_validation_layer

}

%for tp in tracked_params:
<%
tp_input_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) == hf['handle'] and "[in]" in tp['desc']), {})
is_related_create_get_retain_release_func = any(func_name in funcs for funcs in tp_input_handle_funcs.values())
%>
%if tp_input_handle_funcs and not is_related_create_get_retain_release_func:
if (context.enableLifetimeValidation && !refCountContext.isReferenceValid(${tp['name']})) {
refCountContext.logInvalidReference(${tp['name']});
}
%endif
%endfor

${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );

%if func_name == n + "AdapterRelease":
%for tp in tracked_params:
<%
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
%>
%if func_name in tp_handle_funcs['create']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.decrementRefCount(${object_param}, true);
refCountContext.createRefCount(*${tp['name']});
}
%elif func_name == n + "AdapterRetain":
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
%elif func_name in tp_handle_funcs['get']:
if( context.enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
{
refCountContext.incrementRefCount(${object_param}, true);
}
%elif func_name == n + "AdapterGet":
if( context.enableLeakChecking && phAdapters && result == UR_RESULT_SUCCESS )
{
refCountContext.createOrIncrementRefCount(*phAdapters, true);
}
%elif func_name in create_retain_release_funcs["create"]:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.createRefCount(*${object_param});
for (uint32_t i = ${th.param_traits.range_start(tp)}; i < ${th.param_traits.range_end(tp)}; i++) {
refCountContext.createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
}
}
%elif func_name in create_retain_release_funcs["retain"]:
%elif func_name in tp_handle_funcs['retain']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.incrementRefCount(${object_param});
refCountContext.incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%elif func_name in create_retain_release_funcs["release"]:
%elif func_name in tp_handle_funcs['release']:
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
{
refCountContext.decrementRefCount(${object_param});
refCountContext.decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
}
%endif
%endfor

return result;
}
Expand Down Expand Up @@ -167,16 +179,22 @@ namespace ur_validation_layer
if (enabledLayerNames.count(nameFullValidation)) {
enableParameterValidation = true;
enableLeakChecking = true;
enableLifetimeValidation = true;
} else {
if (enabledLayerNames.count(nameParameterValidation)) {
enableParameterValidation = true;
}
if (enabledLayerNames.count(nameLeakChecking)) {
enableLeakChecking = true;
}
if (enabledLayerNames.count(nameLifetimeValidation)) {
// Handle lifetime validation requires leak checking feature.
enableLifetimeValidation = true;
enableLeakChecking = true;
}
}

if(!enableParameterValidation && !enableLeakChecking) {
if (!enableParameterValidation && !enableLeakChecking && !enableLifetimeValidation) {
return result;
}

Expand Down
70 changes: 51 additions & 19 deletions source/loader/layers/validation/ur_leak_check.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2023 Intel Corporation
// Copyright (C) 2023-2024 Intel Corporation
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions.
// See LICENSE.TXT
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Expand All @@ -9,6 +9,7 @@
#include "ur_validation_layer.hpp"

#include <mutex>
#include <typeindex>
#include <unordered_map>
#include <utility>

Expand All @@ -20,7 +21,12 @@ struct RefCountContext {
private:
struct RefRuntimeInfo {
int64_t refCount;
std::type_index type;
std::vector<BacktraceLine> backtrace;

RefRuntimeInfo(int64_t refCount, std::type_index type,
std::vector<BacktraceLine> backtrace)
: refCount(refCount), type(type), backtrace(backtrace) {}
};

enum RefCountUpdateType {
Expand All @@ -34,26 +40,32 @@ struct RefCountContext {
std::unordered_map<void *, struct RefRuntimeInfo> counts;
int64_t adapterCount = 0;

void updateRefCount(void *ptr, enum RefCountUpdateType type,
template <typename T>
void updateRefCount(T handle, enum RefCountUpdateType type,
bool isAdapterHandle = false) {
std::unique_lock<std::mutex> ulock(mutex);

void *ptr = static_cast<void *>(handle);
auto it = counts.find(ptr);

switch (type) {
case REFCOUNT_CREATE_OR_INCREASE:
if (it == counts.end()) {
counts[ptr] = {1, getCurrentBacktrace()};
std::tie(it, std::ignore) = counts.emplace(
ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)),
getCurrentBacktrace()});
if (isAdapterHandle) {
adapterCount++;
}
} else {
counts[ptr].refCount++;
it->second.refCount++;
}
break;
case REFCOUNT_CREATE:
if (it == counts.end()) {
counts[ptr] = {1, getCurrentBacktrace()};
std::tie(it, std::ignore) = counts.emplace(
ptr, RefRuntimeInfo{1, std::type_index(typeid(handle)),
getCurrentBacktrace()});
} else {
context.logger.error("Handle {} already exists", ptr);
return;
Expand All @@ -65,29 +77,31 @@ struct RefCountContext {
"Attempting to retain nonexistent handle {}", ptr);
return;
} else {
counts[ptr].refCount++;
it->second.refCount++;
}
break;
case REFCOUNT_DECREASE:
if (it == counts.end()) {
counts[ptr] = {-1, getCurrentBacktrace()};
std::tie(it, std::ignore) = counts.emplace(
ptr, RefRuntimeInfo{-1, std::type_index(typeid(handle)),
getCurrentBacktrace()});
} else {
counts[ptr].refCount--;
it->second.refCount--;
}

if (counts[ptr].refCount < 0) {
if (it->second.refCount < 0) {
context.logger.error(
"Attempting to release nonexistent handle {}", ptr);
} else if (counts[ptr].refCount == 0 && isAdapterHandle) {
} else if (it->second.refCount == 0 && isAdapterHandle) {
adapterCount--;
}
break;
}

context.logger.debug("Reference count for handle {} changed to {}", ptr,
counts[ptr].refCount);
it->second.refCount);

if (counts[ptr].refCount == 0) {
if (it->second.refCount == 0) {
counts.erase(ptr);
}

Expand All @@ -99,22 +113,36 @@ struct RefCountContext {
}

public:
void createRefCount(void *ptr) { updateRefCount(ptr, REFCOUNT_CREATE); }
template <typename T> void createRefCount(T handle) {
updateRefCount<T>(handle, REFCOUNT_CREATE);
}

void incrementRefCount(void *ptr, bool isAdapterHandle = false) {
updateRefCount(ptr, REFCOUNT_INCREASE, isAdapterHandle);
template <typename T>
void incrementRefCount(T handle, bool isAdapterHandle = false) {
updateRefCount(handle, REFCOUNT_INCREASE, isAdapterHandle);
}

void decrementRefCount(void *ptr, bool isAdapterHandle = false) {
updateRefCount(ptr, REFCOUNT_DECREASE, isAdapterHandle);
template <typename T>
void decrementRefCount(T handle, bool isAdapterHandle = false) {
updateRefCount(handle, REFCOUNT_DECREASE, isAdapterHandle);
}

void createOrIncrementRefCount(void *ptr, bool isAdapterHandle = false) {
updateRefCount(ptr, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
template <typename T>
void createOrIncrementRefCount(T handle, bool isAdapterHandle = false) {
updateRefCount(handle, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
}

void clear() { counts.clear(); }

template <typename T> bool isReferenceValid(T handle) {
auto it = counts.find(static_cast<void *>(handle));
if (it == counts.end() || it->second.refCount < 1) {
return false;
}

return (it->second.type == std::type_index(typeid(handle)));
}

void logInvalidReferences() {
for (auto &[ptr, refRuntimeInfo] : counts) {
context.logger.error("Retained {} reference(s) to handle {}",
Expand All @@ -128,6 +156,10 @@ struct RefCountContext {
}
}

void logInvalidReference(void *ptr) {
context.logger.error("There are no valid references to handle {}", ptr);
}

} refCountContext;

} // namespace ur_validation_layer
Expand Down
Loading

0 comments on commit 186bfb9

Please sign in to comment.