Skip to content

Commit

Permalink
[CUDA][HIP][OpenCL][NATIVECPU] Fix multi-device compile
Browse files Browse the repository at this point in the history
Ensure that all adapters have the correct signatures for the
multi-device compile experimental feature entry points and that they
entry points exist even when returning
`UR_RESULT_ERROR_UNSUPPORTED_FEATURE`.
  • Loading branch information
kbenzie committed Nov 21, 2023
1 parent ce152a6 commit e001b98
Show file tree
Hide file tree
Showing 8 changed files with 68 additions and 13 deletions.
4 changes: 2 additions & 2 deletions source/adapters/cuda/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t hContext,
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, const ur_program_handle_t *, uint32_t,
ur_device_handle_t *, const char *, ur_program_handle_t *) {
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand Down
6 changes: 3 additions & 3 deletions source/adapters/cuda/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,9 +399,9 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable(
if (UR_RESULT_SUCCESS != retVal) {
return retVal;
}
pDdiTable->pfnBuildExp = nullptr;
pDdiTable->pfnCompileExp = nullptr;
pDdiTable->pfnLinkExp = nullptr;
pDdiTable->pfnBuildExp = urProgramBuildExp;
pDdiTable->pfnCompileExp = urProgramCompileExp;
pDdiTable->pfnLinkExp = urProgramLinkExp;
return retVal;
}

Expand Down
4 changes: 2 additions & 2 deletions source/adapters/hip/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urProgramBuild(ur_context_handle_t,
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, const ur_program_handle_t *, uint32_t,
ur_device_handle_t *, const char *, ur_program_handle_t *) {
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

Expand Down
6 changes: 3 additions & 3 deletions source/adapters/hip/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,9 +354,9 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable(
if (UR_RESULT_SUCCESS != retVal) {
return retVal;
}
pDdiTable->pfnBuildExp = nullptr;
pDdiTable->pfnCompileExp = nullptr;
pDdiTable->pfnLinkExp = nullptr;
pDdiTable->pfnBuildExp = urProgramBuildExp;
pDdiTable->pfnCompileExp = urProgramCompileExp;
pDdiTable->pfnLinkExp = urProgramLinkExp;
return retVal;
}

Expand Down
20 changes: 20 additions & 0 deletions source/adapters/native_cpu/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,26 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
DIE_NO_IMPLEMENTATION
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL
urProgramRetain(ur_program_handle_t hProgram) {
hProgram->incrementReferenceCount();
Expand Down
15 changes: 15 additions & 0 deletions source/adapters/native_cpu/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,19 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetVirtualMemProcAddrTable(
return retVal;
}

UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable(
ur_api_version_t version, ///< [in] API version requested
ur_program_exp_dditable_t
*pDdiTable ///< [in,out] pointer to table of DDI function pointers
) {
auto retVal = validateProcInputs(version, pDdiTable);
if (UR_RESULT_SUCCESS != retVal) {
return retVal;
}
pDdiTable->pfnBuildExp = urProgramBuildExp;
pDdiTable->pfnCompileExp = urProgramCompileExp;
pDdiTable->pfnLinkExp = urProgramLinkExp;
return retVal;
}

} // extern "C"
20 changes: 20 additions & 0 deletions source/adapters/opencl/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,26 @@ urProgramLink(ur_context_handle_t hContext, uint32_t count,
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramCompileExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramBuildExp(ur_program_handle_t,
uint32_t,
ur_device_handle_t *,
const char *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

UR_APIEXPORT ur_result_t UR_APICALL urProgramLinkExp(
ur_context_handle_t, uint32_t, ur_device_handle_t *, uint32_t,
const ur_program_handle_t *, const char *, ur_program_handle_t *) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}

static cl_int mapURProgramBuildInfoToCL(ur_program_build_info_t URPropName) {

switch (static_cast<uint32_t>(URPropName)) {
Expand Down
6 changes: 3 additions & 3 deletions source/adapters/opencl/ur_interface_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,9 +390,9 @@ UR_DLLEXPORT ur_result_t UR_APICALL urGetProgramExpProcAddrTable(
if (UR_RESULT_SUCCESS != retVal) {
return retVal;
}
pDdiTable->pfnBuildExp = nullptr;
pDdiTable->pfnCompileExp = nullptr;
pDdiTable->pfnLinkExp = nullptr;
pDdiTable->pfnBuildExp = urProgramBuildExp;
pDdiTable->pfnCompileExp = urProgramCompileExp;
pDdiTable->pfnLinkExp = urProgramLinkExp;
return retVal;
}

Expand Down

0 comments on commit e001b98

Please sign in to comment.