Skip to content

Commit

Permalink
Fix urProgramCompileExp, urProgramBuildExp, and urProgramLinkExp defi…
Browse files Browse the repository at this point in the history
…nition to match spec

Signed-off-by: Spruit, Neil R <neil.r.spruit@intel.com>
  • Loading branch information
nrspruit committed Nov 17, 2023
1 parent 4b5e559 commit 0790bf8
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 108 deletions.
6 changes: 2 additions & 4 deletions source/adapters/cuda/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,16 +226,14 @@ urProgramCompile(ur_context_handle_t hContext, ur_program_handle_t hProgram,
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_context_handle_t,
ur_program_handle_t,
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_context_handle_t,
ur_program_handle_t,
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
Expand Down
6 changes: 2 additions & 4 deletions source/adapters/hip/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,14 @@ urProgramCompile(ur_context_handle_t hContext, ur_program_handle_t hProgram,
return urProgramBuild(hContext, hProgram, pOptions);
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_context_handle_t,
ur_program_handle_t,
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_context_handle_t,
ur_program_handle_t,
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
Expand Down
123 changes: 61 additions & 62 deletions source/adapters/level_zero/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,16 +113,16 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(
const char *Options ///< [in][optional] pointer to build options
///< null-terminated string.
) {
return urProgramBuildExp(Context, Program, 1, Context->Devices.data(),
Options);
return urProgramBuildExp(Program, 1, Context->Devices.data(), Options);
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
ur_context_handle_t Context, ///< [in] handle of the context instance.
ur_program_handle_t Program, ///< [in] Handle of the program to build.
uint32_t numDevices, ur_device_handle_t *phDevices,
const char *Options ///< [in][optional] pointer to build options
///< null-terminated string.
ur_program_handle_t hProgram, ///< [in] Handle of the program to build.
uint32_t numDevices, ///< [in] number of devices
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to
///< array of device handles
const char *pOptions ///< [in][optional] pointer to build options
///< null-terminated string.
) {
// TODO
// Check if device belongs to associated context.
Expand All @@ -131,43 +131,42 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
// UR_RESULT_ERROR_INVALID_VALUE);

// We should have either IL or native device code.
UR_ASSERT(Program->Code, UR_RESULT_ERROR_INVALID_PROGRAM);
UR_ASSERT(hProgram->Code, UR_RESULT_ERROR_INVALID_PROGRAM);

// It is legal to build a program created from either IL or from native
// device code.
if (Program->State != ur_program_handle_t_::IL &&
Program->State != ur_program_handle_t_::Native) {
if (hProgram->State != ur_program_handle_t_::IL &&
hProgram->State != ur_program_handle_t_::Native) {
return UR_RESULT_ERROR_INVALID_OPERATION;
}

std::scoped_lock<ur_shared_mutex> Guard(Program->Mutex);
std::scoped_lock<ur_shared_mutex> Guard(hProgram->Mutex);

// Ask Level Zero to build and load the native code onto the device.
ZeStruct<ze_module_desc_t> ZeModuleDesc;
ur_program_handle_t_::SpecConstantShim Shim(Program);
ZeModuleDesc.format = (Program->State == ur_program_handle_t_::IL)
ur_program_handle_t_::SpecConstantShim Shim(hProgram);
ZeModuleDesc.format = (hProgram->State == ur_program_handle_t_::IL)
? ZE_MODULE_FORMAT_IL_SPIRV
: ZE_MODULE_FORMAT_NATIVE;
ZeModuleDesc.inputSize = Program->CodeLength;
ZeModuleDesc.pInputModule = Program->Code.get();
ZeModuleDesc.pBuildFlags = Options;
ZeModuleDesc.inputSize = hProgram->CodeLength;
ZeModuleDesc.pInputModule = hProgram->Code.get();
ZeModuleDesc.pBuildFlags = pOptions;
ZeModuleDesc.pConstants = Shim.ze();

ze_device_handle_t ZeDevice = phDevices[0]->ZeDevice;
ze_context_handle_t ZeContext = Program->Context->ZeContext;
std::ignore = Context;
ze_context_handle_t ZeContext = hProgram->Context->ZeContext;
std::ignore = numDevices;
ze_module_handle_t ZeModule = nullptr;

ur_result_t Result = UR_RESULT_SUCCESS;
Program->State = ur_program_handle_t_::Exe;
hProgram->State = ur_program_handle_t_::Exe;
ze_result_t ZeResult =
ZE_CALL_NOCHECK(zeModuleCreate, (ZeContext, ZeDevice, &ZeModuleDesc,
&ZeModule, &Program->ZeBuildLog));
&ZeModule, &hProgram->ZeBuildLog));
if (ZeResult != ZE_RESULT_SUCCESS) {
// We adjust ur_program below to avoid attempting to release zeModule when
// RT calls urProgramRelease().
Program->State = ur_program_handle_t_::Invalid;
hProgram->State = ur_program_handle_t_::Invalid;
Result = ze2urResult(ZeResult);
if (ZeModule) {
ZE_CALL_NOCHECK(zeModuleDestroy, (ZeModule));
Expand All @@ -179,9 +178,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
// call to zeModuleDynamicLink. However, modules created with
// urProgramBuild are supposed to be fully linked and ready to use.
// Therefore, do an extra check now for unresolved symbols.
ZeResult = checkUnresolvedSymbols(ZeModule, &Program->ZeBuildLog);
ZeResult = checkUnresolvedSymbols(ZeModule, &hProgram->ZeBuildLog);
if (ZeResult != ZE_RESULT_SUCCESS) {
Program->State = ur_program_handle_t_::Invalid;
hProgram->State = ur_program_handle_t_::Invalid;
Result = (ZeResult == ZE_RESULT_ERROR_MODULE_LINK_FAILURE)
? UR_RESULT_ERROR_PROGRAM_BUILD_FAILURE
: ze2urResult(ZeResult);
Expand All @@ -193,22 +192,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
}

// We no longer need the IL / native code.
Program->Code.reset();
Program->ZeModule = ZeModule;
hProgram->Code.reset();
hProgram->ZeModule = ZeModule;
return Result;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(
ur_context_handle_t Context, ///< [in] handle of the context instance.
ur_program_handle_t
Program, ///< [in][out] handle of the program to compile.
uint32_t numDevices, ur_device_handle_t *phDevices,
const char *Options ///< [in][optional] pointer to build options
///< null-terminated string.
hProgram, ///< [in][out] handle of the program to compile.
uint32_t numDevices, ///< [in] number of devices
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to
///< array of device handles
const char *pOptions ///< [in][optional] pointer to build options
///< null-terminated string.
) {
std::ignore = numDevices;
std::ignore = phDevices;
return urProgramCompile(Context, Program, Options);
return urProgramCompile(hProgram->Context, hProgram, pOptions);
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramCompile(
Expand Down Expand Up @@ -251,38 +251,37 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLink(
ur_program_handle_t
*Program ///< [out] pointer to handle of program object created.
) {
return urProgramLinkExp(Context, Count, Programs, 1, Context->Devices.data(),
return urProgramLinkExp(Context, Count, Context->Devices.data(), 1, Programs,
Options, Program);
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t Context, ///< [in] handle of the context instance.
ur_context_handle_t hContext, ///< [in] handle of the context instance.
uint32_t numDevices, ///< [in] number of devices
ur_device_handle_t *phDevices, ///< [in][range(0, numDevices)] pointer to
///< array of device handles
uint32_t Count, ///< [in] number of program handles in `phPrograms`.
const ur_program_handle_t *Programs, ///< [in][range(0, count)] pointer to
///< array of program handles.
const char *Options, ///< [in][optional] pointer to linker options
///< null-terminated string.
uint32_t count, ///< [in] number of program handles in `phPrograms`.
const ur_program_handle_t *phPrograms, ///< [in][range(0, count)] pointer to
///< array of program handles.
const char *pOptions, ///< [in][optional] pointer to linker options
///< null-terminated string.
ur_program_handle_t
*Program ///< [out] pointer to handle of program object created.
*phProgram ///< [out] pointer to handle of program object created.
) {
std::ignore = numDevices;

UR_ASSERT(Context->isValidDevice(phDevices[0]),
UR_ASSERT(hContext->isValidDevice(phDevices[0]),
UR_RESULT_ERROR_INVALID_DEVICE);

// We do not support any link flags at this time because the Level Zero API
// does not have any way to pass flags that are specific to linking.
if (Options && *Options != '\0') {
if (pOptions && *pOptions != '\0') {
std::string ErrorMessage(
"Level Zero does not support kernel link flags: \"");
ErrorMessage.append(Options);
ErrorMessage.append(pOptions);
ErrorMessage.push_back('\"');
ur_program_handle_t_ *UrProgram = new ur_program_handle_t_(
ur_program_handle_t_::Invalid, Context, ErrorMessage);
*Program = reinterpret_cast<ur_program_handle_t>(UrProgram);
ur_program_handle_t_::Invalid, hContext, ErrorMessage);
*phProgram = reinterpret_cast<ur_program_handle_t>(UrProgram);
return UR_RESULT_ERROR_PROGRAM_LINK_FAILURE;
}

Expand All @@ -299,11 +298,11 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
// potential if there was some other code that holds more than one of these
// locks simultaneously with "exclusive" access. However, there is no such
// code like that, so this is also not a danger.
std::vector<std::shared_lock<ur_shared_mutex>> Guards(Count);
for (uint32_t I = 0; I < Count; I++) {
std::shared_lock<ur_shared_mutex> Guard(Programs[I]->Mutex);
std::vector<std::shared_lock<ur_shared_mutex>> Guards(count);
for (uint32_t I = 0; I < count; I++) {
std::shared_lock<ur_shared_mutex> Guard(phPrograms[I]->Mutex);
Guards[I].swap(Guard);
if (Programs[I]->State != ur_program_handle_t_::Object) {
if (phPrograms[I]->State != ur_program_handle_t_::Object) {
return UR_RESULT_ERROR_INVALID_OPERATION;
}
}
Expand All @@ -316,23 +315,23 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
// Construct a ze_module_program_exp_desc_t which contains information about
// all of the modules that will be linked together.
ZeStruct<ze_module_program_exp_desc_t> ZeExtModuleDesc;
std::vector<size_t> CodeSizes(Count);
std::vector<const uint8_t *> CodeBufs(Count);
std::vector<const char *> BuildFlagPtrs(Count);
std::vector<const ze_module_constants_t *> SpecConstPtrs(Count);
std::vector<size_t> CodeSizes(count);
std::vector<const uint8_t *> CodeBufs(count);
std::vector<const char *> BuildFlagPtrs(count);
std::vector<const ze_module_constants_t *> SpecConstPtrs(count);
std::vector<ur_program_handle_t_::SpecConstantShim> SpecConstShims;
SpecConstShims.reserve(Count);
SpecConstShims.reserve(count);

for (uint32_t I = 0; I < Count; I++) {
ur_program_handle_t Program = Programs[I];
for (uint32_t I = 0; I < count; I++) {
ur_program_handle_t Program = phPrograms[I];
CodeSizes[I] = Program->CodeLength;
CodeBufs[I] = Program->Code.get();
BuildFlagPtrs[I] = Program->BuildFlags.c_str();
SpecConstShims.emplace_back(Program);
SpecConstPtrs[I] = SpecConstShims[I].ze();
}

ZeExtModuleDesc.count = Count;
ZeExtModuleDesc.count = count;
ZeExtModuleDesc.inputSizes = CodeSizes.data();
ZeExtModuleDesc.pInputModules = CodeBufs.data();
ZeExtModuleDesc.pBuildFlags = BuildFlagPtrs.data();
Expand Down Expand Up @@ -366,8 +365,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
//
// TODO: Remove this workaround when the driver is fixed.
if (!phDevices[0]->Platform->ZeDriverModuleProgramExtensionFound ||
(Count == 1)) {
if (Count == 1) {
(count == 1)) {
if (count == 1) {
ZeModuleDesc.pNext = nullptr;
ZeModuleDesc.inputSize = ZeExtModuleDesc.inputSizes[0];
ZeModuleDesc.pInputModule = ZeExtModuleDesc.pInputModules[0];
Expand All @@ -382,7 +381,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(

// Call the Level Zero API to compile, link, and create the module.
ze_device_handle_t ZeDevice = phDevices[0]->ZeDevice;
ze_context_handle_t ZeContext = Context->ZeContext;
ze_context_handle_t ZeContext = hContext->ZeContext;
ze_module_handle_t ZeModule = nullptr;
ze_module_build_log_handle_t ZeBuildLog = nullptr;
ze_result_t ZeResult =
Expand Down Expand Up @@ -420,8 +419,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
? ur_program_handle_t_::Exe
: ur_program_handle_t_::Invalid;
ur_program_handle_t_ *UrProgram =
new ur_program_handle_t_(State, Context, ZeModule, ZeBuildLog);
*Program = reinterpret_cast<ur_program_handle_t>(UrProgram);
new ur_program_handle_t_(State, hContext, ZeModule, ZeBuildLog);
*phProgram = reinterpret_cast<ur_program_handle_t>(UrProgram);
} catch (const std::bad_alloc &) {
return UR_RESULT_ERROR_OUT_OF_HOST_MEMORY;
} catch (...) {
Expand Down
38 changes: 0 additions & 38 deletions source/ur/ur.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,41 +295,3 @@ class UrReturnHelper {
void *param_value;
size_t *param_value_size_ret;
};

// Needed to have compatibility with piProgramBuild
// when passing a specific list of devices
// See: https://github.com/oneapi-src/unified-runtime/issues/912
UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(
ur_context_handle_t hContext, ///< [in] handle of the context instance.
ur_program_handle_t hProgram, ///< [in] Handle of the program to build.
uint32_t numDevices, ur_device_handle_t *phDevices,
const char *pOptions ///< [in][optional] pointer to build options
///< null-terminated string.
);

// Needed to have compatibility with piProgramCompile
// when passing a specific list of devices
// See: https://github.com/oneapi-src/unified-runtime/issues/912
UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(
ur_context_handle_t Context, ///< [in] handle of the context instance.
ur_program_handle_t
Program, ///< [in][out] handle of the program to compile.
uint32_t numDevices, ur_device_handle_t *phDevices,
const char *Options ///< [in][optional] pointer to build options
///< null-terminated string.
);

// Needed to have compatibility with piProgramLink
// when passing a specific list of devices
// See: https://github.com/oneapi-src/unified-runtime/issues/912
UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t Context, ///< [in] handle of the context instance.
uint32_t Count, ///< [in] number of program handles in `phPrograms`.
const ur_program_handle_t *Programs, ///< [in][range(0, count)] pointer to
///< array of program handles.
uint32_t numDevices, ur_device_handle_t *phDevices,
const char *Options, ///< [in][optional] pointer to linker options
///< null-terminated string.
ur_program_handle_t
*Program ///< [out] pointer to handle of program object created.
);

0 comments on commit 0790bf8

Please sign in to comment.