From bea5080a9e2fc0cfb7fbae130e8ce5b0d25d02ac Mon Sep 17 00:00:00 2001 From: linzs148 <1483100349@qq.com> Date: Tue, 30 Apr 2024 14:55:40 +0000 Subject: [PATCH] refine functions --- oneflow/core/device/cudnn_conv_util.cpp | 288 ++++++++++---------- oneflow/core/device/cudnn_conv_util.h | 55 +--- oneflow/user/kernels/conv_cudnn_kernels.cpp | 61 +++-- 3 files changed, 186 insertions(+), 218 deletions(-) diff --git a/oneflow/core/device/cudnn_conv_util.cpp b/oneflow/core/device/cudnn_conv_util.cpp index fbdcebf2ec8..39f8e7e83a4 100644 --- a/oneflow/core/device/cudnn_conv_util.cpp +++ b/oneflow/core/device/cudnn_conv_util.cpp @@ -335,6 +335,90 @@ CudnnConvArgs::CudnnConvArgs(const user_op::KernelComputeContext& ctx, DataType params.max_ws_size = max_workspace_size; } +cudnn_frontend::Tensor GetTensorDescriptor(const user_op::Tensor* t, const int64_t id) { + auto dim = t->shape_view(); + auto stride = t->stride(); + return cudnn_frontend::TensorBuilder() + .setDim(dim.size(), dim.data()) + .setStride(stride.size(), stride.data()) + .setId(id) + .setAlignment(32) + .setDataType(GetCudnnDataType(t->data_type())) + .build(); +} + +cudnn_frontend::Tensor GetTensorDescriptor(const user_op::TensorDesc& t, const int64_t id) { + auto dim = t.shape(); + auto stride = t.stride(); + return cudnn_frontend::TensorBuilder() + .setDim(dim.size(), dim.data()) + .setStride(stride.size(), stride.data()) + .setId(id) + .setAlignment(32) + .setDataType(GetCudnnDataType(t.data_type())) + .build(); +} + +cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::InferContext& ctx, + cudnnDataType_t data_type) { + if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) { + data_type = CUDNN_DATA_FLOAT; + } + + std::vector padding; + const auto& padding_before = ctx.Attr>("padding_before"); + copy(padding_before.begin(), padding_before.end(), back_inserter(padding)); + + std::vector stride; + const auto& strides = ctx.Attr>("strides"); + copy(strides.begin(), strides.end(), back_inserter(stride)); + + std::vector dilation; + const auto& dilation_rate = ctx.Attr>("dilation_rate"); + copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation)); + + uint64_t ndim = stride.size(); + return cudnn_frontend::ConvDescBuilder() + .setDataType(data_type) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(ndim) + .setStrides(ndim, stride.data()) + .setPrePadding(ndim, padding.data()) + .setPostPadding(ndim, padding.data()) + .setDilation(ndim, dilation.data()) + .build(); +} + +cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::KernelComputeContext& ctx, + cudnnDataType_t data_type) { + if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) { + data_type = CUDNN_DATA_FLOAT; + } + + std::vector padding; + const auto& padding_before = ctx.Attr>("padding_before"); + copy(padding_before.begin(), padding_before.end(), back_inserter(padding)); + + std::vector stride; + const auto& strides = ctx.Attr>("strides"); + copy(strides.begin(), strides.end(), back_inserter(stride)); + + std::vector dilation; + const auto& dilation_rate = ctx.Attr>("dilation_rate"); + copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation)); + + uint64_t ndim = stride.size(); + return cudnn_frontend::ConvDescBuilder() + .setDataType(data_type) + .setMathMode(CUDNN_CROSS_CORRELATION) + .setNDims(ndim) + .setStrides(ndim, stride.data()) + .setPrePadding(ndim, padding.data()) + .setPostPadding(ndim, padding.data()) + .setDilation(ndim, dilation.data()) + .build(); +} + CudnnConvArgsV8::CudnnConvArgsV8(const user_op::InferContext& ctx, const user_op::TensorDesc& x, const user_op::TensorDesc& y, const user_op::TensorDesc& w) : xdesc(GetTensorDescriptor(x, 'x')), @@ -443,30 +527,6 @@ cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvReso args.wdesc.Get(), algo, sz); } -void RunSingleConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, - user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w, user_op::Tensor* b, - const CudnnConvArgsV8& args) { - std::string tag; - auto configs = - GetConfigs(handle, desc, args.xdesc, args.ydesc, args.wdesc, args.cdesc, args.beta, tag); - TryConfigs(handle, x, y, w, b, configs, tag); -} - -cudnn_frontend::EngineConfigList GetConfigs(const cudnnHandle_t handle, - const cudnnBackendDescriptorType_t desc, - const cudnn_frontend::Tensor& xdesc, - const cudnn_frontend::Tensor& ydesc, - const cudnn_frontend::Tensor& wdesc, - const cudnn_frontend::ConvDesc& cdesc, float beta, - std::string& tag) { - auto op_graph = BuildConvOpGraph(handle, desc, xdesc, ydesc, wdesc, cdesc, beta); - tag = op_graph.getTag(); - auto sources = GetGeneratorSources(desc); - cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data()); - auto configs = generator.generate_engine_config(op_graph); - return configs; -} - cudnn_frontend::OperationGraph BuildConvOpGraph(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, const cudnn_frontend::Tensor& xdesc, @@ -488,88 +548,20 @@ cudnn_frontend::OperationGraph BuildConvOpGraph(const cudnnHandle_t handle, return op_graph; } -cudnn_frontend::Tensor GetTensorDescriptor(const user_op::Tensor* t, const int64_t id) { - auto dim = t->shape_view(); - auto stride = t->stride(); - return cudnn_frontend::TensorBuilder() - .setDim(dim.size(), dim.data()) - .setStride(stride.size(), stride.data()) - .setId(id) - .setAlignment(32) - .setDataType(GetCudnnDataType(t->data_type())) - .build(); -} - -cudnn_frontend::Tensor GetTensorDescriptor(const user_op::TensorDesc& t, const int64_t id) { - auto dim = t.shape(); - auto stride = t.stride(); - return cudnn_frontend::TensorBuilder() - .setDim(dim.size(), dim.data()) - .setStride(stride.size(), stride.data()) - .setId(id) - .setAlignment(32) - .setDataType(GetCudnnDataType(t.data_type())) - .build(); -} - -cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::InferContext& ctx, - cudnnDataType_t data_type) { - if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) { - data_type = CUDNN_DATA_FLOAT; - } - - std::vector padding; - const auto& padding_before = ctx.Attr>("padding_before"); - copy(padding_before.begin(), padding_before.end(), back_inserter(padding)); - - std::vector stride; - const auto& strides = ctx.Attr>("strides"); - copy(strides.begin(), strides.end(), back_inserter(stride)); - - std::vector dilation; - const auto& dilation_rate = ctx.Attr>("dilation_rate"); - copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation)); - - uint64_t ndim = stride.size(); - return cudnn_frontend::ConvDescBuilder() - .setDataType(data_type) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(ndim) - .setStrides(ndim, stride.data()) - .setPrePadding(ndim, padding.data()) - .setPostPadding(ndim, padding.data()) - .setDilation(ndim, dilation.data()) - .build(); -} - -cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::KernelComputeContext& ctx, - cudnnDataType_t data_type) { - if (data_type == CUDNN_DATA_HALF || data_type == CUDNN_DATA_BFLOAT16) { - data_type = CUDNN_DATA_FLOAT; - } - - std::vector padding; - const auto& padding_before = ctx.Attr>("padding_before"); - copy(padding_before.begin(), padding_before.end(), back_inserter(padding)); - - std::vector stride; - const auto& strides = ctx.Attr>("strides"); - copy(strides.begin(), strides.end(), back_inserter(stride)); - - std::vector dilation; - const auto& dilation_rate = ctx.Attr>("dilation_rate"); - copy(dilation_rate.begin(), dilation_rate.end(), back_inserter(dilation)); - - uint64_t ndim = stride.size(); - return cudnn_frontend::ConvDescBuilder() - .setDataType(data_type) - .setMathMode(CUDNN_CROSS_CORRELATION) - .setNDims(ndim) - .setStrides(ndim, stride.data()) - .setPrePadding(ndim, padding.data()) - .setPostPadding(ndim, padding.data()) - .setDilation(ndim, dilation.data()) - .build(); +void FilterEngineConfigs(cudnn_frontend::EngineConfigList& from, + cudnn_frontend::EngineConfigList& to, bool deterministic) { + auto filter = [=](cudnnBackendDescriptor_t c) { + if (deterministic) { + if (cudnn_frontend::hasNumericalNote(c)) { + return true; + } + } + if (cudnn_frontend::hasNumericalNote(c)) { + return true; + } + return false; + }; + cudnn_frontend::filter(from, to, filter); } std::vector GetGeneratorSources( @@ -579,7 +571,7 @@ std::vector GetGeneratorSources( .cudnn_conf() .cudnn_conv_use_deterministic_algo_only(); bool heuristic = ParseBooleanFromEnv("ONEFLOW_CUDNN_USE_HEURISTIC_MODE_B", false); - auto heur_mode = heuristic ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_A; + auto heur_mode = heuristic ? CUDNN_HEUR_MODE_B : CUDNN_HEUR_MODE_INSTANT; // Method for engine config generator based on heuristics const auto heurgen_method = [deterministic, @@ -610,20 +602,43 @@ std::vector GetGeneratorSources( return sources; } -void FilterEngineConfigs(cudnn_frontend::EngineConfigList& from, - cudnn_frontend::EngineConfigList& to, bool deterministic) { - auto filter = [=](cudnnBackendDescriptor_t c) { - if (deterministic) { - if (cudnn_frontend::hasNumericalNote(c)) { - return true; - } - } - if (cudnn_frontend::hasNumericalNote(c)) { - return true; - } +cudnn_frontend::EngineConfigList CudnnFrontendGetConfigs(const cudnnHandle_t handle, + const cudnnBackendDescriptorType_t desc, + const cudnn_frontend::Tensor& xdesc, + const cudnn_frontend::Tensor& ydesc, + const cudnn_frontend::Tensor& wdesc, + const cudnn_frontend::ConvDesc& cdesc, + float beta, std::string& tag) { + auto op_graph = BuildConvOpGraph(handle, desc, xdesc, ydesc, wdesc, cdesc, beta); + tag = op_graph.getTag(); + auto sources = GetGeneratorSources(desc); + cudnn_frontend::EngineConfigGenerator generator(sources.size(), sources.data()); + auto configs = generator.generate_engine_config(op_graph); + return configs; +} + +bool PlanErrataException(const cudnnHandle_t handle, const std::string& executionPlanTag) { + static nlohmann::json errata_json_handle; + static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, ""); + if (!has_json) { return false; - }; - cudnn_frontend::filter(from, to, filter); + } else { + return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle, + []() { return true; }); + } +} + +void RunConvPlan(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y, + user_op::Tensor* w, user_op::Tensor* buf, + const cudnn_frontend::ExecutionPlan& plan) { + void* data[] = {x->mut_dptr(), y->mut_dptr(), w->mut_dptr()}; + int64_t ids[] = {'x', 'y', 'w'}; + auto variantPack = cudnn_frontend::VariantPackBuilder() + .setWorkspacePointer(buf->mut_dptr()) + .setDataPointers(3, data) + .setUids(3, ids) + .build(); + OF_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); } void TryConfigs(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y, @@ -642,6 +657,15 @@ void TryConfigs(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* } } +void CudnnFrontendRunConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, + user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w, + user_op::Tensor* b, const CudnnConvArgsV8& args) { + std::string tag; + auto configs = CudnnFrontendGetConfigs(handle, desc, args.xdesc, args.ydesc, args.wdesc, + args.cdesc, args.beta, tag); + TryConfigs(handle, x, y, w, b, configs, tag); +} + size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle, cudnn_frontend::EngineConfigList& configs, const std::string& tag) { @@ -658,30 +682,6 @@ size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle, return 1L; } -bool PlanErrataException(const cudnnHandle_t handle, const std::string& executionPlanTag) { - static nlohmann::json errata_json_handle; - static bool has_json = cudnn_frontend::load_from_config(errata_json_handle, ""); - if (!has_json) { - return false; - } else { - return cudnn_frontend::check_errata(errata_json_handle, executionPlanTag, handle, - []() { return true; }); - } -} - -void RunConvPlan(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y, - user_op::Tensor* w, user_op::Tensor* buf, - const cudnn_frontend::ExecutionPlan& plan) { - void* data[] = {x->mut_dptr(), y->mut_dptr(), w->mut_dptr()}; - int64_t ids[] = {'x', 'y', 'w'}; - auto variantPack = cudnn_frontend::VariantPackBuilder() - .setWorkspacePointer(buf->mut_dptr()) - .setDataPointers(3, data) - .setUids(3, ids) - .build(); - OF_CUDNN_CHECK(cudnnBackendExecute(handle, plan.get_raw_desc(), variantPack.get_raw_desc())); -} - template<> struct CudnnConvAlgorithmSearch { using perf_t = cudnnConvolutionFwdAlgoPerf_t; diff --git a/oneflow/core/device/cudnn_conv_util.h b/oneflow/core/device/cudnn_conv_util.h index b804f76b884..90946a61a9e 100644 --- a/oneflow/core/device/cudnn_conv_util.h +++ b/oneflow/core/device/cudnn_conv_util.h @@ -186,55 +186,22 @@ cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvReso cudnnStatus_t GetCudnnConvWorkspaceSize(const CudnnConvArgs& args, CudnnConvResource* res, cudnnConvolutionBwdFilterAlgo_t algo, size_t* sz); -void RunSingleConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, - user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w, user_op::Tensor* b, - const CudnnConvArgsV8& args); - -cudnn_frontend::EngineConfigList GetConfigs(const cudnnHandle_t handle, - const cudnnBackendDescriptorType_t desc, - const cudnn_frontend::Tensor& xdesc, - const cudnn_frontend::Tensor& ydesc, - const cudnn_frontend::Tensor& wdesc, - const cudnn_frontend::ConvDesc& cdesc, float beta, - std::string& tag); - -cudnn_frontend::OperationGraph BuildConvOpGraph(const cudnnHandle_t handle, - const cudnnBackendDescriptorType_t desc, - const cudnn_frontend::Tensor& xdesc, - const cudnn_frontend::Tensor& ydesc, - const cudnn_frontend::Tensor& wdesc, - const cudnn_frontend::ConvDesc& cdesc, float beta); - -cudnn_frontend::Tensor GetTensorDescriptor(const user_op::Tensor* t, const int64_t id); - -cudnn_frontend::Tensor GetTensorDescriptor(const user_op::TensorDesc& t, const int64_t id); - -cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::InferContext& ctx, - cudnnDataType_t data_type); - -cudnn_frontend::ConvDesc GetConvDescriptor(const user_op::KernelComputeContext& ctx, - cudnnDataType_t data_type); - -std::vector GetGeneratorSources( - const cudnnBackendDescriptorType_t desc); - -void FilterEngineConfigs(cudnn_frontend::EngineConfigList& from, - cudnn_frontend::EngineConfigList& to, bool deterministic); - -void TryConfigs(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y, - user_op::Tensor* w, user_op::Tensor* buf, cudnn_frontend::EngineConfigList& configs, - const std::string& tag); +cudnn_frontend::EngineConfigList CudnnFrontendGetConfigs(const cudnnHandle_t handle, + const cudnnBackendDescriptorType_t desc, + const cudnn_frontend::Tensor& xdesc, + const cudnn_frontend::Tensor& ydesc, + const cudnn_frontend::Tensor& wdesc, + const cudnn_frontend::ConvDesc& cdesc, + float beta, std::string& tag); + +void CudnnFrontendRunConv(const cudnnHandle_t handle, const cudnnBackendDescriptorType_t desc, + user_op::Tensor* x, user_op::Tensor* y, user_op::Tensor* w, + user_op::Tensor* b, const CudnnConvArgsV8& args); size_t GetCudnnConvWorkspaceSizeV8(const cudnnHandle_t handle, cudnn_frontend::EngineConfigList& configs, const std::string& tag); -bool PlanErrataException(const cudnnHandle_t handle, const std::string& executionPlanTag); - -void RunConvPlan(const cudnnHandle_t handle, user_op::Tensor* x, user_op::Tensor* y, - user_op::Tensor* w, user_op::Tensor* buf, - const cudnn_frontend::ExecutionPlan& plan); - template perf_t FindCudnnConvAlgorithm(CudnnConvArgs* args); diff --git a/oneflow/user/kernels/conv_cudnn_kernels.cpp b/oneflow/user/kernels/conv_cudnn_kernels.cpp index db116a103a5..f8475c30202 100644 --- a/oneflow/user/kernels/conv_cudnn_kernels.cpp +++ b/oneflow/user/kernels/conv_cudnn_kernels.cpp @@ -305,8 +305,8 @@ class ConvGpuKernelV8 final : public user_op::OpKernel, public user_op::CudaGrap // trigger conv compute auto handle = ctx->stream()->As()->cudnn_handle(); - RunSingleConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, input, output, - weight, buffer, args); + CudnnFrontendRunConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, input, + output, weight, buffer, args); // process bias auto bias = ctx->Tensor4ArgNameAndIndex("bias", 0); @@ -331,24 +331,25 @@ class ConvGpuKernelV8 final : public user_op::OpKernel, public user_op::CudaGrap } }; -#define REGISTER_CONV_KERNEL_V8(op_name, ndims) \ - REGISTER_USER_KERNEL(#op_name) \ - .SetCreateFn>() \ - .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA \ - && user_op::HobEnvBool("ONEFLOW_KERNEL_ENABLE_CUDNN_V8", false)) \ - .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ - auto& input = ctx->InputTensorDesc("in", 0); \ - auto& output = ctx->InputTensorDesc("out", 0); \ - auto& weight = ctx->InputTensorDesc("weight", 0); \ - CudnnConvArgsV8 args(*ctx, input, output, weight); \ - auto handle = Singleton::Get()->Get(); \ - std::string tag; \ - auto configs = GetConfigs(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, \ - args.xdesc, args.ydesc, args.wdesc, args.cdesc, args.beta, tag); \ - size_t workspace_size = GetCudnnConvWorkspaceSizeV8(handle, configs, tag); \ - Singleton::Get()->Put(handle); \ - return workspace_size; \ - }) \ +#define REGISTER_CONV_KERNEL_V8(op_name, ndims) \ + REGISTER_USER_KERNEL(#op_name) \ + .SetCreateFn>() \ + .SetIsMatchedHob(user_op::HobDeviceType() == DeviceType::kCUDA \ + && user_op::HobEnvBool("ONEFLOW_KERNEL_ENABLE_CUDNN_V8", false)) \ + .SetInferTmpSizeFn([](user_op::InferContext* ctx) -> size_t { \ + auto& input = ctx->InputTensorDesc("in", 0); \ + auto& output = ctx->InputTensorDesc("out", 0); \ + auto& weight = ctx->InputTensorDesc("weight", 0); \ + CudnnConvArgsV8 args(*ctx, input, output, weight); \ + auto handle = Singleton::Get()->Get(); \ + std::string tag; \ + auto configs = CudnnFrontendGetConfigs( \ + handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, args.xdesc, \ + args.ydesc, args.wdesc, args.cdesc, args.beta, tag); \ + size_t workspace_size = GetCudnnConvWorkspaceSizeV8(handle, configs, tag); \ + Singleton::Get()->Put(handle); \ + return workspace_size; \ + }) \ .SetPriority(user_op::kKernelPriorityOptimized); REGISTER_CONV_KERNEL_V8(conv1d, 1); @@ -457,8 +458,8 @@ class ConvDataGradGpuKernelV8 final : public user_op::OpKernel, public user_op:: // trigger conv compute auto handle = ctx->stream()->As()->cudnn_handle(); - RunSingleConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, input_diff, - output_diff, weight, buffer, args); + CudnnFrontendRunConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, + input_diff, output_diff, weight, buffer, args); } bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, @@ -481,9 +482,9 @@ REGISTER_USER_KERNEL("conv_data_grad") CudnnConvArgsV8 args(*ctx, input_diff, output_diff, weight); auto handle = Singleton::Get()->Get(); std::string tag; - auto configs = - GetConfigs(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, - args.xdesc, args.ydesc, args.wdesc, args.cdesc, args.beta, tag); + auto configs = CudnnFrontendGetConfigs( + handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, args.xdesc, + args.ydesc, args.wdesc, args.cdesc, args.beta, tag); size_t workspace_size = GetCudnnConvWorkspaceSizeV8(handle, configs, tag); Singleton::Get()->Put(handle); return workspace_size; @@ -582,8 +583,8 @@ class ConvFilterGradGpuKernelV8 final : public user_op::OpKernel, public user_op // trigger conv compute auto handle = ctx->stream()->As()->cudnn_handle(); - RunSingleConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, input, - output_diff, weight_diff, buffer, args); + CudnnFrontendRunConv(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, + input, output_diff, weight_diff, buffer, args); } bool IsCudaGraphSupported(user_op::KernelInitContext* ctx, @@ -606,9 +607,9 @@ REGISTER_USER_KERNEL("conv_filter_grad") CudnnConvArgsV8 args(*ctx, input, output_diff, weight_diff); auto handle = Singleton::Get()->Get(); std::string tag; - auto configs = - GetConfigs(handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, - args.xdesc, args.ydesc, args.wdesc, args.cdesc, args.beta, tag); + auto configs = CudnnFrontendGetConfigs( + handle, CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, args.xdesc, + args.ydesc, args.wdesc, args.cdesc, args.beta, tag); size_t workspace_size = GetCudnnConvWorkspaceSizeV8(handle, configs, tag); Singleton::Get()->Put(handle); return workspace_size;