Skip to content

Commit

Permalink
Address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
uditagarwal97 committed Sep 11, 2024
1 parent b6f1f7a commit 7262fd1
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 112 deletions.
2 changes: 1 addition & 1 deletion clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,7 +1153,7 @@ class BinaryWrapper {
Twine(OffloadKindTag) + Twine(ImgId) + Twine(".data"), Kind,
Img.Tgt);

// Change image format to compressed_non.
// Change image format to compressed_none.
Ffmt = ConstantInt::get(Type::getInt8Ty(C),
BinaryImageFormat::compressed_none);
}
Expand Down
10 changes: 7 additions & 3 deletions sycl/source/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@ function(add_sycl_rt_library LIB_NAME LIB_OBJ_NAME)
set(CMAKE_FIND_DEBUG_MODE 1)

# Need zstd for device image compression.
find_package(zstd REQUIRED)
target_link_libraries(${LIB_NAME} PRIVATE ${zstd_STATIC_LIBRARY})
target_include_directories(${LIB_OBJ_NAME} PRIVATE ${zstd_INCLUDE_DIR})
find_package(zstd)
if (NOT zstd_FOUND)
target_compile_definitions(${LIB_OBJ_NAME} PRIVATE SYCL_RT_ZSTD_NOT_AVAIABLE)
else()
target_link_libraries(${LIB_NAME} PRIVATE ${zstd_STATIC_LIBRARY})
target_include_directories(${LIB_OBJ_NAME} PRIVATE ${zstd_INCLUDE_DIR})
endif()

target_include_directories(${LIB_OBJ_NAME} PRIVATE ${BOOST_UNORDERED_INCLUDE_DIRS})

Expand Down
120 changes: 65 additions & 55 deletions sycl/source/detail/compression.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
//===----------------------------------------------------------------------===//
#pragma once

#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE

#include <sycl/exception.hpp>

