Skip to content

Commit

Permalink
Refactor cuda adapter to new logger
Browse files Browse the repository at this point in the history
This commit refactors the Cuda adapter to adopt the new logger
introduced in d9cd223 (Integrate logger with library, 2023-02-03).
Signed-off-by: Łukasz Plewa <lukasz.plewa@intel.com>
  • Loading branch information
lplewa authored and kbenzie committed Apr 12, 2024
1 parent 68e525a commit 281e3aa
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 34 deletions.
28 changes: 28 additions & 0 deletions source/adapters/cuda/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,42 @@
#include <ur_api.h>

#include "common.hpp"
#include "logger/ur_logger.hpp"
#include "tracing.hpp"

struct ur_adapter_handle_t_ {
std::atomic<uint32_t> RefCount = 0;
std::mutex Mutex;
struct cuda_tracing_context_t_ *TracingCtx = nullptr;
logger::Logger &logger;
ur_adapter_handle_t_();
};

class ur_legacy_sink : public logger::Sink {
public:
ur_legacy_sink(std::string logger_name = "", bool skip_prefix = true)
: Sink(std::move(logger_name), skip_prefix) {
this->ostream = &std::cerr;
}

virtual void print([[maybe_unused]] logger::Level level,
const std::string &msg) override {
std::cerr << msg << std::endl;
}

~ur_legacy_sink() = default;
};
ur_adapter_handle_t_::ur_adapter_handle_t_()
: logger(logger::get_logger("cuda")) {

if (std::getenv("UR_LOG_CUDA") != nullptr)
return;

if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") != nullptr ||
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") != nullptr) {
logger.setLegacySink(std::make_unique<ur_legacy_sink>());
}
}
ur_adapter_handle_t_ adapter{};

UR_APIEXPORT ur_result_t UR_APICALL
Expand Down
5 changes: 3 additions & 2 deletions source/adapters/cuda/command_buffer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <ur_api.h>

#include "context.hpp"
#include "logger/ur_logger.hpp"
#include <cuda.h>
#include <memory>

Expand Down Expand Up @@ -169,10 +170,10 @@ static inline const char *getUrResultString(ur_result_t Result) {
#define UR_CALL(Call, Result) \
{ \
if (PrintTrace) \
fprintf(stderr, "UR ---> %s\n", #Call); \
logger::always("UR ---> {}", #Call); \
Result = (Call); \
if (PrintTrace) \
fprintf(stderr, "UR <--- %s(%s)\n", #Call, getUrResultString(Result)); \
logger::always("UR <--- {}({})", #Call, getUrResultString(Result)); \
}

// Handle to a kernel command.
Expand Down
50 changes: 19 additions & 31 deletions source/adapters/cuda/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "common.hpp"
#include "logger/ur_logger.hpp"

#include <cuda.h>

Expand Down Expand Up @@ -41,22 +42,18 @@ void checkErrorUR(CUresult Result, const char *Function, int Line,
return;
}

if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") == nullptr &&
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") == nullptr) {
const char *ErrorString = nullptr;
const char *ErrorName = nullptr;
cuGetErrorName(Result, &ErrorName);
cuGetErrorString(Result, &ErrorString);
std::stringstream SS;
SS << "\nUR CUDA ERROR:"
<< "\n\tValue: " << Result
<< "\n\tName: " << ErrorName
<< "\n\tDescription: " << ErrorString
<< "\n\tFunction: " << Function << "\n\tSource Location: " << File
<< ":" << Line << "\n"
<< std::endl;
std::cerr << SS.str();
}
const char *ErrorString = nullptr;
const char *ErrorName = nullptr;
cuGetErrorName(Result, &ErrorName);
cuGetErrorString(Result, &ErrorString);
std::stringstream SS;
SS << "\nUR CUDA ERROR:"
<< "\n\tValue: " << Result
<< "\n\tName: " << ErrorName
<< "\n\tDescription: " << ErrorString
<< "\n\tFunction: " << Function << "\n\tSource Location: " << File
<< ":" << Line << "\n";
logger::error("{}", SS.str());

if (std::getenv("PI_CUDA_ABORT") != nullptr ||
std::getenv("UR_CUDA_ABORT") != nullptr) {
Expand All @@ -72,16 +69,11 @@ void checkErrorUR(ur_result_t Result, const char *Function, int Line,
return;
}

if (std::getenv("SYCL_PI_SUPPRESS_ERROR_MESSAGE") == nullptr &&
std::getenv("UR_SUPPRESS_ERROR_MESSAGE") == nullptr) {
std::stringstream SS;
SS << "\nUR ERROR:"
<< "\n\tValue: " << Result
<< "\n\tFunction: " << Function << "\n\tSource Location: " << File
<< ":" << Line << "\n"
<< std::endl;
std::cerr << SS.str();
}
std::stringstream SS;
SS << "\nUR ERROR:"
<< "\n\tValue: " << Result << "\n\tFunction: " << Function
<< "\n\tSource Location: " << File << ":" << Line << "\n";
logger::error("{}", SS.str());

if (std::getenv("PI_CUDA_ABORT") != nullptr) {
std::abort();
Expand All @@ -101,7 +93,7 @@ std::string getCudaVersionString() {
}

void detail::ur::die(const char *Message) {
std::cerr << "ur_die: " << Message << std::endl;
logger::always("ur_die:{}", Message);
std::terminate();
}

Expand All @@ -110,10 +102,6 @@ void detail::ur::assertion(bool Condition, const char *Message) {
die(Message);
}

void detail::ur::cuPrint(const char *Message) {
std::cerr << "ur_print: " << Message << std::endl;
}

// Global variables for ZER_EXT_RESULT_ADAPTER_SPECIFIC_ERROR
thread_local ur_result_t ErrorMessageCode = UR_RESULT_SUCCESS;
thread_local char ErrorMessage[MaxMessageSize];
Expand Down
3 changes: 2 additions & 1 deletion source/adapters/cuda/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "adapter.hpp"
#include "context.hpp"
#include "device.hpp"
#include "logger/ur_logger.hpp"
#include "platform.hpp"
#include "ur_util.hpp"

Expand Down Expand Up @@ -293,7 +294,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
std::getenv("UR_CUDA_ENABLE_IMAGE_SUPPORT") != nullptr) {
Enabled = true;
} else {
detail::ur::cuPrint(
logger::always(
"Images are not fully supported by the CUDA BE, their support is "
"disabled by default. Their partial support can be activated by "
"setting SYCL_PI_CUDA_ENABLE_IMAGE_SUPPORT environment variable at "
Expand Down

0 comments on commit 281e3aa

Please sign in to comment.