Skip to content

Commit

Permalink
feat(kernel): add gi float32 nchw backdata conv
Browse files Browse the repository at this point in the history
GitOrigin-RevId: c4daadfc9590f5f0c05e04c48335e359f775102d
  • Loading branch information
megvii-mge committed Aug 21, 2023
1 parent 999e286 commit 525e463
Show file tree
Hide file tree
Showing 6 changed files with 378 additions and 2 deletions.
24 changes: 24 additions & 0 deletions compiler/lib/KernelGen/GeneralIntrinsic/ConvKernel/ConvKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,30 @@ class Conv1x1FloatMk4 : public GIConvImpl {
std::shared_ptr<TContext> GetInnerCtx(TContext* ctx) const;
MatmulM4N12MK4Kernel m_inner_gemm;
};

class Float32NchwBackward : public GIConvImpl {
public:
std::string GetKernelSymbol(TContext* context) const override;
bool IsAvailable(TContext* context) const override;
//! kernel gen
std::string GetKernelBody(TContext* context) const override;
//! init gen
std::string GetInitBody(TContext* context) const override;
std::vector<KernelObj> GetDependInternalSymbol(TContext* context) const override;

std::string GetWorkspaceBody(TContext* ctx) const override {
return GetWorkspaceBodyCondition(ctx, false);
}
std::string GetWorkspaceBodyAndJitExec(TContext* ctx) const override {
return GetWorkspaceBodyCondition(ctx, true);
}

private:
std::string GetWorkspaceBodyCondition(TContext* ctx, bool jit) const;
std::shared_ptr<TContext> GetInnerGemmCtx(TContext* ctx) const;
MatmulM4N12Kernel m_inner_gemm;
};

class WinogradFloatF43NCHW44 : public GIConvImpl {
mutable WinogradFrameNchw44 m_framework;
mutable WinogradF43Strategy4x16MK4 m_winograd_strategy;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
#include <sstream>
#include <string>
#include "GeneralIntrinsic/Activation.h"
#include "GeneralIntrinsic/ConvKernel/ConvKernel.h"
#include "GeneralIntrinsic/InternalKernel/InternalKernel.h"
#include "compiler/KernelGen/KernelGen.h"

using namespace megcc;
using namespace KernelGen;
using namespace GeneralIntrinsic;

bool Float32NchwBackward::IsAvailable(TContext* ctx) const {
bool param_value_ok =
ctx->getAttrUInt("kernel_h") == ctx->getAttrUInt("kernel_w") &&
ctx->getAttrUInt("stride_h") == ctx->getAttrUInt("stride_w") &&
ctx->getAttrUInt("dilate_h") == 1 && ctx->getAttrUInt("dilate_w") == 1;
bool param_mode_ok = (ctx->getAttrStr("sparse") == "DENSE" ||
ctx->getAttrStr("sparse") == "GROUP") &&
ctx->getAttrStr("format") == "NCHW" &&
ctx->getAttrStr("mode") == "CROSS_CORRELATION";
bool noline_ok = !ctx->haveAttr("nonlineMode");
bool no_bias_ok = !ConvImpl::is_bias(ctx);

bool type_ok = ctx->getAttrInt("nr_operands") >= 3 &&
ctx->getAttrOprand("operand:0").dtype == "f32" &&
ctx->getAttrOprand("operand:1").dtype == "f32" &&
ctx->getAttrOprand("operand:2").dtype == "f32";
return param_value_ok && param_mode_ok && type_ok && noline_ok && no_bias_ok;
}

std::string Float32NchwBackward::GetKernelSymbol(TContext* ctx) const {
CC_ASSERT(ctx);
std::string name_temp =
"GI_kernel_back_data_conv2d_${format}_${kernel_h}x${kernel_w}_${"
"sparse}_p${pad_h}x${pad_w}_s${stride_h}x${stride_w}_d${"
"dilate_h}x${dilate_w}_f32f32f32";
return StringTemplate::StringTemplateArgs(ctx)
.add_ctx_int("kernel_h")
.add_ctx_int("kernel_w")
.add_ctx_str("format")
.add_ctx_str("sparse")
.add_ctx_int("pad_h")
.add_ctx_int("pad_w")
.add_ctx_int("stride_h")
.add_ctx_int("stride_w")
.add_ctx_int("dilate_h")
.add_ctx_int("dilate_w")
.render(name_temp);
}

std::string Float32NchwBackward::GetInitBody(TContext* ctx) const {
std::stringstream writer;
auto inner_gemm_ctx = GetInnerGemmCtx(ctx);
writer << m_inner_gemm.GetPackASignature(inner_gemm_ctx.get()) << ";\n";
writer << m_inner_gemm.GetPackAWorkspaceSignature(inner_gemm_ctx.get()) << ";\n";
writer << GenCommonRet() << " " << GetInitSignature(ctx);
uint32_t nr_out_weight = 1;
std::string common_def = R"(
Tensor* in_weights = inputs[0]; //! The weight is the first input in BackwardConv.
const uint32_t nr_dim = in_weights->layout.nr_dim;
const uint32_t group = nr_dim == 5 ? in_weights->layout.dims[0] : 1;
const uint32_t ic_idx = nr_dim == 5 ? 1 : 0;
int ymax = in_weights->layout.dims[ic_idx + 1] * in_weights->layout.dims[ic_idx + 2] * in_weights->layout.dims[ic_idx + 3];
int kmax = in_weights->layout.dims[ic_idx + 0];
int ldin = ymax;
)";
std::string fill_weight_attr =
R"(
out_weights->layout.nr_dim = 2;
out_weights->layout.dims[0] = group;
out_weights->layout.dims[1] = )" +
m_inner_gemm.GetPackAWorkspaceSymbol(inner_gemm_ctx.get()) +
R"((0, ymax, 0, kmax)/sizeof(float);
out_weights->layout.stride[0] = out_weights->layout.dims[1];
out_weights->layout.stride[1] = 1;
out_weights->dtype.type_enum=TinyNN_FLOAT;
out_weights->name = in_weights->name;
)";
std::string fill_weight_transform =
R"(
for (int g = 0; g < group; ++g) {
float* outptr = (float*)(out_weights->ptr) + g * out_weights->layout.stride[0];
float* inptr = (float*)(in_weights->ptr) + g * kmax * ymax;
)" + m_inner_gemm.GetPackASymbol(inner_gemm_ctx.get()) +
"(outptr, inptr, ldin, 0, ymax, 0, kmax);}";
writer << StringTemplate::render_init_body(
nr_out_weight, fill_weight_attr, fill_weight_transform, common_def);