#include <iostream>
#include <memory>
#include <zstd.h>
Expand All @@ -21,23 +25,7 @@ namespace detail {
// Singleton class to handle ZSTD compression and decompression.
class ZSTDCompressor {
private:
// Initialize ZSTD context and error code.
ZSTDCompressor() {
m_ZSTD_compression_ctx = static_cast<void *>(ZSTD_createCCtx());
m_ZSTD_decompression_ctx = static_cast<void *>(ZSTD_createDCtx());

if (!m_ZSTD_compression_ctx || !m_ZSTD_decompression_ctx) {
std::cerr << "Error creating ZSTD contexts. \n";
}

m_lastError = 0;
}

// Free ZSTD contexts.
~ZSTDCompressor() {
ZSTD_freeCCtx(static_cast<ZSTD_CCtx *>(m_ZSTD_compression_ctx));
ZSTD_freeDCtx(static_cast<ZSTD_DCtx *>(m_ZSTD_decompression_ctx));
}
ZSTDCompressor() {}

ZSTDCompressor(const ZSTDCompressor &) = delete;
ZSTDCompressor &operator=(const ZSTDCompressor &) = delete;
Expand All @@ -50,73 +38,93 @@ class ZSTDCompressor {

// Public APIs
public:
// Return 0 is last (de)compression was successful, otherwise return error
// code.
static int GetLastError() { return GetSingletonInstance().m_lastError; }

// Returns a string representation of the error code.
// If the error code is 0, it returns "No error detected".
static std::string GetErrorString(int code) {
return ZSTD_getErrorName(code);
}

// Blob (de)compression do not assume format/structure of the input buffer.
static std::unique_ptr<char> CompressBlob(const char *src, size_t srcSize,
size_t &dstSize, int level) {
auto &instance = GetSingletonInstance();

// Lazy initialize compression context.
if (!instance.m_ZSTD_compression_ctx) {

// Call ZSTD_createCCtx() and ZSTD_freeCCtx() to create and free the
// context.
instance.m_ZSTD_compression_ctx =
std::unique_ptr<ZSTD_CCtx, size_t (*)(ZSTD_CCtx *)>(ZSTD_createCCtx(),
ZSTD_freeCCtx);
if (!instance.m_ZSTD_compression_ctx) {
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Failed to create ZSTD compression context");
}
}

// Get maximum size of the compressed buffer and allocate it.
auto dstBufferSize = ZSTD_compressBound(srcSize);
auto dstBuffer = std::unique_ptr<char>(new char[dstBufferSize]);

if (!dstBuffer)
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Failed to allocate memory for compressed data");

// Compress the input buffer.
dstSize = ZSTD_compressCCtx(
static_cast<ZSTD_CCtx *>(instance.m_ZSTD_compression_ctx),
static_cast<void *>(dstBuffer.get()), dstBufferSize,
static_cast<const void *>(src), srcSize, level);
dstSize =
ZSTD_compressCCtx(instance.m_ZSTD_compression_ctx.get(),
static_cast<void *>(dstBuffer.get()), dstBufferSize,
static_cast<const void *>(src), srcSize, level);

// Store the error code if compression failed.
if (ZSTD_isError(dstSize))
instance.m_lastError = dstSize;
else
instance.m_lastError = 0;
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
ZSTD_getErrorName(dstSize));

// Pass ownership of the buffer to the caller.
return dstBuffer;
}

static std::unique_ptr<unsigned char>
DecompressBlob(const char *src, size_t srcSize, size_t &dstSize) {
static std::unique_ptr<char> DecompressBlob(const char *src, size_t srcSize,
size_t &dstSize) {
auto &instance = GetSingletonInstance();

// Lazy initialize decompression context.
if (!instance.m_ZSTD_decompression_ctx) {

// Call ZSTD_createDCtx() and ZSTD_freeDCtx() to create and free the
// context.
instance.m_ZSTD_decompression_ctx =
std::unique_ptr<ZSTD_DCtx, size_t (*)(ZSTD_DCtx *)>(ZSTD_createDCtx(),
ZSTD_freeDCtx);
if (!instance.m_ZSTD_decompression_ctx) {
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Failed to create ZSTD decompression context");
}
}

// Size of decompressed image can be larger than what we can allocate
// on heap. In that case, we need to use streaming decompression.
// TODO: Throw if the decompression size is too large.
auto dstBufferSize = ZSTD_getFrameContentSize(src, srcSize);

if (dstBufferSize == ZSTD_CONTENTSIZE_UNKNOWN ||
dstBufferSize == ZSTD_CONTENTSIZE_ERROR) {

std::cerr << "Error determining size of uncompressed data\n";
dstSize = 0;
instance.m_lastError = dstBufferSize;
return nullptr;
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Error determining size of uncompressed data.");
}

// Allocate buffer for decompressed data.
auto dstBuffer =
std::unique_ptr<unsigned char>(new unsigned char[dstBufferSize]);
auto dstBuffer = std::unique_ptr<char>(new char[dstBufferSize]);

dstSize = ZSTD_decompressDCtx(
static_cast<ZSTD_DCtx *>(instance.m_ZSTD_decompression_ctx),
static_cast<void *>(dstBuffer.get()), dstBufferSize,
static_cast<const void *>(src), srcSize);
if (!dstBuffer)
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Failed to allocate memory for decompressed data");

dstSize =
ZSTD_decompressDCtx(instance.m_ZSTD_decompression_ctx.get(),
static_cast<void *>(dstBuffer.get()), dstBufferSize,
static_cast<const void *>(src), srcSize);

// In case of decompression error, return the error message and set dstSize
// to 0.
if (ZSTD_isError(dstSize)) {
instance.m_lastError = dstSize;
dstSize = 0;
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
ZSTD_getErrorName(dstSize));
}

// Pass ownership of the buffer to the caller.
Expand All @@ -125,12 +133,14 @@ class ZSTDCompressor {

// Data fields
private:
int m_lastError;
// ZSTD context. Reusing ZSTD context speeds up subsequent (de)compression.
// Storing as void* to avoid including ZSTD headers in this file.
void *m_ZSTD_compression_ctx;
void *m_ZSTD_decompression_ctx;
// ZSTD contexts. Reusing ZSTD context speeds up subsequent (de)compression.
std::unique_ptr<ZSTD_CCtx, size_t (*)(ZSTD_CCtx *)> m_ZSTD_compression_ctx{
nullptr, nullptr};
std::unique_ptr<ZSTD_DCtx, size_t (*)(ZSTD_DCtx *)> m_ZSTD_decompression_ctx{
nullptr, nullptr};
};
} // namespace detail
} // namespace _V1
} // namespace sycl

