Skip to content

Commit

Permalink
fix(kernel): add some elemwise opr mode
Browse files Browse the repository at this point in the history
GitOrigin-RevId: 996f257ee19ed77445f0b662aad957b127a6c0fe
  • Loading branch information
megvii-mge committed Sep 8, 2023
1 parent f8d1b8f commit eabc1d9
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def NegateKernel: UnaryElemwiseKernel<"NEGATE">;
def RoundKernel: UnaryElemwiseKernel<"ROUND">;
def SILUKernel: UnaryElemwiseKernel<"SILU">;
def ERFKernel: UnaryElemwiseKernel<"ERF">;
def SQRTKernel: UnaryElemwiseKernel<"SQRT">;
def SINKernel: UnaryElemwiseKernel<"SIN">;
def COSKernel: UnaryElemwiseKernel<"COS">;


def AddKernel: BinaryElemwiseKernel<"ADD">;
def SubKernel: BinaryElemwiseKernel<"SUB">;
Expand Down
6 changes: 6 additions & 0 deletions compiler/lib/Conversion/MGBToKernel/MGBToKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ class ConvertElemwise final : public OpConversionPattern<MGB::Elemwise> {
return createOp<Kernel::SILUKernel>(op, operands, rewriter);
case Mode::ERF:
return createOp<Kernel::ERFKernel>(op, operands, rewriter);
case Mode::SQRT:
return createOp<Kernel::SQRTKernel>(op, operands, rewriter);
case Mode::SIN:
return createOp<Kernel::SINKernel>(op, operands, rewriter);
case Mode::COS:
return createOp<Kernel::COSKernel>(op, operands, rewriter);
default:
CC_ABORT << "Unsupport Elemwise mode :" << static_cast<int>(op.mode())
<< "\n";
Expand Down
9 changes: 8 additions & 1 deletion compiler/lib/KernelGen/BareMetal/ElemwiseKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ std::string gen_unary(std::string mode) {
return "val / (1 + expf(-val))";
} else if (mode == "ERF") {
return "erff(val)";
} else if (mode == "SQRT") {
return "sqrtf(val)";
} else if (mode == "SIN") {
return "sinf(val)";
} else if (mode == "COS") {
return "cosf(val)";
} else {
CC_ABORT << "not support mode " << mode.c_str() << "\n";
}
Expand Down Expand Up @@ -414,7 +420,8 @@ bool ElmwiseKernel::IsAvailable(TContext* context) const {
bool mode_ok_unary = mode == "RELU" || mode == "SIGMOID" || mode == "EXP" ||
mode == "NEGATE" || mode == "ROUND" || mode == "ABS" ||
mode == "H_SWISH" || mode == "LOG" || mode == "SILU" ||
mode == "ERF";
mode == "ERF" || mode == "SQRT" || mode == "SIN" ||
mode == "COS";
bool mode_ok_binary = mode == "ADD" || mode == "SUB" || mode == "MUL" ||
mode == "MAX" || mode == "MIN" || mode == "LEQ" ||
mode == "LT" || mode == "FLOOR_DIV" || mode == "EQ" ||
Expand Down
9 changes: 8 additions & 1 deletion compiler/test/kernel/opr/naive/elementwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ TEST(NAIVE, ElementwiseUnique) {
ElemwiseForward::Param param;
for (auto mode :
{MODE::RELU, MODE::SIGMOID, MODE::EXP, MODE::NEGATE, MODE::ROUND,
MODE::H_SWISH, MODE::ABS, MODE::ERF}) {
MODE::H_SWISH, MODE::ABS, MODE::ERF, MODE::SIN, MODE::COS}) {
param.mode = mode;
checker.set_param(param);
checker.execs({{1}, {}});
Expand All @@ -26,6 +26,13 @@ TEST(NAIVE, ElementwiseUnique) {
megcc::test::UniformRNG rng(1e-5, 3);
checker.set_rng(0, &rng);
checker.execs({{1, 10}, {}});

param.mode = MODE::SQRT;
checker.set_param(param);
checker.set_rng(0, &rng);
checker.execs({{1}, {}});
checker.execs({{1, 10}, {}});
checker.execs({{1, 10, 12, 13}, {}});
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,8 @@ struct EnumTrait<::megdnn::param::Elemwise::Mode> : public std::true_type {
return "OR";
case ::megdnn::param::Elemwise::Mode::XOR:
return "XOR";
case ::megdnn::param::Elemwise::Mode::SQRT:
return "SQRT";
default:
return {};
}
Expand Down

0 comments on commit eabc1d9

Please sign in to comment.