Skip to content

Commit

Permalink
feat(kernel): add naive QMUL/qsi32xqsi32=qsi8 elemwise multi type kernel
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 622b670c16e07781741c07c114f8d0bd6a70837b
  • Loading branch information
megvii-mge committed Mar 22, 2024
1 parent bebf0bb commit 49cb5fd
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 20 deletions.
87 changes: 68 additions & 19 deletions compiler/lib/KernelGen/BareMetal/ElemwiseMultiType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ std::string gen_binary(std::string mode) {
float val1 = scale_1 * val_1;
int8_t out_val = fp32_to_int8( ((val0 + val1) > 0? (val0 + val1):0) * scale_div);
)";
} else if (mode == "QMUL") {
return R"(
int8_t out_val = fp32_to_int8(val_0 * val_1 * scale_mul);
)";
} else {
CC_ABORT << "not support mode " << mode.c_str() << "\n";
}
Expand All @@ -57,9 +61,34 @@ bool ElemwiseMultiTypeKernel::IsAvailable(TContext* context) const {
auto nr_operands = context->getAttrInt("nr_operands");
bool nr_operands_ok = nr_operands == 2 || nr_operands == 3;
bool mode_ok_unary = nr_operands == 2 && mode == "QRELU";
bool dtype_ok_unary =
nr_operands == 2 &&
Utils::is_quant_dtype(context->getAttrOprand("operand:0").dtype) &&
Utils::is_quant_dtype(context->getAttrOprand("operand:1").dtype, 8);
bool mode_ok_binary =
nr_operands == 3 && (mode == "QADD" || mode == "QFUSE_ADD_RELU");
return nr_operands_ok && (mode_ok_unary || mode_ok_binary);
nr_operands == 3 &&
(mode == "QADD" || mode == "QFUSE_ADD_RELU" || mode == "QMUL");
bool dtype_ok_binary =
nr_operands == 3 &&
Utils::is_quant_dtype(context->getAttrOprand("operand:0").dtype) &&
Utils::is_quant_dtype(context->getAttrOprand("operand:1").dtype) &&
Utils::is_quant_dtype(context->getAttrOprand("operand:2").dtype, 8);
const auto& op0_shape = context->getAttrOprand("operand:0").shape;
const auto& op1_shape = context->getAttrOprand("operand:1").shape;
size_t op1_nr_elem = 1;
for (auto dim : op1_shape) {
op1_nr_elem *= dim;
}
//! broadcast mode 0: op0 shape: (a, b, c, d, ...), op1 shape: (1, b, 1, 1, ...)
//! broadcast mode 1: op0 shape: (a, b, c, d, ...), op1_nr_elem = 1
bool shape_ok_binary =
nr_operands == 3 &&
((op0_shape == op1_shape) ||
(op0_shape.size() == op1_shape.size() && op0_shape.size() > 2 &&
op0_shape[1] == op1_shape[1] && op1_nr_elem == op1_shape[1]) ||
(op1_nr_elem == 1));
return nr_operands_ok && ((mode_ok_unary && dtype_ok_unary) ||
(mode_ok_binary && dtype_ok_binary && shape_ok_binary));
}