#endif // SYCL_RT_ZSTD_NOT_AVAIABLE
40 changes: 18 additions & 22 deletions sycl/source/detail/device_binary_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,28 +172,27 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) {

if (Format == SYCL_DEVICE_BINARY_TYPE_NONE)
// try to determine the format; may remain "NONE"
Format = ur::getBinaryImageFormat(this->Bin->BinaryStart, getSize());
Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize());

SpecConstIDMap.init(this->Bin, __SYCL_PROPERTY_SET_SPEC_CONST_MAP);
SpecConstIDMap.init(Bin, __SYCL_PROPERTY_SET_SPEC_CONST_MAP);
SpecConstDefaultValuesMap.init(
this->Bin, __SYCL_PROPERTY_SET_SPEC_CONST_DEFAULT_VALUES_MAP);
DeviceLibReqMask.init(this->Bin, __SYCL_PROPERTY_SET_DEVICELIB_REQ_MASK);
KernelParamOptInfo.init(this->Bin, __SYCL_PROPERTY_SET_KERNEL_PARAM_OPT_INFO);
AssertUsed.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_ASSERT_USED);
ProgramMetadata.init(this->Bin, __SYCL_PROPERTY_SET_PROGRAM_METADATA);
Bin, __SYCL_PROPERTY_SET_SPEC_CONST_DEFAULT_VALUES_MAP);
DeviceLibReqMask.init(Bin, __SYCL_PROPERTY_SET_DEVICELIB_REQ_MASK);
KernelParamOptInfo.init(Bin, __SYCL_PROPERTY_SET_KERNEL_PARAM_OPT_INFO);
AssertUsed.init(Bin, __SYCL_PROPERTY_SET_SYCL_ASSERT_USED);
ProgramMetadata.init(Bin, __SYCL_PROPERTY_SET_PROGRAM_METADATA);
// Convert ProgramMetadata into the UR format
for (const auto &Prop : ProgramMetadata) {
ProgramMetadataUR.push_back(
ur::mapDeviceBinaryPropertyToProgramMetadata(Prop));
}

ExportedSymbols.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS);
ImportedSymbols.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_IMPORTED_SYMBOLS);
DeviceGlobals.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_GLOBALS);
DeviceRequirements.init(this->Bin,
__SYCL_PROPERTY_SET_SYCL_DEVICE_REQUIREMENTS);
HostPipes.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_HOST_PIPES);
VirtualFunctions.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS);
ExportedSymbols.init(Bin, __SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS);
ImportedSymbols.init(Bin, __SYCL_PROPERTY_SET_SYCL_IMPORTED_SYMBOLS);
DeviceGlobals.init(Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_GLOBALS);
DeviceRequirements.init(Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_REQUIREMENTS);
HostPipes.init(Bin, __SYCL_PROPERTY_SET_SYCL_HOST_PIPES);
VirtualFunctions.init(Bin, __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS);

ImageId = ImageCounter++;
}
Expand Down Expand Up @@ -231,6 +230,7 @@ DynRTDeviceBinaryImage::~DynRTDeviceBinaryImage() {
Bin = nullptr;
}

