From 7262fd1275f63fbeaaed398954eebb8cd18edb72 Mon Sep 17 00:00:00 2001 From: "Agarwal, Udit" Date: Wed, 11 Sep 2024 12:48:28 -0700 Subject: [PATCH] Address reviews --- .../ClangOffloadWrapper.cpp | 2 +- sycl/source/CMakeLists.txt | 10 +- sycl/source/detail/compression.hpp | 120 ++++++++++-------- sycl/source/detail/device_binary_image.cpp | 40 +++--- sycl/source/detail/device_binary_image.hpp | 4 +- .../program_manager/program_manager.cpp | 22 ++-- .../program_manager/program_manager.hpp | 3 - .../compression/CompressionTests.cpp | 28 ++-- 8 files changed, 117 insertions(+), 112 deletions(-) diff --git a/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp b/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp index 7d1fc238ed436..ef3d4fc372da3 100644 --- a/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp +++ b/clang/tools/clang-offload-wrapper/ClangOffloadWrapper.cpp @@ -1153,7 +1153,7 @@ class BinaryWrapper { Twine(OffloadKindTag) + Twine(ImgId) + Twine(".data"), Kind, Img.Tgt); - // Change image format to compressed_non. + // Change image format to compressed_none. Ffmt = ConstantInt::get(Type::getInt8Ty(C), BinaryImageFormat::compressed_none); } diff --git a/sycl/source/CMakeLists.txt b/sycl/source/CMakeLists.txt index f16d6899de285..f432203836855 100644 --- a/sycl/source/CMakeLists.txt +++ b/sycl/source/CMakeLists.txt @@ -74,9 +74,13 @@ function(add_sycl_rt_library LIB_NAME LIB_OBJ_NAME) set(CMAKE_FIND_DEBUG_MODE 1) # Need zstd for device image compression. - find_package(zstd REQUIRED) - target_link_libraries(${LIB_NAME} PRIVATE ${zstd_STATIC_LIBRARY}) - target_include_directories(${LIB_OBJ_NAME} PRIVATE ${zstd_INCLUDE_DIR}) + find_package(zstd) + if (NOT zstd_FOUND) + target_compile_definitions(${LIB_OBJ_NAME} PRIVATE SYCL_RT_ZSTD_NOT_AVAIABLE) + else() + target_link_libraries(${LIB_NAME} PRIVATE ${zstd_STATIC_LIBRARY}) + target_include_directories(${LIB_OBJ_NAME} PRIVATE ${zstd_INCLUDE_DIR}) + endif() target_include_directories(${LIB_OBJ_NAME} PRIVATE ${BOOST_UNORDERED_INCLUDE_DIRS}) diff --git a/sycl/source/detail/compression.hpp b/sycl/source/detail/compression.hpp index c106b7861ada1..fe997d08159d8 100644 --- a/sycl/source/detail/compression.hpp +++ b/sycl/source/detail/compression.hpp @@ -7,6 +7,10 @@ //===----------------------------------------------------------------------===// #pragma once +#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE + +#include + #include #include #include @@ -21,23 +25,7 @@ namespace detail { // Singleton class to handle ZSTD compression and decompression. class ZSTDCompressor { private: - // Initialize ZSTD context and error code. - ZSTDCompressor() { - m_ZSTD_compression_ctx = static_cast(ZSTD_createCCtx()); - m_ZSTD_decompression_ctx = static_cast(ZSTD_createDCtx()); - - if (!m_ZSTD_compression_ctx || !m_ZSTD_decompression_ctx) { - std::cerr << "Error creating ZSTD contexts. \n"; - } - - m_lastError = 0; - } - - // Free ZSTD contexts. - ~ZSTDCompressor() { - ZSTD_freeCCtx(static_cast(m_ZSTD_compression_ctx)); - ZSTD_freeDCtx(static_cast(m_ZSTD_decompression_ctx)); - } + ZSTDCompressor() {} ZSTDCompressor(const ZSTDCompressor &) = delete; ZSTDCompressor &operator=(const ZSTDCompressor &) = delete; @@ -50,73 +38,93 @@ class ZSTDCompressor { // Public APIs public: - // Return 0 is last (de)compression was successful, otherwise return error - // code. - static int GetLastError() { return GetSingletonInstance().m_lastError; } - - // Returns a string representation of the error code. - // If the error code is 0, it returns "No error detected". - static std::string GetErrorString(int code) { - return ZSTD_getErrorName(code); - } - // Blob (de)compression do not assume format/structure of the input buffer. static std::unique_ptr CompressBlob(const char *src, size_t srcSize, size_t &dstSize, int level) { auto &instance = GetSingletonInstance(); + // Lazy initialize compression context. + if (!instance.m_ZSTD_compression_ctx) { + + // Call ZSTD_createCCtx() and ZSTD_freeCCtx() to create and free the + // context. + instance.m_ZSTD_compression_ctx = + std::unique_ptr(ZSTD_createCCtx(), + ZSTD_freeCCtx); + if (!instance.m_ZSTD_compression_ctx) { + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Failed to create ZSTD compression context"); + } + } + // Get maximum size of the compressed buffer and allocate it. auto dstBufferSize = ZSTD_compressBound(srcSize); auto dstBuffer = std::unique_ptr(new char[dstBufferSize]); + if (!dstBuffer) + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Failed to allocate memory for compressed data"); + // Compress the input buffer. - dstSize = ZSTD_compressCCtx( - static_cast(instance.m_ZSTD_compression_ctx), - static_cast(dstBuffer.get()), dstBufferSize, - static_cast(src), srcSize, level); + dstSize = + ZSTD_compressCCtx(instance.m_ZSTD_compression_ctx.get(), + static_cast(dstBuffer.get()), dstBufferSize, + static_cast(src), srcSize, level); // Store the error code if compression failed. if (ZSTD_isError(dstSize)) - instance.m_lastError = dstSize; - else - instance.m_lastError = 0; + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + ZSTD_getErrorName(dstSize)); // Pass ownership of the buffer to the caller. return dstBuffer; } - static std::unique_ptr - DecompressBlob(const char *src, size_t srcSize, size_t &dstSize) { + static std::unique_ptr DecompressBlob(const char *src, size_t srcSize, + size_t &dstSize) { auto &instance = GetSingletonInstance(); + // Lazy initialize decompression context. + if (!instance.m_ZSTD_decompression_ctx) { + + // Call ZSTD_createDCtx() and ZSTD_freeDCtx() to create and free the + // context. + instance.m_ZSTD_decompression_ctx = + std::unique_ptr(ZSTD_createDCtx(), + ZSTD_freeDCtx); + if (!instance.m_ZSTD_decompression_ctx) { + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Failed to create ZSTD decompression context"); + } + } + // Size of decompressed image can be larger than what we can allocate // on heap. In that case, we need to use streaming decompression. - // TODO: Throw if the decompression size is too large. auto dstBufferSize = ZSTD_getFrameContentSize(src, srcSize); if (dstBufferSize == ZSTD_CONTENTSIZE_UNKNOWN || dstBufferSize == ZSTD_CONTENTSIZE_ERROR) { - - std::cerr << "Error determining size of uncompressed data\n"; - dstSize = 0; - instance.m_lastError = dstBufferSize; - return nullptr; + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Error determining size of uncompressed data."); } // Allocate buffer for decompressed data. - auto dstBuffer = - std::unique_ptr(new unsigned char[dstBufferSize]); + auto dstBuffer = std::unique_ptr(new char[dstBufferSize]); - dstSize = ZSTD_decompressDCtx( - static_cast(instance.m_ZSTD_decompression_ctx), - static_cast(dstBuffer.get()), dstBufferSize, - static_cast(src), srcSize); + if (!dstBuffer) + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Failed to allocate memory for decompressed data"); + + dstSize = + ZSTD_decompressDCtx(instance.m_ZSTD_decompression_ctx.get(), + static_cast(dstBuffer.get()), dstBufferSize, + static_cast(src), srcSize); // In case of decompression error, return the error message and set dstSize // to 0. if (ZSTD_isError(dstSize)) { - instance.m_lastError = dstSize; - dstSize = 0; + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + ZSTD_getErrorName(dstSize)); } // Pass ownership of the buffer to the caller. @@ -125,12 +133,14 @@ class ZSTDCompressor { // Data fields private: - int m_lastError; - // ZSTD context. Reusing ZSTD context speeds up subsequent (de)compression. - // Storing as void* to avoid including ZSTD headers in this file. - void *m_ZSTD_compression_ctx; - void *m_ZSTD_decompression_ctx; + // ZSTD contexts. Reusing ZSTD context speeds up subsequent (de)compression. + std::unique_ptr m_ZSTD_compression_ctx{ + nullptr, nullptr}; + std::unique_ptr m_ZSTD_decompression_ctx{ + nullptr, nullptr}; }; } // namespace detail } // namespace _V1 } // namespace sycl + +#endif // SYCL_RT_ZSTD_NOT_AVAIABLE \ No newline at end of file diff --git a/sycl/source/detail/device_binary_image.cpp b/sycl/source/detail/device_binary_image.cpp index 5259f65234d76..e0815f16c9a2f 100644 --- a/sycl/source/detail/device_binary_image.cpp +++ b/sycl/source/detail/device_binary_image.cpp @@ -172,28 +172,27 @@ void RTDeviceBinaryImage::init(sycl_device_binary Bin) { if (Format == SYCL_DEVICE_BINARY_TYPE_NONE) // try to determine the format; may remain "NONE" - Format = ur::getBinaryImageFormat(this->Bin->BinaryStart, getSize()); + Format = ur::getBinaryImageFormat(Bin->BinaryStart, getSize()); - SpecConstIDMap.init(this->Bin, __SYCL_PROPERTY_SET_SPEC_CONST_MAP); + SpecConstIDMap.init(Bin, __SYCL_PROPERTY_SET_SPEC_CONST_MAP); SpecConstDefaultValuesMap.init( - this->Bin, __SYCL_PROPERTY_SET_SPEC_CONST_DEFAULT_VALUES_MAP); - DeviceLibReqMask.init(this->Bin, __SYCL_PROPERTY_SET_DEVICELIB_REQ_MASK); - KernelParamOptInfo.init(this->Bin, __SYCL_PROPERTY_SET_KERNEL_PARAM_OPT_INFO); - AssertUsed.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_ASSERT_USED); - ProgramMetadata.init(this->Bin, __SYCL_PROPERTY_SET_PROGRAM_METADATA); + Bin, __SYCL_PROPERTY_SET_SPEC_CONST_DEFAULT_VALUES_MAP); + DeviceLibReqMask.init(Bin, __SYCL_PROPERTY_SET_DEVICELIB_REQ_MASK); + KernelParamOptInfo.init(Bin, __SYCL_PROPERTY_SET_KERNEL_PARAM_OPT_INFO); + AssertUsed.init(Bin, __SYCL_PROPERTY_SET_SYCL_ASSERT_USED); + ProgramMetadata.init(Bin, __SYCL_PROPERTY_SET_PROGRAM_METADATA); // Convert ProgramMetadata into the UR format for (const auto &Prop : ProgramMetadata) { ProgramMetadataUR.push_back( ur::mapDeviceBinaryPropertyToProgramMetadata(Prop)); } - ExportedSymbols.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS); - ImportedSymbols.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_IMPORTED_SYMBOLS); - DeviceGlobals.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_GLOBALS); - DeviceRequirements.init(this->Bin, - __SYCL_PROPERTY_SET_SYCL_DEVICE_REQUIREMENTS); - HostPipes.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_HOST_PIPES); - VirtualFunctions.init(this->Bin, __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS); + ExportedSymbols.init(Bin, __SYCL_PROPERTY_SET_SYCL_EXPORTED_SYMBOLS); + ImportedSymbols.init(Bin, __SYCL_PROPERTY_SET_SYCL_IMPORTED_SYMBOLS); + DeviceGlobals.init(Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_GLOBALS); + DeviceRequirements.init(Bin, __SYCL_PROPERTY_SET_SYCL_DEVICE_REQUIREMENTS); + HostPipes.init(Bin, __SYCL_PROPERTY_SET_SYCL_HOST_PIPES); + VirtualFunctions.init(Bin, __SYCL_PROPERTY_SET_SYCL_VIRTUAL_FUNCTIONS); ImageId = ImageCounter++; } @@ -231,6 +230,7 @@ DynRTDeviceBinaryImage::~DynRTDeviceBinaryImage() { Bin = nullptr; } +#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage( sycl_device_binary CompressedBin) : RTDeviceBinaryImage() { @@ -243,18 +243,13 @@ CompressedRTDeviceBinaryImage::CompressedRTDeviceBinaryImage( reinterpret_cast(CompressedBin->BinaryStart), compressedDataSize, DecompressedSize); - if (!m_DecompressedData) { - throw sycl::exception( - sycl::make_error_code(sycl::errc::runtime), - "Failed to decompress device binary image. " + - ZSTDCompressor::GetErrorString(ZSTDCompressor::GetLastError())); - } - Bin = new sycl_device_binary_struct(*CompressedBin); - Bin->BinaryStart = m_DecompressedData.get(); + Bin->BinaryStart = (const unsigned char *)(m_DecompressedData.get()); Bin->BinaryEnd = Bin->BinaryStart + DecompressedSize; // Set the new format to none and let RT determine the format. + // TODO: Add support for automatically detecting compressed + // binary format. Bin->Format = SYCL_DEVICE_BINARY_TYPE_NONE; init(Bin); @@ -265,6 +260,7 @@ CompressedRTDeviceBinaryImage::~CompressedRTDeviceBinaryImage() { delete Bin; Bin = nullptr; } +#endif // SYCL_RT_ZSTD_NOT_AVAIABLE } // namespace detail } // namespace _V1 diff --git a/sycl/source/detail/device_binary_image.hpp b/sycl/source/detail/device_binary_image.hpp index acf3265b0099b..62dc0afce90fd 100644 --- a/sycl/source/detail/device_binary_image.hpp +++ b/sycl/source/detail/device_binary_image.hpp @@ -276,6 +276,7 @@ class DynRTDeviceBinaryImage : public RTDeviceBinaryImage { std::unique_ptr Data; }; +#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE // Compressed device binary image. It decompresses the binary image on // construction and stores the decompressed data as RTDeviceBinaryImage. // Also, frees the decompressed data in destructor. @@ -290,8 +291,9 @@ class CompressedRTDeviceBinaryImage : public RTDeviceBinaryImage { } private: - std::unique_ptr m_DecompressedData; + std::unique_ptr m_DecompressedData; }; +#endif // SYCL_RT_ZSTD_NOT_AVAIABLE } // namespace detail } // namespace _V1 diff --git a/sycl/source/detail/program_manager/program_manager.cpp b/sycl/source/detail/program_manager/program_manager.cpp index 498b104796fb3..e94eabfa86eac 100644 --- a/sycl/source/detail/program_manager/program_manager.cpp +++ b/sycl/source/detail/program_manager/program_manager.cpp @@ -1529,6 +1529,13 @@ getDeviceLibPrograms(const ContextImplPtr Context, return Programs; } +// Check if device image is compressed. +static inline bool isDeviceImageCompressed(sycl_device_binary Bin) { + + auto currFormat = static_cast(Bin->Format); + return currFormat == SYCL_DEVICE_BINARY_TYPE_COMPRESSED_NONE; +} + ProgramManager::ProgramPtr ProgramManager::build( ProgramPtr Program, const ContextImplPtr Context, const std::string &CompileOptions, const std::string &LinkOptions, @@ -1660,7 +1667,14 @@ void ProgramManager::addImages(sycl_device_binaries DeviceBinary) { std::unique_ptr Img; if (isDeviceImageCompressed(RawImg)) +#ifndef SYCL_RT_ZSTD_NOT_AVAIABLE Img = std::make_unique(RawImg); +#else + throw sycl::exception(sycl::make_error_code(sycl::errc::runtime), + "Recieved a compressed device image, but " + "SYCL RT was built without ZSTD support." + "Aborting. "); +#endif else Img = std::make_unique(RawImg); @@ -2808,14 +2822,6 @@ ur_kernel_handle_t ProgramManager::getOrCreateMaterializedKernel( return UrKernel; } -// Check if device image is compressed. -inline bool -ProgramManager::isDeviceImageCompressed(sycl_device_binary Bin) const { - - auto currFormat = static_cast(Bin->Format); - return currFormat == SYCL_DEVICE_BINARY_TYPE_COMPRESSED_NONE; -} - bool doesDevSupportDeviceRequirements(const device &Dev, const RTDeviceBinaryImage &Img) { return !checkDevSupportDeviceRequirements(Dev, Img).has_value(); diff --git a/sycl/source/detail/program_manager/program_manager.hpp b/sycl/source/detail/program_manager/program_manager.hpp index c1a6a208417c3..c38cc7babd370 100644 --- a/sycl/source/detail/program_manager/program_manager.hpp +++ b/sycl/source/detail/program_manager/program_manager.hpp @@ -320,9 +320,6 @@ class ProgramManager { collectDependentDeviceImagesForVirtualFunctions( const RTDeviceBinaryImage &Img, device Dev); - // Returns whether the device image is compressed or not. - inline bool isDeviceImageCompressed(sycl_device_binary Bin) const; - /// The three maps below are used during kernel resolution. Any kernel is /// identified by its name. using RTDeviceBinaryImageUPtr = std::unique_ptr; diff --git a/sycl/unittests/compression/CompressionTests.cpp b/sycl/unittests/compression/CompressionTests.cpp index 77a577cec2a2f..0114cfeac5551 100644 --- a/sycl/unittests/compression/CompressionTests.cpp +++ b/sycl/unittests/compression/CompressionTests.cpp @@ -38,15 +38,6 @@ TEST(CompressionTest, SimpleCompression) { // Check if decompressed data is same as original data. std::string decompressedStr((char *)decompressedData.get(), decompressedSize); ASSERT_EQ(data, decompressedStr); - - // Check that error code is 0 after successful decompression. - int errorCode = ZSTDCompressor::GetLastError(); - ASSERT_EQ(errorCode, 0); - - // Check that error string is "No error detected" after successful - // decompression. - std::string errorString = ZSTDCompressor::GetErrorString(errorCode); - ASSERT_EQ(errorString, "No error detected"); } // Test getting error code and error string. @@ -55,14 +46,15 @@ TEST(CompressionTest, SimpleCompression) { TEST(CompressionTest, NegativeErrorTest) { std::string input = "Hello, World!"; size_t decompressedSize = 0; - auto compressedData = ZSTDCompressor::DecompressBlob( - input.c_str(), input.size(), decompressedSize); - - int errorCode = ZSTDCompressor::GetLastError(); - ASSERT_NE(errorCode, 0); - - std::string errorString = ZSTDCompressor::GetErrorString(errorCode); - ASSERT_NE(errorString, "No error detected"); + bool threwException = false; + try { + auto compressedData = ZSTDCompressor::DecompressBlob( + input.c_str(), input.size(), decompressedSize); + } catch (sycl::exception &e) { + threwException = true; + } + + ASSERT_TRUE(threwException); } // Test passing empty input to (de)compress. @@ -75,7 +67,6 @@ TEST(CompressionTest, EmptyInputTest) { ASSERT_NE(compressedData, nullptr); ASSERT_GT(compressedSize, 0); - ASSERT_EQ(ZSTDCompressor::GetLastError(), 0); size_t decompressedSize = 0; auto decompressedData = ZSTDCompressor::DecompressBlob( @@ -83,7 +74,6 @@ TEST(CompressionTest, EmptyInputTest) { ASSERT_NE(decompressedData, nullptr); ASSERT_EQ(decompressedSize, 0); - ASSERT_EQ(ZSTDCompressor::GetLastError(), 0); std::string decompressedStr((char *)decompressedData.get(), decompressedSize); ASSERT_EQ(input, decompressedStr);