From d6f4663f0498066bb1af47f8b3675f0cbc7dc54d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Szymon=20Karpi=C5=84ski?= Date: Mon, 18 Dec 2023 18:57:31 +0100 Subject: [PATCH] Manually protect cfitsio in non-reentrant mode MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Szymon KarpiƄski --- dali/operators/reader/loader/fits_loader.cc | 2 +- dali/operators/reader/loader/fits_loader.h | 9 +--- .../reader/loader/fits_loader_gpu.cc | 2 +- dali/util/fits.cc | 43 +++++++++++++------ dali/util/fits.h | 28 +++++++++--- 5 files changed, 55 insertions(+), 29 deletions(-) diff --git a/dali/operators/reader/loader/fits_loader.cc b/dali/operators/reader/loader/fits_loader.cc index e8af05f5e24..face36dc76e 100644 --- a/dali/operators/reader/loader/fits_loader.cc +++ b/dali/operators/reader/loader/fits_loader.cc @@ -34,7 +34,7 @@ void FitsLoaderCPU::ReadDataFromHDU(const fits::FitsHandle& current_file, int status = 0, anynul = 0, nulval = 0; Index nelem = header.size(); - fits::FITS_CALL(fits_read_img(current_file, header.datatype_code, 1, nelem, &nulval, + FITS_CALL(fits_read_img(current_file, header.datatype_code, 1, nelem, &nulval, static_cast(target.data[output_idx].raw_mutable_data()), &anynul, &status)); } diff --git a/dali/operators/reader/loader/fits_loader.h b/dali/operators/reader/loader/fits_loader.h index f9383b86c9a..f0ed30dccf4 100644 --- a/dali/operators/reader/loader/fits_loader.h +++ b/dali/operators/reader/loader/fits_loader.h @@ -58,11 +58,6 @@ class FitsLoader : public FileLoader { DALI_ENFORCE(hdu_indices_.size() == dtypes_.size(), "Number of extensions does not match the number of provided types"); - - DALI_ENFORCE(fits_is_reentrant(), - "Loaded instance of cfitsio library does not support multithreading. " - "Please recompile cfitsio in reentrant mode (--enable-reentrant) " - "or use cfitsio delivered in DALI_deps"); } void PrepareEmpty(Target& target) override { @@ -83,14 +78,14 @@ class FitsLoader : public FileLoader { auto path = filesystem::join_path(file_root_, filename); auto current_file = fits::FitsHandle::OpenFile(path.c_str(), READONLY); - fits::FITS_CALL(fits_get_num_hdus(current_file, &num_hdus, &status)); + FITS_CALL(fits_get_num_hdus(current_file, &num_hdus, &status)); // resize ouput vector according to the number of HDUs ResizeTarget(target, hdu_indices_.size()); for (size_t output_idx = 0; output_idx < hdu_indices_.size(); output_idx++) { // move to appropiate hdu - fits::FITS_CALL(fits_movabs_hdu(current_file, hdu_indices_[output_idx], NULL, &status)); + FITS_CALL(fits_movabs_hdu(current_file, hdu_indices_[output_idx], NULL, &status)); // read the header fits::HeaderData header; diff --git a/dali/operators/reader/loader/fits_loader_gpu.cc b/dali/operators/reader/loader/fits_loader_gpu.cc index 545b794e5c9..d1f23530b2e 100644 --- a/dali/operators/reader/loader/fits_loader_gpu.cc +++ b/dali/operators/reader/loader/fits_loader_gpu.cc @@ -48,7 +48,7 @@ void FitsLoaderGPU::ReadDataFromHDU(const fits::FitsHandle& current_file, target.data[output_idx].Resize(header.shape, header.type()); // copy the image to host memory - fits::FITS_CALL(fits_read_img(current_file, header.datatype_code, 1, nelem, &nulval, + FITS_CALL(fits_read_img(current_file, header.datatype_code, 1, nelem, &nulval, static_cast(target.data[output_idx].raw_mutable_data()), &anynul, &status)); } diff --git a/dali/util/fits.cc b/dali/util/fits.cc index a36cf87a52f..1cf7bc074d3 100644 --- a/dali/util/fits.cc +++ b/dali/util/fits.cc @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -36,12 +36,6 @@ std::string GetFitsErrorMessage(int status) { return status_str; } -void HandleFitsError(int status) { - if (status) { - DALI_FAIL(GetFitsErrorMessage(status)); - } -} - int ImgTypeToDatatypeCode(int img_type) { switch (img_type) { case SBYTE_IMG: @@ -154,8 +148,9 @@ inline void ExtractData<0>(fitsfile* fptr, std::vector>& ra unsigned char charnull = 0; raw_data[*size].resize(nelemll / sizeof(unsigned char)); - fits_read_col(fptr, TBYTE, (fptr->Fptr)->cn_compressed, irow, 1, static_cast(nelemll), - &charnull, raw_data[*size].data(), nullptr, status); + FITS_CALL(fits_read_col(fptr, TBYTE, (fptr->Fptr)->cn_compressed, irow, 1, + static_cast(nelemll), + &charnull, raw_data[*size].data(), nullptr, status)); ++(*size); *sum_nelemll += nelemll; @@ -164,10 +159,6 @@ inline void ExtractData<0>(fitsfile* fptr, std::vector>& ra } // namespace -void FITS_CALL(int status) { - return HandleFitsError(status); -} - void ParseHeader(HeaderData& parsed_header, fitsfile* src) { int32_t hdu_type, img_type, n_dims, status = 0; @@ -186,7 +177,10 @@ void ParseHeader(HeaderData& parsed_header, fitsfile* src) { parsed_header.hdu_type = hdu_type; parsed_header.datatype_code = ImgTypeToDatatypeCode(img_type); parsed_header.type_info = &TypeFromFitsDatatypeCode(parsed_header.datatype_code); - parsed_header.compressed = (fits_is_compressed_image(src, &status) == 1); + { + FitsLock lock; + parsed_header.compressed = (fits_is_compressed_image(src, &status) == 1); + } if (parsed_header.compressed) { FITS_CALL(fits_get_num_rows(src, &parsed_header.rows, &status)); /*get NROW value */ @@ -285,5 +279,26 @@ size_t HeaderData::nbytes() const { return type_info ? type_info->size() * size() : 0_uz; } +void HandleFitsError(int status) { + if (status) { + DALI_FAIL(GetFitsErrorMessage(status)); + } +} + +FitsLock::FitsLock() : lock_(mutex(), std::defer_lock) { + if (!fits_is_reentrant()) { + DALI_WARN_ONCE("Loaded instance of CFITSIO library does not support multithreading. " + "Please recompile CFITSIO in reentrant mode (--enable-reentrant) " + "or use CFITSIO delivered in DALI_deps. Using non-reentrant version " + "of CFITSIO may degrade the performance."); + lock_.lock(); + } +} + +std::mutex& fits::FitsLock::mutex() { + static std::mutex mutex = {}; + return mutex; +} + } // namespace fits } // namespace dali diff --git a/dali/util/fits.h b/dali/util/fits.h index c6d663a710c..4d92593720d 100644 --- a/dali/util/fits.h +++ b/dali/util/fits.h @@ -1,4 +1,4 @@ -// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -35,6 +35,25 @@ namespace dali { namespace fits { +class DLL_PUBLIC FitsLock { + public: + FitsLock(); + + private: + std::mutex &mutex(); + std::unique_lock lock_; +}; + +DLL_PUBLIC void HandleFitsError(int status); + +/** @brief Wrapper that automatically handles cfitsio error checking.*/ +#define FITS_CALL(code) \ + do { \ + fits::FitsLock lock; \ + fits::HandleFitsError(code); \ + } while (0) + + const std::set supportedTypes = {DALI_UINT8, DALI_UINT16, DALI_UINT32, DALI_UINT64, DALI_INT8, DALI_INT16, DALI_INT32, DALI_INT64, DALI_FLOAT16, DALI_FLOAT, DALI_FLOAT64}; @@ -86,7 +105,7 @@ class DLL_PUBLIC FitsHandle : public UniqueHandle { int status = 0; fitsfile *ff = nullptr; - fits_open_file(&ff, path, mode, &status); + FITS_CALL(fits_open_file(&ff, path, mode, &status)); DALI_ENFORCE(status == 0, make_string("Failed to open a file: ", path, " Make sure it exists!")); @@ -97,15 +116,12 @@ class DLL_PUBLIC FitsHandle : public UniqueHandle { /** @brief Calls fits_close_file on the file handle */ static void DestroyHandle(fitsfile *ff) { int status = 0; - fits_close_file(ff, &status); + FITS_CALL(fits_close_file(ff, &status)); DALI_ENFORCE(status == 0, make_string("Failed while executing fits_close_file! Status code: ", status)); } }; -/** @brief Wrapper that automatically handles cfitsio error checking.*/ -DLL_PUBLIC void FITS_CALL(int status); - } // namespace fits } // namespace dali