From 571b1431db31c8f1eef92e5c1ffe92e873e01d58 Mon Sep 17 00:00:00 2001 From: Rafal Banas Date: Wed, 18 Sep 2024 15:10:15 +0200 Subject: [PATCH] Add experimental CV-CUDA resize Signed-off-by: Rafal Banas --- cmake/Dependencies.common.cmake | 2 +- dali/operators/image/resize/CMakeLists.txt | 7 +- .../image/resize/experimental/CMakeLists.txt | 18 ++ .../image/resize/experimental/resize.cc | 79 ++++++ .../image/resize/experimental/resize.h | 135 ++++++++++ .../experimental/resize_op_impl_cvcuda.h | 246 ++++++++++++++++++ dali/operators/image/resize/resize_base.cc | 11 +- dali/operators/image/resize/resize_base.h | 20 +- dali/operators/nvcvop/nvcvop.cc | 74 +++++- dali/operators/nvcvop/nvcvop.h | 88 +++++-- .../test_dali_stateless_operators.py | 6 + dali/test/python/operator_2/test_resize.py | 90 +++++-- dali/test/python/test_dali_cpu_only.py | 1 + .../python/test_dali_variable_batch_size.py | 1 + 14 files changed, 717 insertions(+), 61 deletions(-) create mode 100644 dali/operators/image/resize/experimental/CMakeLists.txt create mode 100644 dali/operators/image/resize/experimental/resize.cc create mode 100644 dali/operators/image/resize/experimental/resize.h create mode 100644 dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h diff --git a/cmake/Dependencies.common.cmake b/cmake/Dependencies.common.cmake index c9dad1fb70..2a6f564d8d 100644 --- a/cmake/Dependencies.common.cmake +++ b/cmake/Dependencies.common.cmake @@ -264,7 +264,7 @@ if (BUILD_CVCUDA) set(DALI_BUILD_PYTHON ${BUILD_PYTHON}) set(BUILD_PYTHON OFF) # for now we use only median blur from CV-CUDA - set(CV_CUDA_SRC_PATERN medianblur median_blur morphology warp) + set(CV_CUDA_SRC_PATERN medianblur median_blur morphology warp HQResize) check_and_add_cmake_submodule(${PROJECT_SOURCE_DIR}/third_party/cvcuda) set(BUILD_PYTHON ${DALI_BUILD_PYTHON}) endif() diff --git a/dali/operators/image/resize/CMakeLists.txt b/dali/operators/image/resize/CMakeLists.txt index 4a89bcdd69..3034ffa344 100644 --- a/dali/operators/image/resize/CMakeLists.txt +++ b/dali/operators/image/resize/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2017-2024, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +if (BUILD_CVCUDA) + add_subdirectory(experimental) +endif() + # Get all the source files and dump test files collect_headers(DALI_INST_HDRS PARENT_SCOPE) collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) + diff --git a/dali/operators/image/resize/experimental/CMakeLists.txt b/dali/operators/image/resize/experimental/CMakeLists.txt new file mode 100644 index 0000000000..89d48d9f60 --- /dev/null +++ b/dali/operators/image/resize/experimental/CMakeLists.txt @@ -0,0 +1,18 @@ +# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +collect_headers(DALI_INST_HDRS PARENT_SCOPE) +collect_sources(DALI_OPERATOR_SRCS PARENT_SCOPE) +collect_test_sources(DALI_OPERATOR_TEST_SRCS PARENT_SCOPE) diff --git a/dali/operators/image/resize/experimental/resize.cc b/dali/operators/image/resize/experimental/resize.cc new file mode 100644 index 0000000000..ac301057f9 --- /dev/null +++ b/dali/operators/image/resize/experimental/resize.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#define DALI_RESIZE_BASE_CC + +#include "dali/operators/image/resize/experimental/resize.h" +#include +#include "dali/pipeline/data/views.h" + +namespace dali { + +DALI_SCHEMA(experimental__Resize) + .DocStr(R"code(Resize images.)code") + .NumInput(1) + .NumOutput(1) + .AdditionalOutputsFn([](const OpSpec& spec) { + return static_cast(spec.GetArgument("save_attrs")); + }) + .InputLayout(0, {"HWC", "FHWC", "CHW", "FCHW", "CFHW" , + "DHWC", "FDHWC", "CDHW", "FCDHW", "CFDHW" }) + .AddOptionalArg("save_attrs", + R"code(Save reshape attributes for testing.)code", false) + .AddOptionalArg("image_type", "Image type", nullptr) + .DeprecateArg("image_type") // deprecated since 0.25dev + .SupportVolumetric() + .AllowSequences() + .AddParent("ResizeAttr") + .AddParent("ResamplingFilterAttr"); + +CvCudaResize::CvCudaResize(const OpSpec &spec) + : StatelessOperator(spec) + , ResizeBase(spec) { + save_attrs_ = this->spec_.HasArgument("save_attrs"); + InitializeBackend(); +} + +void CvCudaResize::InitializeBackend() { + InitializeGPU(spec_.GetArgument("minibatch_size"), + spec_.GetArgument("temp_buffer_hint")); +} + +void CvCudaResize::RunImpl(Workspace &ws) { + const auto &input = ws.Input(0); + auto &output = ws.Output(0); + + RunResize(ws, output, input); + output.SetLayout(input.GetLayout()); + + if (save_attrs_) { + auto &attr_out = ws.Output(1); + const auto &attr_shape = attr_out.shape(); + assert(attr_shape.num_samples() == input.shape().num_samples() && + attr_shape.sample_dim() == 1 && + is_uniform(attr_shape) && + attr_shape[0][0] == NumSpatialDims()); + + if (!attr_staging_.has_data()) + attr_staging_.set_pinned(true); + attr_staging_.Resize(attr_out.shape(), DALI_INT32); + auto attr_view = view(attr_staging_); + SaveAttrs(attr_view, input.shape()); + attr_out.Copy(attr_staging_, ws.stream()); + } +} + +DALI_REGISTER_OPERATOR(experimental__Resize, CvCudaResize, GPU); + +} // namespace dali diff --git a/dali/operators/image/resize/experimental/resize.h b/dali/operators/image/resize/experimental/resize.h new file mode 100644 index 0000000000..5b09f4196d --- /dev/null +++ b/dali/operators/image/resize/experimental/resize.h @@ -0,0 +1,135 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_IMAGE_RESIZE_EXPERIMENTAL_RESIZE_H_ +#define DALI_OPERATORS_IMAGE_RESIZE_EXPERIMENTAL_RESIZE_H_ + +#include +#include +#include +#include + +#include "dali/core/common.h" +#include "dali/core/error_handling.h" +#include "dali/kernels/context.h" +#include "dali/kernels/imgproc/resample/params.h" +#include "dali/kernels/scratch.h" +#include "dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h" +#include "dali/operators/image/resize/resize_attr.h" +#include "dali/operators/image/resize/resize_base.h" +#include "dali/pipeline/operator/checkpointing/stateless_operator.h" +#include "dali/pipeline/operator/common.h" + +namespace dali { + +class CvCudaResize : public StatelessOperator, protected ResizeBase { + public: + explicit CvCudaResize(const OpSpec &spec); + + protected: + void SetupResize(TensorListShape<> &out_shape, DALIDataType out_type, + const TensorListShape<> &in_shape, DALIDataType in_type, + span params, int spatial_ndim, + int first_spatial_dim) { + VALUE_SWITCH(spatial_ndim, static_spatial_ndim, (2, 3), + ( + using ImplType = ResizeOpImplCvCuda; + SetImpl([&]{ return std::make_unique(GetMinibatchSize()); }); + impl_->Setup(out_shape, in_shape, first_spatial_dim, params); + ), // NOLINT + (DALI_FAIL(make_string("Unsupported number of resized dimensions: ", spatial_ndim)))); + } + + + int NumSpatialDims() const { + return resize_attr_.spatial_ndim_; + } + int FirstSpatialDim() const { + return resize_attr_.first_spatial_dim_; + } + + bool CanInferOutputs() const override { + return true; + } + + bool SetupImpl(std::vector &output_desc, const Workspace &ws) override; + + void RunImpl(Workspace &ws) override; + + void SaveAttrs(const TensorListView &shape_data, + const TensorListShape<> &orig_shape) const { + int N = orig_shape.num_samples(); + int D = NumSpatialDims(); + assert(shape_data.sample_dim() == 1); + for (int i = 0; i < N; i++) { + auto sample_shape = orig_shape.tensor_shape_span(i); + assert(static_cast(shape_data.shape[i][0]) == D); + int *out_shape = shape_data.data[i]; + for (int d = 0; d < D; d++) { + out_shape[d] = sample_shape[FirstSpatialDim() + d]; + } + } + } + + void PrepareParams(const ArgumentWorkspace &ws, const TensorListShape<> &input_shape, + const TensorLayout &layout) { + resize_attr_.PrepareResizeParams(spec_, ws, input_shape, layout); + assert(NumSpatialDims() >= 1 && NumSpatialDims() <= 3); + assert(FirstSpatialDim() >= 0); + int N = input_shape.num_samples(); + resample_params_.resize(N * NumSpatialDims()); + resampling_attr_.PrepareFilterParams(spec_, ws, N); + resampling_attr_.GetResamplingParams(make_span(resample_params_), + make_cspan(resize_attr_.params_)); + } + + void InitializeBackend(); + + USE_OPERATOR_MEMBERS(); + std::vector resample_params_; + TensorList attr_staging_; + using Operator::RunImpl; + bool save_attrs_ = false; + + ResizeAttr resize_attr_; + ResamplingFilterAttr resampling_attr_; +}; + +bool CvCudaResize::SetupImpl(std::vector &output_desc, const Workspace &ws) { + output_desc.resize(save_attrs_ ? 2 : 1); + auto &input = ws.Input(0); + + const auto &in_shape = input.shape(); + auto in_type = input.type(); + auto in_layout = input.GetLayout(); + int N = in_shape.num_samples(); + + PrepareParams(ws, in_shape, in_layout); + + auto out_type = resampling_attr_.GetOutputType(in_type); + + output_desc[0].type = out_type; + this->SetupResize(output_desc[0].shape, out_type, in_shape, in_type, + make_cspan(this->resample_params_), NumSpatialDims(), FirstSpatialDim()); + + if (save_attrs_) { + output_desc[1].shape = uniform_list_shape(N, TensorShape<1>({NumSpatialDims()})); + output_desc[1].type = DALI_INT32; + } + return true; +} + +} // namespace dali + +#endif // DALI_OPERATORS_IMAGE_RESIZE_EXPERIMENTAL_RESIZE_H_ diff --git a/dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h b/dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h new file mode 100644 index 0000000000..d79e3e9176 --- /dev/null +++ b/dali/operators/image/resize/experimental/resize_op_impl_cvcuda.h @@ -0,0 +1,246 @@ +// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DALI_OPERATORS_IMAGE_RESIZE_EXPERIMENTAL_RESIZE_OP_IMPL_CVCUDA_H_ +#define DALI_OPERATORS_IMAGE_RESIZE_EXPERIMENTAL_RESIZE_OP_IMPL_CVCUDA_H_ + +#include + +#include + +#include "dali/kernels/imgproc/resample/params.h" +#include "dali/operators/image/resize/resize_op_impl.h" +#include "dali/operators/nvcvop/nvcvop.h" + +namespace dali { + +template +class ResizeOpImplCvCuda : public ResizeBase::Impl { + public: + explicit ResizeOpImplCvCuda(int minibatch_size) : minibatch_size_(minibatch_size) {} + + static_assert(spatial_ndim == 2 || spatial_ndim == 3, "Only 2D and 3D resizing is supported"); + + + /// Dimensionality of each separate frame. If input contains no channel dimension, one is added + static constexpr int frame_ndim = spatial_ndim + 1; + + void Setup(TensorListShape<> &out_shape, const TensorListShape<> &in_shape, int first_spatial_dim, + span params) override { + // Calculate output shape of the input, as supplied (sequences, planar images, etc) + GetResizedShape(out_shape, in_shape, params, spatial_ndim, first_spatial_dim); + + // Create "frames" from outer dimensions and "channels" from inner dimensions. + GetFrameShapesAndParams(in_shape_, params_, in_shape, params, first_spatial_dim); + + // Now that we have per-frame parameters, we can calculate the output shape of the + // effective frames (from videos, channel planes, etc). + GetResizedShape(out_shape_, in_shape_, make_cspan(params_), 0); + + curr_minibatch_size_ = minibatch_size_; + // CV-CUDA tensors can't represent empty samples, + // so curr_minibatch_size_ is set to zero to skip each empty sample + if (HasEmptySamples(in_shape_)) { + curr_minibatch_size_ = 1; + } + + // Now that we know how many logical frames there are, calculate batch subdivision. + CalculateMinibatchPartition(in_shape_.num_samples(), curr_minibatch_size_); + + SetupKernel(); + } + + void SetupKernel() { + const int dim = in_shape_.sample_dim(); + kernels::KernelContext ctx; + rois_.clear(); + rois_.reserve(minibatch_size_ * minibatches_.size()); + workspace_reqs_ = {}; + std::vector mb_input_shapes(curr_minibatch_size_); + std::vector mb_output_shapes(curr_minibatch_size_); + for (int mb_idx = 0, num_mb = minibatches_.size(); mb_idx < num_mb; mb_idx++) { + auto &mb = minibatches_[mb_idx]; + + int end = mb.start + mb.count; + auto param_slice = make_span(¶ms_[mb.start], mb.count); + for (int i = mb.start, j = 0; i < end; i++, j++) { + rois_.push_back(GetRoi(param_slice[j])); + for (int d = 0; d < dim; d++) { + mb_input_shapes[j].extent[d] = in_shape_.tensor_shape_span(i)[d]; + mb_output_shapes[j].extent[d] = out_shape_.tensor_shape_span(i)[d]; + } + } + int num_channels = in_shape_[0][frame_ndim - 1]; + HQResizeTensorShapesI mb_input_shape{mb_input_shapes.data(), mb.count, spatial_ndim, + num_channels}; + HQResizeTensorShapesI mb_output_shape{mb_output_shapes.data(), mb.count, spatial_ndim, + num_channels}; + mb.rois = HQResizeRoisF{mb.count, spatial_ndim, &rois_.data()[mb.start]}; + + auto param = params_[mb.start][0]; + mb.min_interpolation = GetInterpolationType(param.min_filter); + mb.mag_interpolation = GetInterpolationType(param.mag_filter); + mb.antialias = param.min_filter.antialias; + auto ws_req = resize_op_.getWorkspaceRequirements(mb.count, mb_input_shape, mb_output_shape, + mb.min_interpolation, mb.mag_interpolation, + mb.antialias, mb.rois); + workspace_reqs_ = nvcvop::MaxWorkspaceRequirements(workspace_reqs_, ws_req); + } + } + + static bool HasEmptySamples(const TensorListShape<> &in_shape) { + for (int s = 0; s < in_shape.num_samples(); s++) { + if (volume(in_shape.tensor_shape_span(s)) == 0) { + return true; + } + } + return false; + } + + HQResizeRoiF GetRoi(const ResamplingParamsND ¶ms) { + HQResizeRoiF roi; + for (int d = 0; d < spatial_ndim; d++) { + roi.lo[d] = params[d].roi.start; + roi.hi[d] = params[d].roi.end; + } + return roi; + } + + NVCVInterpolationType GetInterpolationType(kernels::FilterDesc filter_desc) { + using kernels::ResamplingFilterType; + switch (filter_desc.type) { + case ResamplingFilterType::Nearest: + return NVCVInterpolationType::NVCV_INTERP_NEAREST; + case ResamplingFilterType::Linear: + return NVCVInterpolationType::NVCV_INTERP_LINEAR; + case ResamplingFilterType::Triangular: + return NVCVInterpolationType::NVCV_INTERP_LINEAR; + case ResamplingFilterType::Cubic: + return NVCVInterpolationType::NVCV_INTERP_CUBIC; + case ResamplingFilterType::Lanczos3: + return NVCVInterpolationType::NVCV_INTERP_LANCZOS; + case ResamplingFilterType::Gaussian: + return NVCVInterpolationType::NVCV_INTERP_GAUSSIAN; + default: + DALI_FAIL("Unsupported filter type"); + } + } + + void RunResize(Workspace &ws, TensorList &output, + const TensorList &input) override { + TensorList in_frames; + in_frames.ShareData(input); + in_frames.Resize(in_shape_); + PrepareInput(in_frames); + + TensorList out_frames; + out_frames.ShareData(output); + out_frames.Resize(out_shape_); + PrepareOutput(out_frames); + + + kernels::DynamicScratchpad scratchpad({}, AccessOrder(ws.stream())); + + auto workspace_mem = op_workspace_.Allocate(workspace_reqs_, scratchpad); + + for (size_t b = 0; b < minibatches_.size(); b++) { + MiniBatch &mb = minibatches_[b]; + if (mb.input.numTensors() == 0 || mb.output.numTensors() == 0) { + assert(mb.count == 1); + continue; + } + resize_op_(ws.stream(), workspace_mem, mb.input, mb.output, mb.min_interpolation, + mb.mag_interpolation, mb.antialias, mb.rois); + } + } + + int CalculateMinibatchPartition(int total_frames, int minibatch_size) { + int num_minibatches = div_ceil(total_frames, minibatch_size); + + minibatches_.resize(num_minibatches); + int start = 0; + for (int i = 0; i < num_minibatches; i++) { + int end = (i + 1) * total_frames / num_minibatches; + auto &mb = minibatches_[i]; + mb.start = start; + mb.count = end - start; + start = end; + } + return num_minibatches; + } + + TensorListShape in_shape_, out_shape_; + std::vector> params_; + + cvcuda::HQResize resize_op_{}; + nvcvop::NVCVOpWorkspace op_workspace_; + cvcuda::WorkspaceRequirements workspace_reqs_{}; + std::vector rois_; + const TensorLayout sample_layout_ = (spatial_ndim == 2) ? "HWC" : "DHWC"; + + struct MiniBatch { + int start, count; + nvcv::TensorBatch input; + nvcv::TensorBatch output; + NVCVInterpolationType min_interpolation; + NVCVInterpolationType mag_interpolation; + bool antialias; + HQResizeRoisF rois; + }; + + std::vector minibatches_; + + void PrepareInput(const TensorList &input) { + for (auto &mb : minibatches_) { + int curr_capacity = mb.output ? mb.output.capacity() : 0; + if (mb.count > curr_capacity) { + int new_capacity = std::max(mb.count, curr_capacity * 2); + auto reqs = nvcv::TensorBatch::CalcRequirements(new_capacity); + mb.input = nvcv::TensorBatch(reqs); + } else { + mb.input.clear(); + } + for (int i = mb.start; i < mb.start + mb.count; ++i) { + if (volume(in_shape_.tensor_shape_span(i)) != 0) { + mb.input.pushBack(nvcvop::AsTensor(input[i], sample_layout_)); + } + } + } + } + + void PrepareOutput(const TensorList &out) { + for (auto &mb : minibatches_) { + int curr_capacity = mb.output ? mb.output.capacity() : 0; + if (mb.count > curr_capacity) { + int new_capacity = std::max(mb.count, curr_capacity * 2); + auto reqs = nvcv::TensorBatch::CalcRequirements(new_capacity); + mb.output = nvcv::TensorBatch(reqs); + } else { + mb.output.clear(); + } + for (int i = mb.start; i < mb.start + mb.count; ++i) { + if (volume(out_shape_.tensor_shape_span(i)) != 0) { + mb.output.pushBack(nvcvop::AsTensor(out[i], sample_layout_)); + } + } + } + } + + int minibatch_size_; + int curr_minibatch_size_; +}; + +} // namespace dali + +#endif // DALI_OPERATORS_IMAGE_RESIZE_EXPERIMENTAL_RESIZE_OP_IMPL_CVCUDA_H_ diff --git a/dali/operators/image/resize/resize_base.cc b/dali/operators/image/resize/resize_base.cc index d6731d92ed..9c8ce87031 100644 --- a/dali/operators/image/resize/resize_base.cc +++ b/dali/operators/image/resize/resize_base.cc @@ -95,17 +95,10 @@ void ResizeBase::SetupResizeStatic( span params, int first_spatial_dim) { using ImplType = ResizeOpImplCPU; - auto *impl = dynamic_cast(impl_.get()); - if (!impl) { - impl_.reset(); - auto unq_impl = std::make_unique(kmgr_); - impl = unq_impl.get(); - impl_ = std::move(unq_impl); - } - impl->Setup(out_shape, in_shape, first_spatial_dim, params); + SetImpl([&]{ return std::make_unique(kmgr_); }); + impl_->Setup(out_shape, in_shape, first_spatial_dim, params); } - template <> void ResizeBase::InitializeCPU(int num_threads) { if (num_threads != num_threads_) { diff --git a/dali/operators/image/resize/resize_base.h b/dali/operators/image/resize/resize_base.h index 415115e43b..446cecc572 100644 --- a/dali/operators/image/resize/resize_base.h +++ b/dali/operators/image/resize/resize_base.h @@ -74,6 +74,25 @@ class DLL_PUBLIC ResizeBase { } struct Impl; // this needs to be public, because implementations inherit from it + + protected: + template + void SetImpl(const ImplTypeFactory &impl_factory) { + auto *impl = dynamic_cast(impl_.get()); + if (!impl) { + impl_.reset(); + auto unq_impl = impl_factory(); + impl = unq_impl.get(); + impl_ = std::move(unq_impl); + } + } + + int GetMinibatchSize() const { + return minibatch_size_; + } + + std::unique_ptr impl_; + private: template void SetupResizeStatic(TensorListShape<> &out_shape, @@ -90,7 +109,6 @@ class DLL_PUBLIC ResizeBase { int num_threads_ = 1; int minibatch_size_ = 32; - std::unique_ptr impl_; kernels::KernelManager kmgr_; }; diff --git a/dali/operators/nvcvop/nvcvop.cc b/dali/operators/nvcvop/nvcvop.cc index 8e17b87cb4..a82982eb71 100644 --- a/dali/operators/nvcvop/nvcvop.cc +++ b/dali/operators/nvcvop/nvcvop.cc @@ -14,6 +14,7 @@ #include "dali/operators/nvcvop/nvcvop.h" + #include namespace dali::nvcvop { @@ -163,8 +164,23 @@ void PushImagesToBatch(nvcv::ImageBatchVarShape &batch, const TensorList &tensor, TensorLayout layout, const std::optional> &reshape) { auto orig_shape = tensor.shape(); - auto dtype = GetDataType(tensor.type(), 1); + TensorShape<> shape; + if (reshape.has_value()) { + DALI_ENFORCE(volume(*reshape) == volume(orig_shape), + make_string("Cannot reshape from ", orig_shape, " to ", *reshape, ".")); + shape = reshape.value(); + } else { + shape = orig_shape; + } + TensorLayout out_layout = layout.empty() ? tensor.GetLayout() : layout; + + return AsTensor(const_cast(tensor.raw_data()), shape, tensor.type(), out_layout); +} + +nvcv::Tensor AsTensor(SampleView sample, TensorLayout layout, + const std::optional> &reshape) { + auto orig_shape = sample.shape(); TensorShape<> shape; if (reshape.has_value()) { DALI_ENFORCE(volume(*reshape) == volume(orig_shape), @@ -174,20 +190,62 @@ nvcv::Tensor AsTensor(const Tensor &tensor, TensorLayout layout, shape = orig_shape; } + return AsTensor(sample.raw_mutable_data(), shape, sample.type(), layout); +} + +nvcv::Tensor AsTensor(ConstSampleView sample, TensorLayout layout, + const std::optional> &reshape) { + auto orig_shape = sample.shape(); + TensorShape<> shape; + if (reshape.has_value()) { + DALI_ENFORCE(volume(*reshape) == volume(orig_shape), + make_string("Cannot reshape from ", orig_shape, " to ", *reshape, ".")); + shape = reshape.value(); + } else { + shape = orig_shape; + } + + return AsTensor(const_cast(sample.raw_data()), shape, sample.type(), layout); +} + +nvcv::Tensor AsTensor(void *data, const TensorShape<> shape, DALIDataType daliDType, + TensorLayout layout) { + auto dtype = GetDataType(daliDType, 1); nvcv::TensorDataStridedCuda::Buffer inBuf; - inBuf.basePtr = reinterpret_cast(const_cast(tensor.raw_data())); + inBuf.basePtr = reinterpret_cast(const_cast(data)); inBuf.strides[shape.size() - 1] = dtype.strideBytes(); for (int d = shape.size() - 2; d >= 0; --d) { inBuf.strides[d] = shape[d + 1] * inBuf.strides[d + 1]; } - TensorLayout out_layout = layout.empty() ? tensor.GetLayout() : layout; - DALI_ENFORCE(out_layout.empty() || out_layout.size() == shape.size(), - make_string("Layout ", out_layout, - " does not match the number of dimensions: ", shape.size())); - nvcv::TensorShape out_shape(shape.data(), shape.size(), - nvcv::TensorLayout(out_layout.c_str())); + DALI_ENFORCE( + layout.empty() || layout.size() == shape.size(), + make_string("Layout ", layout, " does not match the number of dimensions: ", shape.size())); + nvcv::TensorShape out_shape(shape.data(), shape.size(), nvcv::TensorLayout(layout.c_str())); nvcv::TensorDataStridedCuda inData(out_shape, dtype, inBuf); return nvcv::TensorWrapData(inData); } +void PushTensorsToBatch(nvcv::TensorBatch &batch, const TensorList &t_list, + TensorLayout layout) { + for (int s = 0; s < t_list.num_samples(); ++s) { + batch.pushBack(AsTensor(t_list[s], layout)); + } +} + +cvcuda::Workspace NVCVOpWorkspace::Allocate(const cvcuda::WorkspaceRequirements &reqs, + kernels::Scratchpad &scratchpad) { + auto *hostBuffer = scratchpad.AllocateHost(reqs.hostMem.size, reqs.hostMem.alignment); + auto *pinnedBuffer = + scratchpad.AllocatePinned(reqs.pinnedMem.size, reqs.pinnedMem.alignment); + auto *gpuBuffer = scratchpad.AllocateGPU(reqs.cudaMem.size, reqs.cudaMem.alignment); + + workspace_.hostMem.data = hostBuffer; + workspace_.hostMem.req = reqs.hostMem; + workspace_.pinnedMem.data = pinnedBuffer; + workspace_.pinnedMem.req = reqs.pinnedMem; + workspace_.cudaMem.data = gpuBuffer; + workspace_.cudaMem.req = reqs.cudaMem; + return workspace_; +} + } // namespace dali::nvcvop diff --git a/dali/operators/nvcvop/nvcvop.h b/dali/operators/nvcvop/nvcvop.h index 57e96a4879..998fd9ba10 100644 --- a/dali/operators/nvcvop/nvcvop.h +++ b/dali/operators/nvcvop/nvcvop.h @@ -15,10 +15,12 @@ #ifndef DALI_OPERATORS_NVCVOP_NVCVOP_H_ #define DALI_OPERATORS_NVCVOP_NVCVOP_H_ -#include #include #include +#include +#include #include +#include #include #include @@ -30,6 +32,7 @@ #include "dali/pipeline/operator/arg_helper.h" #include "dali/pipeline/operator/operator.h" #include "dali/pipeline/operator/sequence_operator.h" +#include "dali/core/cuda_event_pool.h" namespace dali::nvcvop { @@ -52,6 +55,22 @@ NVCVInterpolationType GetInterpolationType(DALIInterpType interpolation_type); */ nvcv::DataKind GetDataKind(DALIDataType dtype); +/** + * @brief Construct a DataType object with a given number of channels and given channel type + */ +nvcv::DataType GetDataType(DALIDataType dtype, int num_channels = 1); + +/** + * @brief Construct a DataType object with a given number of channels and given channel type + * + * @tparam T channel type + */ +template +nvcv::DataType GetDataType(int num_channels = 1) { + return GetDataType(TypeTable::GetTypeId(), num_channels); +} + + /** * @brief Get image format for an image with a given data type and number of channels. * @@ -87,6 +106,15 @@ nvcv::Image AsImage(ConstSampleView sample, const nvcv::ImageFormat nvcv::Tensor AsTensor(const Tensor &tensor, TensorLayout layout = "", const std::optional> &reshape = std::nullopt); +nvcv::Tensor AsTensor(SampleView sample, TensorLayout layout = "", + const std::optional> &reshape = std::nullopt); + +nvcv::Tensor AsTensor(ConstSampleView sample, TensorLayout layout = "", + const std::optional> &reshape = std::nullopt); + +nvcv::Tensor AsTensor(void *data, const TensorShape<> shape, DALIDataType dtype, + TensorLayout layout); + /** * @brief Allocates an image batch using a dynamic scratchpad. * Allocated images have the shape and data type matching the samples in the given TensorList. @@ -99,24 +127,54 @@ void AllocateImagesLike(nvcv::ImageBatchVarShape &output, const TensorList &t_list); -/** - * @brief Construct a DataType object with a given number of channels and given channel type - */ -nvcv::DataType GetDataType(DALIDataType dtype, int num_channels = 1); -/** - * @brief Construct a DataType object with a given number of channels and given channel type - * - * @tparam T channel type - */ -template -nvcv::DataType GetDataType(int num_channels = 1) { - return GetDataType(TypeTable::GetTypeId(), num_channels); +void PushTensorsToBatch(nvcv::TensorBatch &batch, const TensorList &t_list, + TensorLayout layout); + +class NVCVOpWorkspace { + public: + NVCVOpWorkspace() { + CUDA_CALL(cudaGetDevice(&device_id_)); + auto &eventPool = CUDAEventPool::instance(); + workspace_.hostMem.ready = eventPool.Get(device_id_).release(); + workspace_.pinnedMem.ready = eventPool.Get(device_id_).release(); + workspace_.cudaMem.ready = eventPool.Get(device_id_).release(); + } + + cvcuda::Workspace Allocate(const cvcuda::WorkspaceRequirements &reqs, + kernels::Scratchpad &scratchpad); + + ~NVCVOpWorkspace() { + CUDA_DTOR_CALL(cudaEventSynchronize(workspace_.hostMem.ready)); + CUDA_DTOR_CALL(cudaEventSynchronize(workspace_.pinnedMem.ready)); + CUDA_DTOR_CALL(cudaEventSynchronize(workspace_.cudaMem.ready)); + + auto &eventPool = CUDAEventPool::instance(); + eventPool.Put(CUDAEvent(workspace_.hostMem.ready), device_id_); + eventPool.Put(CUDAEvent(workspace_.pinnedMem.ready), device_id_); + eventPool.Put(CUDAEvent(workspace_.cudaMem.ready), device_id_); + } + + private: + cvcuda::Workspace workspace_{}; + int device_id_{}; +}; + +inline cvcuda::WorkspaceRequirements MaxWorkspaceRequirements( + const cvcuda::WorkspaceRequirements &a, const cvcuda::WorkspaceRequirements &b) { + cvcuda::WorkspaceRequirements max; + max.hostMem.size = std::max(a.hostMem.size, b.hostMem.size); + max.hostMem.alignment = std::max(a.hostMem.alignment, b.hostMem.alignment); + max.pinnedMem.size = std::max(a.pinnedMem.size, b.pinnedMem.size); + max.pinnedMem.alignment = std::max(a.pinnedMem.alignment, b.pinnedMem.alignment); + max.cudaMem.size = std::max(a.cudaMem.size, b.cudaMem.size); + max.cudaMem.alignment = std::max(a.cudaMem.alignment, b.cudaMem.alignment); + return max; } /** diff --git a/dali/test/python/checkpointing/test_dali_stateless_operators.py b/dali/test/python/checkpointing/test_dali_stateless_operators.py index e42d3756e1..cb3cd6031d 100644 --- a/dali/test/python/checkpointing/test_dali_stateless_operators.py +++ b/dali/test/python/checkpointing/test_dali_stateless_operators.py @@ -1006,6 +1006,12 @@ def test_warp_perspective_stateless(device): check_single_input(fn.experimental.warp_perspective, device, matrix=np.eye(3)) +@params("gpu") +@stateless_signed_off("experimental.resize") +def test_experimental_resize(device): + check_single_input(fn.experimental.resize, device, resize_x=50, resize_y=50) + + @params("cpu") @stateless_signed_off("zeros", "ones", "full", "zeros_like", "ones_like", "full_like") def test_full_operator_family(device): diff --git a/dali/test/python/operator_2/test_resize.py b/dali/test/python/operator_2/test_resize.py index 11280599da..2d56f7e269 100644 --- a/dali/test/python/operator_2/test_resize.py +++ b/dali/test/python/operator_2/test_resize.py @@ -183,7 +183,26 @@ def resize_PIL(dim, channel_first, dtype, interp, data, size, roi_start, roi_end ) +def resize_op(backend): + if backend in ["cpu", "gpu"]: + return fn.resize + elif backend == "cvcuda": + return fn.experimental.resize + else: + assert False + + +def backend_device(backend): + if backend in ["cvcuda", "gpu"]: + return "gpu" + elif backend == "cpu": + return "cpu" + else: + assert False + + def resize_dali( + backend, input, channel_first, dtype, @@ -198,7 +217,7 @@ def resize_dali( minibatch_size, max_size, ): - return fn.resize( + return resize_op(backend)( input, interp_type=interp, dtype=dtype, @@ -301,7 +320,7 @@ def max_size(dim): def build_pipes( - device, + backend, dim, batch_size, channel_first, @@ -315,7 +334,14 @@ def build_pipes( use_size_input, use_roi, ): - dali_pipe = Pipeline(batch_size=batch_size, num_threads=8, device_id=0, seed=1234) + dali_pipe = Pipeline( + batch_size=batch_size, + num_threads=8, + device_id=0, + seed=1234, + exec_async=False, + exec_pipelined=False, + ) with dali_pipe: if dim == 2: files, labels = dali.fn.readers.caffe(path=db_2d_folder, random_shuffle=True) @@ -323,7 +349,7 @@ def build_pipes( else: images_cpu = dali.fn.external_source(source=random_3d_loader(batch_size), layout="DHWC") - images_hwc = images_cpu if device == "cpu" else images_cpu.gpu() + images_hwc = images_cpu if backend_device(backend) == "cpu" else images_cpu.gpu() if channel_first: images = dali.fn.transpose( @@ -363,6 +389,7 @@ def build_pipes( size = [300, 400] if dim == 2 else [80, 100, 120] resized = resize_dali( + backend, images, channel_first, dtype, @@ -398,6 +425,7 @@ def build_pipes( d = 31 # some other fixed value resized = resize_dali( + backend, images, channel_first, dtype, @@ -456,7 +484,7 @@ def interior(array, channel_first): def _test_ND( - device, + backend, dim, batch_size, channel_first, @@ -471,7 +499,7 @@ def _test_ND( use_roi, ): dali_pipe, pil_pipe = build_pipes( - device, + backend, dim, batch_size, channel_first, @@ -567,7 +595,7 @@ def get_output(): check_batch(dali_interior, ref_interior, batch_size, max_avg_err, max_err) -def _tests(dim, device): +def _tests(dim, backend): batch_size = 2 if dim == 3 else 10 # - Cannot test linear against PIL, because PIL uses triangular filter when downscaling # - Cannot test Nearest Neighbor because rounding errors cause gross discrepancies (pixel shift) @@ -593,7 +621,7 @@ def _tests(dim, device): interp = [types.INTERP_TRIANGULAR, types.INTERP_LANCZOS3][interp] yield ( _test_ND, - device, + backend, dim, batch_size, False, @@ -629,7 +657,17 @@ def test_3D_cpu(): yield (f, *args) -def _test_stitching(device, dim, channel_first, dtype, interp): +def test_2D_cvcuda(): + for f, *args in _tests(2, "cvcuda"): + yield (f, *args) + + +def test_3D_cvcuda(): + for f, *args in _tests(3, "cvcuda"): + yield (f, *args) + + +def _test_stitching(backend, dim, channel_first, dtype, interp): batch_size = 1 if dim == 3 else 10 pipe = dali.pipeline.Pipeline( batch_size=batch_size, num_threads=1, device_id=0, seed=1234, prefetch_queue_depth=1 @@ -641,7 +679,7 @@ def _test_stitching(device, dim, channel_first, dtype, interp): else: images_cpu = dali.fn.external_source(source=random_3d_loader(batch_size), layout="DHWC") - images_hwc = images_cpu if device == "cpu" else images_cpu.gpu() + images_hwc = images_cpu if backend_device(backend) == "cpu" else images_cpu.gpu() if channel_first: images = dali.fn.transpose( @@ -656,7 +694,7 @@ def _test_stitching(device, dim, channel_first, dtype, interp): roi_start = [0] * dim roi_end = [1] * dim - resized = fn.resize( + resized = resize_op(backend)( images, dtype=dtype, min_filter=interp, mag_filter=interp, size=out_size_full ) @@ -689,7 +727,7 @@ def _test_stitching(device, dim, channel_first, dtype, interp): pipe.build() for iter in range(1): out = pipe.run() - if device == "gpu": + if backend_device(backend) == "gpu": out = [x.as_cpu() for x in out] whole = out[0] tiled = [] @@ -716,7 +754,7 @@ def _test_stitching(device, dim, channel_first, dtype, interp): def test_stitching(): - for device in ["cpu", "gpu"]: + for backend in ["cpu", "gpu", "cvcuda"]: for dim in [3]: for dtype in [types.UINT8, types.FLOAT]: for channel_first in [False, True]: @@ -726,10 +764,10 @@ def test_stitching(): types.INTERP_TRIANGULAR, types.INTERP_LANCZOS3, ]: - yield _test_stitching, device, dim, channel_first, dtype, interp + yield _test_stitching, backend, dim, channel_first, dtype, interp -def _test_empty_input(dim, device): +def _test_empty_input(dim, backend): batch_size = 8 pipe = Pipeline(batch_size=batch_size, num_threads=8, device_id=0, seed=1234) if dim == 2: @@ -738,7 +776,7 @@ def _test_empty_input(dim, device): else: images_cpu = dali.fn.external_source(source=random_3d_loader(batch_size), layout="DHWC") - images = images_cpu if device == "cpu" else images_cpu.gpu() + images = images_cpu if backend_device(backend) == "cpu" else images_cpu.gpu() in_rel_shapes = np.ones([batch_size, dim], dtype=np.float32) @@ -751,15 +789,15 @@ def _test_empty_input(dim, device): sizes = np.random.randint(20, 50, [batch_size, dim], dtype=np.int32) size_inp = fn.external_source(lambda: [x.astype(np.float32) for x in sizes]) - resize_no_empty = fn.resize(images, size=size_inp, mode="not_larger") - resize_with_empty = fn.resize(degenerate_images, size=size_inp, mode="not_larger") + resize_no_empty = resize_op(backend)(images, size=size_inp, mode="not_larger") + resize_with_empty = resize_op(backend)(degenerate_images, size=size_inp, mode="not_larger") pipe.set_outputs(resize_no_empty, resize_with_empty) pipe.build() for it in range(3): out_no_empty, out_with_empty = pipe.run() - if device == "gpu": + if backend_device(backend) == "gpu": out_no_empty = out_no_empty.as_cpu() out_with_empty = out_with_empty.as_cpu() for i in range(batch_size): @@ -770,12 +808,12 @@ def _test_empty_input(dim, device): def test_empty_input(): - for device in ["cpu", "gpu"]: + for backend in ["cpu", "gpu", "cvcuda"]: for dim in [2, 3]: - yield _test_empty_input, dim, device + yield _test_empty_input, dim, backend -def _test_very_small_output(dim, device): +def _test_very_small_output(dim, backend): batch_size = 8 pipe = Pipeline(batch_size=batch_size, num_threads=8, device_id=0, seed=1234) if dim == 2: @@ -784,9 +822,9 @@ def _test_very_small_output(dim, device): else: images_cpu = dali.fn.external_source(source=random_3d_loader(batch_size), layout="DHWC") - images = images_cpu if device == "cpu" else images_cpu.gpu() + images = images_cpu if backend_device(backend) == "cpu" else images_cpu.gpu() - resize_tiny = fn.resize(images, size=1e-10) + resize_tiny = resize_op(backend)(images, size=1e-10) pipe.set_outputs(resize_tiny) pipe.build() @@ -799,9 +837,9 @@ def _test_very_small_output(dim, device): def test_very_small_output(): - for device in ["cpu", "gpu"]: + for backend in ["cpu", "gpu", "cvcuda"]: for dim in [2, 3]: - yield _test_very_small_output, dim, device + yield _test_very_small_output, dim, backend large_data = None diff --git a/dali/test/python/test_dali_cpu_only.py b/dali/test/python/test_dali_cpu_only.py index fa18df260f..6e1331e5a6 100644 --- a/dali/test/python/test_dali_cpu_only.py +++ b/dali/test/python/test_dali_cpu_only.py @@ -1627,6 +1627,7 @@ def test_io_file_read_cpu(): "experimental.dilate", # not supported for CPU "experimental.erode", # not supported for CPU "experimental.warp_perspective", # not supported for CPU + "experimental.resize" # not supported for CPU "plugin.video.decoder", # not supported for CPU ] diff --git a/dali/test/python/test_dali_variable_batch_size.py b/dali/test/python/test_dali_variable_batch_size.py index 2b159dc3d1..7944ba5976 100644 --- a/dali/test/python/test_dali_variable_batch_size.py +++ b/dali/test/python/test_dali_variable_batch_size.py @@ -375,6 +375,7 @@ def numba_setup_out_shape(out_shape, in_shape): (fn.experimental.dilate, {"devices": ["gpu"]}), (fn.experimental.erode, {"devices": ["gpu"]}), (fn.experimental.warp_perspective, {"matrix": np.eye(3), "devices": ["gpu"]}), + (fn.experimental.resize, {"resize_x": 50, "resize_y": 50, "devices": ["gpu"]}), (fn.zeros_like, {"devices": ["cpu"]}), (fn.ones_like, {"devices": ["cpu"]}), ]