diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index a2cb3279438a6..48bd6e8a33927 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -79,7 +79,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, FusedConv); #if defined(CUDA_VERSION) && CUDA_VERSION >= 11000 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, FastGelu); @@ -175,7 +174,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, #endif - BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/contrib_ops/cuda/fused_conv.cc b/onnxruntime/contrib_ops/cuda/fused_conv.cc deleted file mode 100644 index 6cce3658719b2..0000000000000 --- a/onnxruntime/contrib_ops/cuda/fused_conv.cc +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cuda/nn/conv.h" -#include "core/providers/cuda/cuda_common.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -template -class FusedConv : public onnxruntime::cuda::Conv { - public: - using Base = onnxruntime::cuda::Conv; - FusedConv(const OpKernelInfo& info) : onnxruntime::cuda::Conv(info) { - std::string activation; - if (info.GetAttr("activation", &activation) == Status::OK() && - MapMode(activation) == Status::OK() && - cudnnCreateActivationDescriptor(&activation_desc_) == CUDNN_STATUS_SUCCESS) { - status_ = cudnnSetActivationDescriptor(activation_desc_, - activation_mode_, - cudnnNanPropagation_t::CUDNN_NOT_PROPAGATE_NAN, - std::numeric_limits::max()); - } - } - - ORT_DISALLOW_COPY_AND_ASSIGNMENT(FusedConv); - - ~FusedConv() { - if (activation_desc_) { - cudnnDestroyActivationDescriptor(activation_desc_); - status_ = CUDNN_STATUS_NOT_INITIALIZED; - activation_desc_ = nullptr; - } - } - - Status ComputeInternal(OpKernelContext* context) const override { - CUDNN_RETURN_IF_ERROR(status_); - std::lock_guard lock(Base::s_.mutex); - ORT_RETURN_IF_ERROR(Base::UpdateState(context, true)); - if (Base::s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - bool has_z = nullptr != Base::s_.z_data; - bool has_b = nullptr != Base::s_.b_data; - auto alpha = &(Base::alpha_); - auto beta = &(Base::beta_); - IAllocatorUniquePtr workspace = Base::GetWorkSpace(); - auto cudnn_status = cudnnConvolutionBiasActivationForward(Base::CudnnHandle(), - alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.algo, - workspace.get(), - Base::s_.workspace_bytes, - has_z ? alpha : beta, - has_z ? Base::s_.z_tensor : Base::s_.y_tensor, - has_z ? Base::s_.z_data : Base::s_.y_data, - Base::s_.b_tensor, - has_b ? Base::s_.b_data : Base::s_.b_zero, - activation_desc_, - Base::s_.y_tensor, - Base::s_.y_data); - if (CUDNN_STATUS_SUCCESS != cudnn_status) { - CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(Base::CudnnHandle(), - alpha, - Base::s_.x_tensor, - Base::s_.x_data, - Base::s_.w_desc, - Base::s_.w_data, - Base::s_.conv_desc, - Base::s_.algo, - workspace.get(), - Base::s_.workspace_bytes, - beta, - Base::s_.y_tensor, - Base::s_.y_data)); - if (has_b) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), alpha, Base::s_.b_tensor, Base::s_.b_data, - alpha, Base::s_.y_tensor, Base::s_.y_data)); - } - if (has_z) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(Base::CudnnHandle(), alpha, Base::s_.z_tensor, Base::s_.z_data, - alpha, Base::s_.y_tensor, Base::s_.y_data)); - } - CUDNN_RETURN_IF_ERROR(cudnnActivationForward(Base::CudnnHandle(), activation_desc_, alpha, Base::s_.y_tensor, - Base::s_.y_data, beta, Base::s_.y_tensor, Base::s_.y_data)); - } - if (Base::s_.post_slicing_required) { - onnxruntime::cuda::SliceOutUnwantedOutputSection(this->Stream(), Base::s_.y_data, Base::s_.y_dims_with_adjusted_pads, Base::s_.Y->MutableDataRaw(), - Base::s_.y_dims, Base::s_.slice_starts, Base::s_.slice_ends, Base::s_.slice_axes, Base::s_.element_size); - } - return Status::OK(); - } - - private: - Status MapMode(const std::string& activaton_mode) { - if (activaton_mode == "Relu") { - activation_mode_ = cudnnActivationMode_t::CUDNN_ACTIVATION_RELU; - } else { - return Status(common::StatusCategory::ONNXRUNTIME, - common::StatusCode::INVALID_ARGUMENT, - "unsupported conv activation mode"); - } - return Status::OK(); - } - cudnnStatus_t status_ = CUDNN_STATUS_NOT_INITIALIZED; - cudnnActivationMode_t activation_mode_; - cudnnActivationDescriptor_t activation_desc_ = nullptr; -}; - -ONNX_OPERATOR_TYPED_KERNEL_EX( - FusedConv, - kMSDomain, - 1, - float, - kCudaExecutionProvider, - KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType()), - FusedConv); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index fdda10f467a5e..43c539867d477 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1273,12 +1273,6 @@ activation.)DOC") "", "T", OpSchema::Optional) - .Input( - 3, - "Z", - "", - "T", - OpSchema::Optional) .Output( 0, "Y", diff --git a/onnxruntime/core/optimizer/conv_activation_fusion.cc b/onnxruntime/core/optimizer/conv_activation_fusion.cc index 933d374eaf26c..295468d61877a 100644 --- a/onnxruntime/core/optimizer/conv_activation_fusion.cc +++ b/onnxruntime/core/optimizer/conv_activation_fusion.cc @@ -100,106 +100,50 @@ Status ConvActivationFusion::ApplyImpl(Graph& graph, bool& modified, int graph_l continue; } - if (node->GetExecutionProviderType() == onnxruntime::kCudaExecutionProvider) { - if (node->InputDefs()[0]->TypeAsProto()->tensor_type().elem_type() != - ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { - continue; - } - if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13})) { - Node& conv_node = *node; - Node& act_node = *graph.GetNode(next_node.Index()); - auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" + act_node.Name()); - Node& fused_conv = graph.AddNode(node_name, - "FusedConv", - node_name, - conv_node.MutableInputDefs(), - {}, - &conv_node.GetAttributes(), - onnxruntime::kMSDomain); - fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType()); - fused_conv.AddAttribute("activation", "Relu"); - graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv); - modified = true; - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Add", {6, 7, 13})) { - const auto& last_node = *(next_node.OutputNodesBegin()); - if (last_node.GetExecutionProviderType() != node->GetExecutionProviderType()) { - continue; - } - if (graph_utils::IsSupportedOptypeVersionAndDomain(last_node, "Relu", {6, 13}) && - next_node.GetOutputEdgesCount() == 1) { - Node& conv_node = *node; - Node& add_node = *graph.GetNode(next_node.Index()); - Node& act_node = *graph.GetNode(last_node.Index()); - auto conv_inputs = conv_node.MutableInputDefs(); - auto conv_outputs = conv_node.MutableOutputDefs(); - auto add_inputs = add_node.MutableInputDefs(); - for (auto add_input : add_inputs) { - if (add_input->Name() != conv_outputs[0]->Name()) { - conv_inputs.push_back(add_input); - break; - } - } - auto node_name = graph.GenerateNodeName(conv_node.Name() + "_" + - add_node.Name() + "_" + - act_node.Name()); - Node& fused_conv = graph.AddNode(node_name, - "FusedConv", - node_name, - conv_inputs, - {}, &conv_node.GetAttributes(), - onnxruntime::kMSDomain); - fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType()); - fused_conv.AddAttribute("activation", "Relu"); - graph_utils::FinalizeNodeFusion(graph, {conv_node, add_node, act_node}, fused_conv); - modified = true; - } - } - } else { - // Test if this is an activation that can be fused and also extract the - // activation's parameters. - std::vector activation_params; - if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) && - !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) && - !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) { - if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) { - activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f()); - } else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) { - float min, max; - if (GetClipConstantMinMax(graph, next_node, min, max)) { - activation_params.push_back(min); - activation_params.push_back(max); - } else { - continue; - } + // Test if this is an activation that can be fused and also extract the + // activation's parameters. + std::vector activation_params; + if (!graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Relu", {6, 13}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Sigmoid", {6, 13}) && + !graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Tanh", {6, 13})) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "LeakyRelu", {6})) { + activation_params.push_back(graph_utils::GetNodeAttribute(next_node, "alpha")->f()); + } else if (graph_utils::IsSupportedOptypeVersionAndDomain(next_node, "Clip", {6, 11, 12, 13})) { + float min, max; + if (GetClipConstantMinMax(graph, next_node, min, max)) { + activation_params.push_back(min); + activation_params.push_back(max); } else { continue; } + } else { + continue; } + } - Node& conv_node = *node; - Node& act_node = *graph.GetNode(next_node.Index()); + Node& conv_node = *node; + Node& act_node = *graph.GetNode(next_node.Index()); - Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node.Name()), "FusedConv", - "fused Conv " + conv_node.Name() + "with activation " + act_node.OpType(), - conv_node.MutableInputDefs(), - {}, - &conv_node.GetAttributes(), - "com.microsoft"); + Node& fused_conv = graph.AddNode(graph.GenerateNodeName("fused " + conv_node.Name()), "FusedConv", + "fused Conv " + conv_node.Name() + "with activation " + act_node.OpType(), + conv_node.MutableInputDefs(), + {}, + &conv_node.GetAttributes(), + "com.microsoft"); - // Assign provider to this new node. Provider should be same as the provider for old node. - fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType()); + // Assign provider to this new node. Provider should be same as the provider for old node. + fused_conv.SetExecutionProviderType(conv_node.GetExecutionProviderType()); - // Add attributes to specify the activation type and parameters. - fused_conv.AddAttribute("activation", next_node.OpType()); - if (activation_params.size() > 0) { - fused_conv.AddAttribute("activation_params", activation_params); - } + // Add attributes to specify the activation type and parameters. + fused_conv.AddAttribute("activation", next_node.OpType()); + if (activation_params.size() > 0) { + fused_conv.AddAttribute("activation_params", activation_params); + } - // move output definitions and edges from act_node to fused_conv. delete conv_node and act_node. - graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv); + // move output definitions and edges from act_node to fused_conv. delete conv_node and act_node. + graph_utils::FinalizeNodeFusion(graph, {conv_node, act_node}, fused_conv); - modified = true; - } + modified = true; } return Status::OK(); diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index d9e31a976f27a..2e8ed66f3af41 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -142,9 +142,9 @@ std::vector> GenerateTransformers(TransformerL transformers.emplace_back(onnxruntime::make_unique(cpu_execution_providers)); std::unordered_set cpu_acl_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider}; - std::unordered_set cpu_cuda_acl_armnn_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider}; + std::unordered_set cpu_acl_armnn_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kAclExecutionProvider, onnxruntime::kArmNNExecutionProvider}; - transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_acl_armnn_execution_providers)); + transformers.emplace_back(onnxruntime::make_unique(cpu_acl_armnn_execution_providers)); std::unordered_set cpu_cuda_execution_providers = {onnxruntime::kCpuExecutionProvider, onnxruntime::kCudaExecutionProvider}; transformers.emplace_back(onnxruntime::make_unique(cpu_cuda_execution_providers)); diff --git a/onnxruntime/core/providers/cuda/nn/conv.cc b/onnxruntime/core/providers/cuda/nn/conv.cc index f61f93fab597d..75d8b2479f3dc 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.cc +++ b/onnxruntime/core/providers/cuda/nn/conv.cc @@ -34,15 +34,15 @@ REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) REGISTER_KERNEL_TYPED(MLFloat16) -Status SliceOutUnwantedOutputSection(cudaStream_t stream, - const void* input_data, - const std::vector& input_dims, - void* output_data, - const std::vector& output_dims, - std::vector starts, - const std::vector& ends, - const std::vector& axes, - size_t element_size) { +static Status SliceOutUnwantedOutputSection(cudaStream_t stream, + const void* input_data, + const std::vector& input_dims, + void* output_data, + const std::vector& output_dims, + std::vector starts, + const std::vector& ends, + const std::vector& axes, + size_t element_size) { SliceOp::PrepareForComputeMetadata compute_metadata(input_dims); SliceBase::PrepareForCompute(starts, ends, axes, compute_metadata); @@ -54,264 +54,275 @@ Status SliceOutUnwantedOutputSection(cudaStream_t stream, } template -Status Conv::UpdateState(OpKernelContext* context, bool bias_expected) const { - //set X +Status Conv::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + const Tensor* X = context->Input(0); const TensorShape& x_shape = X->Shape(); const auto& x_dims = x_shape.GetDims(); - s_.x_data = reinterpret_cast(X->template Data()); - s_.element_size = X->DataType()->Size(); - //set W + auto x_data = reinterpret_cast(X->template Data()); + const Tensor* W = context->Input(1); const TensorShape& w_shape = W->Shape(); std::vector w_dims = w_shape.GetDims(); - s_.w_data = reinterpret_cast(W->template Data()); - //set B - if (context->InputCount() >= 3) { - const Tensor* B = context->Input(2); - s_.b_data = reinterpret_cast(B->template Data()); - } else { - s_.b_data = nullptr; - } - //set Z - if (context->InputCount() >= 4) { - const Tensor* Z = context->Input(3); - ORT_RETURN_IF_ERROR(s_.z_tensor.Set(Z->Shape().GetDims(), CudnnTensor::GetDataType())); - s_.z_data = reinterpret_cast(Z->template Data()); - } else { - s_.z_data = nullptr; - } - bool input_dims_changed = (s_.last_x_dims != x_dims); - bool w_dims_changed = (s_.last_w_dims != w_dims); - if (input_dims_changed || w_dims_changed) { - if (input_dims_changed) - s_.last_x_dims = x_dims; - - if (w_dims_changed) { - s_.last_w_dims = w_dims; - s_.cached_benchmark_results.clear(); - } + auto w_data = reinterpret_cast(W->template Data()); - const int64_t N = X->Shape()[0]; - const int64_t M = W->Shape()[0]; + size_t num_inputs = OpKernel::Node().InputDefs().size(); + bool has_bias = (num_inputs == 3); - ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); + CudaT* y_data = nullptr; - std::vector kernel_shape; - ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); - auto rank = kernel_shape.size(); - std::vector pads(conv_attrs_.pads); - if (pads.empty()) { - pads.resize(rank * 2, 0); - } - std::vector dilations(conv_attrs_.dilations); - if (dilations.empty()) { - dilations.resize(rank, 1); - } - std::vector strides(conv_attrs_.strides); - if (strides.empty()) { - strides.resize(rank, 1); - } + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); - std::vector y_dims; - y_dims.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' - y_dims.insert(y_dims.begin(), {N, M}); - - std::vector y_dims_with_adjusted_pads; - y_dims_with_adjusted_pads.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' - y_dims_with_adjusted_pads.insert(y_dims_with_adjusted_pads.begin(), {N, M}); - - bool post_slicing_required = false; - std::vector slice_starts; - slice_starts.reserve(rank); - - std::vector slice_ends; - slice_ends.reserve(rank); - - std::vector slice_axes; - slice_axes.reserve(rank); - - ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(x_shape.Slice(2), kernel_shape, - strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, - post_slicing_required, slice_starts, slice_ends, slice_axes)); - ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); - s_.y_dims = y_dims; - s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; - s_.post_slicing_required = post_slicing_required; - s_.slice_starts = slice_starts; - s_.slice_ends = slice_ends; - s_.slice_axes = slice_axes; - - s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - if (post_slicing_required) { - // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * s_.element_size); - s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); - } else { - // No post slicing needed. Fill the output tensor's buffer directly. - s_.y_data = reinterpret_cast(s_.Y->template MutableData()); - } + size_t element_size = X->DataType()->Size(); - std::vector x_dims_cudnn = x_dims; - std::vector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; - if (rank < 2) { - // cudnn only takes 4D or 5D input, so pad dimensions if needed - x_dims_cudnn.push_back(1); - y_dims_cudnn.push_back(1); - w_dims.push_back(1); - pads.insert(pads.begin() + rank, 0); - pads.insert(pads.end(), 0); - kernel_shape.push_back(1); - strides.push_back(1); - dilations.push_back(1); - } + Tensor* Y = nullptr; - if (w_dims_changed) { - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); - } - ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); - ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, - CUDNN_CROSS_CORRELATION, CudnnTensor::GetDataType())); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, gsl::narrow_cast(conv_attrs_.group))); + // We may have to write the CuDNN Conv results to a temporary bufferwhen we deal with + // asymmetric padding as we have to take the results written to this temporary buffer and slice out + // extraneous portions of the result + IAllocatorUniquePtr memory_for_cudnn_conv_results; - if (context->InputCount() >= 3) { - const Tensor* B = context->Input(2); - const auto& b_shape = B->Shape(); - ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); - std::vector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = b_shape[0]; - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); - //s_.b_data = reinterpret_cast(B->template Data()); - } else if (bias_expected) { - std::vector b_dims(2 + kernel_shape.size(), 1); - b_dims[1] = w_dims[0]; - auto malloc_size = b_dims[1] * sizeof(CudaT); - ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); - if (s_.b_zero) { - CUDA_CALL_THROW(cudaFree(s_.b_zero)); - s_.b_zero = nullptr; + { + std::lock_guard lock(s_.mutex); + // TODO: add a global cache if need to handle cases for multiple frames running simultaneuously with different batch_size + bool input_dims_changed = (s_.last_x_dims != x_dims); + bool w_dims_changed = (s_.last_w_dims != w_dims); + if (input_dims_changed || w_dims_changed) { + if (input_dims_changed) + s_.last_x_dims = x_dims; + + if (w_dims_changed) { + s_.last_w_dims = w_dims; + s_.cached_benchmark_results.clear(); + } + + const int64_t N = X->Shape()[0]; + const int64_t M = W->Shape()[0]; + + ORT_RETURN_IF_ERROR(conv_attrs_.ValidateInputShape(X, W)); + + std::vector kernel_shape; + ORT_RETURN_IF_ERROR(conv_attrs_.ComputeKernelShape(W->Shape(), kernel_shape)); + auto rank = kernel_shape.size(); + std::vector pads(conv_attrs_.pads); + if (pads.empty()) { + pads.resize(rank * 2, 0); + } + std::vector dilations(conv_attrs_.dilations); + if (dilations.empty()) { + dilations.resize(rank, 1); + } + std::vector strides(conv_attrs_.strides); + if (strides.empty()) { + strides.resize(rank, 1); + } + + std::vector y_dims; + y_dims.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' + y_dims.insert(y_dims.begin(), {N, M}); + + std::vector y_dims_with_adjusted_pads; + y_dims_with_adjusted_pads.reserve(2 + rank); // rank indicates number of feature dimensions - so add 2 to account for 'N' and 'C' + y_dims_with_adjusted_pads.insert(y_dims_with_adjusted_pads.begin(), {N, M}); + + bool post_slicing_required = false; + std::vector slice_starts; + slice_starts.reserve(rank); + + std::vector slice_ends; + slice_ends.reserve(rank); + + std::vector slice_axes; + slice_axes.reserve(rank); + + ORT_RETURN_IF_ERROR(conv_attrs_.InferOutputShapeWithAdjustedPads(x_shape.Slice(2), kernel_shape, + strides, dilations, pads, y_dims, y_dims_with_adjusted_pads, + post_slicing_required, slice_starts, slice_ends, slice_axes)); + ORT_ENFORCE(y_dims.size() == y_dims_with_adjusted_pads.size()); + s_.y_dims = y_dims; + s_.y_dims_with_adjusted_pads = y_dims_with_adjusted_pads; + s_.post_slicing_required = post_slicing_required; + s_.slice_starts = slice_starts; + s_.slice_ends = slice_ends; + s_.slice_axes = slice_axes; + + Y = context->Output(0, TensorShape(s_.y_dims)); + if (!post_slicing_required) { + // No post slicing needed. Fill the output tensor's buffer directly. + y_data = reinterpret_cast(Y->template MutableData()); + } else { + // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. + memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(y_dims_with_adjusted_pads).Size() * element_size); + y_data = reinterpret_cast(memory_for_cudnn_conv_results.get()); + } + + std::vector x_dims_cudnn = x_dims; + std::vector y_dims_cudnn = !post_slicing_required ? y_dims : y_dims_with_adjusted_pads; + if (rank < 2) { + // cudnn only takes 4D or 5D input, so pad dimensions if needed + x_dims_cudnn.push_back(1); + y_dims_cudnn.push_back(1); + w_dims.push_back(1); + pads.insert(pads.begin() + rank, 0); + pads.insert(pads.end(), 0); + kernel_shape.push_back(1); + strides.push_back(1); + dilations.push_back(1); + } + + if (w_dims_changed) + ORT_RETURN_IF_ERROR(s_.filter_desc.Set(w_dims, CudnnTensor::GetDataType())); + + // Special case when there is a dim value of 0 in the shape. + // Return only after we have cached the following for subsequent runs : + // 1) `w_dims` in the `filter_desc` + // 2) `y_dims` in s_.y_dims + if (Y->Shape().Size() == 0) { + return Status::OK(); + } + + ORT_RETURN_IF_ERROR(s_.x_tensor.Set(x_dims_cudnn, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.y_tensor.Set(y_dims_cudnn, CudnnTensor::GetDataType())); + + cudnnConvolutionMode_t mode = CUDNN_CROSS_CORRELATION; + ORT_RETURN_IF_ERROR(s_.conv_desc.Set(kernel_shape.size(), pads, strides, dilations, + mode, CudnnTensor::GetDataType())); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionGroupCount(s_.conv_desc, gsl::narrow_cast(conv_attrs_.group))); + + if (has_bias) { + const Tensor* B = context->Input(2); + const auto& b_shape = B->Shape(); + ORT_RETURN_IF_NOT(b_shape.NumDimensions() == 1, "bias should be 1D"); + std::vector b_dims(2 + kernel_shape.size()); + b_dims[0] = 1; // N + b_dims[1] = b_shape[0]; // C + for (size_t i = 0; i < kernel_shape.size(); i++) b_dims[2 + i] = 1; + + ORT_RETURN_IF_ERROR(s_.b_tensor.Set(b_dims, CudnnTensor::GetDataType())); } - CUDA_CALL_THROW(cudaMalloc(&s_.b_zero, malloc_size)); - CUDA_CALL_THROW(cudaMemsetAsync(s_.b_zero, 0, malloc_size, Stream())); + + if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { + IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize); + + // set math type to tensor core before algorithm search + if (std::is_same::value) + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); + + cudnnConvolutionFwdAlgoPerf_t perf; + int algo_count = 1; + const CUDAExecutionProvider* cuda_ep = static_cast(this->Info().GetExecutionProvider()); + int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); + ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); + switch (cudnn_conv_algo) { + case 0: + CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( + CudnnHandle(), + s_.x_tensor, + x_data, + s_.filter_desc, + w_data, + s_.conv_desc, + s_.y_tensor, + y_data, + 1, + &algo_count, + &perf, + algo_search_workspace.get(), + AlgoSearchWorkspaceSize)); + break; + + case 1: + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( + CudnnHandle(), + s_.x_tensor, + s_.filter_desc, + s_.conv_desc, + s_.y_tensor, + 1, + &algo_count, + &perf)); + break; + + default: + perf.algo = kDefaultConvAlgo; + CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardWorkspaceSize( + CudnnHandle(), + s_.x_tensor, + s_.filter_desc, + s_.conv_desc, + s_.y_tensor, + perf.algo, + &perf.memory)); + if (std::is_same::value) { + perf.mathType = CUDNN_TENSOR_OP_MATH; + } + else { + perf.mathType = CUDNN_DEFAULT_MATH; + } + } + + s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); + } + + const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn); + CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); + s_.algo = perf.algo; + s_.workspace_bytes = perf.memory; } - if (!s_.cached_benchmark_results.contains(x_dims_cudnn)) { - IAllocatorUniquePtr algo_search_workspace = GetScratchBuffer(AlgoSearchWorkspaceSize); - - // set math type to tensor core before algorithm search - if (std::is_same::value) - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, CUDNN_TENSOR_OP_MATH)); - - cudnnConvolutionFwdAlgoPerf_t perf; - int algo_count = 1; - const CUDAExecutionProvider* cuda_ep = static_cast(this->Info().GetExecutionProvider()); - int cudnn_conv_algo = cuda_ep->GetCudnnConvAlgo(); - ORT_ENFORCE(cudnn_conv_algo > -1 && cudnn_conv_algo < 3, "cudnn_conv_algo should be 0, 1 or 2, but got ", cudnn_conv_algo); - switch (cudnn_conv_algo) { - case 0: - CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionForwardAlgorithmEx( - CudnnHandle(), - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.y_tensor, - s_.y_data, - 1, - &algo_count, - &perf, - algo_search_workspace.get(), - AlgoSearchWorkspaceSize)); - break; - - case 1: - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardAlgorithm_v7( - CudnnHandle(), - s_.x_tensor, - s_.w_desc, - s_.conv_desc, - s_.y_tensor, - 1, - &algo_count, - &perf)); - break; - - default: - perf.algo = kDefaultConvAlgo; - CUDNN_RETURN_IF_ERROR(cudnnGetConvolutionForwardWorkspaceSize( - CudnnHandle(), - s_.x_tensor, - s_.w_desc, - s_.conv_desc, - s_.y_tensor, - perf.algo, - &perf.memory)); - if (std::is_same::value) { - perf.mathType = CUDNN_TENSOR_OP_MATH; - } else { - perf.mathType = CUDNN_DEFAULT_MATH; - } + if (!y_data) { + Y = context->Output(0, TensorShape(s_.y_dims)); + // special case when there is a dim value of 0 in the shape. + if (Y->Shape().Size() == 0) + return Status::OK(); + + if (!s_.post_slicing_required) { + y_data = reinterpret_cast(Y->template MutableData()); + } else { + // Post slicing needed. Create and fill in the Conv results in an intermediate buffer. + memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * element_size); + y_data = reinterpret_cast(memory_for_cudnn_conv_results.get()); } - s_.cached_benchmark_results.insert(x_dims_cudnn, {perf.algo, perf.memory, perf.mathType}); } - const auto& perf = s_.cached_benchmark_results.at(x_dims_cudnn); - CUDNN_RETURN_IF_ERROR(cudnnSetConvolutionMathType(s_.conv_desc, perf.mathType)); - s_.algo = perf.algo; - s_.workspace_bytes = perf.memory; - } else { - //set Y - s_.Y = context->Output(0, TensorShape(s_.y_dims)); - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); + + const auto alpha = Consts::One; + const auto beta = Consts::Zero; + + IAllocatorUniquePtr workspace = GetScratchBuffer(s_.workspace_bytes); + + CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(CudnnHandle(), + &alpha, + s_.x_tensor, + x_data, + s_.filter_desc, + w_data, + s_.conv_desc, + s_.algo, + workspace.get(), + s_.workspace_bytes, + &beta, + s_.y_tensor, + y_data)); + + if (has_bias) { + const Tensor* B = context->Input(2); + auto b_data = reinterpret_cast(B->template Data()); + CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha, s_.b_tensor, b_data, &alpha, s_.y_tensor, + y_data)); } + + // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions + // This may have lead to extra results that are unnecessary and hence we slice that off here if (s_.post_slicing_required) { - s_.memory_for_cudnn_conv_results = GetScratchBuffer(TensorShape(s_.y_dims_with_adjusted_pads).Size() * s_.element_size); - s_.y_data = reinterpret_cast(s_.memory_for_cudnn_conv_results.get()); - } else { - s_.y_data = reinterpret_cast(s_.Y->template MutableData()); + SliceOutUnwantedOutputSection(Stream(), y_data, s_.y_dims_with_adjusted_pads, Y->MutableDataRaw(), + s_.y_dims, s_.slice_starts, s_.slice_ends, s_.slice_axes, element_size); } } - return Status::OK(); -} -template -Status Conv::ComputeInternal(OpKernelContext* context) const { - std::lock_guard lock(s_.mutex); - ORT_RETURN_IF_ERROR(UpdateState(context)); - if (s_.Y->Shape().Size() == 0) { - return Status::OK(); - } - IAllocatorUniquePtr workspace = GetWorkSpace(); - CUDNN_RETURN_IF_ERROR(cudnnConvolutionForward(CudnnHandle(), - &alpha_, - s_.x_tensor, - s_.x_data, - s_.w_desc, - s_.w_data, - s_.conv_desc, - s_.algo, - workspace.get(), - s_.workspace_bytes, - &beta_, - s_.y_tensor, - s_.y_data)); - if (nullptr != s_.b_data) { - CUDNN_RETURN_IF_ERROR(cudnnAddTensor(CudnnHandle(), &alpha_, s_.b_tensor, s_.b_data, - &alpha_, s_.y_tensor, s_.y_data)); - } - // To deal with asymmetric padding, we may have over-padded on one or both sides of the spatial dimensions - // This may have lead to extra results that are unnecessary and hence we slice that off here - if (s_.post_slicing_required) { - SliceOutUnwantedOutputSection(Stream(), s_.y_data, s_.y_dims_with_adjusted_pads, s_.Y->MutableDataRaw(), - s_.y_dims, s_.slice_starts, s_.slice_ends, s_.slice_axes, s_.element_size); - } return Status::OK(); -} // namespace cuda +} CudnnConvolutionDescriptor::CudnnConvolutionDescriptor() : desc_(nullptr) { } diff --git a/onnxruntime/core/providers/cuda/nn/conv.h b/onnxruntime/core/providers/cuda/nn/conv.h index 04f9865a1a16c..c0ad84abc025f 100644 --- a/onnxruntime/core/providers/cuda/nn/conv.h +++ b/onnxruntime/core/providers/cuda/nn/conv.h @@ -121,18 +121,9 @@ struct CudnnConvState { size_t workspace_bytes; decltype(AlgoPerfType().algo) algo; CudnnTensor x_tensor; - const void* x_data = nullptr; - size_t element_size = 0; - CudnnFilterDescriptor w_desc; - const void* w_data = nullptr; + CudnnFilterDescriptor filter_desc; CudnnTensor b_tensor; - const void* b_data = nullptr; - void* b_zero = nullptr; CudnnTensor y_tensor; - Tensor* Y = nullptr; - void* y_data = nullptr; - CudnnTensor z_tensor; - const void* z_data = nullptr; CudnnConvolutionDescriptor conv_desc; struct PerfResultParams { @@ -151,14 +142,6 @@ struct CudnnConvState { // note that conv objects are shared between execution frames, and a lock is needed to avoid multi-thread racing OrtMutex mutex; - IAllocatorUniquePtr memory_for_cudnn_conv_results; - - ~CudnnConvState() { - if (b_zero) { - CUDA_CALL_THROW(cudaFree(b_zero)); - b_zero = nullptr; - } - } }; enum : size_t { @@ -168,8 +151,6 @@ enum : size_t { template class Conv : public CudaKernel { public: - using CudaT = typename ToCudaType::MappedType; - Conv(const OpKernelInfo& info) : CudaKernel(info), conv_attrs_(info) { auto pads_size = conv_attrs_.pads.size(); ORT_ENFORCE(pads_size % 2 == 0); @@ -177,26 +158,10 @@ class Conv : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; - protected: - inline IAllocatorUniquePtr GetWorkSpace() const { - return GetScratchBuffer(s_.workspace_bytes); - } - const CudaT alpha_ = Consts::One; - const CudaT beta_ = Consts::Zero; - Status UpdateState(OpKernelContext* context, bool bias_expected = false) const; + private: ConvAttributes conv_attrs_; mutable CudnnConvState s_; constexpr static auto kDefaultConvAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; }; - -Status SliceOutUnwantedOutputSection(cudaStream_t stream, - const void* input_data, - const std::vector& input_dims, - void* output_data, - const std::vector& output_dims, - std::vector starts, - const std::vector& ends, - const std::vector& axes, - size_t element_size); } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc index 8a795b5d903b9..e258d8d685533 100644 --- a/onnxruntime/core/providers/cuda/nn/conv_transpose.cc +++ b/onnxruntime/core/providers/cuda/nn/conv_transpose.cc @@ -94,11 +94,11 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ s_.y_dims = y_dims; if (w_dims_changed) - ORT_RETURN_IF_ERROR(s_.w_desc.Set(w_dims, CudnnTensor::GetDataType())); + ORT_RETURN_IF_ERROR(s_.filter_desc.Set(w_dims, CudnnTensor::GetDataType())); // Special case when there is a dim value of 0 in the shape. // Return only after we have cached the following for subsequent runs : - // 1) `w_dims` in the `w_desc` + // 1) `w_dims` in the `filter_desc` // 2) `y_dims` in s_.y_dims if (p.Y->Shape().Size() == 0) { return Status::OK(); @@ -138,7 +138,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ int algo_count = 1; CUDNN_RETURN_IF_ERROR(cudnnFindConvolutionBackwardDataAlgorithmEx( CudnnHandle(), - s_.w_desc, + s_.filter_desc, w_data, s_.x_tensor, x_data, @@ -184,7 +184,7 @@ Status ConvTranspose::DoConvTranspose(OpKernelContext* context, bool dynamic_ cudnnConvolutionBackwardData( CudnnHandle(), &alpha, - s_.w_desc, + s_.filter_desc, w_data, s_.x_tensor, x_data, diff --git a/onnxruntime/test/contrib_ops/fused_conv_test.cc b/onnxruntime/test/contrib_ops/fused_conv_test.cc deleted file mode 100644 index bce624a7233ad..0000000000000 --- a/onnxruntime/test/contrib_ops/fused_conv_test.cc +++ /dev/null @@ -1,140 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "gtest/gtest.h" -#include "test/providers/provider_test_utils.h" - -namespace onnxruntime { -namespace test { - -#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS) -using namespace std; - -struct ConvOpAndTestAttributes { - string auto_pad; - vector dilations; - int64_t group; - vector kernel_shape; - vector pads; - vector strides; - string activation; -}; - -static std::unordered_set excluded_providers = { - kCpuExecutionProvider, - kDnnlExecutionProvider, - kOpenVINOExecutionProvider, - kNupharExecutionProvider, - kVitisAIExecutionProvider, - kTensorrtExecutionProvider, - kNnapiExecutionProvider, - kRknpuExecutionProvider, - kDmlExecutionProvider, - kMIGraphXExecutionProvider, - kAclExecutionProvider, - kArmNNExecutionProvider, - kRocmExecutionProvider}; - -void TestConvOp(const ConvOpAndTestAttributes& attributes, const vector>& inputs, const vector>& input_shapes, const std::initializer_list& expected_output, const vector& expected_output_shape, bool weight_is_initializer = false, OpTester::ExpectResult expect_result = OpTester::ExpectResult::kExpectSuccess, const std::string& err_str = "") { - OpTester test("FusedConv", 1, onnxruntime::kMSDomain); - test.AddAttribute("group", attributes.group); - test.AddAttribute("kernel_shape", attributes.kernel_shape); - - if (!attributes.dilations.empty()) { - test.AddAttribute("dilations", attributes.dilations); - } - - // Only one of pads / auto_pad can be present - if (!attributes.pads.empty()) { - test.AddAttribute("pads", attributes.pads); - } else { - test.AddAttribute("auto_pad", attributes.auto_pad); - } - - if (!attributes.strides.empty()) { - test.AddAttribute("strides", attributes.strides); - } - - ORT_ENFORCE(!attributes.activation.empty(), "activation must be set"); - test.AddAttribute("activation", attributes.activation); - - const char* szNames[] = {"X", "W", "B", "Z"}; - test.AddInput(szNames[0], input_shapes[0], inputs[0]); - test.AddInput(szNames[1], input_shapes[1], inputs[1], weight_is_initializer); - if (inputs.size() >= 3) - test.AddInput(szNames[2], input_shapes[2], inputs[2]); - if (inputs.size() >= 4) - test.AddInput(szNames[3], input_shapes[3], inputs[3]); - test.AddOutput("Y", expected_output_shape, expected_output); - test.Run(expect_result, err_str, excluded_providers); -} - -TEST(FusedConvTest, Conv2D_Relu) { - ConvOpAndTestAttributes attrs = { - "", // auto_pad - vector{1, 1}, // dilations - 1, // group - vector{2, 2}, // kernel_shape - vector{0, 0, 0, 0}, // pads - vector{1, 1}, // strides - "Relu" // activation - }; - - vector X = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; - vector X_shape = {1, 1, 3, 3}; - vector W = {1.0f, 1.0f, 1.0f, 1.0f, -1.0f, -1.0f, -1.0f, -1.0f}; - vector W_shape = {2, 1, 2, 2}; - vector Y_shape = {1, 2, 2, 2}; - auto expected_vals = {12.0f, 16.0f, 24.0f, 28.0f, 0.0f, 0.0f, 0.0f, 0.0f}; - TestConvOp(attrs, {X, W}, {X_shape, W_shape}, expected_vals, Y_shape); -} - -TEST(FusedConvTest, Conv2D_Bias_Relu) { - ConvOpAndTestAttributes attrs = { - "", // auto_pad - vector{1, 1}, // dilations - 1, // group - vector{2, 2}, // kernel_shape - vector{0, 0, 0, 0}, // pads - vector{1, 1}, // strides - "Relu" // activation - }; - - vector X = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; - vector X_shape = {1, 1, 3, 3}; - vector W = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - vector W_shape = {2, 1, 2, 2}; - vector Y_shape = {1, 2, 2, 2}; - vector B = {1.0f, -1.0f}; - vector B_shape = {2}; - auto expected_vals = {13.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 27.0f}; - TestConvOp(attrs, {X, W, B}, {X_shape, W_shape, B_shape}, expected_vals, Y_shape); -} - -TEST(FusedConvTest, Conv2D_Bias_Z_Relu) { - ConvOpAndTestAttributes attrs = { - "", // auto_pad - vector{1, 1}, // dilations - 1, // group - vector{2, 2}, // kernel_shape - vector{0, 0, 0, 0}, // pads - vector{1, 1}, // strides - "Relu" // activation - }; - - vector X = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f}; - vector X_shape = {1, 1, 3, 3}; - vector W = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f}; - vector W_shape = {2, 1, 2, 2}; - vector Y_shape = {1, 2, 2, 2}; - vector B = {1.0f, -1.0f}; - vector B_shape = {2}; - vector Z = {-1.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 1.0f}; - vector Z_shape = {1, 2, 2, 2}; - auto expected_vals = {12.0f, 17.0f, 25.0f, 29.0f, 11.0f, 15.0f, 23.0f, 28.0f}; - TestConvOp(attrs, {X, W, B, Z}, {X_shape, W_shape, B_shape, Z_shape}, expected_vals, Y_shape); -} -#endif - -} // namespace test -} // namespace onnxruntime diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 9a01db9ee48b4..5ef67047f8daf 100644 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -523,48 +523,20 @@ TEST_F(GraphTransformationTests, FuseConvBNMulAddUnsqueeze) { } } -#if defined(USE_CUDA) && !defined(DISABLE_CONTRIB_OPS) -TEST_F(GraphTransformationTests, FuseCudaConvAddRelu) { - auto model_uri = MODEL_FOLDER "fusion/conv_add_relu.onnx"; - std::shared_ptr p_model; - ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); - Graph& graph = p_model->MainGraph(); - for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } - std::map op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 1); - ASSERT_TRUE(op_to_count["Relu"] == 1); - onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; - graph_transformation_mgr.Register(onnxruntime::make_unique(), TransformerLevel::Level2); - ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level2, *logger_)); - op_to_count = CountOpsInGraph(graph); - ASSERT_TRUE(op_to_count["Add"] == 0); - ASSERT_TRUE(op_to_count["Relu"] == 0); -} -#endif - #ifndef DISABLE_CONTRIB_OPS TEST_F(GraphTransformationTests, FuseConvActivation) { -#ifdef USE_CUDA - std::unordered_map, std::string> model_to_op_name{{ORT_TSTR("fusion/conv_relu.onnx"), "Relu"}}; -#else std::unordered_map, std::string> model_to_op_name{{ORT_TSTR("fusion/conv_relu.onnx"), "Relu"}, {ORT_TSTR("fusion/conv_clip.onnx"), "Clip"}, {ORT_TSTR("fusion/conv_sigmoid.onnx"), "Sigmoid"}, {ORT_TSTR("fusion/conv_tanh.onnx"), "Tanh"}, {ORT_TSTR("fusion/conv_leakyrelu.onnx"), "LeakyRelu"}}; -#endif + for (const auto& model : model_to_op_name) { auto model_uri = MODEL_FOLDER + model.first; std::shared_ptr p_model; ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); Graph& graph = p_model->MainGraph(); -#ifdef USE_CUDA - for (auto& node : p_model->MainGraph().Nodes()) { - node.SetExecutionProviderType(kCudaExecutionProvider); - } -#endif + std::map op_to_count = CountOpsInGraph(graph); ASSERT_TRUE(op_to_count[model.second] >= 1); diff --git a/onnxruntime/test/testdata/transform/fusion/conv_add_relu.onnx b/onnxruntime/test/testdata/transform/fusion/conv_add_relu.onnx deleted file mode 100644 index 85ac455090d93..0000000000000 --- a/onnxruntime/test/testdata/transform/fusion/conv_add_relu.onnx +++ /dev/null @@ -1,38 +0,0 @@ -:¾ - -X -W -BC"Conv - -SY"Relu - -C -AS"AddgraphZ -X - - - - -Z -W - - - - -Z -B - - -Z -A - - - - -b -Y - - - - -B \ No newline at end of file diff --git a/tools/ci_build/amd_hipify.py b/tools/ci_build/amd_hipify.py index 4816783f25ced..56a8ff8d50cae 100644 --- a/tools/ci_build/amd_hipify.py +++ b/tools/ci_build/amd_hipify.py @@ -66,8 +66,7 @@ 'conv_transpose_with_dynamic_pads.h', 'cuda_contrib_kernels.cc', 'cuda_contrib_kernels.h', - 'inverse.cc', - 'fused_conv.cc' + 'inverse.cc' ] provider_excluded_files = [