From 2d1fd997a6fcf801ac4ca7a63e0656ef1cb9f488 Mon Sep 17 00:00:00 2001 From: bibek <108366729+bghimireamd@users.noreply.github.com> Date: Thu, 24 Oct 2024 12:40:46 -0500 Subject: [PATCH] Enable tuning in Batch norm CK solver (#3326) --- src/batchnorm/problem_description.cpp | 6 + .../miopen/batchnorm/invoke_params.hpp | 4 +- .../miopen/batchnorm/problem_description.hpp | 79 ++++- src/include/miopen/batchnorm/solvers.hpp | 211 +++++++++++- .../miopen/conv/problem_description.hpp | 35 -- .../miopen/problem_description_base.hpp | 35 ++ src/ocl/batchnormocl.cpp | 2 +- src/solver/batchnorm/backward_ck.cpp | 299 ++++++++++++++---- src/solver/batchnorm/forward_inference_ck.cpp | 217 +++++++++---- .../batchnorm/forward_per_activation.cpp | 2 +- .../batchnorm/forward_spatial_multiple.cpp | 2 +- .../batchnorm/forward_spatial_single.cpp | 2 +- src/solver/batchnorm/forward_training_ck.cpp | 282 ++++++++++++++--- .../conv_ck_igemm_fwd_bias_activ_fused.cpp | 1 + .../mha/mha_ck_fa_v2_solver_forward.cpp | 9 +- 15 files changed, 965 insertions(+), 221 deletions(-) diff --git a/src/batchnorm/problem_description.cpp b/src/batchnorm/problem_description.cpp index d1458ef9fb..ac63fdf73e 100644 --- a/src/batchnorm/problem_description.cpp +++ b/src/batchnorm/problem_description.cpp @@ -138,6 +138,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const ss << "fp32" << static_cast(IsFp32()); ss << "fp64" << static_cast(IsFp64()); ss << "fbf16" << static_cast(IsBfp16()); + ss << "fmix" << static_cast(IsMix()); ss << "c" << c; } else @@ -154,6 +155,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const ss << "fp32" << static_cast(IsFp32()); ss << "fp64" << static_cast(IsFp64()); ss << "fbf16" << static_cast(IsBfp16()); + ss << "fmix" << static_cast(IsMix()); ss << "single" << static_cast(single); ss << "n" << n; ss << "c" << c; @@ -172,6 +174,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const ss << "fp32" << static_cast(IsFp32()); ss << "fp64" << static_cast(IsFp64()); ss << "fbf16" << static_cast(IsBfp16()); + ss << "fmix" << static_cast(IsMix()); ss << "gx" << xgridsize; ss << "gy" << ygridsize; ss << "lx" << xlocalsize; @@ -201,6 +204,7 @@ NetworkConfig ProblemDescription::MakeForwardInferenceNetworkConfig() const ss << "fp32" << static_cast(IsFp32()); ss << "fp64" << static_cast(IsFp64()); ss << "fbf16" << static_cast(IsBfp16()); + ss << "fmix" << static_cast(IsMix()); ss << "mode" << bn_mode; ss << "HWdims" << in_cstride; ss << "C" << c; @@ -308,6 +312,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const ss << "fp32" << static_cast(IsFp32()); ss << "fp64" << static_cast(IsFp64()); ss << "fbf16" << static_cast(IsBfp16()); + ss << "fmix" << static_cast(IsMix()); ss << "single" << static_cast(single); ss << "gcn" << ldsgcn; } @@ -330,6 +335,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const ss << "fp32" << static_cast(IsFp32()); ss << "fp64" << static_cast(IsFp64()); ss << "fbf16" << static_cast(IsBfp16()); + ss << "fmix" << static_cast(IsMix()); ss << "nhw" << in_nhw; } ss << "layout" << in_layout; diff --git a/src/include/miopen/batchnorm/invoke_params.hpp b/src/include/miopen/batchnorm/invoke_params.hpp index efab2f3ddb..5c3df3cbcf 100644 --- a/src/include/miopen/batchnorm/invoke_params.hpp +++ b/src/include/miopen/batchnorm/invoke_params.hpp @@ -32,9 +32,9 @@ namespace miopen { namespace batchnorm { -struct InvokeParams : public miopen::InvokeParams +struct FwdTrainInvokeParams : public miopen::InvokeParams { - InvokeParams() = default; + FwdTrainInvokeParams() = default; ConstData_t x = nullptr; Data_t y = nullptr; diff --git a/src/include/miopen/batchnorm/problem_description.hpp b/src/include/miopen/batchnorm/problem_description.hpp index d28e91adfd..8054111128 100644 --- a/src/include/miopen/batchnorm/problem_description.hpp +++ b/src/include/miopen/batchnorm/problem_description.hpp @@ -52,7 +52,12 @@ struct ProblemDescriptionTag { }; -struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, ProblemDescriptionTag +struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, + ProblemDescriptionTag +#if MIOPEN_ENABLE_SQLITE + , + SQLiteSerializable +#endif { // Forward Training ProblemDescription(miopenBatchNormMode_t bn_mode_, @@ -218,10 +223,49 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob bool IsFp64() const { return xDesc.GetType() == miopenDouble; } bool IsFp32() const { return xDesc.GetType() == miopenFloat; } bool IsFp16() const { return xDesc.GetType() == miopenHalf; } + bool IsMix() const + { + return xDesc.GetType() == miopenHalf && sMeanDesc.GetType() == miopenFloat; + } bool IsBfp16() const { return xDesc.GetType() == miopenBFloat16; } + void Serialize(std::ostream& stream) const { stream << MakeNetworkConfig().ToString(); } + NetworkConfig MakeNetworkConfig() const override; + template + static void Visit(Self&& self, std::function f) + { + // The column names match the driver command line argument names + f(self.spatial_dim, "spatial_dim"); + f(self.GetBatchSize(), "batchsize"); + f(self.GetChannel(), "in_channels"); + f(self.GetHeight(), "in_h"); + f(self.GetWidth(), "in_w"); + f(self.GetDepth(), "in_d"); + + f(self.resultsave, "resultsave"); + f(self.resultrunning, "resultrunning"); + f(self.useSaved, "useSaved"); + } + + template + static void Visit(Self&& self, std::function f) + { + f(self.ComputeInLayout(), "layout"); + f(self.GetDirectionStr(), "direction"); + f(GetDataTypeName(self.xDesc.GetType()), "data_type"); + f(self.GetModeStr(), "mode"); + } + + template + static void VisitAll(Self&& self, const Visitor& f) + { + Visit(std::forward(self), [&](int64_t value, std::string name) { f(value, name); }); + Visit(std::forward(self), + [&](std::string value, std::string name) { f(value, name); }); + } + // This declaration marks batchnorm as a primitive with tuning enabled. // Any tunable solver would be able pick it and fetch a db instance in ExecutePrimitive. // It has to be discoverable via ADL from problem description. @@ -267,6 +311,39 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob std::string ComputeInLayout() const { return ComputeLayout(xDesc); } std::string ComputeOutLayout() const { return ComputeLayout(yOrDyDesc); } std::string ComputeDinLayout() const { return ComputeLayout(dxDesc); } + + size_t GetSpatialDims() const { return spatial_dim; } + + std::size_t GetBatchSize() const { return GetN5(GetSpatialDims(), xDesc.GetLengths()); } + std::size_t GetChannel() const { return GetC5(GetSpatialDims(), xDesc.GetLengths()); } + std::size_t GetHeight() const { return GetH5(GetSpatialDims(), xDesc.GetLengths()); } + std::size_t GetWidth() const { return GetW5(GetSpatialDims(), xDesc.GetLengths()); } + std::size_t GetDepth() const { return GetD5(GetSpatialDims(), xDesc.GetLengths()); } + + std::string GetDirectionStr() const + { + std::string s; + + switch(direction) + { + case Direction::ForwardInference: return "Inf"; ; + case Direction::ForwardTraining: return "Trn"; + case Direction::Backward: return "Bwd"; + default: MIOPEN_THROW(miopenStatusInvalidValue, "Wrong Batchnorm Direction provided"); + } + + return s; + } + + std::string GetModeStr() const + { + switch(bn_mode) + { + case miopenBNPerActivation: return "0"; + case miopenBNSpatial: return "1"; + default: MIOPEN_THROW(miopenStatusInvalidValue, "Wrong Batchnorm Mode provided"); + } + } }; } // namespace batchnorm diff --git a/src/include/miopen/batchnorm/solvers.hpp b/src/include/miopen/batchnorm/solvers.hpp index 7edba36b49..947976f403 100644 --- a/src/include/miopen/batchnorm/solvers.hpp +++ b/src/include/miopen/batchnorm/solvers.hpp @@ -44,6 +44,11 @@ namespace batchnorm { using BatchnormSolver = NonTunableSolverBase; +template +using BatchNormTunableSolver = + TunableSolverMixin; +; + struct BnFwdTrainingSpatialSingle final : BatchnormSolver { const std::string& SolverDbId() const override @@ -132,34 +137,210 @@ struct BnFwdInference final : BatchnormSolver const miopen::batchnorm::ProblemDescription& problem) const override; }; -struct BnCKFwdInference final : BatchnormSolver +struct PerformanceConfigBnCKFwdInference : PerfConfigBase +{ + int index; + std::string kernel_id; + std::vector valid_kernels; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdInference(int idx, std::string kernl_id) + : index(idx), kernel_id(kernl_id) + { + } + PerformanceConfigBnCKFwdInference() : PerformanceConfigBnCKFwdInference(0, "") {} + PerformanceConfigBnCKFwdInference(bool) : PerformanceConfigBnCKFwdInference(0, "") {} + MIOPEN_INTERNALS_EXPORT void + HeuristicInit(const miopen::batchnorm::ProblemDescription& problem_desc); + MIOPEN_INTERNALS_EXPORT bool + SetNextValue(const miopen::batchnorm::ProblemDescription& problem_desc); + MIOPEN_INTERNALS_EXPORT bool IsValidValue() const; + MIOPEN_INTERNALS_EXPORT bool + IsValid(const ExecutionContext&, + const miopen::batchnorm::ProblemDescription& problem_desc) const; + + template + static void Visit(Self&& s, F f) + { + f(s.kernel_id, "kernel_id"); + } + MIOPEN_INTERNALS_EXPORT bool operator==(const PerformanceConfigBnCKFwdInference& other) const; + +private: + template + void Init(const miopen::batchnorm::ProblemDescription&); + template + bool CheckIsSupportCKArgs(const miopen::batchnorm::ProblemDescription&) const; +}; + +struct BnCKFwdInference final : BatchNormTunableSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& context, - const miopen::batchnorm::ProblemDescription& problem) const override; - ConvSolution GetSolution(const ExecutionContext& context, - const miopen::batchnorm::ProblemDescription& problem) const override; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdInference GetDefaultPerformanceConfig( + const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc) const override; + MIOPEN_INTERNALS_EXPORT bool + IsValidPerformanceConfig(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKFwdInference& config) const override; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdInference + Search(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const AnyInvokeParams& invoke_ctx) const override; + MIOPEN_INTERNALS_EXPORT bool + IsApplicable(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc) const override; + MIOPEN_INTERNALS_EXPORT ConvSolution + GetSolution(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKFwdInference& config) const override; +}; + +struct PerformanceConfigBnCKBwdBackward : PerfConfigBase +{ + int index; + std::string kernel_id; + std::vector valid_kernels; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKBwdBackward(int idx, std::string kernl_id) + : index(idx), kernel_id(kernl_id) + { + } + PerformanceConfigBnCKBwdBackward() : PerformanceConfigBnCKBwdBackward(0, "") {} + PerformanceConfigBnCKBwdBackward(bool) : PerformanceConfigBnCKBwdBackward(0, "") {} + MIOPEN_INTERNALS_EXPORT void + HeuristicInit(const miopen::batchnorm::ProblemDescription& problem_desc); + MIOPEN_INTERNALS_EXPORT bool + SetNextValue(const miopen::batchnorm::ProblemDescription& problem_desc); + MIOPEN_INTERNALS_EXPORT bool IsValidValue() const; + MIOPEN_INTERNALS_EXPORT bool + IsValid(const ExecutionContext&, + const miopen::batchnorm::ProblemDescription& problem_desc) const; + + template + static void Visit(Self&& s, F f) + { + f(s.kernel_id, "kernel_id"); + } + MIOPEN_INTERNALS_EXPORT bool operator==(const PerformanceConfigBnCKBwdBackward& other) const; + +private: + template + void Init(const miopen::batchnorm::ProblemDescription&); + template + bool CheckIsSupportCKArgs(const miopen::batchnorm::ProblemDescription&) const; }; -struct BnCKBwdBackward final : BatchnormSolver +struct BnCKBwdBackward final : BatchNormTunableSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& context, - const miopen::batchnorm::ProblemDescription& problem) const override; - ConvSolution GetSolution(const ExecutionContext& context, - const miopen::batchnorm::ProblemDescription& problem) const override; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKBwdBackward GetDefaultPerformanceConfig( + const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc) const override; + MIOPEN_INTERNALS_EXPORT bool + IsValidPerformanceConfig(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKBwdBackward& config) const override; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKBwdBackward + Search(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const AnyInvokeParams& invoke_ctx) const override; + MIOPEN_INTERNALS_EXPORT bool + IsApplicable(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc) const override; + MIOPEN_INTERNALS_EXPORT ConvSolution + GetSolution(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKBwdBackward& config) const override; +}; + +struct PerformanceConfigBnCKFwdTraining : PerfConfigBase +{ + int index; + std::string kernel_id; + std::vector valid_kernels; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdTraining(int idx, std::string kernl_id) + : index(idx), kernel_id(kernl_id) + { + } + PerformanceConfigBnCKFwdTraining() : PerformanceConfigBnCKFwdTraining(0, "") {} + PerformanceConfigBnCKFwdTraining(bool) : PerformanceConfigBnCKFwdTraining(0, "") {} + MIOPEN_INTERNALS_EXPORT void + HeuristicInit(const miopen::batchnorm::ProblemDescription& problem_desc); + MIOPEN_INTERNALS_EXPORT bool + SetNextValue(const miopen::batchnorm::ProblemDescription& problem_desc); + MIOPEN_INTERNALS_EXPORT bool IsValidValue() const; + MIOPEN_INTERNALS_EXPORT bool + IsValid(const ExecutionContext&, + const miopen::batchnorm::ProblemDescription& problem_desc) const; + + template + static void Visit(Self&& s, F f) + { + f(s.kernel_id, "kernel_id"); + } + MIOPEN_INTERNALS_EXPORT bool operator==(const PerformanceConfigBnCKFwdTraining& other) const; + +private: + template + void Init(const miopen::batchnorm::ProblemDescription&); + template + bool CheckIsSupportCKArgs(const miopen::batchnorm::ProblemDescription&) const; }; -struct BnCKFwdTraining final : BatchnormSolver +struct BnCKFwdTraining final : BatchNormTunableSolver { const std::string& SolverDbId() const override { return GetSolverDbId(); } - bool IsApplicable(const ExecutionContext& context, - const miopen::batchnorm::ProblemDescription& problem) const override; - ConvSolution GetSolution(const ExecutionContext& context, - const miopen::batchnorm::ProblemDescription& problem) const override; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdTraining GetDefaultPerformanceConfig( + const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc) const override; + MIOPEN_INTERNALS_EXPORT bool + IsValidPerformanceConfig(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKFwdTraining& config) const override; + MIOPEN_INTERNALS_EXPORT PerformanceConfigBnCKFwdTraining + Search(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const AnyInvokeParams& invoke_ctx) const override; + MIOPEN_INTERNALS_EXPORT bool + IsApplicable(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc) const override; + MIOPEN_INTERNALS_EXPORT ConvSolution + GetSolution(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKFwdTraining& config) const override; }; } // namespace batchnorm diff --git a/src/include/miopen/conv/problem_description.hpp b/src/include/miopen/conv/problem_description.hpp index f0ac17d0b0..6472d99ed8 100644 --- a/src/include/miopen/conv/problem_description.hpp +++ b/src/include/miopen/conv/problem_description.hpp @@ -32,7 +32,6 @@ #include #include -#include #include #if MIOPEN_ENABLE_SQLITE @@ -101,36 +100,6 @@ constexpr TElement GetWofCHWN(const std::vector& data) return std::get<2>(GetCHWN(data)); } -template -constexpr TElement GetN5(unsigned spatial_dims, const std::vector& data) -{ - return std::get<0>(GetNCDHW(spatial_dims, data)); -} - -template -constexpr TElement GetC5(unsigned spatial_dims, const std::vector& data) -{ - return std::get<1>(GetNCDHW(spatial_dims, data)); -} - -template -constexpr TElement GetD5(unsigned spatial_dims, const std::vector& data) -{ - return std::get<2>(GetNCDHW(spatial_dims, data)); -} - -template -constexpr TElement GetH5(unsigned spatial_dims, const std::vector& data) -{ - return std::get<3>(GetNCDHW(spatial_dims, data)); -} - -template -constexpr TElement GetW5(unsigned spatial_dims, const std::vector& data) -{ - return std::get<4>(GetNCDHW(spatial_dims, data)); -} - namespace conv { MIOPEN_INTERNALS_EXPORT miopenAlphaBetaCase_t ClassifyAlphaBeta(const Scalar& alpha, @@ -391,10 +360,6 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase return os; } -#if MIOPEN_ENABLE_SQLITE - static std::string table_name() { return "config"; } -#endif - template static void Visit(Self&& self, std::function f) { diff --git a/src/include/miopen/problem_description_base.hpp b/src/include/miopen/problem_description_base.hpp index 3e19f85e8c..a7914a8b5a 100644 --- a/src/include/miopen/problem_description_base.hpp +++ b/src/include/miopen/problem_description_base.hpp @@ -28,7 +28,9 @@ #include #include +#include +#include #include namespace miopen { @@ -51,6 +53,36 @@ inline std::string GetDataTypeName(miopenDataType_t data_type) return "Unknown(" + std::to_string(data_type) + ")"; } +template +constexpr TElement GetN5(unsigned spatial_dims, const std::vector& data) +{ + return std::get<0>(GetNCDHW(spatial_dims, data)); +} + +template +constexpr TElement GetC5(unsigned spatial_dims, const std::vector& data) +{ + return std::get<1>(GetNCDHW(spatial_dims, data)); +} + +template +constexpr TElement GetD5(unsigned spatial_dims, const std::vector& data) +{ + return std::get<2>(GetNCDHW(spatial_dims, data)); +} + +template +constexpr TElement GetH5(unsigned spatial_dims, const std::vector& data) +{ + return std::get<3>(GetNCDHW(spatial_dims, data)); +} + +template +constexpr TElement GetW5(unsigned spatial_dims, const std::vector& data) +{ + return std::get<4>(GetNCDHW(spatial_dims, data)); +} + struct ProblemDescriptionBase { ProblemDescriptionBase() = default; @@ -60,6 +92,9 @@ struct ProblemDescriptionBase ProblemDescriptionBase& operator=(const ProblemDescriptionBase&) = default; [[nodiscard]] virtual NetworkConfig MakeNetworkConfig() const = 0; +#if MIOPEN_ENABLE_SQLITE + static std::string table_name() { return "config"; } +#endif }; } // namespace miopen diff --git a/src/ocl/batchnormocl.cpp b/src/ocl/batchnormocl.cpp index f33c5ac5db..dca94078be 100644 --- a/src/ocl/batchnormocl.cpp +++ b/src/ocl/batchnormocl.cpp @@ -136,7 +136,7 @@ void BatchNormForwardTraining(Handle& handle, : AlgorithmName{"miopenBatchNormForwardTrainingPerActivation"}; const auto invoke_params = [&]() { - auto tmp = batchnorm::InvokeParams{}; + auto tmp = miopen::batchnorm::FwdTrainInvokeParams{}; tmp.type = InvokeType::Run; tmp.x = x; tmp.y = y; diff --git a/src/solver/batchnorm/backward_ck.cpp b/src/solver/batchnorm/backward_ck.cpp index bca7afc3a5..c99a67250b 100644 --- a/src/solver/batchnorm/backward_ck.cpp +++ b/src/solver/batchnorm/backward_ck.cpp @@ -25,23 +25,24 @@ *******************************************************************************/ #include +#include #include #include #include #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL #include -#include #include +#include #endif -MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_CK_BN_BACK) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CK_BN_BACK) namespace miopen { namespace solver { namespace batchnorm { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL -using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using index_t = int32_t; +using PassThroughOp = ck::tensor_operation::element_wise::PassThrough; +using index_t = int32_t; constexpr index_t Rank = 4; constexpr index_t NumBatchNormReduceDim = 3; @@ -66,7 +67,7 @@ using DeviceOpBNBwdPtrs = ck::tensor_operation::device::instance::DeviceOperatio ScaleDataType, DscaleDbiasDataType, MeanVarDataType, - PassThrough, + PassThroughOp, Rank, NumBatchNormReduceDim>>; @@ -111,7 +112,7 @@ struct CKArgsBNormBwd data_ctx.savedMean, data_ctx.savedInvVariance, epsilon, - PassThrough{}, + PassThroughOp{}, data_ctx.dx, data_ctx.resultBnScaleDiff, data_ctx.resultBnBiasDiff); @@ -135,6 +136,81 @@ struct CKArgsBNormBwd std::array reduceDims{0, 1, 2}; }; +template +void PerformanceConfigBnCKBwdBackward::Init( + const miopen::batchnorm::ProblemDescription& problem_desc) +{ + const auto& args = CKArgsBNormBwd{problem_desc}; + const auto bn_bwd_ptrs = DeviceOpBNBwdPtrs::GetInstances(); + if(bn_bwd_ptrs.empty()) + MIOPEN_THROW(miopenStatusInternalError, "BnCKBwdBackward bn_bwd_ptrs empty"); + + for(const auto& it : bn_bwd_ptrs) + { + auto argument_ptr = it->MakeArgumentPointer(args.lens, + args.in_strides, + args.in_strides, + args.in_strides, + args.reduceDims, + args.arrScaleBiasMeanVarLengths, + args.arrScaleBiasMeanVarStrides, + args.arrScaleBiasMeanVarStrides, + args.arrScaleBiasMeanVarStrides, + nullptr, + nullptr, + nullptr, + nullptr, + nullptr, + 0.0, + PassThroughOp{}, + nullptr, + nullptr, + nullptr); + if(it->IsSupportedArgument(argument_ptr.get())) + { + valid_kernels.push_back(it->GetTypeString()); + } + } + + if(valid_kernels.empty()) + MIOPEN_THROW(miopenStatusInternalError, "BnCKBwdBackward valid_kernels empty"); + + this->index = 0; + this->kernel_id = valid_kernels[0]; +} + +template +bool PerformanceConfigBnCKBwdBackward::CheckIsSupportCKArgs( + const miopen::batchnorm::ProblemDescription& problem) const +{ + return IsCKArgsSupported, + CKArgsBNormBwd>(problem, this->kernel_id); +} + template , CKArgsBNormBwd>(problem); } +#endif + +void PerformanceConfigBnCKBwdBackward::HeuristicInit( + const miopen::batchnorm::ProblemDescription& problem_desc) +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; +#else + switch(problem_desc.GetXDesc().GetType()) + { + case miopenHalf: Init(problem_desc); break; + case miopenBFloat16: Init(problem_desc); break; + case miopenFloat: Init(problem_desc); break; + case miopenDouble: Init(problem_desc); break; + case miopenFloat8: + case miopenBFloat8: + case miopenInt8: + case miopenInt32: + case miopenInt64: + default: MIOPEN_THROW("Unsupported datatype"); + } + +#endif +} -template -static ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) +bool PerformanceConfigBnCKBwdBackward::SetNextValue( + const miopen::batchnorm::ProblemDescription& problem_desc) +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; + return false; +#else + if(this->valid_kernels.empty()) + { + this->HeuristicInit(problem_desc); + return true; + } + if((this->index + 1) < valid_kernels.size()) + { + ++this->index; + this->kernel_id = this->valid_kernels[index]; + return true; + } + else + return false; +#endif +} + +bool PerformanceConfigBnCKBwdBackward::IsValidValue() const { - const auto& valid_kernel_ids = FillValidKernelsIDs, - CKArgsBNormBwd>(bn_problem); - assert(!valid_kernel_ids.empty()); - const auto& kernel_id = valid_kernel_ids[0]; - return InitAnyInvokerFactory, - CKArgsBNormBwd, - miopen::batchnorm::BwdInvokeParams>(bn_problem, kernel_id); + return this->index >= 0 && this->index < valid_kernels.size(); } +bool PerformanceConfigBnCKBwdBackward::IsValid( + const ExecutionContext&, const miopen::batchnorm::ProblemDescription& problem_desc) const +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; + return false; +#else + switch(problem_desc.GetXDesc().GetType()) + { + case miopenHalf: return CheckIsSupportCKArgs(problem_desc); + case miopenBFloat16: + return CheckIsSupportCKArgs(problem_desc); + case miopenFloat: return CheckIsSupportCKArgs(problem_desc); + case miopenDouble: return CheckIsSupportCKArgs(problem_desc); + case miopenFloat8: + case miopenBFloat8: + case miopenInt8: + case miopenInt32: + case miopenInt64: + default: MIOPEN_THROW("Unsupported datatype"); + } + return false; #endif +} + +bool PerformanceConfigBnCKBwdBackward::operator==( + const PerformanceConfigBnCKBwdBackward& other) const +{ + return this->kernel_id == other.kernel_id; +} + +PerformanceConfigBnCKBwdBackward BnCKBwdBackward::GetDefaultPerformanceConfig( + const ExecutionContext&, const miopen::batchnorm::ProblemDescription& problem_desc) const +{ + PerformanceConfigBnCKBwdBackward pp; + pp.HeuristicInit(problem_desc); + MIOPEN_LOG_I(pp.ToString()); + return pp; +} + +bool BnCKBwdBackward::IsValidPerformanceConfig( + const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKBwdBackward& config) const +{ + return config.IsValid(ctx, problem_desc); +} + +PerformanceConfigBnCKBwdBackward +BnCKBwdBackward::Search(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const AnyInvokeParams& invoke_ctx) const +{ + return GenericSearch(*this, ctx, problem_desc, invoke_ctx); +} bool BnCKBwdBackward::IsApplicable( [[maybe_unused]] const ExecutionContext& context, [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - if(env::disabled(MIOPEN_DEBUG_CONV_CK_BN_BACK)) + if(env::disabled(MIOPEN_DEBUG_CK_BN_BACK)) return false; if(!bn_problem.IsLayoutNHWC()) return false; if(!ck_utility::is_ck_supported_hardware(context.GetStream())) return false; - if(bn_problem.GetXDesc().GetType() != bn_problem.GetScaleBiasDiffDesc().GetType()) + if(!bn_problem.Is2D()) return false; if(bn_problem.GetDirection() != miopen::batchnorm::Direction::Backward) return false; - if(!bn_problem.Is2D()) - return false; + switch(bn_problem.GetXDesc().GetType()) { - case miopenFloat: return CheckCKApplicability(bn_problem); - case miopenDouble: return CheckCKApplicability(bn_problem); case miopenHalf: return CheckCKApplicability(bn_problem); case miopenBFloat16: return CheckCKApplicability(bn_problem); + case miopenFloat: return CheckCKApplicability(bn_problem); + case miopenDouble: return CheckCKApplicability(bn_problem); case miopenInt64: case miopenInt32: case miopenInt8: - case miopenBFloat8: - case miopenFloat8: break; + case miopenFloat8: + case miopenBFloat8: break; } #endif return false; } -ConvSolution BnCKBwdBackward::GetSolution( - [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const +template +ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& problem, + InvokerFactoryMakerNHWC&& invoker_factory_maker_nhwc) { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - switch(bn_problem.GetXDesc().GetType()) + if(problem.IsLayoutNHWC()) + { + switch(problem.GetXDesc().GetType()) + { + case miopenFloat: return invoker_factory_maker_nhwc(F32{}); + case miopenDouble: return invoker_factory_maker_nhwc(F64{}); + case miopenHalf: return invoker_factory_maker_nhwc(F16{}); + case miopenBFloat16: return invoker_factory_maker_nhwc(BF16{}); + default: + MIOPEN_THROW(miopenStatusInternalError, + "BnCKBwdBackward operation does not support this data type"); + } + } + // Todo: problem.IsLayoutDefault() + else { - - case miopenFloat: return MakeAnyInvokerFactory(bn_problem); - case miopenDouble: return MakeAnyInvokerFactory(bn_problem); - case miopenHalf: return MakeAnyInvokerFactory(bn_problem); - case miopenBFloat16: - return MakeAnyInvokerFactory(bn_problem); - case miopenInt8: - case miopenInt32: - case miopenInt64: - case miopenBFloat8: - case miopenFloat8: - default: MIOPEN_THROW(miopenStatusInternalError, - "BnCKBwdBackward operation does not support this data type"); + "BnCKBwdBackward operation does not support this data layout"); } +#else + return {}; #endif +} + +ConvSolution BnCKBwdBackward::GetSolution( + [[maybe_unused]] const ExecutionContext&, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem, + [[maybe_unused]] const PerformanceConfigBnCKBwdBackward& config) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + return MakeAnyInvokerFactory( + bn_problem, + [&](auto data_type_val) { + using T = decltype(data_type_val); + + using AccTy = std::conditional_t, + T, // T==F64 + F32>; // T==F32 + return InitAnyInvokerFactory, + CKArgsBNormBwd, + miopen::batchnorm::BwdInvokeParams, + miopen::batchnorm::ProblemDescription>(bn_problem, + config.kernel_id); + } + // Todo: InvokerFactoryMakerNCHW + ); +#else + std::ignore = bn_problem; + std::ignore = config; return {}; +#endif } } // namespace batchnorm diff --git a/src/solver/batchnorm/forward_inference_ck.cpp b/src/solver/batchnorm/forward_inference_ck.cpp index 4288dd8cb5..bec1805de6 100644 --- a/src/solver/batchnorm/forward_inference_ck.cpp +++ b/src/solver/batchnorm/forward_inference_ck.cpp @@ -26,13 +26,15 @@ *******************************************************************************/ #include +#include #include #include #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL #include +#include #include #endif -MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_CK_BN_INFER) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CK_BN_INFER) namespace miopen { namespace solver { @@ -110,6 +112,13 @@ struct CKArgsBNormFwd {data_ctx.y}, Normalize{data_ctx.epsilon}); } + + template + bool IsSupportedBy(const ConvPtr& invoker_ptr) const + { + auto arg_ptr = MakeArgPtr(invoker_ptr, miopen::batchnorm::InfInvokeParams{}); + return invoker_ptr->IsSupportedArgument(arg_ptr.get()); + } }; template -static int CheckCKApplicability(const miopen::batchnorm::ProblemDescription& problem) +void PerformanceConfigBnCKFwdInference::Init( + const miopen::batchnorm::ProblemDescription& problem_desc) { - const auto& args = CKArgsBNormFwd{problem}; + const auto& args = CKArgsBNormFwd{problem_desc}; const auto bn_fwd_ptrs = DeviceOpBnFwdInfPtrs:: GetInstances(); - assert(!bn_fwd_ptrs.empty()); - int count = 0; + if(bn_fwd_ptrs.empty()) + MIOPEN_THROW(miopenStatusInternalError, "BnCKFwdInference bn_fwd_ptrs empty"); + for(const auto& it : bn_fwd_ptrs) { auto argument_ptr = it->MakeArgumentPointer(args.xyLengths, @@ -140,11 +151,14 @@ static int CheckCKApplicability(const miopen::batchnorm::ProblemDescription& pro Normalize{0.0}); if(it->IsSupportedArgument(argument_ptr.get())) { - return count; + valid_kernels.push_back(it->GetTypeString()); } - count++; } - return -1; + + if(valid_kernels.empty()) + MIOPEN_THROW(miopenStatusInternalError, "BnCKFwdInference valid_kernels empty"); + this->index = 0; + this->kernel_id = valid_kernels[0]; } template -ConvSolution InvokerFactoryMakerNHWC(const miopen::batchnorm::ProblemDescription& bn_problem) +bool PerformanceConfigBnCKFwdInference::CheckIsSupportCKArgs( + const miopen::batchnorm::ProblemDescription& problem) const { - ConvSolution result; - const auto kernel_index = CheckCKApplicability(bn_problem); - auto bn_fwd_ptrs = - DeviceOpBnFwdInfPtrs:: - GetInstances(); + return IsCKArgsSupported< + DeviceOpBnFwdInfPtrs, + CKArgsBNormFwd>(problem, this->kernel_id); +} - assert(kernel_index >= 0 && !bn_fwd_ptrs.empty() && kernel_index < bn_fwd_ptrs.size()); - auto bn_ptr = std::move(bn_fwd_ptrs.at(kernel_index)); - - result.invoker_factory = [args = CKArgsBNormFwd{bn_problem}, - sh_bn_ptr = std::shared_ptr{std::move(bn_ptr)}]( - const std::vector& /*kernels*/) mutable { - return [args = std::move(args), sh_bn_ptr = std::move(sh_bn_ptr)]( - const Handle& handle, const AnyInvokeParams& primitive_parameters) { - const auto& params = primitive_parameters.CastTo(); - - auto argument_ptr = args.MakeArgPtr(sh_bn_ptr, params); - - auto invoker_ptr = sh_bn_ptr->MakeInvokerPointer(); - const auto enable_profiling = handle.IsProfilingEnabled(); - - float elapsed_time = - invoker_ptr->Run(argument_ptr.get(), {handle.GetStream(), enable_profiling}); - if(enable_profiling) - { - handle.ResetKernelTime(); - handle.AccumKernelTime(elapsed_time); - } - }; - }; - return result; +template +static bool CheckCKApplicability(const miopen::batchnorm::ProblemDescription& problem) +{ + return IsCKApplicable< + DeviceOpBnFwdInfPtrs, + CKArgsBNormFwd>(problem); } #endif +void PerformanceConfigBnCKFwdInference::HeuristicInit( + const miopen::batchnorm::ProblemDescription& problem_desc) +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; +#else + switch(problem_desc.GetXDesc().GetType()) + { + case miopenHalf: Init(problem_desc); break; + case miopenBFloat16: Init(problem_desc); break; + case miopenFloat: Init(problem_desc); break; + case miopenDouble: Init(problem_desc); break; + case miopenFloat8: + case miopenBFloat8: + case miopenInt8: + case miopenInt32: + case miopenInt64: + default: MIOPEN_THROW("Unsupported datatype"); + } + +#endif +} + +bool PerformanceConfigBnCKFwdInference::SetNextValue( + const miopen::batchnorm::ProblemDescription& problem_desc) +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; + return false; +#else + if(this->valid_kernels.empty()) + { + this->HeuristicInit(problem_desc); + if(valid_kernels.empty()) + MIOPEN_THROW(miopenStatusInternalError, "BnCKFwdInference valid_kernels empty"); + return true; + } + if((this->index + 1) < valid_kernels.size()) + { + ++this->index; + this->kernel_id = this->valid_kernels[index]; + return true; + } + else + return false; +#endif +} + +bool PerformanceConfigBnCKFwdInference::IsValidValue() const +{ + return this->index >= 0 && this->index < valid_kernels.size(); +} + +bool PerformanceConfigBnCKFwdInference::IsValid( + const ExecutionContext&, const miopen::batchnorm::ProblemDescription& problem_desc) const +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; + return false; +#else + switch(problem_desc.GetXDesc().GetType()) + { + case miopenHalf: return CheckIsSupportCKArgs(problem_desc); + case miopenBFloat16: + return CheckIsSupportCKArgs(problem_desc); + case miopenFloat: return CheckIsSupportCKArgs(problem_desc); + case miopenDouble: return CheckIsSupportCKArgs(problem_desc); + case miopenFloat8: + case miopenBFloat8: + case miopenInt8: + case miopenInt32: + case miopenInt64: + default: MIOPEN_THROW("Unsupported datatype"); + } + return false; +#endif +} + +bool PerformanceConfigBnCKFwdInference::operator==( + const PerformanceConfigBnCKFwdInference& other) const +{ + return this->kernel_id == other.kernel_id; +} + +PerformanceConfigBnCKFwdInference BnCKFwdInference::GetDefaultPerformanceConfig( + const ExecutionContext&, const miopen::batchnorm::ProblemDescription& problem_desc) const +{ + PerformanceConfigBnCKFwdInference pp; + pp.HeuristicInit(problem_desc); + MIOPEN_LOG_I(pp.ToString()); + return pp; +} + +bool BnCKFwdInference::IsValidPerformanceConfig( + const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKFwdInference& config) const +{ + return config.IsValid(ctx, problem_desc); +} + +PerformanceConfigBnCKFwdInference +BnCKFwdInference::Search(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const AnyInvokeParams& invoke_ctx) const +{ + return GenericSearch(*this, ctx, problem_desc, invoke_ctx); +} + bool BnCKFwdInference::IsApplicable( [[maybe_unused]] const ExecutionContext& context, [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - if(env::disabled(MIOPEN_DEBUG_CONV_CK_BN_INFER)) + if(env::disabled(MIOPEN_DEBUG_CK_BN_INFER)) return false; if(!bn_problem.IsLayoutNHWC()) return false; @@ -212,12 +316,10 @@ bool BnCKFwdInference::IsApplicable( switch(bn_problem.GetXDesc().GetType()) { - case miopenHalf: return (CheckCKApplicability(bn_problem) != -1); - case miopenFloat: return (CheckCKApplicability(bn_problem) != -1); - case miopenDouble: - return (CheckCKApplicability(bn_problem) != -1); - case miopenBFloat16: - return (CheckCKApplicability(bn_problem) != -1); + case miopenHalf: return CheckCKApplicability(bn_problem); + case miopenBFloat16: return CheckCKApplicability(bn_problem); + case miopenFloat: return CheckCKApplicability(bn_problem); + case miopenDouble: return CheckCKApplicability(bn_problem); case miopenInt64: case miopenInt32: case miopenInt8: @@ -258,8 +360,9 @@ ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& } ConvSolution BnCKFwdInference::GetSolution( - [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const + [[maybe_unused]] const ExecutionContext&, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem, + [[maybe_unused]] const PerformanceConfigBnCKFwdInference& config) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL return MakeAnyInvokerFactory( @@ -270,11 +373,17 @@ ConvSolution BnCKFwdInference::GetSolution( using AccTy = std::conditional_t, T, // T==F64 F32>; // T==F32 - return InvokerFactoryMakerNHWC(bn_problem); + return InitAnyInvokerFactory, + CKArgsBNormFwd, + miopen::batchnorm::InfInvokeParams, + miopen::batchnorm::ProblemDescription>(bn_problem, + config.kernel_id); } // Todo: InvokerFactoryMakerNCHW ); #else + std::ignore = bn_problem; + std::ignore = config; return {}; #endif } diff --git a/src/solver/batchnorm/forward_per_activation.cpp b/src/solver/batchnorm/forward_per_activation.cpp index 7c1b34dea8..0e4d6da799 100644 --- a/src/solver/batchnorm/forward_per_activation.cpp +++ b/src/solver/batchnorm/forward_per_activation.cpp @@ -130,7 +130,7 @@ BnFwdTrainingPerActivation::GetSolution(const ExecutionContext& context, result.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { decltype(auto) kernel = handle_.Run(kernels.front()); - decltype(auto) params = raw_params.CastTo(); + decltype(auto) params = raw_params.CastTo(); const auto resultsave = params.resultSaveMean != nullptr && params.resultSaveInvVariance != nullptr; const auto resultrunning = diff --git a/src/solver/batchnorm/forward_spatial_multiple.cpp b/src/solver/batchnorm/forward_spatial_multiple.cpp index ad065cece3..9f85c8d7ab 100644 --- a/src/solver/batchnorm/forward_spatial_multiple.cpp +++ b/src/solver/batchnorm/forward_spatial_multiple.cpp @@ -208,7 +208,7 @@ ConvSolution BnFwdTrainingSpatialMultiple::GetSolution( result.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { decltype(auto) kernel = handle_.Run(kernels.front()); - decltype(auto) params = raw_params.CastTo(); + decltype(auto) params = raw_params.CastTo(); const auto resultsave = params.resultSaveMean != nullptr && params.resultSaveInvVariance != nullptr; const auto resultrunning = diff --git a/src/solver/batchnorm/forward_spatial_single.cpp b/src/solver/batchnorm/forward_spatial_single.cpp index eb45ec5ea2..732181073a 100644 --- a/src/solver/batchnorm/forward_spatial_single.cpp +++ b/src/solver/batchnorm/forward_spatial_single.cpp @@ -244,7 +244,7 @@ BnFwdTrainingSpatialSingle::GetSolution(const ExecutionContext& context, result.invoker_factory = [=](const std::vector& kernels) { return [=](const Handle& handle_, const AnyInvokeParams& raw_params) { decltype(auto) kernel = handle_.Run(kernels.front()); - decltype(auto) params = raw_params.CastTo(); + decltype(auto) params = raw_params.CastTo(); const auto resultsave = params.resultSaveMean != nullptr && params.resultSaveInvVariance != nullptr; const auto resultrunning = diff --git a/src/solver/batchnorm/forward_training_ck.cpp b/src/solver/batchnorm/forward_training_ck.cpp index 90ddc3043f..cdaafe7b58 100644 --- a/src/solver/batchnorm/forward_training_ck.cpp +++ b/src/solver/batchnorm/forward_training_ck.cpp @@ -27,13 +27,14 @@ #include #include +#include #include #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL #include -#include #include +#include #endif -MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING) +MIOPEN_DECLARE_ENV_VAR_BOOL(MIOPEN_DEBUG_CK_BN_FWD_TRAINING) namespace miopen { namespace solver { @@ -119,7 +120,7 @@ struct CKArgsBNormFwdTraining template bool IsSupportedBy(const ConvPtr& invoker_ptr) const { - auto arg_ptr = MakeArgPtr(invoker_ptr, miopen::batchnorm::InvokeParams{}); + auto arg_ptr = MakeArgPtr(invoker_ptr, miopen::batchnorm::FwdTrainInvokeParams{}); return invoker_ptr->IsSupportedArgument(arg_ptr.get()); } @@ -133,6 +134,77 @@ struct CKArgsBNormFwdTraining std::array reduceDims{0, 1, 2}; }; +template +void PerformanceConfigBnCKFwdTraining::Init( + const miopen::batchnorm::ProblemDescription& problem_desc) +{ + const auto& args = CKArgsBNormFwdTraining{problem_desc}; + const auto bn_fwd_ptrs = DeviceOpBNFwdTrainingPtrs::GetInstances(); + if(bn_fwd_ptrs.empty()) + MIOPEN_THROW(miopenStatusInternalError, "BnCKFwdTraining bn_fwd_ptrs empty"); + + for(const auto& it : bn_fwd_ptrs) + { + auto argument_ptr = it->MakeArgumentPointer(args.xyLengths, + args.xyStrides, + args.xyStrides, + args.reduceDims, + args.arrScaleBiasMeanVarLengths, + args.arrScaleBiasMeanVarStrides, + args.arrScaleBiasMeanVarStrides, + args.arrScaleBiasMeanVarStrides, + nullptr, + nullptr, + nullptr, + 0.0, + PassThroughOp{}, + nullptr, + nullptr, + nullptr, + 0.0, + nullptr, + nullptr); + if(it->IsSupportedArgument(argument_ptr.get())) + { + valid_kernels.push_back(it->GetTypeString()); + } + } + + if(valid_kernels.empty()) + MIOPEN_THROW(miopenStatusInternalError, "BnCKFwdTraining valid_kernels empty"); + + this->index = 0; + this->kernel_id = valid_kernels[0]; +} + +template +bool PerformanceConfigBnCKFwdTraining::CheckIsSupportCKArgs( + const miopen::batchnorm::ProblemDescription& problem) const +{ + return IsCKArgsSupported, + CKArgsBNormFwdTraining>(problem, this->kernel_id); +} + template , CKArgsBNormFwdTraining>(problem); } +#endif -template -static ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& bn_problem) +void PerformanceConfigBnCKFwdTraining::HeuristicInit( + const miopen::batchnorm::ProblemDescription& problem_desc) { - const auto& valid_kernel_ids = FillValidKernelsIDs, - CKArgsBNormFwdTraining>(bn_problem); - assert(!valid_kernel_ids.empty()); - const auto& kernel_id = valid_kernel_ids[0]; - return InitAnyInvokerFactory, - CKArgsBNormFwdTraining, - miopen::batchnorm::InvokeParams>(bn_problem, kernel_id); +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; +#else + switch(problem_desc.GetXDesc().GetType()) + { + case miopenHalf: Init(problem_desc); break; + case miopenBFloat16: Init(problem_desc); break; + case miopenFloat: Init(problem_desc); break; + case miopenDouble: Init(problem_desc); break; + case miopenFloat8: + case miopenBFloat8: + case miopenInt8: + case miopenInt32: + case miopenInt64: + default: MIOPEN_THROW("Unsupported datatype"); + } + +#endif } + +bool PerformanceConfigBnCKFwdTraining::SetNextValue( + const miopen::batchnorm::ProblemDescription& problem_desc) +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; + return false; +#else + if(this->valid_kernels.empty()) + { + this->HeuristicInit(problem_desc); + return true; + } + if((this->index + 1) < valid_kernels.size()) + { + ++this->index; + this->kernel_id = this->valid_kernels[index]; + return true; + } + else + return false; #endif +} + +bool PerformanceConfigBnCKFwdTraining::IsValidValue() const +{ + return this->index >= 0 && this->index < valid_kernels.size(); +} + +bool PerformanceConfigBnCKFwdTraining::IsValid( + const ExecutionContext&, const miopen::batchnorm::ProblemDescription& problem_desc) const +{ +#if !MIOPEN_BACKEND_HIP || !MIOPEN_USE_COMPOSABLEKERNEL + std::ignore = problem_desc; + return false; +#else + switch(problem_desc.GetXDesc().GetType()) + { + case miopenHalf: return CheckIsSupportCKArgs(problem_desc); + case miopenBFloat16: + return CheckIsSupportCKArgs(problem_desc); + case miopenFloat: return CheckIsSupportCKArgs(problem_desc); + case miopenDouble: return CheckIsSupportCKArgs(problem_desc); + case miopenFloat8: + case miopenBFloat8: + case miopenInt8: + case miopenInt32: + case miopenInt64: + default: MIOPEN_THROW("Unsupported datatype"); + } + return false; +#endif +} + +bool PerformanceConfigBnCKFwdTraining::operator==( + const PerformanceConfigBnCKFwdTraining& other) const +{ + return this->kernel_id == other.kernel_id; +} + +PerformanceConfigBnCKFwdTraining BnCKFwdTraining::GetDefaultPerformanceConfig( + const ExecutionContext&, const miopen::batchnorm::ProblemDescription& problem_desc) const +{ + PerformanceConfigBnCKFwdTraining pp; + pp.HeuristicInit(problem_desc); + MIOPEN_LOG_I(pp.ToString()); + return pp; +} + +bool BnCKFwdTraining::IsValidPerformanceConfig( + const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const PerformanceConfigBnCKFwdTraining& config) const +{ + return config.IsValid(ctx, problem_desc); +} + +PerformanceConfigBnCKFwdTraining +BnCKFwdTraining::Search(const ExecutionContext& ctx, + const miopen::batchnorm::ProblemDescription& problem_desc, + const AnyInvokeParams& invoke_ctx) const +{ + return GenericSearch(*this, ctx, problem_desc, invoke_ctx); +} bool BnCKFwdTraining::IsApplicable( [[maybe_unused]] const ExecutionContext& context, [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - if(env::disabled(MIOPEN_DEBUG_CONV_CK_BN_FWD_TRAINING)) + if(env::disabled(MIOPEN_DEBUG_CK_BN_FWD_TRAINING)) return false; if(!bn_problem.IsLayoutNHWC()) return false; @@ -197,45 +349,75 @@ bool BnCKFwdTraining::IsApplicable( switch(bn_problem.GetXDesc().GetType()) { case miopenHalf: return CheckCKApplicability(bn_problem); + case miopenBFloat16: return CheckCKApplicability(bn_problem); case miopenFloat: return CheckCKApplicability(bn_problem); case miopenDouble: return CheckCKApplicability(bn_problem); - case miopenBFloat16: { - bool var = CheckCKApplicability(bn_problem); - return var; - } case miopenInt64: case miopenInt32: case miopenInt8: - case miopenBFloat8: - case miopenFloat8: break; + case miopenFloat8: + case miopenBFloat8: break; } #endif return false; } -ConvSolution BnCKFwdTraining::GetSolution( - [[maybe_unused]] const ExecutionContext& context, - [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem) const +template +ConvSolution MakeAnyInvokerFactory(const miopen::batchnorm::ProblemDescription& problem, + InvokerFactoryMakerNHWC&& invoker_factory_maker_nhwc) { #if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL - switch(bn_problem.GetXDesc().GetType()) + if(problem.IsLayoutNHWC()) + { + switch(problem.GetXDesc().GetType()) + { + case miopenFloat: return invoker_factory_maker_nhwc(F32{}); + case miopenDouble: return invoker_factory_maker_nhwc(F64{}); + case miopenHalf: return invoker_factory_maker_nhwc(F16{}); + case miopenBFloat16: return invoker_factory_maker_nhwc(BF16{}); + default: + MIOPEN_THROW(miopenStatusInternalError, + "BnCKFwdTraining operation does not support this data type"); + } + } + // Todo: problem.IsLayoutDefault() + else { - - case miopenFloat: return MakeAnyInvokerFactory(bn_problem); - case miopenDouble: return MakeAnyInvokerFactory(bn_problem); - case miopenHalf: return MakeAnyInvokerFactory(bn_problem); - case miopenBFloat16: return MakeAnyInvokerFactory(bn_problem); - case miopenInt8: - case miopenInt32: - case miopenInt64: - case miopenBFloat8: - case miopenFloat8: - default: MIOPEN_THROW(miopenStatusInternalError, - "BnCKFwdTraining operation does not support this data type"); + "BnCKFwdTraining operation does not support this data layout"); } +#else + return {}; #endif +} + +ConvSolution BnCKFwdTraining::GetSolution( + [[maybe_unused]] const ExecutionContext&, + [[maybe_unused]] const miopen::batchnorm::ProblemDescription& bn_problem, + [[maybe_unused]] const PerformanceConfigBnCKFwdTraining& config) const +{ +#if MIOPEN_BACKEND_HIP && MIOPEN_USE_COMPOSABLEKERNEL + return MakeAnyInvokerFactory( + bn_problem, + [&](auto data_type_val) { + using T = decltype(data_type_val); + + using AccTy = std::conditional_t, + T, // T==F64 + F32>; // T==F32 + return InitAnyInvokerFactory, + CKArgsBNormFwdTraining, + miopen::batchnorm::FwdTrainInvokeParams, + miopen::batchnorm::ProblemDescription>(bn_problem, + config.kernel_id); + } + // Todo: InvokerFactoryMakerNCHW + ); +#else + std::ignore = bn_problem; + std::ignore = config; return {}; +#endif } } // namespace batchnorm diff --git a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp index 6e29710d7e..9ec44dc9f1 100644 --- a/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp +++ b/src/solver/conv_ck_igemm_fwd_bias_activ_fused.cpp @@ -366,6 +366,7 @@ bool PerformanceConfigConvCKIgemmFwdBiasActivFused::operator==( { return this->kernel_id == other.kernel_id; } + PerformanceConfigConvCKIgemmFwdBiasActivFused ConvCKIgemmFwdBiasActivFused::GetDefaultPerformanceConfig( const FusionContext&, const FusionDescription& fdesc_problem) const diff --git a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp index 6313f3920f..bd36667f12 100644 --- a/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp +++ b/src/solver/mha/mha_ck_fa_v2_solver_forward.cpp @@ -239,8 +239,13 @@ MhaCKFlashAttentionV2Forward::GetSolution([[maybe_unused]] const ExecutionContex // and isn't async. fmha_runtime_args.p_drop = probability; - fmha_runtime_args.drop_seed_offset = - std::make_pair(dataFwd.dropoutSeedData, dataFwd.dropoutOffsetData); + // fmha_runtime_args.drop_seed_offset = + // std::make_pair(dataFwd.dropoutSeedData), + // dataFwd.dropoutOffsetData); + + // using dataFwd.dropoutSeedData gpu pointer was causing compiler error + // since dropout is disabled for now, placing 0. + fmha_runtime_args.drop_seed_offset = std::make_pair(0, 0); // Create stream_config, and set it to not time kernel. ck_tile::stream_config stream_config;