return writer.str();
}

std::string Float32NchwBackward::GetWorkspaceBodyCondition(
TContext* ctx, bool jit) const {
std::stringstream ss;
auto inner_ctx = GetInnerGemmCtx(ctx);
if (jit) {
ss << m_inner_gemm.GetPackBWorkspaceBody(inner_ctx.get()) << ";\n";
} else {
ss << "extern " << m_inner_gemm.GetPackBWorkspaceSignature(inner_ctx.get())
<< ";\n";
}
ss << GenCommonRet() << " " << GetWorkspaceSignature(ctx);
std::string workspace_temp =
R"({
TINYNN_ASSERT(workspace);
const Layout weight_layout = inputs[0]->layout;
const uint32_t group = weight_layout.nr_dim == 5 ? weight_layout.dims[0] : 1;
const Layout in_layout = inputs[1]->layout;
const uint32_t ic = in_layout.dims[1] / group;
const uint32_t ih = in_layout.dims[2];
const uint32_t iw = in_layout.dims[3];
const uint32_t hw = ih * iw;
*workspace = )" +
m_inner_gemm.GetPackBWorkspaceSymbol(inner_ctx.get()) +
StringTemplate::StringTemplateArgs(ctx)
.add_ctx_int("kernel_h")
.add_ctx_int("kernel_w")
.add_ctx_int("stride_h")
.add_ctx_int("stride_w")
.add_ctx_int("pad_h")
.add_ctx_int("pad_w")
.render(R"((0, hw, 0, ic);
const uint32_t kh = ${kernel_h};
const uint32_t kw = ${kernel_w};
if (kh != 1 || kw != 1 || ${stride_h} != 1 || ${stride_w} != 1 || ${pad_h} != 0 || ${pad_w} != 0) {
const uint32_t oc = weight_layout.nr_dim == 5 ? weight_layout.dims[2] : weight_layout.dims[1];
*workspace = *workspace + oc * kh * kw * hw * sizeof(float);
}
return TinyNN_SUCCESS;
})");
ss << workspace_temp;
return ss.str();
}