std::string ElemwiseMultiTypeKernel::GetKernelSymbol(TContext* context) const {
Expand All @@ -80,9 +109,6 @@ std::string ElemwiseMultiTypeKernel::GetKernelBody(TContext* context) const {
if (context->getAttrInt("nr_operands") == 2) {
auto op0 = context->getAttrOprand("operand:0");
auto dst = context->getAttrOprand("operand:1");
CC_ASSERT(
Utils::is_quant_dtype(op0.dtype, 8) &&
Utils::is_quant_dtype(dst.dtype, 8));
auto op0_specifier = Utils::cvt_dtype_specifier(op0.dtype);
auto dst_specifier = Utils::cvt_dtype_specifier(dst.dtype);
std::string binary_str = R"({
Expand Down Expand Up @@ -116,11 +142,6 @@ std::string ElemwiseMultiTypeKernel::GetKernelBody(TContext* context) const {
auto op0 = context->getAttrOprand("operand:0");
auto op1 = context->getAttrOprand("operand:1");
auto dst = context->getAttrOprand("operand:2");
CC_ASSERT(
Utils::is_quant_dtype(op0.dtype, 8) &&
Utils::is_quant_dtype(op1.dtype, 8) &&
Utils::is_quant_dtype(dst.dtype, 8));
CC_ASSERT(op0.shape == op1.shape) << "no support broadcast\n";
auto op0_specifier = Utils::cvt_dtype_specifier(op0.dtype);
auto op1_specifier = Utils::cvt_dtype_specifier(op1.dtype);
auto dst_specifier = Utils::cvt_dtype_specifier(dst.dtype);
Expand All @@ -135,17 +156,45 @@ std::string ElemwiseMultiTypeKernel::GetKernelBody(TContext* context) const {
float scale_dst = outputs[0]->dtype.param.scale;
TINYNN_ASSERT(output_data);
float scale_div = 1.f / scale_dst;
float scale_mul = scale_0 * scale_1 * scale_div;
Layout in_layout = inputs[0]->layout;
size_t nr_elem = 1;
for (int i = 0; i < in_layout.nr_dim; ++i) {
nr_elem *= in_layout.dims[i];
Layout in_layout0 = inputs[0]->layout;
size_t nr_elem0 = 1;
for (int i = 0; i < in_layout0.nr_dim; ++i) {
nr_elem0 *= in_layout0.dims[i];
}
for(size_t i = 0; i < nr_elem; ++i){
${op0_specifier} val_0 = input_0[i];
${op1_specifier} val_1 = input_1[i];
${act};
output_data[i] = out_val;
Layout in_layout1 = inputs[1]->layout;
size_t nr_elem1 = 1;
for (int i = 0; i < in_layout1.nr_dim; ++i) {
nr_elem1 *= in_layout1.dims[i];
}
if (nr_elem0 == nr_elem1) {
for(size_t i = 0; i < nr_elem0; ++i){
${op0_specifier} val_0 = input_0[i];
${op1_specifier} val_1 = input_1[i];
${act};
output_data[i] = out_val;
}
} else if (nr_elem1 == 1) {
${op1_specifier} val_1 = input_1[0];
for(size_t i = 0; i < nr_elem0; ++i){
${op0_specifier} val_0 = input_0[i];
${act};
output_data[i] = out_val;
}
} else {
TINYNN_ASSERT(nr_elem0 > nr_elem1);
for (int i = 0; i < in_layout0.dims[0]; ++i) {
for (int j = 0; j < in_layout0.dims[1]; ++j) {
${op1_specifier} val_1 = input_1[j];
for (int k = 0; k < in_layout0.stride[1]; ++k) {
int idx = i * in_layout0.stride[0] + j * in_layout0.stride[1] + k;
${op0_specifier} val_0 = input_0[idx];
${act};
output_data[idx] = out_val;
}
}
}
}
return TinyNN_SUCCESS;
}
Expand Down
17 changes: 16 additions & 1 deletion compiler/test/kernel/opr/naive/elemwise_multitype.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,26 @@ TEST(NAIVE, ElementwiseMultitypeBinary) {
checker.set_dtype(2, dtype::QuantizedS8(3.f));
ElemwiseMultiType::Param param;

for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU}) {
for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU, MODE::QMUL}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {1}, {}});
checker.execs({{1, 10}, {1, 10}, {}});
checker.execs({{2, 3, 4, 5}, {2, 3, 4, 5}, {}});
checker.execs({{2, 3, 4, 5}, {1, 3, 1, 1}, {}});
checker.execs({{2, 3, 4, 5}, {1}, {}});
}

checker.set_dtype(0, dtype::QuantizedS32(1.f));
checker.set_dtype(1, dtype::QuantizedS32(2.f));

for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU, MODE::QMUL}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {1}, {}});
checker.execs({{1, 10}, {1, 10}, {}});
checker.execs({{2, 3, 4, 5}, {2, 3, 4, 5}, {}});
checker.execs({{2, 3, 4, 5}, {1, 3, 1, 1}, {}});
checker.execs({{2, 3, 4, 5}, {1}, {}});
}
}

0 comments on commit 49cb5fd

Please sign in to comment.