Skip to content

Commit

Permalink
feat(kernel): add naive and arm64 QMUL/qsi32xqsi32=qsi8 elemwise mult…
Browse files Browse the repository at this point in the history
…i type kernel

GitOrigin-RevId: 46034a9e0ed2830414a5b5a633f5f7b26f0e4dbf
  • Loading branch information
megvii-mge committed Mar 22, 2024
1 parent eb91db8 commit d01fbe7
Show file tree
Hide file tree
Showing 6 changed files with 1,198 additions and 22 deletions.
1,010 changes: 1,010 additions & 0 deletions compiler/lib/KernelGen/Arm/Arm64/Elemwise/ElemwiseMultiType.cpp

Large diffs are not rendered by default.

24 changes: 24 additions & 0 deletions compiler/lib/KernelGen/Arm/Arm64/Elemwise/ElemwiseMultiType.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#pragma once
#include <sstream>
#include <string>
#include "compiler/Common/Logger.h"
#include "compiler/KernelGen/KernelGen.h"

namespace megcc {
namespace KernelGen {
namespace Arm64 {

class ElemwiseMultiTypeKernel : public KernelFunc {
public:
bool IsAvailable(TContext* context) const override;
//! kernel gen
std::string GetKernelSymbol(TContext* context) const override;

std::string GetKernelBody(TContext* context) const override;
};

} // namespace Arm64
} // namespace KernelGen
} // namespace megcc

// vim: syntax=cpp.doxygen
4 changes: 4 additions & 0 deletions compiler/lib/KernelGen/Arm/Arm64/KernelPack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include "BatchedMatmul/BatchedMatmul.h"
#include "ConvKernel.h"
#include "Elemwise/Elemwise.h"
#include "Elemwise/ElemwiseMultiType.h"
#include "InternalKernel/InternalKernel.h"
#include "KernelPack.h"
#include "MatMulKernel/MatMul.h"
Expand Down Expand Up @@ -44,6 +45,9 @@ struct AllA64Kernel {
inner_map[KernelPack::KernType::ElemwiseKernel] = {
std::make_shared<Arm64::ElemwiseKernel>()};

inner_map[KernelPack::KernType::ElemwiseMultiKernel] = {
std::make_shared<Arm64::ElemwiseMultiTypeKernel>()};

inner_map[KernelPack::KernType::BatchMatmulKernel] = {
std::make_shared<Arm64::Fp32BatchedMatmul>()};

Expand Down
91 changes: 70 additions & 21 deletions compiler/lib/KernelGen/BareMetal/ElemwiseMultiType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ std::string gen_binary(std::string mode) {
float val1 = scale_1 * val_1;
int8_t out_val = fp32_to_int8( ((val0 + val1) > 0? (val0 + val1):0) * scale_div);
)";
} else if (mode == "QMUL") {
return R"(
int8_t out_val = fp32_to_int8(val_0 * val_1 * scale_mul);
)";
} else {
CC_ABORT << "not support mode " << mode.c_str() << "\n";
}
Expand All @@ -57,9 +61,34 @@ bool ElemwiseMultiTypeKernel::IsAvailable(TContext* context) const {
auto nr_operands = context->getAttrInt("nr_operands");
bool nr_operands_ok = nr_operands == 2 || nr_operands == 3;
bool mode_ok_unary = nr_operands == 2 && mode == "QRELU";
bool dtype_ok_unary =
nr_operands == 2 &&
Utils::is_quant_dtype(context->getAttrOprand("operand:0").dtype) &&
Utils::is_quant_dtype(context->getAttrOprand("operand:1").dtype, 8);
bool mode_ok_binary =
nr_operands == 3 && (mode == "QADD" || mode == "QFUSE_ADD_RELU");
return nr_operands_ok && (mode_ok_unary || mode_ok_binary);
nr_operands == 3 &&
(mode == "QADD" || mode == "QFUSE_ADD_RELU" || mode == "QMUL");
bool dtype_ok_binary =
nr_operands == 3 &&
Utils::is_quant_dtype(context->getAttrOprand("operand:0").dtype) &&
Utils::is_quant_dtype(context->getAttrOprand("operand:1").dtype) &&
Utils::is_quant_dtype(context->getAttrOprand("operand:2").dtype, 8);
const auto& op0_shape = context->getAttrOprand("operand:0").shape;
const auto& op1_shape = context->getAttrOprand("operand:1").shape;
size_t op1_nr_elem = 1;
for (auto dim : op1_shape) {
op1_nr_elem *= dim;
}
//! broadcast mode 0: op0 shape: (a, b, c, d, ...), op1 shape: (1, b, 1, 1, ...)
//! broadcast mode 1: op0 shape: (a, b, c, d, ...), op1_nr_elem = 1
bool shape_ok_binary =
nr_operands == 3 &&
((op0_shape == op1_shape) ||
(op0_shape.size() == op1_shape.size() && op0_shape.size() > 2 &&
op0_shape[1] == op1_shape[1] && op1_nr_elem == op1_shape[1]) ||
(op1_nr_elem == 1));
return nr_operands_ok && ((mode_ok_unary && dtype_ok_unary) ||
(mode_ok_binary && dtype_ok_binary && shape_ok_binary));
}