std::vector<KernelObj> Float32NchwBackward::GetDependInternalSymbol(
TContext* ctx) const {
auto inner_gemm_ctx = GetInnerGemmCtx(ctx);

return {
{m_inner_gemm.GetKernelSymbol(inner_gemm_ctx.get()),
m_inner_gemm.GetKernelBody(inner_gemm_ctx.get()),
m_inner_gemm.GetBodyGuardBegin(inner_gemm_ctx.get()),
m_inner_gemm.GetBodyGuardEnd(inner_gemm_ctx.get())}};
}

std::shared_ptr<TContext> Float32NchwBackward::GetInnerGemmCtx(TContext* ctx) const {
auto inner_ctx = std::make_shared<CodeGenContext>();
inner_ctx->setAttr("with_bias", false);
inner_ctx->setAttr("transposeA", true);
inner_ctx->setAttr("transposeB", false);
inner_ctx->setAttr("format", "NCHW");
inner_ctx->setAttr("dtype", "f32");
return inner_ctx;
}

std::string Float32NchwBackward::GetKernelBody(TContext* ctx) const {
std::stringstream writer;
auto inner_ctx = GetInnerGemmCtx(ctx);
writer << m_inner_gemm.GetNakedKernelSignature(inner_ctx.get()) << ";\n";
writer << m_inner_gemm.GetPackBSignature(inner_ctx.get()) << ";\n";
writer << GenCommonRet() << " " << GetKernelSignature(ctx);

std::string col2img = R"(
for (int oc = 0; oc < out_c; ++oc) {
for(int kh_idx = 0; kh_idx < kh; ++kh_idx) {
for(int kw_idx = 0; kw_idx < kw; ++kw_idx) {
for(int ih = 0; ih < in_h; ++ih) {
const int oh = ih * sh + kh_idx - ph;
for(int iw = 0; iw < in_w; ++iw){
const int ow = iw * sw + kw_idx - pw;
if (oh >= 0 && oh < out_h && ow >= 0 && ow < out_w) {
output_data[oc * out_h * out_w + oh * out_w + ow] += *out_ptr;
}
++out_ptr;
}
}
}
}
}
)";
if (ctx->getAttrInt("pad_h") == 0 && ctx->getAttrInt("pad_w") == 0 &&
ctx->getAttrInt("stride_h") == 1 && ctx->getAttrInt("stride_w") == 1) {
col2img = R"(
for (int oc = 0; oc < out_c; ++oc) {
for(int kh_idx = 0; kh_idx < kh; ++kh_idx) {
for(int kw_idx = 0; kw_idx < kw; ++kw_idx) {
for(int ih = 0; ih < in_h; ++ih) {
const int oh = ih + kh_idx;
for(int iw = 0; iw < in_w; ++iw){
const int ow = iw + kw_idx;
output_data[oc * out_h * out_w + oh * out_w + ow] += *out_ptr++;
}
}
}
}
}
)";
}