#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
sycl_device_binary CompressedBin)
: RTDeviceBinaryImage() {
Expand All @@ -243,18 +243,13 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage(
reinterpret_cast<const char *>(CompressedBin->BinaryStart),
compressedDataSize, DecompressedSize);

if (!m_DecompressedData) {
throw sycl::exception(
sycl::make_error_code(sycl::errc::runtime),
"Failed to decompress device binary image. " +
ZSTDCompressor::GetErrorString(ZSTDCompressor::GetLastError()));
}

Bin = new sycl_device_binary_struct(*CompressedBin);
Bin->BinaryStart = m_DecompressedData.get();
Bin->BinaryStart = (const unsigned char *)(m_DecompressedData.get());
Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize;

// Set the new format to none and let RT determine the format.
// TODO: Add support for automatically detecting compressed
// binary format.
Bin->Format = SYCL_DEVICE_BINARY_TYPE_NONE;

init(Bin);
Expand All @@ -265,6 +260,7 @@ CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage() {
delete Bin;
Bin = nullptr;
}
#endif // SYCL_RT_ZSTD_NOT_AVAIABLE

} // namespace detail
} // namespace _V1
Expand Down
4 changes: 3 additions & 1 deletion sycl/source/detail/device_binary_image.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ class DynRTDeviceBinaryImage : public RTDeviceBinaryImage {
std::unique_ptr<char[]> Data;
};

#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
// Compressed device binary image. It decompresses the binary image on
// construction and stores the decompressed data as RTDeviceBinaryImage.
// Also, frees the decompressed data in destructor.
Expand All @@ -290,8 +291,9 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage {
}

private:
std::unique_ptr<unsigned char> m_DecompressedData;
std::unique_ptr<char> m_DecompressedData;
};
#endif // SYCL_RT_ZSTD_NOT_AVAIABLE

} // namespace detail
} // namespace _V1
Expand Down
22 changes: 14 additions & 8 deletions sycl/source/detail/program_manager/program_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,13 @@ getDeviceLibPrograms(const ContextImplPtr Context,
return Programs;
}

// Check if device image is compressed.
static inline bool isDeviceImageCompressed(sycl_device_binary Bin) {

auto currFormat = static_cast<ur::DeviceBinaryType>(Bin->Format);
return currFormat == SYCL_DEVICE_BINARY_TYPE_COMPRESSED_NONE;
}

ProgramManager::ProgramPtr ProgramManager::build(
ProgramPtr Program, const ContextImplPtr Context,
const std::string &CompileOptions, const std::string &LinkOptions,
Expand Down Expand Up @@ -1660,7 +1667,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) {

std::unique_ptr<RTDeviceBinaryImage> Img;
if (isDeviceImageCompressed(RawImg))
#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE
Img = std::make_unique<CompressedRTDeviceBinaryImage>(RawImg);
#else
throw sycl::exception(sycl::make_error_code(sycl::errc::runtime),
"Recieved a compressed device image, but "
"SYCL RT was built without ZSTD support."
"Aborting. ");
#endif
else
Img = std::make_unique<RTDeviceBinaryImage>(RawImg);

Expand Down Expand Up @@ -2808,14 +2822,6 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel(
return UrKernel;
}

// Check if device image is compressed.
inline bool
ProgramManager::isDeviceImageCompressed(sycl_device_binary Bin) const {

auto currFormat = static_cast<ur::DeviceBinaryType>(Bin->Format);
return currFormat == SYCL_DEVICE_BINARY_TYPE_COMPRESSED_NONE;
}

bool doesDevSupportDeviceRequirements(const device &Dev,
const RTDeviceBinaryImage &Img) {
return !checkDevSupportDeviceRequirements(Dev, Img).has_value();
Expand Down
3 changes: 0 additions & 3 deletions sycl/source/detail/program_manager/program_manager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,9 +320,6 @@ class ProgramManager {
collectDependentDeviceImagesForVirtualFunctions(
const RTDeviceBinaryImage &Img, device Dev);

// Returns whether the device image is compressed or not.
inline bool isDeviceImageCompressed(sycl_device_binary Bin) const;

/// The three maps below are used during kernel resolution. Any kernel is
/// identified by its name.
using RTDeviceBinaryImageUPtr = std::unique_ptr<RTDeviceBinaryImage>;
Expand Down
Loading

0 comments on commit 7262fd1

Please sign in to comment.