std::string ElemwiseMultiTypeKernel::GetKernelSymbol(TContext* context) const {
Expand All @@ -80,12 +109,9 @@ std::string ElemwiseMultiTypeKernel::GetKernelBody(TContext* context) const {
if (context->getAttrInt("nr_operands") == 2) {
auto op0 = context->getAttrOprand("operand:0");
auto dst = context->getAttrOprand("operand:1");
CC_ASSERT(
Utils::is_quant_dtype(op0.dtype, 8) &&
Utils::is_quant_dtype(dst.dtype, 8));
auto op0_specifier = Utils::cvt_dtype_specifier(op0.dtype);
auto dst_specifier = Utils::cvt_dtype_specifier(dst.dtype);
std::string binary_str = R"({
std::string unary_str = R"({
${op0_specifier}* input_0 = (${op0_specifier}*)inputs[0]->ptr;
float scale_0 = inputs[0]->dtype.param.scale;
TINYNN_ASSERT(input_0);
Expand All @@ -111,16 +137,11 @@ std::string ElemwiseMultiTypeKernel::GetKernelBody(TContext* context) const {
.add("op0_specifier", op0_specifier)
.add("dst_specifier", dst_specifier)
.add("act", gen_unary(mode))
.render(binary_str);
.render(unary_str);
} else if (context->getAttrInt("nr_operands") == 3) {
auto op0 = context->getAttrOprand("operand:0");
auto op1 = context->getAttrOprand("operand:1");
auto dst = context->getAttrOprand("operand:2");
CC_ASSERT(
Utils::is_quant_dtype(op0.dtype, 8) &&
Utils::is_quant_dtype(op1.dtype, 8) &&
Utils::is_quant_dtype(dst.dtype, 8));
CC_ASSERT(op0.shape == op1.shape) << "no support broadcast\n";
auto op0_specifier = Utils::cvt_dtype_specifier(op0.dtype);
auto op1_specifier = Utils::cvt_dtype_specifier(op1.dtype);
auto dst_specifier = Utils::cvt_dtype_specifier(dst.dtype);
Expand All @@ -135,17 +156,45 @@ std::string ElemwiseMultiTypeKernel::GetKernelBody(TContext* context) const {
float scale_dst = outputs[0]->dtype.param.scale;
TINYNN_ASSERT(output_data);
float scale_div = 1.f / scale_dst;
float scale_mul = scale_0 * scale_1 * scale_div;
Layout in_layout = inputs[0]->layout;
size_t nr_elem = 1;
for (int i = 0; i < in_layout.nr_dim; ++i) {
nr_elem *= in_layout.dims[i];
Layout in_layout0 = inputs[0]->layout;
size_t nr_elem0 = 1;
for (int i = 0; i < in_layout0.nr_dim; ++i) {
nr_elem0 *= in_layout0.dims[i];
}
for(size_t i = 0; i < nr_elem; ++i){
${op0_specifier} val_0 = input_0[i];
${op1_specifier} val_1 = input_1[i];
${act};
output_data[i] = out_val;
Layout in_layout1 = inputs[1]->layout;
size_t nr_elem1 = 1;
for (int i = 0; i < in_layout1.nr_dim; ++i) {
nr_elem1 *= in_layout1.dims[i];
}
if (nr_elem0 == nr_elem1) {
for(size_t i = 0; i < nr_elem0; ++i){
${op0_specifier} val_0 = input_0[i];
${op1_specifier} val_1 = input_1[i];
${act};
output_data[i] = out_val;
}
} else if (nr_elem1 == 1) {
${op1_specifier} val_1 = input_1[0];
for(size_t i = 0; i < nr_elem0; ++i){
${op0_specifier} val_0 = input_0[i];
${act};
output_data[i] = out_val;
}
} else {
TINYNN_ASSERT(nr_elem0 > nr_elem1);
for (int i = 0; i < in_layout0.dims[0]; ++i) {
for (int j = 0; j < in_layout0.dims[1]; ++j) {
${op1_specifier} val_1 = input_1[j];
for (int k = 0; k < in_layout0.stride[1]; ++k) {
int idx = i * in_layout0.stride[0] + j * in_layout0.stride[1] + k;
${op0_specifier} val_0 = input_0[idx];
${act};
output_data[idx] = out_val;
}
}
}
}
return TinyNN_SUCCESS;
}
Expand Down
63 changes: 63 additions & 0 deletions compiler/test/kernel/opr/arm/elemwise_multitype.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#include "test/kernel/common/checker.h"
using namespace megdnn;
using namespace megcc::test;
using MODE = ElemwiseMultiType::Param::Mode;

TEST(AARCH64, ElementwiseMultitypeUnary) {
Checker<ElemwiseMultiType> checker(megcc::KernelGen::Arch::ARM64);
checker.set_kernel_symbol("Arm64_kernel_.*");
checker.set_epsilon(1e-4);
checker.set_dtype(0, dtype::QuantizedS32(1.f));
checker.set_dtype(1, dtype::QuantizedS8(2.f));
ElemwiseMultiType::Param param;
for (auto mode : {MODE::QRELU}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {}});
checker.execs({{1, 33}, {}});
checker.execs({{1, 10, 12, 13}, {}});
}

checker.set_dtype(0, dtype::QuantizedS8(1.f));
checker.set_dtype(1, dtype::QuantizedS8(3.f));
for (auto mode : {MODE::QRELU}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {}});
checker.execs({{1, 33}, {}});
checker.execs({{1, 10, 12, 13}, {}});
}
}

TEST(AARCH64, ElementwiseMultitypeBinary) {
Checker<ElemwiseMultiType> checker(megcc::KernelGen::Arch::ARM64);
checker.set_kernel_symbol("Arm64_kernel_.*");
checker.set_epsilon(1e-4);
checker.set_dtype(0, dtype::QuantizedS8(1.f));
checker.set_dtype(1, dtype::QuantizedS8(2.f));
checker.set_dtype(2, dtype::QuantizedS8(3.f));
ElemwiseMultiType::Param param;

for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU, MODE::QMUL}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {1}, {}});
checker.execs({{1, 18}, {1, 18}, {}});
checker.execs({{2, 3, 4, 5}, {2, 3, 4, 5}, {}});
checker.execs({{2, 3, 4, 5}, {1, 3, 1, 1}, {}});
checker.execs({{2, 3, 4, 5}, {1}, {}});
}