writer << StringTemplate::StringTemplateArgs(ctx)
.add_ctx_int("kernel_h")
.add_ctx_int("kernel_w")
.add_ctx_int("pad_h")
.add_ctx_int("pad_w")
.add_ctx_int("stride_h")
.add_ctx_int("stride_w")
.render(R"({
float* input_data = inputs[1]->ptr;
float* output_data = outputs[0]->ptr;
Layout weight_layout = inputs[0]->layout;
Layout in_layout = inputs[1]->layout;
Layout out_layout = outputs[0]->layout;
const int group = weight_layout.dims[0];
const int in_n = in_layout.dims[0];
const int in_c = in_layout.dims[1] / group;
const int in_h = in_layout.dims[2];
const int in_w = in_layout.dims[3];
const size_t N = in_h * in_w;
const int out_c = out_layout.dims[1] / group;
const int out_h = out_layout.dims[2];
const int out_w = out_layout.dims[3];
const int kh = ${kernel_h};
const int kw = ${kernel_w};
const int ph = ${pad_h};
const int pw = ${pad_w};
const int sh = ${stride_h};
const int sw = ${stride_w};
const size_t M = out_c * kh * kw;
const int LDB = N;
const int LDC = LDB;
void* workspace_ptr = workspace->ptr;
float* out_buffer = NULL;
const int need_out_buffer = (kh != 1 || kw != 1 || sh != 1 || sw != 1 || ph != 0 || pw != 0);
if (need_out_buffer) {
out_buffer = workspace->ptr;
workspace_ptr = out_buffer + M * N;
}
for (int n_idx = 0; n_idx < in_n; ++n_idx) {
float* weight_data = inputs[0]->ptr;
for (int g = 0; g < group; ++g) {
)") << m_inner_gemm.GetPackBSymbol(inner_ctx.get())
<< R"(
(workspace_ptr, input_data, LDB, 0, in_h * in_w, 0, in_c);
float* out_ptr = need_out_buffer ? out_buffer : output_data;
)" << m_inner_gemm.GetNakedKernelSymbol(inner_ctx.get())
<< StringTemplate::StringTemplateArgs()
.add("col2img", col2img)
.render(R"( (weight_data, workspace_ptr, out_ptr, LDC, M, N, in_c, NULL);
if (need_out_buffer) {
memset(output_data, 0, sizeof(float) * out_c * out_h * out_w);
${col2img}
}
weight_data += weight_layout.stride[0];
input_data += in_c * in_h * in_w;
output_data += out_c * out_h * out_w;
}
}
return TinyNN_SUCCESS;
})");

return writer.str();
}

// vim: syntax=cpp.doxygen
3 changes: 3 additions & 0 deletions compiler/lib/KernelGen/GeneralIntrinsic/KernelPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ struct AllGICommonKernel {

};

inner_map[KernelPack::KernType::ConvBackDataKernel] = {
std::make_shared<GeneralIntrinsic::Float32NchwBackward>()};

inner_map[KernelPack::KernType::PoolingKernel] = {
std::make_shared<GeneralIntrinsic::PoolingNchw88Fp16>(),
std::make_shared<GeneralIntrinsic::PoolingNchw44Fp32>(),
Expand Down
4 changes: 2 additions & 2 deletions compiler/test/kernel/common/src/workload_proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,8 @@ template <>
size_t WorkloadOprProxy<megdnn::ConvolutionBackwardData>::get_compute_workload(
megdnn::ConvolutionBackwardData* opr, const TensorNDArray& tensors) {
auto param = opr->param();
auto src_layout = tensors[0].layout;
auto filter_layout = tensors[1].layout;
auto src_layout = tensors[1].layout;
auto filter_layout = tensors[0].layout;
auto dst_layout = tensors[2].layout;
return get_conv_compute_workload(
src_layout, filter_layout, dst_layout, param.format, param.sparse);
Expand Down
27 changes: 27 additions & 0 deletions compiler/test/kernel/opr/generalIntrinsic/benchmark_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,4 +181,31 @@ TEST(GI, BenchmarkConvBiasIm2colNCHW88) {
.print();
}
#endif
TEST(GI, BenchmarkConvBackDataNCHW) {
Benchmarker<ConvolutionBackwardData> benchmarker(Arch::BAREMETAL);
benchmarker.set_kernel_symbol("GI_kernel_back_data_conv2d_.*");
ConvolutionBackwardData::Param param;
param.compute_mode = ConvolutionBackwardData::Param::ComputeMode::DEFAULT;
param.format = ConvolutionBackwardData::Param::Format::NCHW;

param.pad_h = 1;
param.pad_w = 1;
param.stride_h = 1;
param.stride_w = 1;
param.sparse = ConvBiasForward::Param::Sparse::DENSE;
benchmarker.set_param(param);
uint32_t kernel = 3;
benchmarker.execs({{8, 4, kernel, kernel}, {2, 8, 128, 128}, {2, 4, 128, 128}})
.print();

param.pad_h = 2;
param.pad_w = 2;
param.stride_h = 2;
param.stride_w = 2;
param.sparse = ConvolutionBackwardData::Param::Sparse::GROUP;
benchmarker.set_param(param);
kernel = 5;
benchmarker.execs({{4, 2, 1, kernel, kernel}, {2, 8, 128, 128}, {2, 4, 256, 256}})
.print();
}
#endif
Loading

0 comments on commit 525e463

Please sign in to comment.