diff --git a/source/adapters/null/ur_null.cpp b/source/adapters/null/ur_null.cpp index 5653ca57db..d200712178 100644 --- a/source/adapters/null/ur_null.cpp +++ b/source/adapters/null/ur_null.cpp @@ -142,5 +142,68 @@ context_t::context_t() { } return UR_RESULT_SUCCESS; }; + + ////////////////////////////////////////////////////////////////////////// + urDdiTable.USM.pfnHostAlloc = + [](ur_context_handle_t hContext, const ur_usm_desc_t *pUSMDesc, + ur_usm_pool_handle_t pool, size_t size, void **ppMem) { + if (size == 0) { + *ppMem = nullptr; + return UR_RESULT_ERROR_UNSUPPORTED_SIZE; + } + *ppMem = malloc(size); + if (ppMem == nullptr) { + return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return UR_RESULT_SUCCESS; + }; + + ////////////////////////////////////////////////////////////////////////// + urDdiTable.USM.pfnDeviceAlloc = + [](ur_context_handle_t hContext, ur_device_handle_t hDevice, + const ur_usm_desc_t *pUSMDesc, ur_usm_pool_handle_t pool, + size_t size, void **ppMem) { + if (size == 0) { + *ppMem = nullptr; + return UR_RESULT_ERROR_UNSUPPORTED_SIZE; + } + *ppMem = malloc(size); + if (ppMem == nullptr) { + return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY; + } + return UR_RESULT_SUCCESS; + }; + + ////////////////////////////////////////////////////////////////////////// + urDdiTable.USM.pfnFree = [](ur_context_handle_t hContext, void *pMem) { + free(pMem); + return UR_RESULT_SUCCESS; + }; + + ////////////////////////////////////////////////////////////////////////// + urDdiTable.USM.pfnGetMemAllocInfo = + [](ur_context_handle_t hContext, const void *pMem, + ur_usm_alloc_info_t propName, size_t propSize, void *pPropValue, + size_t *pPropSizeRet) { + switch (propName) { + case UR_USM_ALLOC_INFO_TYPE: + *reinterpret_cast(pPropValue) = + pMem ? UR_USM_TYPE_DEVICE : UR_USM_TYPE_UNKNOWN; + if (pPropSizeRet != nullptr) { + *pPropSizeRet = sizeof(ur_usm_type_t); + } + break; + case UR_USM_ALLOC_INFO_SIZE: + *reinterpret_cast(pPropValue) = pMem ? SIZE_MAX : 0; + if (pPropSizeRet != nullptr) { + *pPropSizeRet = sizeof(size_t); + } + break; + default: + pPropValue = nullptr; + break; + } + return UR_RESULT_SUCCESS; + }; } } // namespace driver diff --git a/test/conformance/CMakeLists.txt b/test/conformance/CMakeLists.txt index 2b2c5238c6..0eedd8add1 100644 --- a/test/conformance/CMakeLists.txt +++ b/test/conformance/CMakeLists.txt @@ -44,6 +44,23 @@ function(add_conformance_test_with_platform_environment name) target_compile_definitions("test-${name}" PRIVATE PLATFORM_ENVIRONMENT) endfunction() +function(add_fuzz_test name) + set(TEST_TARGET_NAME fuzztest-${name}) + add_executable(${TEST_TARGET_NAME} + ${ARGN}) + target_link_libraries(${TEST_TARGET_NAME} + PRIVATE + ${PROJECT_NAME}::loader + ${PROJECT_NAME}::headers + ${PROJECT_NAME}::common + -fsanitize=fuzzer) + add_test(NAME ${TEST_TARGET_NAME} + COMMAND ${TEST_TARGET_NAME} -max_total_time=10 -seed=1 -shrink=1 + WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}) + set_tests_properties(${TEST_TARGET_NAME} PROPERTIES LABELS "fuzz") + target_compile_options(${TEST_TARGET_NAME} PRIVATE -g -fsanitize=fuzzer) +endfunction() + add_subdirectory(testing) add_subdirectory(platform) @@ -74,3 +91,7 @@ if(DEFINED UR_DPCXX) add_subdirectory(program) add_subdirectory(enqueue) endif() + +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_subdirectory(fuzz) +endif() diff --git a/test/conformance/fuzz/CMakeLists.txt b/test/conformance/fuzz/CMakeLists.txt new file mode 100644 index 0000000000..ab8af96776 --- /dev/null +++ b/test/conformance/fuzz/CMakeLists.txt @@ -0,0 +1,7 @@ +# Copyright (C) 2023 Intel Corporation +# Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +# See LICENSE.TXT +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +add_fuzz_test(conformance + urConformanceFuzz.cpp) diff --git a/test/conformance/fuzz/urConformanceFuzz.cpp b/test/conformance/fuzz/urConformanceFuzz.cpp new file mode 100644 index 0000000000..7c2d2ebb28 --- /dev/null +++ b/test/conformance/fuzz/urConformanceFuzz.cpp @@ -0,0 +1,142 @@ +// Copyright (C) 2023 Intel Corporation +// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM Exceptions. +// See LICENSE.TXT +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include + +#include "ur_api.h" + +extern "C" int LLVMFuzzerTestOneInput(uint8_t *data, size_t size) { + uint16_t platformCountFuzz; + uint16_t deviceCountFuzz; + uint8_t deviceTypeFuzz; + uint32_t deviceAllocSizeFuzz; + uint32_t hostAllocSizeFuzz; + + //// Parse fuzzer data + // Make sure fuzzer data size is sufficient to store all variables + if (size < sizeof(platformCountFuzz) + sizeof(deviceCountFuzz) + + sizeof(deviceTypeFuzz) + sizeof(deviceAllocSizeFuzz) + + sizeof(hostAllocSizeFuzz)) { + return -1; + } + + uint8_t *data_ptr = data; + memcpy(&platformCountFuzz, data_ptr, sizeof(platformCountFuzz)); + // Limit the max number of platforms to avoid allocating too much memory for a vector + if (platformCountFuzz > 1024) { + return -1; + } + data_ptr += sizeof(platformCountFuzz); + + memcpy(&deviceCountFuzz, data_ptr, sizeof(deviceCountFuzz)); + data_ptr += sizeof(deviceCountFuzz); + + memcpy(&deviceTypeFuzz, data_ptr, sizeof(deviceTypeFuzz)); + // Pass only integers which can be a valid device type + if (deviceTypeFuzz > 7) { + return -1; + } + data_ptr += sizeof(deviceTypeFuzz); + + memcpy(&deviceAllocSizeFuzz, data_ptr, sizeof(deviceAllocSizeFuzz)); + // Limit the max size of allocations + if (deviceAllocSizeFuzz > 1 * 1024 * 1024) { + return -1; + } + data_ptr += sizeof(deviceAllocSizeFuzz); + + memcpy(&hostAllocSizeFuzz, data_ptr, sizeof(hostAllocSizeFuzz)); + // Limit the max size of allocations + if (hostAllocSizeFuzz > 1 * 1024 * 1024) { + return -1; + } + + //// API calls + ur_result_t res = UR_RESULT_SUCCESS; + + res = urInit(0); + if (res != UR_RESULT_SUCCESS) { + return 0; + } + + // Get valid platforms + std::vector platforms; + uint32_t platformCount = 0; + + res = urPlatformGet(platformCountFuzz, nullptr, &platformCount); + if (res != UR_RESULT_SUCCESS) { + return 0; + } + platformCount = platformCountFuzz % platformCount; + platforms.resize(platformCount); + res = urPlatformGet(platformCount, platforms.data(), nullptr); + if (res != UR_RESULT_SUCCESS || platformCount == 0) { + return 0; + } + + // Get valid devices of a random platform + ur_platform_handle_t hPlatform = + platforms[platformCountFuzz % platforms.size()]; + std::vector devices; + uint32_t deviceCount = 0; + ur_device_type_t deviceType = static_cast(deviceTypeFuzz); + + res = urDeviceGet(hPlatform, deviceType, deviceCountFuzz, nullptr, + &deviceCount); + if (res != UR_RESULT_SUCCESS) { + return 0; + } + deviceCount = deviceCountFuzz % deviceCount; + devices.resize(deviceCount); + res = urDeviceGet(hPlatform, deviceType, devices.size(), devices.data(), + nullptr); + if (res != UR_RESULT_SUCCESS || deviceCount == 0) { + return 0; + } + + // Test API + ur_context_handle_t context; + void *host_ptr = nullptr; + void *device_ptr = nullptr; + size_t usm_alloc_info_size = 0; + ur_device_handle_t device = devices[deviceCountFuzz % devices.size()]; + ur_usm_type_t device_type = UR_USM_TYPE_UNKNOWN; + + urContextCreate(devices.size(), devices.data(), nullptr, &context); + urUSMHostAlloc(context, nullptr, nullptr, hostAllocSizeFuzz, &host_ptr); + if (hostAllocSizeFuzz != 0) { + memset(host_ptr, 'H', hostAllocSizeFuzz); + } else { + assert(host_ptr == nullptr); + } + urUSMDeviceAlloc(context, device, nullptr, nullptr, deviceAllocSizeFuzz, + &device_ptr); + urUSMGetMemAllocInfo(context, device_ptr, UR_USM_ALLOC_INFO_SIZE, + sizeof(usm_alloc_info_size), + static_cast(&usm_alloc_info_size), nullptr); + if (deviceAllocSizeFuzz != 0) { + assert(usm_alloc_info_size >= deviceAllocSizeFuzz); + } else { + assert(usm_alloc_info_size == 0); + } + urUSMGetMemAllocInfo(context, device_ptr, UR_USM_ALLOC_INFO_TYPE, + sizeof(device_type), static_cast(&device_type), + nullptr); + if (deviceAllocSizeFuzz != 0) { + assert(device_type == UR_USM_TYPE_DEVICE); + } else { + assert(device_type == UR_USM_TYPE_UNKNOWN); + } + + urUSMFree(context, host_ptr); + urUSMFree(context, device_ptr); + urContextRelease(context); + for (auto &device : devices) { + urDeviceRelease(device); + } + return 0; +}