checker.set_dtype(0, dtype::QuantizedS32(0.73f));
checker.set_dtype(1, dtype::QuantizedS32(2.21f));

for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU, MODE::QMUL}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {1}, {}});
checker.execs({{1, 18}, {1, 18}, {}});
checker.execs({{2, 3, 4, 5}, {2, 3, 4, 5}, {}});
checker.execs({{2, 3, 4, 5}, {1, 3, 1, 1}, {}});
checker.execs({{2, 3, 4, 5}, {1}, {}});
}
}
28 changes: 27 additions & 1 deletion compiler/test/kernel/opr/naive/elemwise_multitype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,17 @@ TEST(NAIVE, ElementwiseMultitypeUnary) {
checker.execs({{1, 10}, {}});
checker.execs({{1, 10, 12, 13}, {}});
}

checker.set_dtype(0, dtype::QuantizedS32(1.f));
checker.set_dtype(1, dtype::QuantizedS8(2.f));

for (auto mode : {MODE::QRELU}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {}});
checker.execs({{1, 10}, {}});
checker.execs({{1, 10, 12, 13}, {}});
}
}

TEST(NAIVE, ElementwiseMultitypeBinary) {
Expand All @@ -28,11 +39,26 @@ TEST(NAIVE, ElementwiseMultitypeBinary) {
checker.set_dtype(2, dtype::QuantizedS8(3.f));
ElemwiseMultiType::Param param;

for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU}) {
for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU, MODE::QMUL}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {1}, {}});
checker.execs({{1, 10}, {1, 10}, {}});
checker.execs({{2, 3, 4, 5}, {2, 3, 4, 5}, {}});
checker.execs({{2, 3, 4, 5}, {1, 3, 1, 1}, {}});
checker.execs({{2, 3, 4, 5}, {1}, {}});
}

checker.set_dtype(0, dtype::QuantizedS32(1.f));
checker.set_dtype(1, dtype::QuantizedS32(2.f));

for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU, MODE::QMUL}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {1}, {}});
checker.execs({{1, 10}, {1, 10}, {}});
checker.execs({{2, 3, 4, 5}, {2, 3, 4, 5}, {}});
checker.execs({{2, 3, 4, 5}, {1, 3, 1, 1}, {}});
checker.execs({{2, 3, 4, 5}, {1}, {}});
}
}

0 comments on commit d01fbe7

Please sign in to comment.