diff --git a/cmake/Dependencies.common.cmake b/cmake/Dependencies.common.cmake index dd8244c2768..714b516633c 100644 --- a/cmake/Dependencies.common.cmake +++ b/cmake/Dependencies.common.cmake @@ -274,8 +274,9 @@ endif() ################################################################## set(DALI_INSTALL_REQUIRES_NVIMGCODEC "") if(BUILD_NVIMAGECODEC) - set(NVIMGCODEC_MIN_VERSION "0.3.0") - set(NVIMGCODEC_MAX_VERSION "0.4.0") + set(NVIMGCODEC_MIN_VERSION "0.4.0") + set(NVIMGCODEC_MAX_VERSION "0.5.0") + message(STATUS "nvImageCodec - requires version >=${NVIMGCODEC_MIN_VERSION}, <${NVIMGCODEC_MAX_VERSION}") if (WITH_DYNAMIC_NVIMGCODEC) message(STATUS "nvImageCodec - dynamic load") @@ -288,8 +289,8 @@ if(BUILD_NVIMAGECODEC) include(FetchContent) FetchContent_Declare( nvimgcodec_headers - URL https://developer.download.nvidia.com/compute/nvimgcodec/redist/nvimgcodec/linux-x86_64/nvimgcodec-linux-x86_64-0.3.0.5-archive.tar.xz - URL_HASH SHA512=259bff93305c301fb4325c6e2f71da93f3f6e0b38c7c8739913ca70b5a9c74cc898a608c5ac6e830dba1739878e53607ded03deaf2f23af3a9cc473463f100eb + URL https://developer.download.nvidia.com/compute/nvimgcodec/redist/nvimgcodec/linux-x86_64/nvimgcodec-linux-x86_64-0.4.0.9-archive.tar.xz + URL_HASH SHA512=d1dc489b6f6795548ec88293cc3e08034fc5bca636f134d622c9f2c3f54569b8779464ab7628ea3dfaad283ff631962ef0601354ff5e2c602a769beb44f19f00 ) FetchContent_Populate(nvimgcodec_headers) set(nvimgcodec_SEARCH_PATH "${nvimgcodec_headers_SOURCE_DIR}/${CUDA_VERSION_MAJOR}/include") @@ -321,7 +322,7 @@ if(BUILD_NVIMAGECODEC) ExternalProject_Add( nvImageCodec GIT_REPOSITORY https://github.com/NVIDIA/nvImageCodec.git - GIT_TAG v0.3.0 + GIT_TAG v0.4.0 GIT_SUBMODULES "external/pybind11" "external/NVTX" "external/googletest" diff --git a/conda/third_party/dali_nvimagecodec/recipe/meta.yaml b/conda/third_party/dali_nvimagecodec/recipe/meta.yaml index 4c67501779b..83a402ac214 100644 --- a/conda/third_party/dali_nvimagecodec/recipe/meta.yaml +++ b/conda/third_party/dali_nvimagecodec/recipe/meta.yaml @@ -13,7 +13,7 @@ # limitations under the License. -{% set build_version = "0.3.0" %} +{% set build_version = "0.4.0" %} package: name: nvidia-nvimagecodec-cuda{{ environ.get('CUDA_VERSION', '') | replace(".","") }} @@ -21,7 +21,7 @@ package: source: git_url: https://github.com/NVIDIA/nvImageCodec.git - git_rev: v0.3.0 + git_rev: v0.4.0 build: number: 0 diff --git a/dali/core/cuda_event_pool.cc b/dali/core/cuda_event_pool.cc index b90af023197..ee1f1cc7c6e 100644 --- a/dali/core/cuda_event_pool.cc +++ b/dali/core/cuda_event_pool.cc @@ -17,6 +17,7 @@ #include #include "dali/core/cuda_event_pool.h" #include "dali/core/cuda_error.h" +#include "dali/core/math_util.h" namespace dali { @@ -29,6 +30,16 @@ CUDAEventPool::CUDAEventPool(unsigned event_flags) { int num_devices = 0; CUDA_CALL(cudaGetDeviceCount(&num_devices)); dev_events_.resize(num_devices); + // Avoid creation of events during pipeline run + int evt_pool_init_sz = 2000; + if (const char *evt_pool_init_sz_env = std::getenv("DALI_EVENT_POOL_INITIAL_SIZE")) { + evt_pool_init_sz = clamp(atoi(evt_pool_init_sz_env), 0, 10000); + } + for (int device_id = 0; device_id < num_devices; device_id++) { + for (int i = 0; i < evt_pool_init_sz; i++) { + Put(CUDAEvent::CreateWithFlags(cudaEventDisableTiming), device_id); + } + } } CUDAEvent CUDAEventPool::Get(int device_id) { diff --git a/dali/operators/imgcodec/image_decoder.h b/dali/operators/imgcodec/image_decoder.h index aed89353e41..d0aa910115e 100644 --- a/dali/operators/imgcodec/image_decoder.h +++ b/dali/operators/imgcodec/image_decoder.h @@ -32,14 +32,13 @@ #include "dali/pipeline/operator/operator.h" #if not(WITH_DYNAMIC_NVIMGCODEC_ENABLED) -nvimgcodecStatus_t get_libjpeg_turbo_extension_desc(nvimgcodecExtensionDesc_t* ext_desc); -nvimgcodecStatus_t get_libtiff_extension_desc(nvimgcodecExtensionDesc_t* ext_desc); -nvimgcodecStatus_t get_opencv_extension_desc(nvimgcodecExtensionDesc_t* ext_desc); -nvimgcodecStatus_t get_nvjpeg_extension_desc(nvimgcodecExtensionDesc_t* ext_desc); -nvimgcodecStatus_t get_nvjpeg2k_extension_desc(nvimgcodecExtensionDesc_t* ext_desc); +nvimgcodecStatus_t get_libjpeg_turbo_extension_desc(nvimgcodecExtensionDesc_t *ext_desc); +nvimgcodecStatus_t get_libtiff_extension_desc(nvimgcodecExtensionDesc_t *ext_desc); +nvimgcodecStatus_t get_opencv_extension_desc(nvimgcodecExtensionDesc_t *ext_desc); +nvimgcodecStatus_t get_nvjpeg_extension_desc(nvimgcodecExtensionDesc_t *ext_desc); +nvimgcodecStatus_t get_nvjpeg2k_extension_desc(nvimgcodecExtensionDesc_t *ext_desc); #endif - #ifndef DALI_OPERATORS_IMGCODEC_IMAGE_DECODER_H_ #define DALI_OPERATORS_IMGCODEC_IMAGE_DECODER_H_ @@ -119,7 +118,7 @@ inline int static_dali_pinned_free(void *ctx, void *ptr, size_t size, cudaStream return cudaSuccess; } -inline void get_nvimgcodec_version(int* major, int *minor, int* patch) { +inline void get_nvimgcodec_version(int *major, int *minor, int *patch) { static int s_major = -1, s_minor = -1, s_patch = -1; auto version_check_f = [&] { nvimgcodecProperties_t properties{NVIMGCODEC_STRUCTURE_TYPE_PROPERTIES, @@ -148,7 +147,7 @@ class ImageDecoder : public StatelessOperator { ~ImageDecoder() override { #if not(WITH_DYNAMIC_NVIMGCODEC_ENABLED) decoder_.reset(); // first stop the decoder - for (auto& extension : extensions_) { + for (auto &extension : extensions_) { nvimgcodecExtensionDestroy(extension); } #endif @@ -229,11 +228,9 @@ class ImageDecoder : public StatelessOperator { if (std::is_same::value) { thread_pool_ = std::make_unique(num_threads_, device_id_, spec.GetArgument("affine"), "MixedDecoder"); - if (spec_.HasArgument("cache_size")) cache_ = std::make_unique(spec_); } - EnforceMinimumNvimgcodecVersion(); nvimgcodecDeviceAllocator_t *dev_alloc_ptr = nullptr; @@ -267,7 +264,7 @@ class ImageDecoder : public StatelessOperator { nullptr}; const char *log_lvl_env = std::getenv("DALI_NVIMGCODEC_LOG_LEVEL"); - int log_lvl = log_lvl_env ? clamp(atoi(log_lvl_env), 1, 5): 2; + int log_lvl = log_lvl_env ? clamp(atoi(log_lvl_env), 1, 5) : 2; instance_create_info.load_extension_modules = static_cast(WITH_DYNAMIC_NVIMGCODEC_ENABLED); instance_create_info.load_builtin_modules = static_cast(true); @@ -355,8 +352,7 @@ class ImageDecoder : public StatelessOperator { opts_.add_module_option("nvjpeg_cuda_decoder", "preallocate_buffers", true); // Batch size - opts_.add_module_option("nvjpeg_hw_decoder", "preallocate_batch_size", - std::max(1, max_batch_size_)); + opts_.add_module_option("nvjpeg_hw_decoder", "preallocate_batch_size", max_batch_size_); // Nvjpeg2k parallel tiles opts_.add_module_option("nvjpeg2k_cuda_decoder", "num_parallel_tiles", 16); @@ -367,32 +363,35 @@ class ImageDecoder : public StatelessOperator { backends_.clear(); backends_.reserve(4); if (nvimgcodec_device_id != NVIMGCODEC_DEVICE_CPU_ONLY) { - backends_.push_back( - nvimgcodecBackend_t{NVIMGCODEC_STRUCTURE_TYPE_BACKEND, - sizeof(nvimgcodecBackend_t), - nullptr, - NVIMGCODEC_BACKEND_KIND_HW_GPU_ONLY, - {NVIMGCODEC_STRUCTURE_TYPE_BACKEND_PARAMS, - sizeof(nvimgcodecBackendParams_t), nullptr, hw_load}}); - backends_.push_back(nvimgcodecBackend_t{NVIMGCODEC_STRUCTURE_TYPE_BACKEND, - sizeof(nvimgcodecBackend_t), - nullptr, - NVIMGCODEC_BACKEND_KIND_GPU_ONLY, - {NVIMGCODEC_STRUCTURE_TYPE_BACKEND_PARAMS, - sizeof(nvimgcodecBackendParams_t), nullptr, 1.0f}}); - backends_.push_back(nvimgcodecBackend_t{NVIMGCODEC_STRUCTURE_TYPE_BACKEND, - sizeof(nvimgcodecBackend_t), - nullptr, - NVIMGCODEC_BACKEND_KIND_HYBRID_CPU_GPU, - {NVIMGCODEC_STRUCTURE_TYPE_BACKEND_PARAMS, - sizeof(nvimgcodecBackendParams_t), nullptr, 1.0f}}); + backends_.push_back(nvimgcodecBackend_t{ + NVIMGCODEC_STRUCTURE_TYPE_BACKEND, + sizeof(nvimgcodecBackend_t), + nullptr, + NVIMGCODEC_BACKEND_KIND_HW_GPU_ONLY, + {NVIMGCODEC_STRUCTURE_TYPE_BACKEND_PARAMS, sizeof(nvimgcodecBackendParams_t), nullptr, + hw_load, NVIMGCODEC_LOAD_HINT_POLICY_FIXED}}); + backends_.push_back(nvimgcodecBackend_t{ + NVIMGCODEC_STRUCTURE_TYPE_BACKEND, + sizeof(nvimgcodecBackend_t), + nullptr, + NVIMGCODEC_BACKEND_KIND_GPU_ONLY, + {NVIMGCODEC_STRUCTURE_TYPE_BACKEND_PARAMS, sizeof(nvimgcodecBackendParams_t), nullptr, + 1.0f, NVIMGCODEC_LOAD_HINT_POLICY_FIXED}}); + backends_.push_back(nvimgcodecBackend_t{ + NVIMGCODEC_STRUCTURE_TYPE_BACKEND, + sizeof(nvimgcodecBackend_t), + nullptr, + NVIMGCODEC_BACKEND_KIND_HYBRID_CPU_GPU, + {NVIMGCODEC_STRUCTURE_TYPE_BACKEND_PARAMS, sizeof(nvimgcodecBackendParams_t), nullptr, + 1.0f, NVIMGCODEC_LOAD_HINT_POLICY_FIXED}}); } - backends_.push_back(nvimgcodecBackend_t{NVIMGCODEC_STRUCTURE_TYPE_BACKEND, - sizeof(nvimgcodecBackend_t), - nullptr, - NVIMGCODEC_BACKEND_KIND_CPU_ONLY, - {NVIMGCODEC_STRUCTURE_TYPE_BACKEND_PARAMS, - sizeof(nvimgcodecBackendParams_t), nullptr, 1.0f}}); + backends_.push_back(nvimgcodecBackend_t{ + NVIMGCODEC_STRUCTURE_TYPE_BACKEND, + sizeof(nvimgcodecBackend_t), + nullptr, + NVIMGCODEC_BACKEND_KIND_CPU_ONLY, + {NVIMGCODEC_STRUCTURE_TYPE_BACKEND_PARAMS, sizeof(nvimgcodecBackendParams_t), nullptr, 1.0f, + NVIMGCODEC_LOAD_HINT_POLICY_FIXED}}); exec_params_.backends = backends_.data(); exec_params_.num_backends = backends_.size(); @@ -401,13 +400,26 @@ class ImageDecoder : public StatelessOperator { exec_params_.executor = &executor_; exec_params_.max_num_cpu_threads = num_threads_; exec_params_.pre_init = 1; + exec_params_.skip_pre_sync = 1; // we are not doing stream allocations before decoding. decoder_ = NvImageCodecDecoder::Create(instance_, &exec_params_, opts_.to_string()); } - nvimgcodecStatus_t launch(int device_id, int sample_idx, void *task_context, - void (*task)(int thread_id, int sample_idx, void *task_context)) { + nvimgcodecStatus_t schedule(int device_id, int sample_idx, void *task_context, + void (*task)(int thread_id, int sample_idx, void *task_context)) { + assert(tp_); + tp_->AddWork([=](int tid) { task(tid, sample_idx, task_context); }, -(task_count_++), false); + return NVIMGCODEC_STATUS_SUCCESS; + } + + nvimgcodecStatus_t run(int device_id) { + assert(tp_); + tp_->RunAll(false); + return NVIMGCODEC_STATUS_SUCCESS; + } + + nvimgcodecStatus_t wait(int device_id) { assert(tp_); - tp_->AddWork([=](int tid) { task(tid, sample_idx, task_context); }, 0, true); + tp_->WaitForWork(); return NVIMGCODEC_STATUS_SUCCESS; } @@ -417,12 +429,22 @@ class ImageDecoder : public StatelessOperator { return num_threads_; } - static nvimgcodecStatus_t static_launch(void *instance, int device_id, int sample_idx, - void *task_context, - void (*task)(int thread_id, int sample_idx, - void *task_context)) { + static nvimgcodecStatus_t static_schedule(void *instance, int device_id, int sample_idx, + void *task_context, + void (*task)(int thread_id, int sample_idx, + void *task_context)) { auto *handle = static_cast *>(instance); - return handle->launch(device_id, sample_idx, task_context, task); + return handle->schedule(device_id, sample_idx, task_context, task); + } + + static nvimgcodecStatus_t static_run(void *instance, int device_id) { + auto *handle = static_cast *>(instance); + return handle->run(device_id); + } + + static nvimgcodecStatus_t static_wait(void *instance, int device_id) { + auto *handle = static_cast *>(instance); + return handle->wait(device_id); } static int static_get_num_threads(void *instance) { @@ -500,90 +522,8 @@ class ImageDecoder : public StatelessOperator { return std::is_same::value ? thread_pool_.get() : &ws.GetThreadPool(); } - - bool SetupImpl(std::vector &output_descs, const Workspace &ws) override { - DomainTimeRange tr("Setup", DomainTimeRange::kOrange); - tp_ = GetThreadPool(ws); - assert(tp_ != nullptr); - auto auto_cleanup = AtScopeExit([&] { - tp_ = nullptr; - }); - - output_descs.resize(1); - auto &input = ws.template Input(0); - int nsamples = input.num_samples(); - - SetupRoiGenerator(spec_, ws); - TensorListShape<> shapes; - shapes.resize(nsamples, 3); - while (static_cast(state_.size()) < nsamples) - state_.push_back(std::make_unique()); - rois_.resize(nsamples); - - const bool use_cache = cache_ && cache_->IsCacheEnabled() && dtype_ == DALI_UINT8; - auto get_task = [&](int block_idx, int nblocks) { - return [&, block_idx, nblocks](int tid) { - int i_start = nsamples * block_idx / nblocks; - int i_end = nsamples * (block_idx + 1) / nblocks; - for (int i = i_start; i < i_end; i++) { - auto *st = state_[i].get(); - assert(st != nullptr); - const auto &input_sample = input[i]; - - auto src_info = input.GetMeta(i).GetSourceInfo(); - if (use_cache && cache_->IsInCache(src_info)) { - auto cached_shape = cache_->CacheImageShape(src_info); - auto roi = GetRoi(spec_, ws, i, cached_shape); - if (!roi.use_roi()) { - shapes.set_tensor_shape(i, cached_shape); - continue; - } - } - ParseSample(st->parsed_sample, - span{static_cast(input_sample.raw_data()), - volume(input_sample.shape())}); - st->out_shape = st->parsed_sample.dali_img_info.shape; - st->out_shape[2] = NumberOfChannels(format_, st->out_shape[2]); - if (use_orientation_ && - (st->parsed_sample.nvimgcodec_img_info.orientation.rotated % 180 != 0)) { - std::swap(st->out_shape[0], st->out_shape[1]); - } - ROI &roi = rois_[i] = GetRoi(spec_, ws, i, st->out_shape); - if (roi.use_roi()) { - auto roi_sh = roi.shape(); - if (roi.end.size() >= 2) { - DALI_ENFORCE(0 <= roi.end[0] && roi.end[0] <= st->out_shape[0] && - 0 <= roi.end[1] && roi.end[1] <= st->out_shape[1], - "ROI end must fit within the image bounds"); - } - if (roi.begin.size() >= 2) { - DALI_ENFORCE(0 <= roi.begin[0] && roi.begin[0] <= st->out_shape[0] && - 0 <= roi.begin[1] && roi.begin[1] <= st->out_shape[1], - "ROI begin must fit within the image bounds"); - } - st->out_shape[0] = roi_sh[0]; - st->out_shape[1] = roi_sh[1]; - } - shapes.set_tensor_shape(i, st->out_shape); - } - }; - }; - - int nblocks = tp_->NumThreads() + 1; - if (nsamples > nblocks * 4) { - int block_idx = 0; - for (; block_idx < tp_->NumThreads(); block_idx++) { - tp_->AddWork(get_task(block_idx, nblocks), -block_idx); - } - tp_->RunAll(false); // start work but not wait - get_task(block_idx, nblocks)(-1); // run last block - tp_->WaitForWork(); // wait for the other threads - } else { // not worth parallelizing - get_task(0, 1)(-1); // run all in current thread - } - - output_descs[0] = {std::move(shapes), dtype_}; - return true; + bool SetupImpl(std::vector &output_desc, const Workspace &ws) override { + return false; } /** @@ -607,9 +547,7 @@ class ImageDecoder : public StatelessOperator { return !version_at_least(0, 3, 0); } - template - void PrepareOutput(SampleState &st, SampleView out, const ROI &roi, - const Workspace &ws) { + void PrepareOutput(SampleState &st, const ROI &roi, const Workspace &ws) { // Make a copy of the parsed img info. We might modify it // (for example, request planar vs. interleaved, etc) st.image_info = st.parsed_sample.nvimgcodec_img_info; @@ -689,8 +627,6 @@ class ImageDecoder : public StatelessOperator { st.image_info.buffer = st.host_buf.get(); st.decode_out_cpu = {st.image_info.buffer, decode_shape, st.parsed_sample.orig_dtype}; } - } else { - st.image_info.buffer = static_cast(out.raw_mutable_data()); } st.image_info.num_planes = 1; @@ -700,7 +636,6 @@ class ImageDecoder : public StatelessOperator { st.image_info.plane_info[0].height = decode_shape[0]; st.image_info.plane_info[0].width = decode_shape[1]; st.image_info.plane_info[0].num_channels = decode_shape[2]; - st.image = NvImageCodecImage::Create(instance_, &st.image_info); } void RunImplImpl(Workspace &ws) { @@ -716,14 +651,12 @@ class ImageDecoder : public StatelessOperator { tp_ = nullptr; }); - bool has_any_roi = false; - for (auto &roi : rois_) - has_any_roi |= roi.use_roi(); - - nvimgcodecDecodeParams_t decode_params = {NVIMGCODEC_STRUCTURE_TYPE_DECODE_PARAMS, - sizeof(nvimgcodecDecodeParams_t), nullptr}; - decode_params.apply_exif_orientation = static_cast(use_orientation_); - decode_params.enable_roi = static_cast(has_any_roi); + SetupRoiGenerator(spec_, ws); + TensorListShape<> shapes; + shapes.resize(nsamples, 3); + while (static_cast(state_.size()) < nsamples) + state_.push_back(std::make_unique()); + rois_.resize(nsamples); assert(static_cast(state_.size()) >= nsamples); batch_encoded_streams_.clear(); @@ -732,125 +665,182 @@ class ImageDecoder : public StatelessOperator { batch_images_.reserve(nsamples); decode_sample_idxs_.clear(); decode_sample_idxs_.reserve(nsamples); + load_from_cache_.resize(nsamples); - // TODO(janton): consider extending cache to different dtype as well const bool use_cache = cache_ && cache_->IsCacheEnabled() && dtype_ == DALI_UINT8; - if (use_cache) { - int samples_to_load = 0; - DomainTimeRange tr(make_string("CacheLoad"), DomainTimeRange::kOrange); - for (int orig_idx = 0; orig_idx < nsamples; orig_idx++) { - auto src_info = input.GetMeta(orig_idx).GetSourceInfo(); - // To simplify things, we do not allow caching ROIs - bool has_roi = rois_[orig_idx].use_roi(); - if (cache_->IsInCache(src_info) && !has_roi) { - cache_->DeferCacheLoad(src_info, output.template mutable_tensor(orig_idx)); - samples_to_load++; - } else { - decode_sample_idxs_.push_back(orig_idx); - } - } - if (samples_to_load > 0) - cache_->LoadDeferred(ws.stream()); - } else { - decode_sample_idxs_.resize(nsamples); - std::iota(decode_sample_idxs_.begin(), decode_sample_idxs_.end(), 0); - } + auto get_task = [&](int block_idx, int nblocks) { + return [&, block_idx, nblocks](int tid) { + int i_start = nsamples * block_idx / nblocks; + int i_end = nsamples * (block_idx + 1) / nblocks; + for (int i = i_start; i < i_end; i++) { + auto *st = state_[i].get(); + assert(st != nullptr); + const auto &input_sample = input[i]; - int decode_nsamples = decode_sample_idxs_.size(); - { - DomainTimeRange tr(make_string("Prepare descs"), DomainTimeRange::kOrange); - auto get_task = [&](int block_idx, int nblocks) { - return [&, block_idx, nblocks](int tid) { - int i_start = decode_nsamples * block_idx / nblocks; - int i_end = decode_nsamples * (block_idx + 1) / nblocks; - for (int i = i_start; i < i_end; i++) { - int orig_idx = decode_sample_idxs_[i]; - PrepareOutput(*state_[orig_idx], output[orig_idx], rois_[orig_idx], ws); + auto src_info = input.GetMeta(i).GetSourceInfo(); + if (use_cache && cache_->IsInCache(src_info)) { + auto cached_shape = cache_->CacheImageShape(src_info); + auto roi = GetRoi(spec_, ws, i, cached_shape); + if (!roi.use_roi()) { + load_from_cache_[i] = true; + shapes.set_tensor_shape(i, cached_shape); + continue; + } + } + + load_from_cache_[i] = false; + + ParseSample(st->parsed_sample, + span{static_cast(input_sample.raw_data()), + volume(input_sample.shape())}); + st->out_shape = st->parsed_sample.dali_img_info.shape; + st->out_shape[2] = NumberOfChannels(format_, st->out_shape[2]); + if (use_orientation_ && + (st->parsed_sample.nvimgcodec_img_info.orientation.rotated % 180 != 0)) { + std::swap(st->out_shape[0], st->out_shape[1]); + } + + ROI &roi = rois_[i] = GetRoi(spec_, ws, i, st->out_shape); + if (roi.use_roi()) { + auto roi_sh = roi.shape(); + if (roi.end.size() >= 2) { + DALI_ENFORCE(0 <= roi.end[0] && roi.end[0] <= st->out_shape[0] && 0 <= roi.end[1] && + roi.end[1] <= st->out_shape[1], + "ROI end must fit within the image bounds"); + } + if (roi.begin.size() >= 2) { + DALI_ENFORCE(0 <= roi.begin[0] && roi.begin[0] <= st->out_shape[0] && + 0 <= roi.begin[1] && roi.begin[1] <= st->out_shape[1], + "ROI begin must fit within the image bounds"); + } + st->out_shape[0] = roi_sh[0]; + st->out_shape[1] = roi_sh[1]; } - }; + shapes.set_tensor_shape(i, st->out_shape); + PrepareOutput(*state_[i], rois_[i], ws); + } }; + }; - int nblocks = tp_->NumThreads() + 1; - if (decode_nsamples > nblocks * 4) { + { + DomainTimeRange tr("Setup", DomainTimeRange::kOrange); + const int minimum_block_nsamples = 16; + const int max_num_blocks = std::max(1, nsamples / minimum_block_nsamples); + int nblocks = std::min(tp_->NumThreads() + 1, max_num_blocks); + if (nblocks == 1) { + get_task(0, 1)(-1); // run all in current thread + } else { int block_idx = 0; - for (; block_idx < tp_->NumThreads(); block_idx++) { + for (; block_idx < nblocks - 1; block_idx++) { tp_->AddWork(get_task(block_idx, nblocks), -block_idx); } - tp_->RunAll(false); // start work but not wait - get_task(block_idx, nblocks)(-1); // run last block - tp_->WaitForWork(); // wait for the other threads - } else { // not worth parallelizing - get_task(0, 1)(-1); // run all in current thread + tp_->RunAll(false); // start work but not wait + get_task(block_idx, nblocks)(-1); // run last block in current thread + tp_->WaitForWork(); // wait for the other threads } + } - for (int orig_idx : decode_sample_idxs_) { - auto &st = *state_[orig_idx]; - batch_encoded_streams_.push_back(st.parsed_sample.encoded_stream); - batch_images_.push_back(st.image); - } + { + DomainTimeRange tr("output resize", DomainTimeRange::kOrange); + output.Resize(shapes, dtype_); } - // This is a workaround for nvImageCodec <= 0.2 - auto any_need_processing = [&]() { - for (int orig_idx : decode_sample_idxs_) { - auto& st = state_[orig_idx]; - assert(ws.stream() == st->image_info.cuda_stream); // assuming this is true - if (st->need_processing) - return true; + int samples_to_load = 0; + bool any_need_processing = false; + for (int orig_idx = 0; orig_idx < nsamples; orig_idx++) { + auto &st = *state_[orig_idx]; + auto *data_ptr = output.raw_mutable_tensor(orig_idx); + if (load_from_cache_[orig_idx]) { + auto src_info = input.GetMeta(orig_idx).GetSourceInfo(); + cache_->DeferCacheLoad(src_info, static_cast(data_ptr)); + samples_to_load++; + } else { + decode_sample_idxs_.push_back(orig_idx); + any_need_processing |= st.need_processing; + if (!st.need_processing) + st.image_info.buffer = data_ptr; + assert(!ws.has_stream() || + ws.stream() == st.image_info.cuda_stream); // assuming this is true + st.image = NvImageCodecImage::Create(instance_, &st.image_info); } - return false; - }; - if (ws.has_stream() && need_host_sync_alloc() && any_need_processing()) { + } + + if (ws.has_stream() && need_host_sync_alloc() && any_need_processing) { DomainTimeRange tr("alloc sync", DomainTimeRange::kOrange); CUDA_CALL(cudaStreamSynchronize(ws.stream())); } - { - DomainTimeRange tr("Decode", DomainTimeRange::kOrange); + if (samples_to_load > 0) { + DomainTimeRange tr("LoadDeferred", DomainTimeRange::kOrange); + cache_->LoadDeferred(ws.stream()); + } + + size_t nsamples_decode = decode_sample_idxs_.size(); + if (nsamples_decode > 0) { nvimgcodecFuture_t future; - decode_status_.resize(decode_nsamples); - size_t status_size = 0; - CHECK_NVIMGCODEC(nvimgcodecDecoderDecode(decoder_, batch_encoded_streams_.data(), - batch_images_.data(), decode_nsamples, - &decode_params, &future)); - CHECK_NVIMGCODEC( - nvimgcodecFutureGetProcessingStatus(future, decode_status_.data(), &status_size)); - if (static_cast(status_size) != decode_nsamples) - throw std::logic_error("Failed to retrieve processing status"); - CHECK_NVIMGCODEC(nvimgcodecFutureDestroy(future)); - - for (int i = 0; i < decode_nsamples; i++) { - if (decode_status_[i] != NVIMGCODEC_PROCESSING_STATUS_SUCCESS) { - int orig_idx = decode_sample_idxs_[i]; + + bool has_any_roi = false; + for (size_t idx = 0; idx < nsamples_decode; idx++) { + size_t orig_idx = decode_sample_idxs_[idx]; + auto &st = *state_[orig_idx]; + has_any_roi |= rois_[orig_idx].use_roi(); + batch_encoded_streams_.push_back(st.parsed_sample.encoded_stream); + batch_images_.push_back(st.image); + } + std::vector decode_status(nsamples_decode); + size_t decode_status_size = 0; + nvimgcodecDecodeParams_t decode_params = {NVIMGCODEC_STRUCTURE_TYPE_DECODE_PARAMS, + sizeof(nvimgcodecDecodeParams_t), nullptr}; + decode_params.apply_exif_orientation = static_cast(use_orientation_); + decode_params.enable_roi = static_cast(has_any_roi); + + { + DomainTimeRange tr("nvimgcodecDecoderDecode", DomainTimeRange::kOrange); + CHECK_NVIMGCODEC(nvimgcodecDecoderDecode(decoder_, batch_encoded_streams_.data(), + batch_images_.data(), nsamples_decode, + &decode_params, &future)); + CHECK_NVIMGCODEC(nvimgcodecFutureWaitForAll(future)); + CHECK_NVIMGCODEC( + nvimgcodecFutureGetProcessingStatus(future, decode_status.data(), &decode_status_size)); + CHECK_NVIMGCODEC(nvimgcodecFutureDestroy(future)); + } + if (decode_status_size != nsamples_decode) + throw std::runtime_error("Failed to run decoder"); + for (size_t idx = 0; idx < nsamples_decode; idx++) { + size_t orig_idx = decode_sample_idxs_[idx]; + auto st_ptr = state_[orig_idx].get(); + if (decode_status[idx] != NVIMGCODEC_PROCESSING_STATUS_SUCCESS) { throw std::runtime_error(make_string("Failed to decode sample #", orig_idx, " : ", input.GetMeta(orig_idx).GetSourceInfo())); } } - } - - for (int orig_idx : decode_sample_idxs_) { - auto st_ptr = state_[orig_idx].get(); - if (st_ptr->need_processing) { - tp_->AddWork( - [&, out = output[orig_idx], st_ptr, orig_idx](int tid) { - DomainTimeRange tr(make_string("Convert #", orig_idx), DomainTimeRange::kOrange); - auto &st = *st_ptr; - if constexpr (std::is_same::value) { - ConvertGPU(out, st.req_layout, st.req_img_type, st.decode_out_gpu, st.req_layout, - st.orig_img_type, ws.stream(), ROI{}, nvimgcodecOrientation_t{}, - st.dyn_range_multiplier); - st.device_buf.reset(); - } else { - assert(st.dyn_range_multiplier == 1.0f); // TODO(janton): enable - ConvertCPU(out, st.req_layout, st.req_img_type, st.decode_out_cpu, st.req_layout, - st.orig_img_type, ROI{}, nvimgcodecOrientation_t{}); - st.host_buf.reset(); - } - }, - -orig_idx); + if (any_need_processing) { + for (size_t idx = 0; idx < nsamples_decode; idx++) { + size_t orig_idx = decode_sample_idxs_[idx]; + auto st_ptr = state_[orig_idx].get(); + if (st_ptr->need_processing) { + tp_->AddWork( + [&, out = output[orig_idx], st_ptr, orig_idx](int tid) { + DomainTimeRange tr(make_string("Convert #", orig_idx), DomainTimeRange::kOrange); + auto &st = *st_ptr; + if constexpr (std::is_same::value) { + ConvertGPU(out, st.req_layout, st.req_img_type, st.decode_out_gpu, + st.req_layout, st.orig_img_type, ws.stream(), ROI{}, + nvimgcodecOrientation_t{}, st.dyn_range_multiplier); + st.device_buf.reset(); + } else { + assert(st.dyn_range_multiplier == 1.0f); // TODO(janton): enable + ConvertCPU(out, st.req_layout, st.req_img_type, st.decode_out_cpu, + st.req_layout, st.orig_img_type, ROI{}, nvimgcodecOrientation_t{}); + st.host_buf.reset(); + } + }, + -idx); + } + } + tp_->RunAll(true); } } - tp_->RunAll(); if (use_cache) { DomainTimeRange tr(make_string("CacheStore"), DomainTimeRange::kOrange); @@ -882,8 +872,11 @@ class ImageDecoder : public StatelessOperator { sizeof(nvimgcodecExecutorDesc_t), nullptr, this, - &static_launch, + &static_schedule, + &static_run, + &static_wait, &static_get_num_threads}; + nvimgcodecDeviceAllocator_t dev_alloc_ = {}; nvimgcodecPinnedAllocator_t pinned_alloc_ = {}; std::vector backends_; @@ -899,7 +892,8 @@ class ImageDecoder : public StatelessOperator { bool use_orientation_ = true; int max_batch_size_ = 1; int num_threads_ = -1; - ThreadPool* tp_ = nullptr; + ThreadPool *tp_ = nullptr; + int64_t task_count_ = 0; std::vector> state_; std::vector batch_encoded_streams_; std::vector batch_images_; @@ -908,6 +902,8 @@ class ImageDecoder : public StatelessOperator { // In case of cache, the batch we send to the decoder might have fewer samples than the full batch // This vector is used to get the original index of the decoded samples std::vector decode_sample_idxs_; + // True if the sample is loaded from cache, false otherwise + std::vector load_from_cache_; // Manually loaded extensions std::vector extensions_descs_; diff --git a/docs/advanced_topics_performance_tuning.rst b/docs/advanced_topics_performance_tuning.rst index e72c6bbe016..c568dc168fc 100644 --- a/docs/advanced_topics_performance_tuning.rst +++ b/docs/advanced_topics_performance_tuning.rst @@ -133,6 +133,13 @@ In performance-critical applications, this can be avoided by preallocating the p .. autofunction:: PreallocateDeviceMemory .. autofunction:: PreallocatePinnedMemory +Event Pool Initial Size +----------------------- + +DALI utilizes a global CUDA event pool for certain operations. The application can borrow CUDA +events from the pool and return them for reuse once they are no longer needed. The initial size +of the event pool can be set using the environment variable ``DALI_EVENT_POOL_INITIAL_SIZE``. +By default, the pool is filled with 2000 CUDA events. Freeing Memory Pools -------------------- diff --git a/internal_tools/stub_generator/nvimgcodec.json b/internal_tools/stub_generator/nvimgcodec.json index 4a6847ea4d8..ebeaf4aac67 100644 --- a/internal_tools/stub_generator/nvimgcodec.json +++ b/internal_tools/stub_generator/nvimgcodec.json @@ -7,9 +7,11 @@ "not_found_error":"NVIMGCODEC_STATUS_IMPLEMENTATION_UNSUPPORTED", "functions": { "nvimgcodecInstanceCreate": {}, + "nvimgcodecDecoderCanDecode": {}, "nvimgcodecCodeStreamCreateFromHostMem": {}, "nvimgcodecCodeStreamGetImageInfo": {}, "nvimgcodecDecoderCreate": {}, + "nvimgcodecFutureWaitForAll": {}, "nvimgcodecFutureGetProcessingStatus": {}, "nvimgcodecInstanceDestroy": {}, "nvimgcodecImageDestroy": {}, diff --git a/qa/TL0_python-self-test-base-cuda/test.sh b/qa/TL0_python-self-test-base-cuda/test.sh index 969a5e754e8..2a0060aa4e3 100644 --- a/qa/TL0_python-self-test-base-cuda/test.sh +++ b/qa/TL0_python-self-test-base-cuda/test.sh @@ -14,6 +14,7 @@ version_ge "$DALI_CUDA_MAJOR_VERSION" "11" && \ pip uninstall -y `pip list | grep nvidia-cufft | cut -d " " -f1` \ `pip list | grep nvidia-nvjpeg | cut -d " " -f1` \ `pip list | grep nvidia-nvjpeg2k | cut -d " " -f1` \ + `pip list | grep nvidia-nvtiff | cut -d " " -f1` \ `pip list | grep nvidia-npp | cut -d " " -f1` \ || true @@ -43,4 +44,5 @@ version_ge "$DALI_CUDA_MAJOR_VERSION" "11" && \ nvidia-npp-cu${DALI_CUDA_MAJOR_VERSION} \ nvidia-nvjpeg-cu${DALI_CUDA_MAJOR_VERSION} \ nvidia-nvjpeg2k-cu${DALI_CUDA_MAJOR_VERSION} \ + nvidia-nvtiff-cu${DALI_CUDA_MAJOR_VERSION} \ || true diff --git a/qa/test_template_impl.sh b/qa/test_template_impl.sh index 2ccc5b9cc82..c5eb0e18e69 100755 --- a/qa/test_template_impl.sh +++ b/qa/test_template_impl.sh @@ -159,6 +159,7 @@ do install_pip_pkg "pip install --upgrade nvidia-npp-cu${DALI_CUDA_MAJOR_VERSION}${NPP_VERSION} \ nvidia-nvjpeg-cu${DALI_CUDA_MAJOR_VERSION} \ nvidia-nvjpeg2k-cu${DALI_CUDA_MAJOR_VERSION} \ + nvidia-nvtiff-cu${DALI_CUDA_MAJOR_VERSION} \ nvidia-cufft-cu${DALI_CUDA_MAJOR_VERSION} \ -f /pip-packages" fi