-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(kernel): add naive and arm64 QMUL/qsi32xqsi32=qsi8 elemwise mult…
…i type kernel GitOrigin-RevId: 46034a9e0ed2830414a5b5a633f5f7b26f0e4dbf
- Loading branch information
1 parent
eb91db8
commit d01fbe7
Showing
6 changed files
with
1,198 additions
and
22 deletions.
There are no files selected for viewing
1,010 changes: 1,010 additions & 0 deletions
1,010
compiler/lib/KernelGen/Arm/Arm64/Elemwise/ElemwiseMultiType.cpp
Large diffs are not rendered by default.
Oops, something went wrong.
24 changes: 24 additions & 0 deletions
24
compiler/lib/KernelGen/Arm/Arm64/Elemwise/ElemwiseMultiType.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
#pragma once | ||
#include <sstream> | ||
#include <string> | ||
#include "compiler/Common/Logger.h" | ||
#include "compiler/KernelGen/KernelGen.h" | ||
|
||
namespace megcc { | ||
namespace KernelGen { | ||
namespace Arm64 { | ||
|
||
class ElemwiseMultiTypeKernel : public KernelFunc { | ||
public: | ||
bool IsAvailable(TContext* context) const override; | ||
//! kernel gen | ||
std::string GetKernelSymbol(TContext* context) const override; | ||
|
||
std::string GetKernelBody(TContext* context) const override; | ||
}; | ||
|
||
} // namespace Arm64 | ||
} // namespace KernelGen | ||
} // namespace megcc | ||
|
||
// vim: syntax=cpp.doxygen |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#include "test/kernel/common/checker.h" | ||
using namespace megdnn; | ||
using namespace megcc::test; | ||
using MODE = ElemwiseMultiType::Param::Mode; | ||
|
||
TEST(AARCH64, ElementwiseMultitypeUnary) { | ||
Checker<ElemwiseMultiType> checker(megcc::KernelGen::Arch::ARM64); | ||
checker.set_kernel_symbol("Arm64_kernel_.*"); | ||
checker.set_epsilon(1e-4); | ||
checker.set_dtype(0, dtype::QuantizedS32(1.f)); | ||
checker.set_dtype(1, dtype::QuantizedS8(2.f)); | ||
ElemwiseMultiType::Param param; | ||
for (auto mode : {MODE::QRELU}) { | ||
param.mode = mode; | ||
checker.set_param(param); | ||
checker.execs({{1}, {}}); | ||
checker.execs({{1, 33}, {}}); | ||
checker.execs({{1, 10, 12, 13}, {}}); | ||
} | ||
|
||
checker.set_dtype(0, dtype::QuantizedS8(1.f)); | ||
checker.set_dtype(1, dtype::QuantizedS8(3.f)); | ||
for (auto mode : {MODE::QRELU}) { | ||
param.mode = mode; | ||
checker.set_param(param); | ||
checker.execs({{1}, {}}); | ||
checker.execs({{1, 33}, {}}); | ||
checker.execs({{1, 10, 12, 13}, {}}); | ||
} | ||
} | ||
|
||
TEST(AARCH64, ElementwiseMultitypeBinary) { | ||
Checker<ElemwiseMultiType> checker(megcc::KernelGen::Arch::ARM64); | ||
checker.set_kernel_symbol("Arm64_kernel_.*"); | ||
checker.set_epsilon(1e-4); | ||
checker.set_dtype(0, dtype::QuantizedS8(1.f)); | ||
checker.set_dtype(1, dtype::QuantizedS8(2.f)); | ||
checker.set_dtype(2, dtype::QuantizedS8(3.f)); | ||
ElemwiseMultiType::Param param; | ||
|
||
for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU, MODE::QMUL}) { | ||
param.mode = mode; | ||
checker.set_param(param); | ||
checker.execs({{1}, {1}, {}}); | ||
checker.execs({{1, 18}, {1, 18}, {}}); | ||
checker.execs({{2, 3, 4, 5}, {2, 3, 4, 5}, {}}); | ||
checker.execs({{2, 3, 4, 5}, {1, 3, 1, 1}, {}}); | ||
checker.execs({{2, 3, 4, 5}, {1}, {}}); | ||
} | ||
|
||
checker.set_dtype(0, dtype::QuantizedS32(0.73f)); | ||
checker.set_dtype(1, dtype::QuantizedS32(2.21f)); | ||
|
||
for (auto mode : {MODE::QADD, MODE::QFUSE_ADD_RELU, MODE::QMUL}) { | ||
param.mode = mode; | ||
checker.set_param(param); | ||
checker.execs({{1}, {1}, {}}); | ||
checker.execs({{1, 18}, {1, 18}, {}}); | ||
checker.execs({{2, 3, 4, 5}, {2, 3, 4, 5}, {}}); | ||
checker.execs({{2, 3, 4, 5}, {1, 3, 1, 1}, {}}); | ||
checker.execs({{2, 3, 4, 5}, {1}, {}}); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters