diff --git a/compiler/include/compiler/Dialect/Kernel/IR/AbstractKernels.td b/compiler/include/compiler/Dialect/Kernel/IR/AbstractKernels.td index e85d65db..4598f09f 100644 --- a/compiler/include/compiler/Dialect/Kernel/IR/AbstractKernels.td +++ b/compiler/include/compiler/Dialect/Kernel/IR/AbstractKernels.td @@ -64,6 +64,10 @@ def NegateKernel: UnaryElemwiseKernel<"NEGATE">; def RoundKernel: UnaryElemwiseKernel<"ROUND">; def SILUKernel: UnaryElemwiseKernel<"SILU">; def ERFKernel: UnaryElemwiseKernel<"ERF">; +def SQRTKernel: UnaryElemwiseKernel<"SQRT">; +def SINKernel: UnaryElemwiseKernel<"SIN">; +def COSKernel: UnaryElemwiseKernel<"COS">; + def AddKernel: BinaryElemwiseKernel<"ADD">; def SubKernel: BinaryElemwiseKernel<"SUB">; diff --git a/compiler/lib/Conversion/MGBToKernel/MGBToKernel.cpp b/compiler/lib/Conversion/MGBToKernel/MGBToKernel.cpp index 0336d464..09b02266 100644 --- a/compiler/lib/Conversion/MGBToKernel/MGBToKernel.cpp +++ b/compiler/lib/Conversion/MGBToKernel/MGBToKernel.cpp @@ -218,6 +218,12 @@ class ConvertElemwise final : public OpConversionPattern { return createOp(op, operands, rewriter); case Mode::ERF: return createOp(op, operands, rewriter); + case Mode::SQRT: + return createOp(op, operands, rewriter); + case Mode::SIN: + return createOp(op, operands, rewriter); + case Mode::COS: + return createOp(op, operands, rewriter); default: CC_ABORT << "Unsupport Elemwise mode :" << static_cast(op.mode()) << "\n"; diff --git a/compiler/lib/KernelGen/BareMetal/ElemwiseKernel.cpp b/compiler/lib/KernelGen/BareMetal/ElemwiseKernel.cpp index d91da96e..083eae14 100644 --- a/compiler/lib/KernelGen/BareMetal/ElemwiseKernel.cpp +++ b/compiler/lib/KernelGen/BareMetal/ElemwiseKernel.cpp @@ -53,6 +53,12 @@ std::string gen_unary(std::string mode) { return "val / (1 + expf(-val))"; } else if (mode == "ERF") { return "erff(val)"; + } else if (mode == "SQRT") { + return "sqrtf(val)"; + } else if (mode == "SIN") { + return "sinf(val)"; + } else if (mode == "COS") { + return "cosf(val)"; } else { CC_ABORT << "not support mode " << mode.c_str() << "\n"; } @@ -414,7 +420,8 @@ bool ElmwiseKernel::IsAvailable(TContext* context) const { bool mode_ok_unary = mode == "RELU" || mode == "SIGMOID" || mode == "EXP" || mode == "NEGATE" || mode == "ROUND" || mode == "ABS" || mode == "H_SWISH" || mode == "LOG" || mode == "SILU" || - mode == "ERF"; + mode == "ERF" || mode == "SQRT" || mode == "SIN" || + mode == "COS"; bool mode_ok_binary = mode == "ADD" || mode == "SUB" || mode == "MUL" || mode == "MAX" || mode == "MIN" || mode == "LEQ" || mode == "LT" || mode == "FLOOR_DIV" || mode == "EQ" || diff --git a/compiler/test/kernel/opr/naive/elementwise.cpp b/compiler/test/kernel/opr/naive/elementwise.cpp index 9b5109ee..31b20514 100644 --- a/compiler/test/kernel/opr/naive/elementwise.cpp +++ b/compiler/test/kernel/opr/naive/elementwise.cpp @@ -13,7 +13,7 @@ TEST(NAIVE, ElementwiseUnique) { ElemwiseForward::Param param; for (auto mode : {MODE::RELU, MODE::SIGMOID, MODE::EXP, MODE::NEGATE, MODE::ROUND, - MODE::H_SWISH, MODE::ABS, MODE::ERF}) { + MODE::H_SWISH, MODE::ABS, MODE::ERF, MODE::SIN, MODE::COS}) { param.mode = mode; checker.set_param(param); checker.execs({{1}, {}}); @@ -26,6 +26,13 @@ TEST(NAIVE, ElementwiseUnique) { megcc::test::UniformRNG rng(1e-5, 3); checker.set_rng(0, &rng); checker.execs({{1, 10}, {}}); + + param.mode = MODE::SQRT; + checker.set_param(param); + checker.set_rng(0, &rng); + checker.execs({{1}, {}}); + checker.execs({{1, 10}, {}}); + checker.execs({{1, 10, 12, 13}, {}}); } } diff --git a/compiler/test/kernel/prebuild_include/megbrain/enum_reflection.h.inl b/compiler/test/kernel/prebuild_include/megbrain/enum_reflection.h.inl index dda15796..58699473 100644 --- a/compiler/test/kernel/prebuild_include/megbrain/enum_reflection.h.inl +++ b/compiler/test/kernel/prebuild_include/megbrain/enum_reflection.h.inl @@ -493,6 +493,8 @@ struct EnumTrait<::megdnn::param::Elemwise::Mode> : public std::true_type { return "OR"; case ::megdnn::param::Elemwise::Mode::XOR: return "XOR"; + case ::megdnn::param::Elemwise::Mode::SQRT: + return "SQRT"; default: return {}; }