Skip to content

Commit

Permalink
Enable tuning in Batch norm CK solver (#3326)
Browse files Browse the repository at this point in the history
  • Loading branch information
bghimireamd authored Oct 24, 2024
1 parent 42d6e9f commit 2d1fd99
Show file tree
Hide file tree
Showing 15 changed files with 965 additions and 221 deletions.
6 changes: 6 additions & 0 deletions src/batchnorm/problem_description.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "c" << c;
}
else
Expand All @@ -154,6 +155,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "n" << n;
ss << "c" << c;
Expand All @@ -172,6 +174,7 @@ NetworkConfig ProblemDescription::MakeForwardTrainingNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "gx" << xgridsize;
ss << "gy" << ygridsize;
ss << "lx" << xlocalsize;
Expand Down Expand Up @@ -201,6 +204,7 @@ NetworkConfig ProblemDescription::MakeForwardInferenceNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "mode" << bn_mode;
ss << "HWdims" << in_cstride;
ss << "C" << c;
Expand Down Expand Up @@ -308,6 +312,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "single" << static_cast<int>(single);
ss << "gcn" << ldsgcn;
}
Expand All @@ -330,6 +335,7 @@ NetworkConfig ProblemDescription::MakeBackwardNetworkConfig() const
ss << "fp32" << static_cast<int>(IsFp32());
ss << "fp64" << static_cast<int>(IsFp64());
ss << "fbf16" << static_cast<int>(IsBfp16());
ss << "fmix" << static_cast<int>(IsMix());
ss << "nhw" << in_nhw;
}
ss << "layout" << in_layout;
Expand Down
4 changes: 2 additions & 2 deletions src/include/miopen/batchnorm/invoke_params.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@
namespace miopen {
namespace batchnorm {

struct InvokeParams : public miopen::InvokeParams
struct FwdTrainInvokeParams : public miopen::InvokeParams
{
InvokeParams() = default;
FwdTrainInvokeParams() = default;

ConstData_t x = nullptr;
Data_t y = nullptr;
Expand Down
79 changes: 78 additions & 1 deletion src/include/miopen/batchnorm/problem_description.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ struct ProblemDescriptionTag
{
};

struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, ProblemDescriptionTag
struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase,
ProblemDescriptionTag
#if MIOPEN_ENABLE_SQLITE
,
SQLiteSerializable<ProblemDescription>
#endif
{
// Forward Training
ProblemDescription(miopenBatchNormMode_t bn_mode_,
Expand Down Expand Up @@ -218,10 +223,49 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob
bool IsFp64() const { return xDesc.GetType() == miopenDouble; }
bool IsFp32() const { return xDesc.GetType() == miopenFloat; }
bool IsFp16() const { return xDesc.GetType() == miopenHalf; }
bool IsMix() const
{
return xDesc.GetType() == miopenHalf && sMeanDesc.GetType() == miopenFloat;
}
bool IsBfp16() const { return xDesc.GetType() == miopenBFloat16; }

void Serialize(std::ostream& stream) const { stream << MakeNetworkConfig().ToString(); }

NetworkConfig MakeNetworkConfig() const override;

template <class Self>
static void Visit(Self&& self, std::function<void(int64_t, std::string)> f)
{
// The column names match the driver command line argument names
f(self.spatial_dim, "spatial_dim");
f(self.GetBatchSize(), "batchsize");
f(self.GetChannel(), "in_channels");
f(self.GetHeight(), "in_h");
f(self.GetWidth(), "in_w");
f(self.GetDepth(), "in_d");

f(self.resultsave, "resultsave");
f(self.resultrunning, "resultrunning");
f(self.useSaved, "useSaved");
}

template <class Self>
static void Visit(Self&& self, std::function<void(std::string, std::string)> f)
{
f(self.ComputeInLayout(), "layout");
f(self.GetDirectionStr(), "direction");
f(GetDataTypeName(self.xDesc.GetType()), "data_type");
f(self.GetModeStr(), "mode");
}

template <class Self, class Visitor>
static void VisitAll(Self&& self, const Visitor& f)
{
Visit(std::forward<Self>(self), [&](int64_t value, std::string name) { f(value, name); });
Visit(std::forward<Self>(self),
[&](std::string value, std::string name) { f(value, name); });
}

// This declaration marks batchnorm as a primitive with tuning enabled.
// Any tunable solver would be able pick it and fetch a db instance in ExecutePrimitive.
// It has to be discoverable via ADL from problem description.
Expand Down Expand Up @@ -267,6 +311,39 @@ struct MIOPEN_INTERNALS_EXPORT ProblemDescription : ProblemDescriptionBase, Prob
std::string ComputeInLayout() const { return ComputeLayout(xDesc); }
std::string ComputeOutLayout() const { return ComputeLayout(yOrDyDesc); }
std::string ComputeDinLayout() const { return ComputeLayout(dxDesc); }

size_t GetSpatialDims() const { return spatial_dim; }

std::size_t GetBatchSize() const { return GetN5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetChannel() const { return GetC5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetHeight() const { return GetH5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetWidth() const { return GetW5(GetSpatialDims(), xDesc.GetLengths()); }
std::size_t GetDepth() const { return GetD5(GetSpatialDims(), xDesc.GetLengths()); }

std::string GetDirectionStr() const
{
std::string s;

switch(direction)
{
case Direction::ForwardInference: return "Inf"; ;
case Direction::ForwardTraining: return "Trn";
case Direction::Backward: return "Bwd";
default: MIOPEN_THROW(miopenStatusInvalidValue, "Wrong Batchnorm Direction provided");
}

return s;
}

std::string GetModeStr() const
{
switch(bn_mode)
{
case miopenBNPerActivation: return "0";
case miopenBNSpatial: return "1";
default: MIOPEN_THROW(miopenStatusInvalidValue, "Wrong Batchnorm Mode provided");
}
}
};

} // namespace batchnorm
Expand Down
Loading

0 comments on commit 2d1fd99

Please sign in to comment.