Skip to content

Commit

Permalink
update CUDA adapter to new logger
Browse files Browse the repository at this point in the history
  • Loading branch information
lplewa committed Jan 9, 2024
1 parent 710646e commit 7fa7f28
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 41 deletions.
5 changes: 3 additions & 2 deletions source/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,17 @@
//
//===----------------------------------------------------------------------===//

#include <ur_api.h>

#include "common.hpp"
#include "logger/ur_logger.hpp"
#include <ur_api.h>

void enableCUDATracing();
void disableCUDATracing();

struct ur_adapter_handle_t_ {
std::atomic<uint32_t> RefCount = 0;
std::mutex Mutex;
logger::Logger &logger = logger::get_logger("cuda");
};

ur_adapter_handle_t_ adapter{};
Expand Down
5 changes: 3 additions & 2 deletions source/adapters/cuda/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ur_api.h>

#include "context.hpp"
#include "logger/ur_logger.hpp"
#include <cuda.h>
#include <memory>

Expand Down Expand Up @@ -169,10 +170,10 @@ static inline const char *getUrResultString(ur_result_t Result) {
#define UR_CALL(Call, Result) \
{ \
if (PrintTrace) \
fprintf(stderr, "UR ---> %s\n", #Call); \
logger::debug("UR ---> {}", #Call); \
Result = (Call); \
if (PrintTrace) \
fprintf(stderr, "UR <--- %s(%s)\n", #Call, getUrResultString(Result)); \
logger::debug("UR <--- {}({})", #Call, getUrResultString(Result)); \
}

struct ur_exp_command_buffer_handle_t_ {
Expand Down
60 changes: 34 additions & 26 deletions source/adapters/cuda/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "common.hpp"
#include "logger/ur_logger.hpp"

#include <cuda.h>

Expand Down Expand Up @@ -37,29 +38,34 @@ ur_result_t mapErrorUR(CUresult Result) {

void checkErrorUR(CUresult Result, const char *Function, int Line,
const char *File) {
static bool suppress_error =
std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") == nullptr &&
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") == nullptr;

static bool abort = std::getenv("PI_CUDA_ABORT") != nullptr ||
std::getenv("UR_CUDA_ABORT") != nullptr;

if (Result == CUDA_SUCCESS || Result == CUDA_ERROR_DEINITIALIZED) {
return;
}

if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") == nullptr &&
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") == nullptr) {
if (suppress_error) {
const char *ErrorString = nullptr;
const char *ErrorName = nullptr;
cuGetErrorName(Result, &ErrorName);
cuGetErrorString(Result, &ErrorString);
std::stringstream SS;
SS << "\nUR CUDA ERROR:"
<< "\n\tValue: " << Result
<< "\n\tName: " << ErrorName
<< "\n\tDescription: " << ErrorString
<< "\n\tFunction: " << Function << "\n\tSource Location: " << File
<< ":" << Line << "\n"
<< std::endl;
std::cerr << SS.str();
SS << std::endl
<< "CUDA Error" << std::endl
<< "\tValue: " << Result << std::endl
<< "\tName: " << ErrorName << std::endl
<< "\tDescription: " << ErrorString << std::endl
<< "\tFunction: " << Function << std::endl
<< "\tSource Location: " << File << ":" << Line << std::endl;
logger::error("{}", SS.str());
}

if (std::getenv("PI_CUDA_ABORT") != nullptr ||
std::getenv("UR_CUDA_ABORT") != nullptr) {
if (abort) {
std::abort();
}

Expand All @@ -68,22 +74,28 @@ void checkErrorUR(CUresult Result, const char *Function, int Line,

void checkErrorUR(ur_result_t Result, const char *Function, int Line,
const char *File) {
static bool suppress_error =
std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") == nullptr &&
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") == nullptr;

static bool abort = std::getenv("PI_CUDA_ABORT") != nullptr ||
std::getenv("UR_CUDA_ABORT") != nullptr;

if (Result == UR_RESULT_SUCCESS) {
return;
}

if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") == nullptr &&
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") == nullptr) {
if (suppress_error) {
std::stringstream SS;
SS << "\nUR ERROR:"
<< "\n\tValue: " << Result
<< "\n\tFunction: " << Function << "\n\tSource Location: " << File
<< ":" << Line << "\n"
<< std::endl;
std::cerr << SS.str();
SS << std::endl
<< "UR ERROR:" << std::endl
<< "\tValue: " << Result << std::endl
<< "\tFunction: " << Function << std::endl
<< "\tSource Location: " << File << ":" << Line << std::endl;
logger::error("{}", SS.str());
}

if (std::getenv("PI_CUDA_ABORT") != nullptr) {
if (abort) {
std::abort();
}

Expand All @@ -101,7 +113,7 @@ std::string getCudaVersionString() {
}

void detail::ur::die(const char *Message) {
std::cerr << "ur_die: " << Message << std::endl;
logger::error("{}", Message);
std::terminate();
}

Expand All @@ -110,10 +122,6 @@ void detail::ur::assertion(bool Condition, const char *Message) {
die(Message);
}

void detail::ur::cuPrint(const char *Message) {
std::cerr << "ur_print: " << Message << std::endl;
}

// Global variables for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
thread_local ur_result_t ErrorMessageCode = UR_RESULT_SUCCESS;
thread_local char ErrorMessage[MaxMessageSize];
Expand Down
3 changes: 0 additions & 3 deletions source/adapters/cuda/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ namespace ur {
//
[[noreturn]] void die(const char *Message);

// Reports error messages
void cuPrint(const char *Message);

void assertion(bool Condition, const char *Message = nullptr);

} // namespace ur
Expand Down
3 changes: 2 additions & 1 deletion source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "adapter.hpp"
#include "context.hpp"
#include "device.hpp"
#include "logger/ur_logger.hpp"
#include "platform.hpp"
#include "ur_util.hpp"

Expand Down Expand Up @@ -279,7 +280,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
std::getenv("UR_CUDA_ENABLE_IMAGE_SUPPORT") != nullptr) {
Enabled = true;
} else {
detail::ur::cuPrint(
logger::info(
"Images are not fully supported by the CUDA BE, their support is "
"disabled by default. Their partial support can be activated by "
"setting SYCL_PI_CUDA_ENABLE_IMAGE_SUPPORT environment variable at "
Expand Down
6 changes: 3 additions & 3 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,9 +419,9 @@ static ur_result_t enqueueCommandBufferFillHelper(
(CommandBuffer->ZeCommandList, Ptr, Pattern, PatternSize, Size,
LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data()));

urPrint("calling zeCommandListAppendMemoryFill() with"
" ZeEvent %#lx\n",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));
logger::debug("calling zeCommandListAppendMemoryFill() with"
" ZeEvent {}",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}
Expand Down
9 changes: 5 additions & 4 deletions source/adapters/level_zero/virtual_mem.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "device.hpp"
#include "physical_mem.hpp"
#include "ur_level_zero.hpp"
#include "logger/ur_logger.hpp"

UR_APIEXPORT ur_result_t UR_APICALL urVirtualMemGranularityGetInfo(
ur_context_handle_t hContext, ur_device_handle_t hDevice,
Expand All @@ -31,8 +32,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urVirtualMemGranularityGetInfo(
return ReturnValue(PageSize);
}
default:
urPrint("Unsupported propName in urQueueGetInfo: propName=%d(0x%x)\n",
propName, propName);
logger::error("Unsupported propName in urQueueGetInfo: propName={}",
propName);
return UR_RESULT_ERROR_INVALID_VALUE;
}
return UR_RESULT_SUCCESS;
Expand Down Expand Up @@ -111,8 +112,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urVirtualMemGetInfo(
return ReturnValue(RetFlags);
}
default:
urPrint("Unsupported propName in urQueueGetInfo: propName=%d(0x%x)\n",
propName, propName);
logger::error("Unsupported propName in urQueueGetInfo: propName={}",
propName);
return UR_RESULT_ERROR_INVALID_VALUE;
}

Expand Down

0 comments on commit 7fa7f28

Please sign in to comment.