Skip to content

Commit

Permalink
Refactor Level Zero Adapter to use new logger
Browse files Browse the repository at this point in the history
This commit refactors the Level Zero adapter to adopt the new logger
introduced in d9cd223 (Integrate logger with library, 2023-02-03).

Signed-off-by: Łukasz Plewa <lukasz.plewa@intel.com>
  • Loading branch information
lplewa authored and kbenzie committed Apr 22, 2024
1 parent 717791b commit 5dccce5
Show file tree
Hide file tree
Showing 18 changed files with 322 additions and 240 deletions.
60 changes: 47 additions & 13 deletions source/adapters/level_zero/adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,21 @@ ur_adapter_handle_t_ *GlobalAdapter = new ur_adapter_handle_t_();
ur_adapter_handle_t_ *GlobalAdapter;
#endif

class ur_legacy_sink : public logger::Sink {
public:
ur_legacy_sink(std::string logger_name = "", bool skip_prefix = true)
: Sink(std::move(logger_name), skip_prefix) {
this->ostream = &std::cerr;
}

virtual void print([[maybe_unused]] logger::Level level,
const std::string &msg) override {
fprintf(stderr, "%s", msg.c_str());
}

~ur_legacy_sink() = default;
};

ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {
uint32_t ZeDriverCount = 0;
ZE2UR_CALL(zeDriverGet, (&ZeDriverCount, nullptr));
Expand All @@ -44,7 +59,18 @@ ur_result_t initPlatforms(PlatformVec &platforms) noexcept try {

ur_result_t adapterStateInit() { return UR_RESULT_SUCCESS; }

ur_adapter_handle_t_::ur_adapter_handle_t_() {
ur_adapter_handle_t_::ur_adapter_handle_t_()
: logger(logger::get_logger("level_zero")) {

if (UrL0Debug & UR_L0_DEBUG_BASIC) {
logger.setLegacySink(std::make_unique<ur_legacy_sink>());
};

if (UrL0Debug & UR_L0_DEBUG_VALIDATION) {
setEnvVar("ZE_ENABLE_VALIDATION_LAYER", "1");
setEnvVar("ZE_ENABLE_PARAMETER_VALIDATION", "1");
}

PlatformCache.Compute = [](Result<PlatformVec> &result) {
static std::once_flag ZeCallCountInitialized;
try {
Expand All @@ -68,7 +94,7 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
}

if (getenv("SYCL_ENABLE_PCI") != nullptr) {
urPrint(
logger::warning(
"WARNING: SYCL_ENABLE_PCI is deprecated and no longer needed.\n");
}

Expand All @@ -91,8 +117,9 @@ ur_adapter_handle_t_::ur_adapter_handle_t_() {
return;
}
if (*GlobalAdapter->ZeResult != ZE_RESULT_SUCCESS) {
urPrint("zeInit: Level Zero initialization failure\n");
logger::error("zeInit: Level Zero initialization failure\n");
result = ze2urResult(*GlobalAdapter->ZeResult);

return;
}

Expand Down Expand Up @@ -157,10 +184,10 @@ ur_result_t adapterStateTeardown() {
// zeMemAllocShared = 0 \---> zeMemFree = 1
//
// clang-format on

fprintf(stderr, "Check balance of create/destroy calls\n");
fprintf(stderr,
"----------------------------------------------------------\n");
// TODO: use logger to print this messages
std::cerr << "Check balance of create/destroy calls\n";
std::cerr << "----------------------------------------------------------\n";
std::stringstream ss;
for (const auto &Row : CreateDestroySet) {
int diff = 0;
for (auto I = Row.begin(); I != Row.end();) {
Expand All @@ -171,23 +198,30 @@ ur_result_t adapterStateTeardown() {
bool Last = (++I == Row.end());

if (Last) {
fprintf(stderr, " \\--->");
ss << " \\--->";
diff -= ZeCount;
} else {
diff += ZeCount;
if (!First) {
fprintf(stderr, " | \n");
ss << " | ";
std::cerr << ss.str() << "\n";
ss.str("");
ss.clear();
}
}

fprintf(stderr, "%30s = %-5d", ZeName, ZeCount);
ss << std::setw(30) << std::right << ZeName;
ss << " = ";
ss << std::setw(5) << std::left << ZeCount;
}

if (diff) {
LeakFound = true;
fprintf(stderr, " ---> LEAK = %d", diff);
ss << " ---> LEAK = " << diff;
}
fprintf(stderr, "\n");

std::cerr << ss.str() << '\n';
ss.str("");
ss.clear();
}

ZeCallCount->clear();
Expand Down
4 changes: 4 additions & 0 deletions source/adapters/level_zero/adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//
//===----------------------------------------------------------------------===//

#include "logger/ur_logger.hpp"
#include <atomic>
#include <mutex>
#include <optional>
Expand All @@ -16,13 +17,16 @@

using PlatformVec = std::vector<std::unique_ptr<ur_platform_handle_t_>>;

class ur_legacy_sink;

struct ur_adapter_handle_t_ {
ur_adapter_handle_t_();
std::atomic<uint32_t> RefCount = 0;
std::mutex Mutex;

std::optional<ze_result_t> ZeResult;
ZeCache<Result<PlatformVec>> PlatformCache;
logger::Logger &logger;
};

extern ur_adapter_handle_t_ *GlobalAdapter;
62 changes: 30 additions & 32 deletions source/adapters/level_zero/command_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//
//===----------------------------------------------------------------------===//
#include "command_buffer.hpp"
#include "logger/ur_logger.hpp"
#include "ur_level_zero.hpp"

/* L0 Command-buffer Extension Doc see:
Expand Down Expand Up @@ -140,16 +141,16 @@ ur_result_t calculateKernelWorkDimensions(
while (GlobalWorkSize3D[I] % GroupSize[I]) {
--GroupSize[I];
}
if (GlobalWorkSize3D[I] / GroupSize[I] > UINT32_MAX) {
urPrint("calculateKernelWorkDimensions: can't find a WG size "
"suitable for global work size > UINT32_MAX\n");
if (GlobalWorkSize[I] / GroupSize[I] > UINT32_MAX) {
logger::debug("calculateKernelWorkDimensions: can't find a WG size "
"suitable for global work size > UINT32_MAX");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
WG[I] = GroupSize[I];
}
urPrint("calculateKernelWorkDimensions: using computed WG size = {%d, "
"%d, %d}\n",
WG[0], WG[1], WG[2]);
logger::debug("calculateKernelWorkDimensions: using computed WG "
"size = {{{}, {}, {}}}",
WG[0], WG[1], WG[2]);
}
}

Expand Down Expand Up @@ -177,30 +178,27 @@ ur_result_t calculateKernelWorkDimensions(
break;

default:
urPrint("calculateKernelWorkDimensions: unsupported work_dim\n");
logger::error("calculateKernelWorkDimensions: unsupported work_dim");
return UR_RESULT_ERROR_INVALID_VALUE;
}

// Error handling for non-uniform group size case
if (GlobalWorkSize3D[0] !=
size_t(ZeThreadGroupDimensions.groupCountX) * WG[0]) {
urPrint("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a "
"multiple of the group size in the 1st dimension\n");
logger::error("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 1st dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
if (GlobalWorkSize3D[1] !=
size_t(ZeThreadGroupDimensions.groupCountY) * WG[1]) {
urPrint("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a "
"multiple of the group size in the 2nd dimension\n");
logger::error("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 2nd dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}
if (GlobalWorkSize3D[2] !=
size_t(ZeThreadGroupDimensions.groupCountZ) * WG[2]) {
urPrint("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a "
"multiple of the group size in the 3rd dimension\n");
logger::error("calculateKernelWorkDimensions: invalid work_dim. The range "
"is not a multiple of the group size in the 3rd dimension");
return UR_RESULT_ERROR_INVALID_WORK_GROUP_SIZE;
}

Expand Down Expand Up @@ -268,9 +266,9 @@ static ur_result_t enqueueCommandBufferMemCopyHelper(
(CommandBuffer->ZeCommandList, Dst, Src, Size,
LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data()));

urPrint("calling zeCommandListAppendMemoryCopy() with"
" ZeEvent %#" PRIxPTR "\n",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));
logger::debug("calling zeCommandListAppendMemoryCopy() with"
" ZeEvent {}",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -335,9 +333,9 @@ static ur_result_t enqueueCommandBufferMemCopyRectHelper(
DstSlicePitch, Src, &ZeSrcRegion, SrcPitch, SrcSlicePitch,
LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data()));

urPrint("calling zeCommandListAppendMemoryCopyRegion() with"
" ZeEvent %#" PRIxPTR "\n",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));
logger::debug("calling zeCommandListAppendMemoryCopyRegion() with"
" ZeEvent {}",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -378,9 +376,9 @@ static ur_result_t enqueueCommandBufferFillHelper(
(CommandBuffer->ZeCommandList, Ptr, Pattern, PatternSize, Size,
LaunchEvent->ZeEvent, ZeEventList.size(), ZeEventList.data()));

urPrint("calling zeCommandListAppendMemoryFill() with"
" ZeEvent %#lx\n",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));
logger::debug("calling zeCommandListAppendMemoryFill() with"
" ZeEvent {}",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -519,7 +517,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
if (GlobalWorkOffset != NULL) {
if (!CommandBuffer->Context->getPlatform()
->ZeDriverGlobalOffsetExtensionFound) {
urPrint("No global offset extension found on this driver\n");
logger::debug("No global offset extension found on this driver");
return UR_RESULT_ERROR_INVALID_VALUE;
}

Expand Down Expand Up @@ -606,9 +604,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
&ZeThreadGroupDimensions, LaunchEvent->ZeEvent,
ZeEventList.size(), ZeEventList.data()));

urPrint("calling zeCommandListAppendLaunchKernel() with"
" ZeEvent %#" PRIxPTR "\n",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));
logger::debug("calling zeCommandListAppendLaunchKernel() with"
" ZeEvent {}",
ur_cast<std::uintptr_t>(LaunchEvent->ZeEvent));

return UR_RESULT_SUCCESS;
}
Expand Down Expand Up @@ -1068,7 +1066,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
if (NewGlobalWorkOffset && Dim > 0) {
if (!CommandBuffer->Context->getPlatform()
->ZeDriverGlobalOffsetExtensionFound) {
urPrint("No global offset extension found on this driver\n");
logger::error("No global offset extension found on this driver");
return UR_RESULT_ERROR_INVALID_VALUE;
}
auto MutableGroupOffestDesc =
Expand Down Expand Up @@ -1277,8 +1275,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
// Zero.
continue;
} else {
urPrint("urCommandBufferUpdateKernelLaunchExp: unsupported name of "
"execution attribute.\n");
logger::error("urCommandBufferUpdateKernelLaunchExp: unsupported name of "
"execution attribute.");
return UR_RESULT_ERROR_INVALID_VALUE;
}
}
Expand Down
18 changes: 5 additions & 13 deletions source/adapters/level_zero/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
//===----------------------------------------------------------------------===//

#include "common.hpp"
#include "logger/ur_logger.hpp"
#include "usm.hpp"

ur_result_t ze2urResult(ze_result_t ZeResult) {
Expand Down Expand Up @@ -65,15 +66,6 @@ ur_result_t ze2urResult(ze_result_t ZeResult) {
}
}

void urPrint(const char *Format, ...) {
if (UrL0Debug & UR_L0_DEBUG_BASIC) {
va_list Args;
va_start(Args, Format);
vfprintf(stderr, Format, Args);
va_end(Args);
}
}

usm::DisjointPoolAllConfigs DisjointPoolConfigInstance =
InitializeDisjointPoolConfig();

Expand All @@ -86,8 +78,8 @@ bool setEnvVar(const char *name, const char *value) {
int Res = setenv(name, value, 1);
#endif
if (Res != 0) {
urPrint("UR L0 Adapter was unable to set the environment variable: %s\n",
name);
logger::debug(
"UR L0 Adapter was unable to set the environment variable: {}", name);
return false;
}
return true;
Expand Down Expand Up @@ -149,7 +141,7 @@ inline void zeParseError(ze_result_t ZeError, const char *&ErrorString) {

ze_result_t ZeCall::doCall(ze_result_t ZeResult, const char *ZeName,
const char *ZeArgs, bool TraceError) {
urPrint("ZE ---> %s%s\n", ZeName, ZeArgs);
logger::debug("ZE ---> {}{}", ZeName, ZeArgs);

if (UrL0LeaksDebug) {
++(*ZeCallCount)[ZeName];
Expand All @@ -158,7 +150,7 @@ ze_result_t ZeCall::doCall(ze_result_t ZeResult, const char *ZeName,
if (ZeResult && TraceError) {
const char *ErrorString = "Unknown";
zeParseError(ZeResult, ErrorString);
urPrint("Error (%s) in %s\n", ErrorString, ZeName);
logger::error("Error ({}) in {}", ErrorString, ZeName);
}
return ZeResult;
}
Expand Down
10 changes: 2 additions & 8 deletions source/adapters/level_zero/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,10 +174,10 @@ static auto getUrResultString = [](ur_result_t Result) {
#define UR_CALL(Call) \
{ \
if (PrintTrace) \
fprintf(stderr, "UR ---> %s\n", #Call); \
logger::always("UR ---> {}", #Call); \
ur_result_t Result = (Call); \
if (PrintTrace) \
fprintf(stderr, "UR <--- %s(%s)\n", #Call, getUrResultString(Result)); \
logger::always("UR <--- {}({})", #Call, getUrResultString(Result)); \
if (Result != UR_RESULT_SUCCESS) \
return Result; \
}
Expand Down Expand Up @@ -268,9 +268,6 @@ class ZeCall {
// setting environment variables.
bool setEnvVar(const char *name, const char *value);

// Prints to stderr if UR_L0_DEBUG allows it
void urPrint(const char *Format, ...);

// Helper for one-liner validation
#define UR_ASSERT(condition, error) \
if (!(condition)) \
Expand Down Expand Up @@ -301,9 +298,6 @@ template <class T> struct ZesStruct : public T {
// setting environment variables.
bool setEnvVar(const char *name, const char *value);

// Prints to stderr if UR_L0_DEBUG allows it
void urPrint(const char *Format, ...);

// Helper for one-liner validation
#define UR_ASSERT(condition, error) \
if (!(condition)) \
Expand Down
Loading

0 comments on commit 5dccce5

Please sign in to comment.