Skip to content

Commit

Permalink
Manually protect cfitsio in non-reentrant mode
Browse files Browse the repository at this point in the history
Signed-off-by: Szymon Karpiński <skarpinski@nvidia.com>
  • Loading branch information
szkarpinski committed Dec 18, 2023
1 parent f44a607 commit d6f4663
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 29 deletions.
2 changes: 1 addition & 1 deletion dali/operators/reader/loader/fits_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ void FitsLoaderCPU::ReadDataFromHDU(const fits::FitsHandle& current_file,
int status = 0, anynul = 0, nulval = 0;
Index nelem = header.size();

fits::FITS_CALL(fits_read_img(current_file, header.datatype_code, 1, nelem, &nulval,
FITS_CALL(fits_read_img(current_file, header.datatype_code, 1, nelem, &nulval,
static_cast<uint8_t*>(target.data[output_idx].raw_mutable_data()),
&anynul, &status));
}
Expand Down
9 changes: 2 additions & 7 deletions dali/operators/reader/loader/fits_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ class FitsLoader : public FileLoader<Backend, Target> {

DALI_ENFORCE(hdu_indices_.size() == dtypes_.size(),
"Number of extensions does not match the number of provided types");

DALI_ENFORCE(fits_is_reentrant(),
"Loaded instance of cfitsio library does not support multithreading. "
"Please recompile cfitsio in reentrant mode (--enable-reentrant) "
"or use cfitsio delivered in DALI_deps");
}

void PrepareEmpty(Target& target) override {
Expand All @@ -83,14 +78,14 @@ class FitsLoader : public FileLoader<Backend, Target> {

auto path = filesystem::join_path(file_root_, filename);
auto current_file = fits::FitsHandle::OpenFile(path.c_str(), READONLY);
fits::FITS_CALL(fits_get_num_hdus(current_file, &num_hdus, &status));
FITS_CALL(fits_get_num_hdus(current_file, &num_hdus, &status));

// resize ouput vector according to the number of HDUs
ResizeTarget(target, hdu_indices_.size());

for (size_t output_idx = 0; output_idx < hdu_indices_.size(); output_idx++) {
// move to appropiate hdu
fits::FITS_CALL(fits_movabs_hdu(current_file, hdu_indices_[output_idx], NULL, &status));
FITS_CALL(fits_movabs_hdu(current_file, hdu_indices_[output_idx], NULL, &status));

// read the header
fits::HeaderData header;
Expand Down
2 changes: 1 addition & 1 deletion dali/operators/reader/loader/fits_loader_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void FitsLoaderGPU::ReadDataFromHDU(const fits::FitsHandle& current_file,
target.data[output_idx].Resize(header.shape, header.type());

// copy the image to host memory
fits::FITS_CALL(fits_read_img(current_file, header.datatype_code, 1, nelem, &nulval,
FITS_CALL(fits_read_img(current_file, header.datatype_code, 1, nelem, &nulval,
static_cast<uint8_t*>(target.data[output_idx].raw_mutable_data()),
&anynul, &status));
}
Expand Down
43 changes: 29 additions & 14 deletions dali/util/fits.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -36,12 +36,6 @@ std::string GetFitsErrorMessage(int status) {
return status_str;
}

void HandleFitsError(int status) {
if (status) {
DALI_FAIL(GetFitsErrorMessage(status));
}
}

int ImgTypeToDatatypeCode(int img_type) {
switch (img_type) {
case SBYTE_IMG:
Expand Down Expand Up @@ -154,8 +148,9 @@ inline void ExtractData<0>(fitsfile* fptr, std::vector<std::vector<uint8_t>>& ra

unsigned char charnull = 0;
raw_data[*size].resize(nelemll / sizeof(unsigned char));
fits_read_col(fptr, TBYTE, (fptr->Fptr)->cn_compressed, irow, 1, static_cast<int64>(nelemll),
&charnull, raw_data[*size].data(), nullptr, status);
FITS_CALL(fits_read_col(fptr, TBYTE, (fptr->Fptr)->cn_compressed, irow, 1,
static_cast<int64>(nelemll),
&charnull, raw_data[*size].data(), nullptr, status));

++(*size);
*sum_nelemll += nelemll;
Expand All @@ -164,10 +159,6 @@ inline void ExtractData<0>(fitsfile* fptr, std::vector<std::vector<uint8_t>>& ra

} // namespace

void FITS_CALL(int status) {
return HandleFitsError(status);
}

void ParseHeader(HeaderData& parsed_header, fitsfile* src) {
int32_t hdu_type, img_type, n_dims, status = 0;

Expand All @@ -186,7 +177,10 @@ void ParseHeader(HeaderData& parsed_header, fitsfile* src) {
parsed_header.hdu_type = hdu_type;
parsed_header.datatype_code = ImgTypeToDatatypeCode(img_type);
parsed_header.type_info = &TypeFromFitsDatatypeCode(parsed_header.datatype_code);
parsed_header.compressed = (fits_is_compressed_image(src, &status) == 1);
{
FitsLock lock;
parsed_header.compressed = (fits_is_compressed_image(src, &status) == 1);
}

if (parsed_header.compressed) {
FITS_CALL(fits_get_num_rows(src, &parsed_header.rows, &status)); /*get NROW value */
Expand Down Expand Up @@ -285,5 +279,26 @@ size_t HeaderData::nbytes() const {
return type_info ? type_info->size() * size() : 0_uz;
}

void HandleFitsError(int status) {
if (status) {
DALI_FAIL(GetFitsErrorMessage(status));
}
}

FitsLock::FitsLock() : lock_(mutex(), std::defer_lock) {
if (!fits_is_reentrant()) {
DALI_WARN_ONCE("Loaded instance of CFITSIO library does not support multithreading. "
"Please recompile CFITSIO in reentrant mode (--enable-reentrant) "
"or use CFITSIO delivered in DALI_deps. Using non-reentrant version "
"of CFITSIO may degrade the performance.");
lock_.lock();
}
}

std::mutex& fits::FitsLock::mutex() {
static std::mutex mutex = {};
return mutex;
}

} // namespace fits
} // namespace dali
28 changes: 22 additions & 6 deletions dali/util/fits.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,6 +35,25 @@
namespace dali {
namespace fits {

class DLL_PUBLIC FitsLock {
public:
FitsLock();

private:
std::mutex &mutex();
std::unique_lock<std::mutex> lock_;
};

DLL_PUBLIC void HandleFitsError(int status);

/** @brief Wrapper that automatically handles cfitsio error checking.*/
#define FITS_CALL(code) \
do { \
fits::FitsLock lock; \
fits::HandleFitsError(code); \
} while (0)


const std::set<DALIDataType> supportedTypes = {DALI_UINT8, DALI_UINT16, DALI_UINT32, DALI_UINT64,
DALI_INT8, DALI_INT16, DALI_INT32, DALI_INT64,
DALI_FLOAT16, DALI_FLOAT, DALI_FLOAT64};
Expand Down Expand Up @@ -86,7 +105,7 @@ class DLL_PUBLIC FitsHandle : public UniqueHandle<fitsfile *, FitsHandle> {
int status = 0;
fitsfile *ff = nullptr;

fits_open_file(&ff, path, mode, &status);
FITS_CALL(fits_open_file(&ff, path, mode, &status));
DALI_ENFORCE(status == 0,
make_string("Failed to open a file: ", path, " Make sure it exists!"));

Expand All @@ -97,15 +116,12 @@ class DLL_PUBLIC FitsHandle : public UniqueHandle<fitsfile *, FitsHandle> {
/** @brief Calls fits_close_file on the file handle */
static void DestroyHandle(fitsfile *ff) {
int status = 0;
fits_close_file(ff, &status);
FITS_CALL(fits_close_file(ff, &status));
DALI_ENFORCE(status == 0,
make_string("Failed while executing fits_close_file! Status code: ", status));
}
};

/** @brief Wrapper that automatically handles cfitsio error checking.*/
DLL_PUBLIC void FITS_CALL(int status);

} // namespace fits
} // namespace dali

Expand Down

0 comments on commit d6f4663

Please sign in to comment.