diff --git a/docs/compile/cmake.md b/docs/compile/cmake.md index b3272e13a..b2d24a6b3 100644 --- a/docs/compile/cmake.md +++ b/docs/compile/cmake.md @@ -44,6 +44,7 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_CUDA | 是否构建`Cuda`后端,默认为`OFF` | | MNN_CUDA_PROFILE | 是否打开CUDA profile工具,默认为`OFF` | | MNN_CUDA_QUANT | 是否打开CUDA 量化文件编译,默认为`OFF` | +| MNN_CUDA_BF16 | 是否打开CUDA Bf16文件编译,默认为`OFF` | | MNN_TENSORRT | 是否构建`TensorRT`后端,默认为`OFF` | | MNN_COREML | 是否构建`CoreML`后端,默认为`OFF` | | MNN_NNAPI | 是否构建`NNAPI`后端,默认为`OFF` | diff --git a/docs/index.rst b/docs/index.rst index 3330d4b79..ac2730945 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -47,6 +47,7 @@ :maxdepth: 1 :caption: 表达式 :name: expr + inference/expr .. toctree:: diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h index fbf425f47..8b8005baa 100644 --- a/include/MNN/MNNDefine.h +++ b/include/MNN/MNNDefine.h @@ -68,7 +68,7 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ #define STR_IMP(x) #x #define STR(x) STR_IMP(x) #define MNN_VERSION_MAJOR 2 -#define MNN_VERSION_MINOR 5 -#define MNN_VERSION_PATCH 3 +#define MNN_VERSION_MINOR 6 +#define MNN_VERSION_PATCH 0 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) #endif /* MNNDefine_h */ diff --git a/project/ios/MNN.xcodeproj/project.pbxproj b/project/ios/MNN.xcodeproj/project.pbxproj index 9a464b851..64f45f266 100644 --- a/project/ios/MNN.xcodeproj/project.pbxproj +++ b/project/ios/MNN.xcodeproj/project.pbxproj @@ -763,6 +763,8 @@ C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */ = {isa = PBXBuildFile; fileRef = C4F906B127688C3A0026B847 /* NMSModule.hpp */; }; C4F906B427688C3A0026B847 /* NMSModule.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C4F906B227688C3A0026B847 /* NMSModule.cpp */; }; C4FB6CB22769DF0800963B07 /* GeometryCumSum.cpp in Sources */ = {isa = PBXBuildFile; fileRef = C4FB6CB12769DF0800963B07 /* GeometryCumSum.cpp */; }; + CE125CC82A52BF6B003698C9 /* MNNBilinearSampleC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */; }; + CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */; }; CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE7DBFFF28E2DE6B00797689 /* ShapeConvTranspose3D.cpp */; }; CE9AFED628E54E3300566949 /* CPUInterp3D.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CE9AFED428E54E3300566949 /* CPUInterp3D.cpp */; }; CE9AFED728E54E3300566949 /* CPUInterp3D.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CE9AFED528E54E3300566949 /* CPUInterp3D.hpp */; }; @@ -785,9 +787,7 @@ CEDB211C2846D59C00AE9DC4 /* mobilenet_v2.caffe.mnn in Resources */ = {isa = PBXBuildFile; fileRef = CEDB211B2846D59C00AE9DC4 /* mobilenet_v2.caffe.mnn */; }; CEDB211D284706F900AE9DC4 /* MNN.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 0F1465B71FA18D1000F9860A /* MNN.framework */; }; CEDB211E2847070600AE9DC4 /* MNN.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 0F1465B71FA18D1000F9860A /* MNN.framework */; }; - CEE9B9522A3AA4C4006438F2 /* MNNBilinearSampleC16.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B94E2A3AA4C4006438F2 /* MNNBilinearSampleC16.S */; }; CEE9B9532A3AA4C4006438F2 /* MNNCubicLineC16.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B94F2A3AA4C4006438F2 /* MNNCubicLineC16.S */; }; - CEE9B9542A3AA4C4006438F2 /* MNNBilinearLineC16.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9502A3AA4C4006438F2 /* MNNBilinearLineC16.S */; }; CEE9B9552A3AA4C4006438F2 /* MNNCubicSampleC16.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9512A3AA4C4006438F2 /* MNNCubicSampleC16.S */; }; CEE9B95A2A3AA4D4006438F2 /* MNNCubicLineC16.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9562A3AA4D4006438F2 /* MNNCubicLineC16.S */; }; CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9572A3AA4D4006438F2 /* MNNBilinearLineC8.S */; }; @@ -1590,6 +1590,8 @@ C4F906B127688C3A0026B847 /* NMSModule.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = NMSModule.hpp; sourceTree = ""; }; C4F906B227688C3A0026B847 /* NMSModule.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = NMSModule.cpp; sourceTree = ""; }; C4FB6CB12769DF0800963B07 /* GeometryCumSum.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryCumSum.cpp; sourceTree = ""; }; + CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC8.S; sourceTree = ""; }; + CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = ""; }; CE7DBFFF28E2DE6B00797689 /* ShapeConvTranspose3D.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ShapeConvTranspose3D.cpp; sourceTree = ""; }; CE9AFED428E54E3300566949 /* CPUInterp3D.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUInterp3D.cpp; sourceTree = ""; }; CE9AFED528E54E3300566949 /* CPUInterp3D.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPUInterp3D.hpp; sourceTree = ""; }; @@ -1614,9 +1616,7 @@ CEDB21172846D58200AE9DC4 /* testcat.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; name = testcat.jpg; path = ../../../demo/model/MobileNet/testcat.jpg; sourceTree = ""; }; CEDB21182846D58200AE9DC4 /* synset_words.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; name = synset_words.txt; path = ../../../demo/model/MobileNet/synset_words.txt; sourceTree = ""; }; CEDB211B2846D59C00AE9DC4 /* mobilenet_v2.caffe.mnn */ = {isa = PBXFileReference; lastKnownFileType = file; name = mobilenet_v2.caffe.mnn; path = ../../../resource/model/MobileNet/v2/mobilenet_v2.caffe.mnn; sourceTree = ""; }; - CEE9B94E2A3AA4C4006438F2 /* MNNBilinearSampleC16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC16.S; sourceTree = ""; }; CEE9B94F2A3AA4C4006438F2 /* MNNCubicLineC16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNCubicLineC16.S; sourceTree = ""; }; - CEE9B9502A3AA4C4006438F2 /* MNNBilinearLineC16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC16.S; sourceTree = ""; }; CEE9B9512A3AA4C4006438F2 /* MNNCubicSampleC16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNCubicSampleC16.S; sourceTree = ""; }; CEE9B9562A3AA4D4006438F2 /* MNNCubicLineC16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNCubicLineC16.S; sourceTree = ""; }; CEE9B9572A3AA4D4006438F2 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = ""; }; @@ -2501,8 +2501,8 @@ 92FF013A23AA0B4E00AC97F6 /* arm32 */ = { isa = PBXGroup; children = ( - CEE9B9502A3AA4C4006438F2 /* MNNBilinearLineC16.S */, - CEE9B94E2A3AA4C4006438F2 /* MNNBilinearSampleC16.S */, + CE125CC72A52BF6B003698C9 /* MNNBilinearLineC8.S */, + CE125CC62A52BF6B003698C9 /* MNNBilinearSampleC8.S */, CEE9B94F2A3AA4C4006438F2 /* MNNCubicLineC16.S */, CEE9B9512A3AA4C4006438F2 /* MNNCubicSampleC16.S */, 950B28DF29F627E00002F454 /* MNNBinaryAddInt8.S */, @@ -3356,6 +3356,7 @@ 950B28ED29F627F70002F454 /* MNNBinaryMulInt8.S in Sources */, 481FA853259C27E00047F01F /* ShapeTensorArray.cpp in Sources */, 6A131E3F25823349002EC3D6 /* PluginShapeInference.cpp in Sources */, + CE125CC82A52BF6B003698C9 /* MNNBilinearSampleC8.S in Sources */, 92FF025723AA0B5A00AC97F6 /* CPUQuanConvolutionDepthwise.cpp in Sources */, 48034563254157CE004738E3 /* MNNNV21ToBGRAUnit.S in Sources */, 48FA474823AA127B00172C3B /* Expr.cpp in Sources */, @@ -3375,7 +3376,6 @@ 48747D61245D9E33000B9709 /* ConvertUtils.cpp in Sources */, 92FF043B23AA0B7100AC97F6 /* ShapeDetectionPostProcess.cpp in Sources */, 48417FF124D13BF50056D9A7 /* GeometryELU.cpp in Sources */, - CEE9B9522A3AA4C4006438F2 /* MNNBilinearSampleC16.S in Sources */, 48C84B9A250F720C00EE7666 /* CPULayerNorm.cpp in Sources */, 4DF87C4A2887D3560003E2D4 /* calib3d.cpp in Sources */, 48F34734273A7C8400C45394 /* ImageProcessFunction.cpp in Sources */, @@ -3515,6 +3515,7 @@ CECF8C7D299CAD9400D3875B /* md5.c in Sources */, 92FF041923AA0B7100AC97F6 /* ShapeQuantizedMaxPool.cpp in Sources */, 92FF038A23AA0B5A00AC97F6 /* CPURange.cpp in Sources */, + CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */, 92FF03A123AA0B5A00AC97F6 /* Int8FunctionsOpt.cpp in Sources */, 92FF026523AA0B5A00AC97F6 /* CPUQuantizedAvgPool.cpp in Sources */, 92FF029423AA0B5A00AC97F6 /* CPUMatMul.cpp in Sources */, @@ -3555,7 +3556,6 @@ 482BFBD028351BA1009210E4 /* AllShader.cpp in Sources */, 92FF04BA23AA0BFB00AC97F6 /* WrapExecution.cpp in Sources */, 11A01A06258785EA00745FA7 /* MNNVectorTop1Int32.S in Sources */, - CEE9B9542A3AA4C4006438F2 /* MNNBilinearLineC16.S in Sources */, 48FB9DC124A8445A008E1A2D /* MNNAxByClampBroadcastC4.S in Sources */, EBD4842F2485FF660083CE95 /* Arm82Interp.cpp in Sources */, 4819FB3B24C69E680050BD09 /* GeometrySpatialProduct.cpp in Sources */, diff --git a/pymnn/src/util.h b/pymnn/src/util.h index 2e99f896a..0630292e6 100644 --- a/pymnn/src/util.h +++ b/pymnn/src/util.h @@ -107,13 +107,23 @@ inline int64_t unpackLong(PyObject* obj) { } return (int64_t)value; } +inline double unpackDoubleOrLong(PyObject* obj) { + if (PyLong_Check(obj) +#if PY_MAJOR_VERSION < 3 + || PyInt_Check(obj) +#endif + ) { + return static_cast(unpackLong(obj)); + } + return unpackDouble(obj); +} inline void store_scalar(void* data, int dtype, PyObject* obj) { switch (dtype) { case 4: *(uint8_t*)data = (uint8_t)unpackLong(obj); break; case 3: *(int32_t*)data = (int32_t)unpackLong(obj); break; case 9: *(int64_t*)data = unpackLong(obj); break; - case 1: *(float*)data = (float)unpackDouble(obj); break; - case 2: *(double*)data = (double)unpackDouble(obj); break; + case 1: *(float*)data = (float)unpackDoubleOrLong(obj); break; + case 2: *(double*)data = (double)unpackDoubleOrLong(obj); break; case 6: *(int8_t*)data = (int8_t)unpackLong(obj); break; default: PyMNN_ERROR_LOG("store_scalar: invalid type"); } diff --git a/source/backend/cpu/BinaryUtils.hpp b/source/backend/cpu/BinaryUtils.hpp index dff13c01f..ed93f9406 100644 --- a/source/backend/cpu/BinaryUtils.hpp +++ b/source/backend/cpu/BinaryUtils.hpp @@ -330,7 +330,7 @@ void execute(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int } template -void executeInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) { +void executeInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast) { Func f; int size = elementSize; #ifdef MNN_USE_NEON @@ -355,19 +355,19 @@ void executeInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* input #endif for (int i = 0; i < size; ++i) { if (needBroadcast == 0) { - inp0 = (inputData0[0]- zeroPoint) * inputScale0[0]; - inp1 = (inputData1[i]- zeroPoint) * inputScale1[0]; + inp0 = (inputData0[0]- zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + inp1 = (inputData1[i]- zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; output = f(inp0, inp1); } else if (needBroadcast == 1) { - inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - inp1 = (inputData1[0] - zeroPoint) * inputScale1[0]; + inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + inp1 = (inputData1[0] - zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; output = f(inp0, inp1); } else { - inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; + inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; output = f(inp0, inp1); } - int value = (int)roundf(output * outputScale[0]) + zeroPoint; + int value = (int)roundf(output * inputScalesFp32[2]) + zeroPoint + outputOffset[0]; if (value > maxValue) { value = maxValue; } diff --git a/source/backend/cpu/CPUBinaryInt8.cpp b/source/backend/cpu/CPUBinaryInt8.cpp index 569a4988c..a34ff6901 100644 --- a/source/backend/cpu/CPUBinaryInt8.cpp +++ b/source/backend/cpu/CPUBinaryInt8.cpp @@ -16,8 +16,6 @@ #include "BinaryUtils.hpp" #include "math/Vec.hpp" -using Vec16 = MNN::Math::Vec; - namespace MNN { ErrorCode CPUBinaryInt8::onResize(const std::vector& inputs, const std::vector& outputs) { @@ -37,22 +35,24 @@ ErrorCode CPUBinaryInt8::onResize(const std::vector& inputs, const std: auto core = static_cast(backend())->functions(); - mInputQuant0.resize(core->pack); // prepare for arm neon. float32x4 - mInputQuant1.resize(core->pack); - mOutputQuant.resize(core->pack); - std::fill(mInputQuant0.begin(), mInputQuant0.end(), TensorUtils::getDescribe(inputs[0])->quantAttr->scale); - std::fill(mInputQuant1.begin(), mInputQuant1.end(), TensorUtils::getDescribe(inputs[1])->quantAttr->scale); + mInputOffset0.resize(1); + mInputOffset1.resize(1); + mOutputOffset.resize(1); + mQuantScalesInt32.resize(2); // When use int32 scales computing, output scale is needless. + mQuantScalesFp32.resize(3); + mQuantScalesInt32[0] = TensorUtils::getDescribe(inputs[0])->quantAttr->scale * (1 << 16); + mQuantScalesInt32[1] = TensorUtils::getDescribe(inputs[1])->quantAttr->scale * (1 << 16); + mQuantScalesFp32[0] = TensorUtils::getDescribe(inputs[0])->quantAttr->scale; + mQuantScalesFp32[1] = TensorUtils::getDescribe(inputs[1])->quantAttr->scale; if (TensorUtils::getDescribe(outputs[0])->quantAttr->scale != 0) { - std::fill(mOutputQuant.begin(), mOutputQuant.end(), 1 / TensorUtils::getDescribe(outputs[0])->quantAttr->scale); + mQuantScalesFp32[2] = 1 / TensorUtils::getDescribe(outputs[0])->quantAttr->scale; } else { - std::fill(mOutputQuant.begin(), mOutputQuant.end(), 0); + mQuantScalesFp32[2] = 0; } - + mInputOffset0[0] = (int8_t)TensorUtils::getDescribe(inputs[0])->quantAttr->zero; + mInputOffset1[0] = (int8_t)TensorUtils::getDescribe(inputs[1])->quantAttr->zero; + mOutputOffset[0] = (int8_t)TensorUtils::getDescribe(outputs[0])->quantAttr->zero; - if(mActivationType == 1 && outputs[0]->getType().code == halide_type_float) { - mActivationExe.reset(new CPURelu(backend(), 0.0)); - mActivationExe->onResize(outputs, outputs); - } return NO_ERROR; } @@ -79,9 +79,9 @@ ErrorCode CPUBinaryInt8::onExecute(const std::vector& inputs, const std if (realSize > 0) { auto inp0 = input0Ptr + start * inpBytes; auto inp1 = input1Ptr + start * inpBytes; - auto scale0 = mInputQuant0.data() + start; - auto scale1 = mInputQuant1.data() + start; - auto scaleDst = mOutputQuant.data() + start; + auto offset0 = mInputOffset0.data(); + auto offset1 = mInputOffset1.data(); + auto offsetDst = mOutputOffset.data(); if (mNeedBroadcastIndex == 0) { inp0 = input0Ptr; } else if (mNeedBroadcastIndex == 1) { @@ -89,17 +89,14 @@ ErrorCode CPUBinaryInt8::onExecute(const std::vector& inputs, const std } auto out = outputPtr + start * outBytes; #ifdef MNN_USE_NEON - mProc(out, inp0, inp1, scale0, scale1, scaleDst, realSize / 4, mNeedBroadcastIndex); + mProc(out, inp0, inp1, mQuantScalesInt32.data(), mQuantScalesFp32.data(), offset0, offset1, offsetDst, realSize / 4, mNeedBroadcastIndex); #else - mProc(out, inp0, inp1, scale0, scale1, scaleDst, realSize, mNeedBroadcastIndex); + mProc(out, inp0, inp1, mQuantScalesInt32.data(), mQuantScalesFp32.data(), offset0, offset1, offsetDst, realSize, mNeedBroadcastIndex); #endif } } MNN_CONCURRENCY_END(); - - if(mActivationType == 1 && output->getType().code == halide_type_float) { - mActivationExe->onExecute(outputs, outputs);; - } + return NO_ERROR; } diff --git a/source/backend/cpu/CPUBinaryInt8.hpp b/source/backend/cpu/CPUBinaryInt8.hpp index 39f94e787..2924c426a 100644 --- a/source/backend/cpu/CPUBinaryInt8.hpp +++ b/source/backend/cpu/CPUBinaryInt8.hpp @@ -31,9 +31,11 @@ class CPUBinaryInt8 : public Execution { int mTotalSize; int mActivationType = 0; std::shared_ptr mActivationExe; - std::vector mInputQuant0; - std::vector mInputQuant1; - std::vector mOutputQuant; + std::vector mQuantScalesInt32; // input0 and input1 + std::vector mQuantScalesFp32; // input0, input1 and output + std::vector mInputOffset0; + std::vector mInputOffset1; + std::vector mOutputOffset; }; } // namespace MNN #endif /* CPUBinary_hpp */ diff --git a/source/backend/cpu/CPUInterp.cpp b/source/backend/cpu/CPUInterp.cpp index cd153320a..7ade8f7b8 100644 --- a/source/backend/cpu/CPUInterp.cpp +++ b/source/backend/cpu/CPUInterp.cpp @@ -51,10 +51,10 @@ ErrorCode CPUInterp::onExecute(const std::vector &inputs, const std::v case 2: CPUResizeBilinearC4(CPUBilinearSampleC4, CPUBilinearLineC4, inputs, outputs, mWidthPosition.host(), mWidthFactor.host(), mHeightPosition.host(), mHeightFactor.host(), - mLineBuffer.host(), ((CPUBackend *)backend())->threadNumber()); + mLineBuffer.host(), ((CPUBackend *)backend())->threadNumber(), &mInputQuantZero, &mOutputQuantZero); break; case 3: - CPUResizeCubicC4(MNNCubicSampleC4, MNNCubicLineC4, inputs, outputs, mWidthScale, mHeightScale, mWidthOffset, mHeightOffset); + CPUResizeCubicC4(MNNCubicSampleC4, MNNCubicLineC4, inputs, outputs, mWidthScale, mHeightScale, mWidthOffset, mHeightOffset, &mInputQuantZero, &mOutputQuantZero, mOutputQuantMIn, mOutputQuantMax); break; case 4: CPUResizeNearestneighborRoundC4(inputs, outputs, mWidthScale, mHeightScale, mWidthOffset, mHeightOffset); @@ -92,10 +92,10 @@ ErrorCode CPUInterp::onExecute(const std::vector &inputs, const std::v CPUResizeNearestneighborC4(int8ExeInputs, int8ExeOutputs, mWidthScale, mHeightScale, mWidthOffset, mHeightOffset); break; case 2: - CPUResizeBilinearC4(MNNBilinearSampleC8, MNNBilinearLineC8, int8ExeInputs, int8ExeOutputs, mWidthPosition.host(), mWidthFactor.host(), mHeightPosition.host(), mHeightFactor.host(), mLineBuffer.host(), ((CPUBackend *)backend())->threadNumber()); + CPUResizeBilinearC4(MNNBilinearSampleC8, MNNBilinearLineC8, int8ExeInputs, int8ExeOutputs, mWidthPosition.host(), mWidthFactor.host(), mHeightPosition.host(), mHeightFactor.host(), mLineBuffer.host(), ((CPUBackend *)backend())->threadNumber(), &mInputQuantZero, &mOutputQuantZero); break; case 3: - CPUResizeCubicC4(MNNCubicSampleC16, MNNCubicLineC16, int8ExeInputs, int8ExeOutputs, mWidthScale, mHeightScale, mWidthOffset, mHeightOffset); + CPUResizeCubicC4(MNNCubicSampleC16, MNNCubicLineC16, int8ExeInputs, int8ExeOutputs, mWidthScale, mHeightScale, mWidthOffset, mHeightOffset, &mInputQuantZero, &mOutputQuantZero, mOutputQuantMIn, mOutputQuantMax); break; case 4: CPUResizeNearestneighborRoundC4(int8ExeInputs, int8ExeOutputs, mWidthScale, mHeightScale, mWidthOffset, mHeightOffset); @@ -126,7 +126,8 @@ ErrorCode CPUInterp::onResize(const std::vector &inputs, const std::ve if (mResizeType == 3 || mResizeType == 4) { packInt8 = 16; } - if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) { + bool useInt8 = (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) && (CPUBackend::getDataType(outputs[0]) == DataType_DT_INT8 || outputs[0]->getType().bytes() == 1); + if (useInt8) { mInputTemp.reset(Tensor::createDevice({inputs[0]->batch(), inH, inW, UP_DIV(inputs[0]->channel(), packInt8) * packInt8})); mOutputTemp.reset(Tensor::createDevice({outputs[0]->batch(), outH, outW, UP_DIV(outputs[0]->channel(), packInt8) * packInt8})); bool allocSucc = backend()->onAcquireBuffer(mInputTemp.get(), Backend::DYNAMIC); @@ -134,6 +135,10 @@ ErrorCode CPUInterp::onResize(const std::vector &inputs, const std::ve if (!allocSucc) { return OUT_OF_MEMORY; } + mInputQuantZero = TensorUtils::getQuantInfo(inputs[0])[1]; + mOutputQuantZero = TensorUtils::getQuantInfo(outputs[0])[1]; + mOutputQuantMIn = TensorUtils::getQuantInfo(outputs[0])[2]; + mOutputQuantMax = TensorUtils::getQuantInfo(outputs[0])[3]; } if (mResizeType != 2) { diff --git a/source/backend/cpu/CPUInterp.hpp b/source/backend/cpu/CPUInterp.hpp index 6aa69c606..46c9fe37c 100644 --- a/source/backend/cpu/CPUInterp.hpp +++ b/source/backend/cpu/CPUInterp.hpp @@ -36,6 +36,10 @@ class CPUInterp : public CPUResizeCommon { bool mInit = false; std::shared_ptr mInputTemp; std::shared_ptr mOutputTemp; + int8_t mInputQuantZero = 0; + int8_t mOutputQuantZero = 0; + ssize_t mOutputQuantMIn = -127; + ssize_t mOutputQuantMax = 127; }; } // namespace MNN diff --git a/source/backend/cpu/CPUResize.hpp b/source/backend/cpu/CPUResize.hpp index 0ca4da9d8..826aebfa8 100644 --- a/source/backend/cpu/CPUResize.hpp +++ b/source/backend/cpu/CPUResize.hpp @@ -13,6 +13,7 @@ #include "core/Execution.hpp" #include "core/Concurrency.h" #include "backend/cpu/CPUBackend.hpp" +#include "core/TensorUtils.hpp" #include "math/Vec.hpp" #include "core/Macro.h" #include @@ -21,16 +22,16 @@ using Vec4 = MNN::Math::Vec; #ifdef __cplusplus extern "C" { #endif -void CPUBilinearSampleC4(const float* src, float* dst, const int32_t* position, const float* factor, size_t number); -void CPUBilinearLineC4(float* dst, const float* A, const float* B, const float* t, size_t number); -void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, size_t number); -void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, size_t number); -void MNNCubicSampleC4(const float* src, float* dst, int32_t* position, const float* factor, size_t number); -void MNNCubicLineC4(float* dst, const float* A, const float* B, const float* C, const float* D, float* t, - size_t number); -void MNNCubicSampleC16(const int8_t* src, float* dst, int32_t* position, const float* factor, size_t number); -void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, - size_t number); +void CPUBilinearSampleC4(const float* src, float* dst, const int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); +void CPUBilinearLineC4(float* dst, const float* A, const float* B, const float* t, int8_t* zeroPoint, size_t number); +void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); +void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, int8_t* zeroPoint, size_t number); +void MNNCubicSampleC4(const float* src, float* dst, int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); +void MNNCubicLineC4(float* dst, const float* A, const float* B, const float* C, const float* D, float* t, int8_t* zeroPoint, + size_t number, ssize_t minValue, ssize_t maxValue); +void MNNCubicSampleC16(const int8_t* src, float* dst, int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); +void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, int8_t* zeroPoint, + size_t number, ssize_t minValue, ssize_t maxValue); #ifdef __cplusplus } #endif @@ -54,8 +55,8 @@ class CPUResizeCommon : public Execution { virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) = 0; template - void CPUResizeBilinearC4(void sampleFunction(const T*, U*, const int32_t*, const float*, size_t), void lineFunction(T*, const U*, const U*, const float*, size_t), const std::vector &inputs, const std::vector &outputs, const int* widthPosition, const float* widthFactor, const int* heightPosition, - const float* heightFactor, U* lineBuffer, int threadNumber) { + void CPUResizeBilinearC4(void sampleFunction(const T*, U*, const int32_t*, const float*, int8_t*, size_t), void lineFunction(T*, const U*, const U*, const float*, int8_t*, size_t), const std::vector &inputs, const std::vector &outputs, const int* widthPosition, const float* widthFactor, const int* heightPosition, + const float* heightFactor, U* lineBuffer, int threadNumber, int8_t* inputQuantZero, int8_t* outputQuantZero) { auto input = inputs[0]; auto output = outputs[0]; const int batches = input->batch(); @@ -106,7 +107,7 @@ class CPUResizeCommon : public Execution { yCache[k] = yp[j]; yUsed[k] = 1; yCacheLine[j] = yCacheStorage[k]; - sampleFunction(bottomY0, yCacheLine[j], widthPosition, widthFactor, outW); + sampleFunction(bottomY0, yCacheLine[j], widthPosition, widthFactor, inputQuantZero, outW); break; } } @@ -114,7 +115,7 @@ class CPUResizeCommon : public Execution { } T* topY = topData + outW * pack * dy; // Sample Input - lineFunction(topY, yCacheLine[0], yCacheLine[1], &heightFactor[dy], outW); + lineFunction(topY, yCacheLine[0], yCacheLine[1], &heightFactor[dy], outputQuantZero, outW); } } @@ -126,8 +127,8 @@ class CPUResizeCommon : public Execution { } template - void CPUResizeCubicC4(void sampleFunction(const T*, float*, int32_t*, const float*, size_t), void lineFunction(T*, const float*, const float*, const float*, const float*, float*, size_t), - const std::vector &inputs, const std::vector &outputs, float xFactor, float yFactor, float wOffset, float hOffset) { + void CPUResizeCubicC4(void sampleFunction(const T*, float*, int32_t*, const float*, int8_t*, size_t), void lineFunction(T*, const float*, const float*, const float*, const float*, float*, int8_t*, size_t, ssize_t, ssize_t), + const std::vector &inputs, const std::vector &outputs, float xFactor, float yFactor, float wOffset, float hOffset, int8_t* inputQuantZero, int8_t* outputQuantZero, ssize_t minValue, ssize_t maxValue) { auto input = inputs[0]; auto output = outputs[0]; const int batches = input->batch(); @@ -202,7 +203,7 @@ class CPUResizeCommon : public Execution { yCache[k] = yp[j]; yUsed[k] = 1; yCacheLine[j] = yCacheStorage[k]; - sampleFunction(bottomY0, yCacheLine[j], _linePosition, _lineFactor, outW); + sampleFunction(bottomY0, yCacheLine[j], _linePosition, _lineFactor, inputQuantZero, outW); break; } } @@ -212,7 +213,7 @@ class CPUResizeCommon : public Execution { // Sample Input float yFract = (float)(y - floor(y)); auto topY = topData + outW * pack * dy; - lineFunction(topY, yCacheLine[0], yCacheLine[1], yCacheLine[2], yCacheLine[3], &yFract, outW); + lineFunction(topY, yCacheLine[0], yCacheLine[1], yCacheLine[2], yCacheLine[3], &yFract, outputQuantZero, outW, minValue, maxValue); } } MNN_CONCURRENCY_END(); diff --git a/source/backend/cpu/CPUScaleInt8.cpp b/source/backend/cpu/CPUScaleInt8.cpp index cb91e275f..6daa3fc02 100644 --- a/source/backend/cpu/CPUScaleInt8.cpp +++ b/source/backend/cpu/CPUScaleInt8.cpp @@ -17,15 +17,6 @@ namespace MNN { -static int minPow2GeaterThanN(int n) { - int k = 0, pow = 1; - while (pow < n) { - k++; - pow = pow<<1; - } - return 20 - k; -} - CPUScaleInt8::CPUScaleInt8(const Op* op, Backend* bn) : MNN::Execution(bn) { auto scale = op->main_as_Scale(); auto core = static_cast(bn)->functions(); @@ -132,11 +123,6 @@ ErrorCode CPUScaleInt8::onResize(const std::vector &inputs, const std: ::memcpy(biasPtr_, bias_.data(), outputCount * sizeof(int32_t)); mOutputQuantInfo[0] = outputScale; - int planeNumber = 1; - for (int i = 2; i < input->buffer().dimensions; ++i) { - planeNumber *= input->length(i); - } - auto depthStride = planeNumber * core->pack; return NO_ERROR; } @@ -161,12 +147,14 @@ ErrorCode CPUScaleInt8::onExecute(const std::vector& inputs, const std: int numberThread = ((CPUBackend*)backend())->threadNumber(); MNN_CONCURRENCY_BEGIN(tId, numberThread) { + int8_t inputZeroPoint = (int8_t)mInputQuantInfo[1]; + int8_t outputZeroPoint = (int8_t)mOutputQuantInfo[1]; for (int i = tId; i < totalDepth; i+=numberThread) { auto depthIndex = i / batch; const int8_t* inputPtr = input->host() + depthStride * i; const int32_t* biasPtr_ = (const int32_t*)(biasPtr + core->pack * core->bytes * depthIndex); const int32_t* scalePtr_ = (const int32_t*)(scalePtr + core->pack * core->bytes * depthIndex); - MNNScaleAndAddBiasInt8(output->host() + depthStride * i, inputPtr, biasPtr_, scalePtr_, mShiftBits, (ssize_t)mOutputQuantInfo[2], (ssize_t)mOutputQuantInfo[3], (ssize_t)mOutputQuantInfo[1], planeNumber, 1, core->pack); + MNNScaleAndAddBiasInt8(output->host() + depthStride * i, inputPtr, biasPtr_, scalePtr_, mShiftBits, (ssize_t)mOutputQuantInfo[2], (ssize_t)mOutputQuantInfo[3], &inputZeroPoint, &outputZeroPoint, planeNumber, 1, core->pack); } } MNN_CONCURRENCY_END(); diff --git a/source/backend/cpu/CPUSoftMaxInt8.cpp b/source/backend/cpu/CPUSoftMaxInt8.cpp index b21bea795..1630ae52c 100644 --- a/source/backend/cpu/CPUSoftMaxInt8.cpp +++ b/source/backend/cpu/CPUSoftMaxInt8.cpp @@ -93,7 +93,7 @@ void CPUSoftmaxInt8::QuantizedSoftmax(const uint8_t* inputData, int outerSize, i #endif MNN_CONCURRENCY_BEGIN(tId, threadNum) { auto inputDataPtr = src_ + tId * depth; - auto outputDataPtr = dst_ + tId * depth; + uint8_t* outputDataPtr = (uint8_t*)dst_ + tId * depth; for (int b = (int)tId; b < outerSize; b += threadNum, inputDataPtr += depth * threadNum, outputDataPtr += depth * threadNum) { // Determine the largest entry in the current row int8_t maxInRow = -128; @@ -227,7 +227,7 @@ void CPUSoftmaxInt8::QuantizedSoftmax(const uint8_t* inputData, int outerSize, i vsubq_s16(input_s16, max_in_row_s16); int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); - int8x8_t mask = vmovn_s16(vcgeq_s16(input_diff_s16, diff_min_s16)); + uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16)); FixedPointScaledDiffInt32x4 scaled_diff_0 = input_beta_multiplier_f0 * FixedPointScaledDiffInt32x4::FromRaw( @@ -246,9 +246,9 @@ void CPUSoftmaxInt8::QuantizedSoftmax(const uint8_t* inputData, int outerSize, i numBitsOverUnit + 31 - 8); int16x8_t output_s16 = vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1)); - int8x8_t output_s8 = vqmovn_s16(output_s16); - int8x8_t masked_output = vbsl_s8(mask, output_s8, vdup_n_s8(0)); - vst1_s8(outputDataPtr + c, masked_output); + uint8x8_t output_s8 = vqmovun_s16(output_s16); + uint8x8_t masked_output = vbsl_u8(mask, output_s8, vdup_n_u8(0)); + vst1_u8(outputDataPtr + c, masked_output); } #endif for (; c < depth; ++c) { diff --git a/source/backend/cpu/arm/arm32/MNNBilinearLineC16.S b/source/backend/cpu/arm/arm32/MNNBilinearLineC8.S similarity index 89% rename from source/backend/cpu/arm/arm32/MNNBilinearLineC16.S rename to source/backend/cpu/arm/arm32/MNNBilinearLineC8.S index c3ee9f11b..e9365714f 100644 --- a/source/backend/cpu/arm/arm32/MNNBilinearLineC16.S +++ b/source/backend/cpu/arm/arm32/MNNBilinearLineC8.S @@ -14,13 +14,14 @@ .align 5 asm_function MNNBilinearLineC8 -// void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, size_t number) +// void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, int8_t* zeroPoint, size_t number) // Auto load: r0: dst, r1: A, r2: B, r3: t -// r4: number +// r5: zeroPoint, r4: number push {r4-r8, r10, lr} // avoid to touch platform-register r-9 -ldr r4, [sp, #28] +ldr r5, [sp, #28] +ldr r4, [sp, #32] ldr r3, [r3, #0] vpush {q4-q7} @@ -57,6 +58,7 @@ vshr.s32 q3, q3, #14 vqmovn.s32 d4, q2 vqmovn.s32 d5, q3 + vqmovn.s16 d4, q2 vst1.8 {d4}, [r0]! diff --git a/source/backend/cpu/arm/arm32/MNNBilinearSampleC16.S b/source/backend/cpu/arm/arm32/MNNBilinearSampleC8.S similarity index 88% rename from source/backend/cpu/arm/arm32/MNNBilinearSampleC16.S rename to source/backend/cpu/arm/arm32/MNNBilinearSampleC8.S index b209d89c9..fe1afab3a 100644 --- a/source/backend/cpu/arm/arm32/MNNBilinearSampleC16.S +++ b/source/backend/cpu/arm/arm32/MNNBilinearSampleC8.S @@ -14,12 +14,13 @@ .align 5 asm_function MNNBilinearSampleC8 -// void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, size_t number); +// void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); // Auto load: r0: src, r1: dst, r2: position, r3: factor -// r4: number +// r12: zeroPoint, r4: number push {r4-r8, r10, lr} -ldr r4, [sp, #28] +ldr r12, [sp, #28] +ldr r4, [sp, #32] mov lr, #8 vpush {q4-q7} diff --git a/source/backend/cpu/arm/arm32/MNNBinaryAddInt8.S b/source/backend/cpu/arm/arm32/MNNBinaryAddInt8.S index 42140fb19..a0ffd277a 100644 --- a/source/backend/cpu/arm/arm32/MNNBinaryAddInt8.S +++ b/source/backend/cpu/arm/arm32/MNNBinaryAddInt8.S @@ -15,33 +15,45 @@ .align 5 asm_function MNNBinaryAddInt8 -// MNNBinaryAddInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// Auto Load -// r0: dst, r1:src0, r2:src1, r3: scale0 -// Load from sp -// r4:scale1, r5: outputScale, r6:size, r7: needBroadcast -push {r4, r5, r6, r7, r8, lr} - -ldr r4, [sp, #24] -ldr r5, [sp, #28] -ldr r6, [sp, #32] -ldr r7, [sp, #36] +// MNNBinaryAddInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// r0: dst, r1:src0, r2:src1, r3:quantScalesInt32 +// Load from sp: +// r4:quantScalesFp32, r5: offset0, r6: offset1, r7: outputoffset +// r8: size, r9: needBroadcast +push {r4, r5, r6, r7, r8, r9, lr} + +ldr r4, [sp, #28] +ldr r5, [sp, #32] +ldr r6, [sp, #36] +ldr r7, [sp, #40] +ldr r8, [sp, #44] +ldr r9, [sp, #48] vpush {q4-q7} -vld1.32 {q13}, [r3] // q13: scale0 -vld1.32 {q14}, [r4] // q14: scale1 -vld1.32 {q15}, [r5] // q15: outputScale +ldr r12, [r3] +vdup.s32 q13, r12 // scale +ldr r4, [r3, #4] +vdup.s32 q14, r4 + +vld1.8 {d0[0]}, [r5] +vld1.8 {d1[0]}, [r6] +vld1.8 {d4[0]}, [r7] +vdup.8 d0, d0[0] +vdup.8 d1, d1[0] +vdup.8 d4, d4[0] L4: -cmp r6, #4 +cmp r8, #4 blt L1 L4Loop: - cmp r7, #0 + cmp r9, #0 beq L4NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: @@ -50,131 +62,116 @@ L4Loop: b L4Compute L4NeedBroadcast0: - vld1.8 {d22[0]}, [r1] - vdup.8 d22, d22[0] - vdup.8 d23, d22[0] + ldr r4, [r1] + vdup.s8 q11, r4 vld1.32 {q12}, [r2]! b L4Compute L4NeedBroadcast1: vld1.32 {q11}, [r1]! - vld1.8 {d24[0]}, [r2] - vdup.8 d24, d24[0] - vdup.8 d25, d24[0] + ldr r4, [r2] + vdup.s8 q12, r4 b L4Compute L4Compute: - sub r6, r6, #4 + sub r8, r8, #4 vmovl.s8 q4, d22 vmovl.s8 q5, d23 - - vmovl.s16 q0, d8 - vmovl.s16 q1, d9 - vmovl.s16 q2, d10 - vmovl.s16 q3, d11 - - vcvtq.f32.s32 q0, q0 - vcvtq.f32.s32 q1, q1 - vcvtq.f32.s32 q2, q2 - vcvtq.f32.s32 q3, q3 - vmovl.s8 q6, d24 vmovl.s8 q7, d25 - vmulq.f32 q0, q0, q13 - vmulq.f32 q1, q1, q13 - vmulq.f32 q2, q2, q13 - vmulq.f32 q3, q3, q13 - - vmovl.s16 q8, d12 - vmovl.s16 q9, d13 - vmovl.s16 q10, d14 - vmovl.s16 q11, d15 - vcvtq.f32.s32 q8, q8 - vcvtq.f32.s32 q9, q9 - vcvtq.f32.s32 q10, q10 - vcvtq.f32.s32 q11, q11 - vmulq.f32 q8, q8, q14 - vmulq.f32 q9, q9, q14 - vmulq.f32 q10, q10, q14 - vmulq.f32 q11, q11, q14 - - vaddq.f32 q0, q0, q8 - vaddq.f32 q1, q1, q9 - vaddq.f32 q2, q2, q10 - vaddq.f32 q3, q3, q11 - - vmulq.f32 q0, q0, q15 - vmulq.f32 q1, q1, q15 - vmulq.f32 q2, q2, q15 - vmulq.f32 q3, q3, q15 - - vcvtq.s32.f32 q0, q0 - vcvtq.s32.f32 q1, q1 - vcvtq.s32.f32 q2, q2 - vcvtq.s32.f32 q3, q3 - - vqmovn.s32 d8, q0 - vqmovn.s32 d9, q1 - vqmovn.s32 d10, q2 - vqmovn.s32 d11, q3 + vsubw.s8 q4, q4, d0 + vsubw.s8 q5, q5, d0 + vsubw.s8 q6, q6, d1 + vsubw.s8 q7, q7, d1 + + vmovl.s16 q8, d8 + vmovl.s16 q9, d9 + vmovl.s16 q10, d10 + vmovl.s16 q11, d11 + + vmovl.s16 q3, d12 + vmovl.s16 q12, d13 + vmovl.s16 q15, d14 + vmovl.s16 q4, d15 + + vmulq.s32 q8, q8, q13 + vmulq.s32 q9, q9, q13 + vmulq.s32 q10, q10, q13 + vmulq.s32 q11, q11, q13 + + vmulq.s32 q3, q3, q14 + vmulq.s32 q12, q12, q14 + vmulq.s32 q15, q15, q14 + vmulq.s32 q4, q4, q14 + + vaddq.s32 q8, q8, q3 + vaddq.s32 q9, q9, q12 + vaddq.s32 q10, q10, q15 + vaddq.s32 q11, q11, q4 + + vqshrn.s32 d6, q8, #16 + vqshrn.s32 d7, q9, #16 + vqshrn.s32 d8, q10, #16 + vqshrn.s32 d9, q11, #16 + + vaddw.s8 q3, q3, d4 + vaddw.s8 q4, q4, d4 - vqmovn.s16 d0, q4 - vqmovn.s16 d1, q5 - cmp r6, #4 - vst1.32 {q0}, [r0]! + vqmovn.s16 d12, q3 + vqmovn.s16 d13, q4 + cmp r8, #4 + vst1.32 {q6}, [r0]! bge L4Loop L1: -cmp r6, #0 +cmp r8, #0 beq End L1Loop: - cmp r7, #0 + cmp r9, #0 beq L1NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - vld1.32 {d0[0]}, [r1]! + vld1.32 {d6[0]}, [r1]! vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast0: - vld1.8 {d0[0]}, [r1] - vdup.8 d0, d0[0] + ldr r4, [r1] + vdup.s8 d6, r4 vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast1: - vld1.32 {d0[0]}, [r1]! - vld1.8 {d8[0]}, [r2] - vdup.8 d8, d8[0] + vld1.32 {d6[0]}, [r1]! + ldr r4, [r2] + vdup.s8 d8, r4 b L1Compute L1Compute: - subs r6, r6, #1 - vmovl.s8 q1, d0 - vmovl.s16 q2, d2 - vcvtq.f32.s32 q3, q2 - vmulq.f32 q3, q3, q13 + subs r8, r8, #1 + vmovl.s8 q3, d6 + vsubw.s8 q3, q3, d0 + vmovl.s16 q3, d6 + vmulq.s32 q3, q3, q13 vmovl.s8 q5, d8 + vsubw.s8 q5, q5, d1 vmovl.s16 q6, d10 - vcvtq.f32.s32 q7, q6 - vmulq.f32 q7, q7, q14 - - vaddq.f32 q3, q3, q7 + vmulq.s32 q6, q6, q14 - vmulq.f32 q3, q3, q15 - vcvtq.s32.f32 q0, q3 - vqmovn.s32 d2, q0 - vqmovn.s16 d6, q1 + vaddq.s32 q3, q3, q6 + vqshrn.s32 d6, q3, #16 + vaddw.s8 q3, q3, d4 + vqmovn.s16 d6, q3 vst1.32 {d6[0]}, [r0]! bne L1Loop End: vpop {q4-q7} -pop {r4, r5, r6, r7, r8, pc} +pop {r4, r5, r6, r7, r8, r9, pc} #endif #endif diff --git a/source/backend/cpu/arm/arm32/MNNBinaryMaxInt8.S b/source/backend/cpu/arm/arm32/MNNBinaryMaxInt8.S index 671126a08..9ec7d9ba0 100644 --- a/source/backend/cpu/arm/arm32/MNNBinaryMaxInt8.S +++ b/source/backend/cpu/arm/arm32/MNNBinaryMaxInt8.S @@ -15,33 +15,45 @@ .align 5 asm_function MNNBinaryMaxInt8 -// MNNBinaryMaxInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// Auto Load -// r0: dst, r1:src0, r2:src1, r3: scale0 -// Load from sp -// r4:scale1, r5: outputScale, r6:size, r7: needBroadcast -push {r4, r5, r6, r7, r8, lr} - -ldr r4, [sp, #24] -ldr r5, [sp, #28] -ldr r6, [sp, #32] -ldr r7, [sp, #36] +// MNNBinaryMaxInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// r0: dst, r1:src0, r2:src1, r3:quantScalesInt32 +// Load from sp: +// r4:quantScalesFp32, r5: offset0, r6: offset1, r7: outputoffset +// r8: size, r9: needBroadcast +push {r4, r5, r6, r7, r8, r9, lr} + +ldr r4, [sp, #28] +ldr r5, [sp, #32] +ldr r6, [sp, #36] +ldr r7, [sp, #40] +ldr r8, [sp, #44] +ldr r9, [sp, #48] vpush {q4-q7} -vld1.32 {q13}, [r3] // q13: scale0 -vld1.32 {q14}, [r4] // q14: scale1 -vld1.32 {q15}, [r5] // q15: outputScale +ldr r12, [r3] +vdup.s32 q13, r12 // scale +ldr r4, [r3, #4] +vdup.s32 q14, r4 + +vld1.8 {d0[0]}, [r5] +vld1.8 {d1[0]}, [r6] +vld1.8 {d4[0]}, [r7] +vdup.8 d0, d0[0] +vdup.8 d1, d1[0] +vdup.8 d4, d4[0] L4: -cmp r6, #4 +cmp r8, #4 blt L1 L4Loop: - cmp r7, #0 + cmp r9, #0 beq L4NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: @@ -50,131 +62,117 @@ L4Loop: b L4Compute L4NeedBroadcast0: - vld1.8 {d22[0]}, [r1] - vdup.8 d22, d22[0] - vdup.8 d23, d22[0] + ldr r4, [r1] + vdup.s8 q11, r4 vld1.32 {q12}, [r2]! b L4Compute L4NeedBroadcast1: vld1.32 {q11}, [r1]! - vld1.8 {d24[0]}, [r2] - vdup.8 d24, d24[0] - vdup.8 d25, d24[0] + ldr r4, [r2] + vdup.s8 q12, r4 b L4Compute L4Compute: - sub r6, r6, #4 + sub r8, r8, #4 vmovl.s8 q4, d22 vmovl.s8 q5, d23 - - vmovl.s16 q0, d8 - vmovl.s16 q1, d9 - vmovl.s16 q2, d10 - vmovl.s16 q3, d11 - - vcvtq.f32.s32 q0, q0 - vcvtq.f32.s32 q1, q1 - vcvtq.f32.s32 q2, q2 - vcvtq.f32.s32 q3, q3 - vmovl.s8 q6, d24 vmovl.s8 q7, d25 - vmulq.f32 q0, q0, q13 - vmulq.f32 q1, q1, q13 - vmulq.f32 q2, q2, q13 - vmulq.f32 q3, q3, q13 - - vmovl.s16 q8, d12 - vmovl.s16 q9, d13 - vmovl.s16 q10, d14 - vmovl.s16 q11, d15 - vcvtq.f32.s32 q8, q8 - vcvtq.f32.s32 q9, q9 - vcvtq.f32.s32 q10, q10 - vcvtq.f32.s32 q11, q11 - vmulq.f32 q8, q8, q14 - vmulq.f32 q9, q9, q14 - vmulq.f32 q10, q10, q14 - vmulq.f32 q11, q11, q14 - - vmaxq.f32 q0, q0, q8 - vmaxq.f32 q1, q1, q9 - vmaxq.f32 q2, q2, q10 - vmaxq.f32 q3, q3, q11 - - vmulq.f32 q0, q0, q15 - vmulq.f32 q1, q1, q15 - vmulq.f32 q2, q2, q15 - vmulq.f32 q3, q3, q15 - - vcvtq.s32.f32 q0, q0 - vcvtq.s32.f32 q1, q1 - vcvtq.s32.f32 q2, q2 - vcvtq.s32.f32 q3, q3 - - vqmovn.s32 d8, q0 - vqmovn.s32 d9, q1 - vqmovn.s32 d10, q2 - vqmovn.s32 d11, q3 + vsubw.s8 q4, q4, d0 + vsubw.s8 q5, q5, d0 + vsubw.s8 q6, q6, d1 + vsubw.s8 q7, q7, d1 - vqmovn.s16 d0, q4 - vqmovn.s16 d1, q5 - cmp r6, #4 - vst1.32 {q0}, [r0]! + vmovl.s16 q8, d8 + vmovl.s16 q9, d9 + vmovl.s16 q10, d10 + vmovl.s16 q11, d11 + + vmovl.s16 q3, d12 + vmovl.s16 q12, d13 + vmovl.s16 q15, d14 + vmovl.s16 q4, d15 + + vmulq.s32 q8, q8, q13 + vmulq.s32 q9, q9, q13 + vmulq.s32 q10, q10, q13 + vmulq.s32 q11, q11, q13 + + vmulq.s32 q3, q3, q14 + vmulq.s32 q12, q12, q14 + vmulq.s32 q15, q15, q14 + vmulq.s32 q4, q4, q14 + + vmax.s32 q8, q8, q3 + vmax.s32 q9, q9, q12 + vmax.s32 q10, q10, q15 + vmax.s32 q11, q11, q4 + + vqshrn.s32 d6, q8, #16 + vqshrn.s32 d7, q9, #16 + vqshrn.s32 d8, q10, #16 + vqshrn.s32 d9, q11, #16 + + vaddw.s8 q3, q3, d4 + vaddw.s8 q4, q4, d4 + + vqmovn.s16 d12, q3 + vqmovn.s16 d13, q4 + cmp r8, #4 + vst1.32 {q6}, [r0]! bge L4Loop L1: -cmp r6, #0 +cmp r8, #0 beq End L1Loop: - cmp r7, #0 + cmp r9, #0 beq L1NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - vld1.32 {d0[0]}, [r1]! + vld1.32 {d6[0]}, [r1]! vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast0: - vld1.8 {d0[0]}, [r1] - vdup.8 d0, d0[0] + ldr r4, [r1] + vdup.s8 d6, r4 vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast1: - vld1.32 {d0[0]}, [r1]! - vld1.8 {d8[0]}, [r2] - vdup.8 d8, d8[0] + vld1.32 {d6[0]}, [r1]! + ldr r4, [r2] + vdup.s8 d8, r4 b L1Compute L1Compute: - subs r6, r6, #1 - vmovl.s8 q1, d0 - vmovl.s16 q2, d2 - vcvtq.f32.s32 q3, q2 - vmulq.f32 q3, q3, q13 + subs r8, r8, #1 + vmovl.s8 q3, d6 + vsubw.s8 q3, q3, d0 + vmovl.s16 q3, d6 + vmulq.s32 q3, q3, q13 vmovl.s8 q5, d8 + vsubw.s8 q5, q5, d1 vmovl.s16 q6, d10 - vcvtq.f32.s32 q7, q6 - vmulq.f32 q7, q7, q14 - - vmaxq.f32 q3, q3, q7 + vmulq.s32 q6, q6, q14 - vmulq.f32 q3, q3, q15 - vcvtq.s32.f32 q0, q3 - vqmovn.s32 d2, q0 - vqmovn.s16 d6, q1 + vmax.s32 q3, q3, q6 + vqshrn.s32 d6, q3, #16 + vaddw.s8 q3, q3, d4 + vqmovn.s16 d6, q3 vst1.32 {d6[0]}, [r0]! bne L1Loop End: vpop {q4-q7} -pop {r4, r5, r6, r7, r8, pc} +pop {r4, r5, r6, r7, r8, r9, pc} #endif #endif + diff --git a/source/backend/cpu/arm/arm32/MNNBinaryMinInt8.S b/source/backend/cpu/arm/arm32/MNNBinaryMinInt8.S index f9d93c688..70be4fafa 100644 --- a/source/backend/cpu/arm/arm32/MNNBinaryMinInt8.S +++ b/source/backend/cpu/arm/arm32/MNNBinaryMinInt8.S @@ -15,33 +15,45 @@ .align 5 asm_function MNNBinaryMinInt8 -// MNNBinaryMinInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// Auto Load -// r0: dst, r1:src0, r2:src1, r3: scale0 -// Load from sp -// r4:scale1, r5: outputScale, r6:size, r7: needBroadcast -push {r4, r5, r6, r7, r8, lr} - -ldr r4, [sp, #24] -ldr r5, [sp, #28] -ldr r6, [sp, #32] -ldr r7, [sp, #36] +// MNNBinaryMinInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// r0: dst, r1:src0, r2:src1, r3:quantScalesInt32 +// Load from sp: +// r4:quantScalesFp32, r5: offset0, r6: offset1, r7: outputoffset +// r8: size, r9: needBroadcast +push {r4, r5, r6, r7, r8, r9, lr} + +ldr r4, [sp, #28] +ldr r5, [sp, #32] +ldr r6, [sp, #36] +ldr r7, [sp, #40] +ldr r8, [sp, #44] +ldr r9, [sp, #48] vpush {q4-q7} -vld1.32 {q13}, [r3] // q13: scale0 -vld1.32 {q14}, [r4] // q14: scale1 -vld1.32 {q15}, [r5] // q15: outputScale +ldr r12, [r3] +vdup.s32 q13, r12 // scale +ldr r4, [r3, #4] +vdup.s32 q14, r4 + +vld1.8 {d0[0]}, [r5] +vld1.8 {d1[0]}, [r6] +vld1.8 {d4[0]}, [r7] +vdup.8 d0, d0[0] +vdup.8 d1, d1[0] +vdup.8 d4, d4[0] L4: -cmp r6, #4 +cmp r8, #4 blt L1 L4Loop: - cmp r7, #0 + cmp r9, #0 beq L4NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: @@ -50,131 +62,117 @@ L4Loop: b L4Compute L4NeedBroadcast0: - vld1.8 {d22[0]}, [r1] - vdup.8 d22, d22[0] - vdup.8 d23, d22[0] + ldr r4, [r1] + vdup.s8 q11, r4 vld1.32 {q12}, [r2]! b L4Compute L4NeedBroadcast1: vld1.32 {q11}, [r1]! - vld1.8 {d24[0]}, [r2] - vdup.8 d24, d24[0] - vdup.8 d25, d24[0] + ldr r4, [r2] + vdup.s8 q12, r4 b L4Compute L4Compute: - sub r6, r6, #4 + sub r8, r8, #4 vmovl.s8 q4, d22 vmovl.s8 q5, d23 - - vmovl.s16 q0, d8 - vmovl.s16 q1, d9 - vmovl.s16 q2, d10 - vmovl.s16 q3, d11 - - vcvtq.f32.s32 q0, q0 - vcvtq.f32.s32 q1, q1 - vcvtq.f32.s32 q2, q2 - vcvtq.f32.s32 q3, q3 - vmovl.s8 q6, d24 vmovl.s8 q7, d25 - vmulq.f32 q0, q0, q13 - vmulq.f32 q1, q1, q13 - vmulq.f32 q2, q2, q13 - vmulq.f32 q3, q3, q13 - - vmovl.s16 q8, d12 - vmovl.s16 q9, d13 - vmovl.s16 q10, d14 - vmovl.s16 q11, d15 - vcvtq.f32.s32 q8, q8 - vcvtq.f32.s32 q9, q9 - vcvtq.f32.s32 q10, q10 - vcvtq.f32.s32 q11, q11 - vmulq.f32 q8, q8, q14 - vmulq.f32 q9, q9, q14 - vmulq.f32 q10, q10, q14 - vmulq.f32 q11, q11, q14 - - vminq.f32 q0, q0, q8 - vminq.f32 q1, q1, q9 - vminq.f32 q2, q2, q10 - vminq.f32 q3, q3, q11 - - vmulq.f32 q0, q0, q15 - vmulq.f32 q1, q1, q15 - vmulq.f32 q2, q2, q15 - vmulq.f32 q3, q3, q15 - - vcvtq.s32.f32 q0, q0 - vcvtq.s32.f32 q1, q1 - vcvtq.s32.f32 q2, q2 - vcvtq.s32.f32 q3, q3 - - vqmovn.s32 d8, q0 - vqmovn.s32 d9, q1 - vqmovn.s32 d10, q2 - vqmovn.s32 d11, q3 + vsubw.s8 q4, q4, d0 + vsubw.s8 q5, q5, d0 + vsubw.s8 q6, q6, d1 + vsubw.s8 q7, q7, d1 - vqmovn.s16 d0, q4 - vqmovn.s16 d1, q5 - cmp r6, #4 - vst1.32 {q0}, [r0]! + vmovl.s16 q8, d8 + vmovl.s16 q9, d9 + vmovl.s16 q10, d10 + vmovl.s16 q11, d11 + + vmovl.s16 q3, d12 + vmovl.s16 q12, d13 + vmovl.s16 q15, d14 + vmovl.s16 q4, d15 + + vmulq.s32 q8, q8, q13 + vmulq.s32 q9, q9, q13 + vmulq.s32 q10, q10, q13 + vmulq.s32 q11, q11, q13 + + vmulq.s32 q3, q3, q14 + vmulq.s32 q12, q12, q14 + vmulq.s32 q15, q15, q14 + vmulq.s32 q4, q4, q14 + + vmin.s32 q8, q8, q3 + vmin.s32 q9, q9, q12 + vmin.s32 q10, q10, q15 + vmin.s32 q11, q11, q4 + + vqshrn.s32 d6, q8, #16 + vqshrn.s32 d7, q9, #16 + vqshrn.s32 d8, q10, #16 + vqshrn.s32 d9, q11, #16 + + vaddw.s8 q3, q3, d4 + vaddw.s8 q4, q4, d4 + + vqmovn.s16 d12, q3 + vqmovn.s16 d13, q4 + cmp r8, #4 + vst1.32 {q6}, [r0]! bge L4Loop L1: -cmp r6, #0 +cmp r8, #0 beq End L1Loop: - cmp r7, #0 + cmp r9, #0 beq L1NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - vld1.32 {d0[0]}, [r1]! + vld1.32 {d6[0]}, [r1]! vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast0: - vld1.8 {d0[0]}, [r1] - vdup.8 d0, d0[0] + ldr r4, [r1] + vdup.s8 d6, r4 vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast1: - vld1.32 {d0[0]}, [r1]! - vld1.8 {d8[0]}, [r2] - vdup.8 d8, d8[0] + vld1.32 {d6[0]}, [r1]! + ldr r4, [r2] + vdup.s8 d8, r4 b L1Compute L1Compute: - subs r6, r6, #1 - vmovl.s8 q1, d0 - vmovl.s16 q2, d2 - vcvtq.f32.s32 q3, q2 - vmulq.f32 q3, q3, q13 + subs r8, r8, #1 + vmovl.s8 q3, d6 + vsubw.s8 q3, q3, d0 + vmovl.s16 q3, d6 + vmulq.s32 q3, q3, q13 vmovl.s8 q5, d8 + vsubw.s8 q5, q5, d1 vmovl.s16 q6, d10 - vcvtq.f32.s32 q7, q6 - vmulq.f32 q7, q7, q14 - - vminq.f32 q3, q3, q7 + vmulq.s32 q6, q6, q14 - vmulq.f32 q3, q3, q15 - vcvtq.s32.f32 q0, q3 - vqmovn.s32 d2, q0 - vqmovn.s16 d6, q1 + vmin.s32 q3, q3, q6 + vqshrn.s32 d6, q3, #16 + vaddw.s8 q3, q3, d4 + vqmovn.s16 d6, q3 vst1.32 {d6[0]}, [r0]! bne L1Loop End: vpop {q4-q7} -pop {r4, r5, r6, r7, r8, pc} +pop {r4, r5, r6, r7, r8, r9, pc} #endif #endif + diff --git a/source/backend/cpu/arm/arm32/MNNBinaryMulInt8.S b/source/backend/cpu/arm/arm32/MNNBinaryMulInt8.S index bb66c3211..e369fd118 100644 --- a/source/backend/cpu/arm/arm32/MNNBinaryMulInt8.S +++ b/source/backend/cpu/arm/arm32/MNNBinaryMulInt8.S @@ -15,33 +15,47 @@ .align 5 asm_function MNNBinaryMulInt8 -// MNNBinaryMulInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// Auto Load -// r0: dst, r1:src0, r2:src1, r3: scale0 -// Load from sp -// r4:scale1, r5: outputScale, r6:size, r7: needBroadcast -push {r4, r5, r6, r7, r8, lr} - -ldr r4, [sp, #24] -ldr r5, [sp, #28] -ldr r6, [sp, #32] -ldr r7, [sp, #36] +// MNNBinaryMulInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// r0: dst, r1:src0, r2:src1, r3:quantScalesInt32 +// Load from sp: +// r4:quantScalesFp32, r5: offset0, r6: offset1, r7: outputoffset +// r8: size, r9: needBroadcast +push {r4, r5, r6, r7, r8, r9, lr} + +ldr r4, [sp, #28] +ldr r5, [sp, #32] +ldr r6, [sp, #36] +ldr r7, [sp, #40] +ldr r8, [sp, #44] +ldr r9, [sp, #48] vpush {q4-q7} -vld1.32 {q13}, [r3] // q13: scale0 -vld1.32 {q14}, [r4] // q14: scale1 -vld1.32 {q15}, [r5] // q15: outputScale +ldr r12, [r4] +vdup.f32 q13, r12 // scale +ldr lr, [r4, #4] +vdup.f32 q14, lr +ldr lr, [r4, #8] +vdup.f32 q15, lr + +vld1.8 {d0[0]}, [r5] +vld1.8 {d1[0]}, [r6] +vld1.8 {d4[0]}, [r7] +vdup.8 d0, d0[0] +vdup.8 d1, d1[0] +vdup.8 d4, d4[0] L4: -cmp r6, #4 +cmp r8, #4 blt L1 L4Loop: - cmp r7, #0 + cmp r9, #0 beq L4NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: @@ -50,131 +64,140 @@ L4Loop: b L4Compute L4NeedBroadcast0: - vld1.8 {d22[0]}, [r1] - vdup.8 d22, d22[0] - vdup.8 d23, d22[0] + ldr r4, [r1] + vdup.s8 q11, r4 vld1.32 {q12}, [r2]! b L4Compute L4NeedBroadcast1: vld1.32 {q11}, [r1]! - vld1.8 {d24[0]}, [r2] - vdup.8 d24, d24[0] - vdup.8 d25, d24[0] + ldr r4, [r2] + vdup.s8 q12, r4 b L4Compute L4Compute: - sub r6, r6, #4 + sub r8, r8, #4 vmovl.s8 q4, d22 vmovl.s8 q5, d23 - - vmovl.s16 q0, d8 - vmovl.s16 q1, d9 - vmovl.s16 q2, d10 - vmovl.s16 q3, d11 - - vcvtq.f32.s32 q0, q0 - vcvtq.f32.s32 q1, q1 - vcvtq.f32.s32 q2, q2 - vcvtq.f32.s32 q3, q3 - vmovl.s8 q6, d24 vmovl.s8 q7, d25 - vmulq.f32 q0, q0, q13 - vmulq.f32 q1, q1, q13 - vmulq.f32 q2, q2, q13 - vmulq.f32 q3, q3, q13 - - vmovl.s16 q8, d12 - vmovl.s16 q9, d13 - vmovl.s16 q10, d14 - vmovl.s16 q11, d15 - vcvtq.f32.s32 q8, q8 - vcvtq.f32.s32 q9, q9 - vcvtq.f32.s32 q10, q10 - vcvtq.f32.s32 q11, q11 - vmulq.f32 q8, q8, q14 - vmulq.f32 q9, q9, q14 - vmulq.f32 q10, q10, q14 - vmulq.f32 q11, q11, q14 - - vmulq.f32 q0, q0, q8 - vmulq.f32 q1, q1, q9 - vmulq.f32 q2, q2, q10 - vmulq.f32 q3, q3, q11 - - vmulq.f32 q0, q0, q15 - vmulq.f32 q1, q1, q15 - vmulq.f32 q2, q2, q15 - vmulq.f32 q3, q3, q15 - - vcvtq.s32.f32 q0, q0 - vcvtq.s32.f32 q1, q1 - vcvtq.s32.f32 q2, q2 - vcvtq.s32.f32 q3, q3 - - vqmovn.s32 d8, q0 - vqmovn.s32 d9, q1 - vqmovn.s32 d10, q2 - vqmovn.s32 d11, q3 + vsubw.s8 q4, q4, d0 + vsubw.s8 q5, q5, d0 + vsubw.s8 q6, q6, d1 + vsubw.s8 q7, q7, d1 - vqmovn.s16 d0, q4 - vqmovn.s16 d1, q5 - cmp r6, #4 - vst1.32 {q0}, [r0]! + vmovl.s16 q8, d8 + vmovl.s16 q9, d9 + vmovl.s16 q10, d10 + vmovl.s16 q11, d11 + + vmovl.s16 q3, d12 + vmovl.s16 q12, d13 + vmovl.s16 q1, d14 + vmovl.s16 q4, d15 + + vcvt.f32.s32 q8, q8 + vcvt.f32.s32 q9, q9 + vcvt.f32.s32 q10, q10 + vcvt.f32.s32 q11, q11 + + vcvt.f32.s32 q3, q3 + vcvt.f32.s32 q12, q12 + vcvt.f32.s32 q1, q1 + vcvt.f32.s32 q4, q4 + + vmul.f32 q8, q8, q13 + vmul.f32 q9, q9, q13 + vmul.f32 q10, q10, q13 + vmul.f32 q11, q11, q13 + + vmul.f32 q3, q3, q14 + vmul.f32 q12, q12, q14 + vmul.f32 q1, q1, q14 + vmul.f32 q4, q4, q14 + + vmul.f32 q8, q8, q3 + vmul.f32 q9, q9, q12 + vmul.f32 q10, q10, q1 + vmul.f32 q11, q11, q4 + + vmul.f32 q8, q8, q15 + vmul.f32 q9, q9, q15 + vmul.f32 q10, q10, q15 + vmul.f32 q11, q11, q15 + + vcvt.s32.f32 q8, q8 + vcvt.s32.f32 q9, q9 + vcvt.s32.f32 q10, q10 + vcvt.s32.f32 q11, q11 + + vqmovn.s32 d6, q8 + vqmovn.s32 d7, q9 + vqmovn.s32 d8, q10 + vqmovn.s32 d9, q11 + + vaddw.s8 q3, q3, d4 + vaddw.s8 q4, q4, d4 + + vqmovn.s16 d12, q3 + vqmovn.s16 d13, q4 + cmp r8, #4 + vst1.32 {q6}, [r0]! bge L4Loop L1: -cmp r6, #0 +cmp r8, #0 beq End L1Loop: - cmp r7, #0 + cmp r9, #0 beq L1NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - vld1.32 {d0[0]}, [r1]! + vld1.32 {d6[0]}, [r1]! vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast0: - vld1.8 {d0[0]}, [r1] - vdup.8 d0, d0[0] + ldr r4, [r1] + vdup.s8 d6, r4 vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast1: - vld1.32 {d0[0]}, [r1]! - vld1.8 {d8[0]}, [r2] - vdup.8 d8, d8[0] + vld1.32 {d6[0]}, [r1]! + ldr r4, [r2] + vdup.s8 d8, r4 b L1Compute L1Compute: - subs r6, r6, #1 - vmovl.s8 q1, d0 - vmovl.s16 q2, d2 - vcvtq.f32.s32 q3, q2 - vmulq.f32 q3, q3, q13 + subs r8, r8, #1 + vmovl.s8 q3, d6 + vsubw.s8 q3, q3, d0 + vmovl.s16 q3, d6 + vcvt.f32.s32 q3, q3 + vmul.f32 q3, q3, q13 vmovl.s8 q5, d8 + vsubw.s8 q5, q5, d1 vmovl.s16 q6, d10 - vcvtq.f32.s32 q7, q6 - vmulq.f32 q7, q7, q14 - - vmulq.f32 q3, q3, q7 - - vmulq.f32 q3, q3, q15 - vcvtq.s32.f32 q0, q3 - vqmovn.s32 d2, q0 - vqmovn.s16 d6, q1 + vcvt.f32.s32 q6, q6 + vmul.f32 q6, q6, q14 + + vmul.f32 q3, q3, q6 + vmul.f32 q3, q3, q15 + vcvt.s32.f32 q3, q3 + vqmovn.s32 d6, q3 + vaddw.s8 q3, q3, d4 + vqmovn.s16 d6, q3 vst1.32 {d6[0]}, [r0]! bne L1Loop End: vpop {q4-q7} -pop {r4, r5, r6, r7, r8, pc} +pop {r4, r5, r6, r7, r8, r9, pc} #endif #endif diff --git a/source/backend/cpu/arm/arm32/MNNBinarySqdInt8.S b/source/backend/cpu/arm/arm32/MNNBinarySqdInt8.S index 831e29d10..527ced2f5 100644 --- a/source/backend/cpu/arm/arm32/MNNBinarySqdInt8.S +++ b/source/backend/cpu/arm/arm32/MNNBinarySqdInt8.S @@ -15,33 +15,47 @@ .align 5 asm_function MNNBinarySqdInt8 -// MNNBinarySqdInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// Auto Load -// r0: dst, r1:src0, r2:src1, r3: scale0 -// Load from sp -// r4:scale1, r5: outputScale, r6:size, r7: needBroadcast -push {r4, r5, r6, r7, r8, lr} - -ldr r4, [sp, #24] -ldr r5, [sp, #28] -ldr r6, [sp, #32] -ldr r7, [sp, #36] +// MNNBinarySqdInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// r0: dst, r1:src0, r2:src1, r3:quantScalesInt32 +// Load from sp: +// r4:quantScalesFp32, r5: offset0, r6: offset1, r7: outputoffset +// r8: size, r9: needBroadcast +push {r4, r5, r6, r7, r8, r9, lr} + +ldr r4, [sp, #28] +ldr r5, [sp, #32] +ldr r6, [sp, #36] +ldr r7, [sp, #40] +ldr r8, [sp, #44] +ldr r9, [sp, #48] vpush {q4-q7} -vld1.32 {q13}, [r3] // q13: scale0 -vld1.32 {q14}, [r4] // q14: scale1 -vld1.32 {q15}, [r5] // q15: outputScale +ldr r12, [r4] +vdup.f32 q13, r12 // scale +ldr lr, [r4, #4] +vdup.f32 q14, lr +ldr lr, [r4, #8] +vdup.f32 q15, lr + +vld1.8 {d0[0]}, [r5] +vld1.8 {d1[0]}, [r6] +vld1.8 {d4[0]}, [r7] +vdup.8 d0, d0[0] +vdup.8 d1, d1[0] +vdup.8 d4, d4[0] L4: -cmp r6, #4 +cmp r8, #4 blt L1 L4Loop: - cmp r7, #0 + cmp r9, #0 beq L4NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: @@ -50,137 +64,146 @@ L4Loop: b L4Compute L4NeedBroadcast0: - vld1.8 {d22[0]}, [r1] - vdup.8 d22, d22[0] - vdup.8 d23, d22[0] + ldr r4, [r1] + vdup.s8 q11, r4 vld1.32 {q12}, [r2]! b L4Compute L4NeedBroadcast1: vld1.32 {q11}, [r1]! - vld1.8 {d24[0]}, [r2] - vdup.8 d24, d24[0] - vdup.8 d25, d24[0] + ldr r4, [r2] + vdup.s8 q12, r4 b L4Compute L4Compute: - sub r6, r6, #4 + sub r8, r8, #4 vmovl.s8 q4, d22 vmovl.s8 q5, d23 - - vmovl.s16 q0, d8 - vmovl.s16 q1, d9 - vmovl.s16 q2, d10 - vmovl.s16 q3, d11 - - vcvtq.f32.s32 q0, q0 - vcvtq.f32.s32 q1, q1 - vcvtq.f32.s32 q2, q2 - vcvtq.f32.s32 q3, q3 - vmovl.s8 q6, d24 vmovl.s8 q7, d25 - vmulq.f32 q0, q0, q13 - vmulq.f32 q1, q1, q13 - vmulq.f32 q2, q2, q13 - vmulq.f32 q3, q3, q13 - - vmovl.s16 q8, d12 - vmovl.s16 q9, d13 - vmovl.s16 q10, d14 - vmovl.s16 q11, d15 - vcvtq.f32.s32 q8, q8 - vcvtq.f32.s32 q9, q9 - vcvtq.f32.s32 q10, q10 - vcvtq.f32.s32 q11, q11 - vmulq.f32 q8, q8, q14 - vmulq.f32 q9, q9, q14 - vmulq.f32 q10, q10, q14 - vmulq.f32 q11, q11, q14 - - vsubq.f32 q0, q0, q8 - vsubq.f32 q1, q1, q9 - vsubq.f32 q2, q2, q10 - vsubq.f32 q3, q3, q11 - - vmulq.f32 q0, q0, q0 - vmulq.f32 q1, q1, q1 - vmulq.f32 q2, q2, q2 - vmulq.f32 q3, q3, q3 - - vmulq.f32 q0, q0, q15 - vmulq.f32 q1, q1, q15 - vmulq.f32 q2, q2, q15 - vmulq.f32 q3, q3, q15 - - vcvtq.s32.f32 q0, q0 - vcvtq.s32.f32 q1, q1 - vcvtq.s32.f32 q2, q2 - vcvtq.s32.f32 q3, q3 - - vqmovn.s32 d8, q0 - vqmovn.s32 d9, q1 - vqmovn.s32 d10, q2 - vqmovn.s32 d11, q3 + vsubw.s8 q4, q4, d0 + vsubw.s8 q5, q5, d0 + vsubw.s8 q6, q6, d1 + vsubw.s8 q7, q7, d1 - vqmovn.s16 d0, q4 - vqmovn.s16 d1, q5 - cmp r6, #4 - vst1.32 {q0}, [r0]! + vmovl.s16 q8, d8 + vmovl.s16 q9, d9 + vmovl.s16 q10, d10 + vmovl.s16 q11, d11 + + vmovl.s16 q3, d12 + vmovl.s16 q12, d13 + vmovl.s16 q1, d14 + vmovl.s16 q4, d15 + + vcvt.f32.s32 q8, q8 + vcvt.f32.s32 q9, q9 + vcvt.f32.s32 q10, q10 + vcvt.f32.s32 q11, q11 + + vcvt.f32.s32 q3, q3 + vcvt.f32.s32 q12, q12 + vcvt.f32.s32 q1, q1 + vcvt.f32.s32 q4, q4 + + vmul.f32 q8, q8, q13 + vmul.f32 q9, q9, q13 + vmul.f32 q10, q10, q13 + vmul.f32 q11, q11, q13 + + vmul.f32 q3, q3, q14 + vmul.f32 q12, q12, q14 + vmul.f32 q1, q1, q14 + vmul.f32 q4, q4, q14 + + vsub.f32 q8, q8, q3 + vsub.f32 q9, q9, q12 + vsub.f32 q10, q10, q1 + vsub.f32 q11, q11, q4 + + vmul.f32 q8, q8, q8 + vmul.f32 q9, q9, q9 + vmul.f32 q10, q10, q10 + vmul.f32 q11, q11, q11 + + vmul.f32 q8, q8, q15 + vmul.f32 q9, q9, q15 + vmul.f32 q10, q10, q15 + vmul.f32 q11, q11, q15 + + vcvt.s32.f32 q8, q8 + vcvt.s32.f32 q9, q9 + vcvt.s32.f32 q10, q10 + vcvt.s32.f32 q11, q11 + + vqmovn.s32 d6, q8 + vqmovn.s32 d7, q9 + vqmovn.s32 d8, q10 + vqmovn.s32 d9, q11 + + vaddw.s8 q3, q3, d4 + vaddw.s8 q4, q4, d4 + + vqmovn.s16 d12, q3 + vqmovn.s16 d13, q4 + cmp r8, #4 + vst1.32 {q6}, [r0]! bge L4Loop L1: -cmp r6, #0 +cmp r8, #0 beq End L1Loop: - cmp r7, #0 + cmp r9, #0 beq L1NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - vld1.32 {d0[0]}, [r1]! + vld1.32 {d6[0]}, [r1]! vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast0: - vld1.8 {d0[0]}, [r1] - vdup.8 d0, d0[0] + ldr r4, [r1] + vdup.s8 d6, r4 vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast1: - vld1.32 {d0[0]}, [r1]! - vld1.8 {d8[0]}, [r2] - vdup.8 d8, d8[0] + vld1.32 {d6[0]}, [r1]! + ldr r4, [r2] + vdup.s8 d8, r4 b L1Compute L1Compute: - subs r6, r6, #1 - vmovl.s8 q1, d0 - vmovl.s16 q2, d2 - vcvtq.f32.s32 q3, q2 - vmulq.f32 q3, q3, q13 + subs r8, r8, #1 + vmovl.s8 q3, d6 + vsubw.s8 q3, q3, d0 + vmovl.s16 q3, d6 + vcvt.f32.s32 q3, q3 + vmul.f32 q3, q3, q13 vmovl.s8 q5, d8 + vsubw.s8 q5, q5, d1 vmovl.s16 q6, d10 - vcvtq.f32.s32 q7, q6 - vmulq.f32 q7, q7, q14 - - vsubq.f32 q3, q3, q7 - vmulq.f32 q3, q3, q3 - - vmulq.f32 q3, q3, q15 - vcvtq.s32.f32 q0, q3 - vqmovn.s32 d2, q0 - vqmovn.s16 d6, q1 + vcvt.f32.s32 q6, q6 + vmul.f32 q6, q6, q14 + + vsub.f32 q3, q3, q6 + vmul.f32 q3, q3, q3 + vmul.f32 q3, q3, q15 + vcvt.s32.f32 q3, q3 + vqmovn.s32 d6, q3 + vaddw.s8 q3, q3, d4 + vqmovn.s16 d6, q3 vst1.32 {d6[0]}, [r0]! bne L1Loop End: vpop {q4-q7} -pop {r4, r5, r6, r7, r8, pc} +pop {r4, r5, r6, r7, r8, r9, pc} #endif #endif diff --git a/source/backend/cpu/arm/arm32/MNNBinarySubInt8.S b/source/backend/cpu/arm/arm32/MNNBinarySubInt8.S index 83860dfe5..201fc78c5 100644 --- a/source/backend/cpu/arm/arm32/MNNBinarySubInt8.S +++ b/source/backend/cpu/arm/arm32/MNNBinarySubInt8.S @@ -15,33 +15,45 @@ .align 5 asm_function MNNBinarySubInt8 -// MNNBinarySubInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// Auto Load -// r0: dst, r1:src0, r2:src1, r3: scale0 -// Load from sp -// r4:scale1, r5: outputScale, r6:size, r7: needBroadcast -push {r4, r5, r6, r7, r8, lr} - -ldr r4, [sp, #24] -ldr r5, [sp, #28] -ldr r6, [sp, #32] -ldr r7, [sp, #36] +// MNNBinarySubInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// r0: dst, r1:src0, r2:src1, r3:quantScalesInt32 +// Load from sp: +// r4:quantScalesFp32, r5: offset0, r6: offset1, r7: outputoffset +// r8: size, r9: needBroadcast +push {r4, r5, r6, r7, r8, r9, lr} + +ldr r4, [sp, #28] +ldr r5, [sp, #32] +ldr r6, [sp, #36] +ldr r7, [sp, #40] +ldr r8, [sp, #44] +ldr r9, [sp, #48] vpush {q4-q7} -vld1.32 {q13}, [r3] // q13: scale0 -vld1.32 {q14}, [r4] // q14: scale1 -vld1.32 {q15}, [r5] // q15: outputScale +ldr r12, [r3] +vdup.s32 q13, r12 // scale +ldr r4, [r3, #4] +vdup.s32 q14, r4 + +vld1.8 {d0[0]}, [r5] +vld1.8 {d1[0]}, [r6] +vld1.8 {d4[0]}, [r7] +vdup.8 d0, d0[0] +vdup.8 d1, d1[0] +vdup.8 d4, d4[0] L4: -cmp r6, #4 +cmp r8, #4 blt L1 L4Loop: - cmp r7, #0 + cmp r9, #0 beq L4NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: @@ -50,131 +62,116 @@ L4Loop: b L4Compute L4NeedBroadcast0: - vld1.8 {d22[0]}, [r1] - vdup.8 d22, d22[0] - vdup.8 d23, d22[0] + ldr r4, [r1] + vdup.s8 q11, r4 vld1.32 {q12}, [r2]! b L4Compute L4NeedBroadcast1: vld1.32 {q11}, [r1]! - vld1.8 {d24[0]}, [r2] - vdup.8 d24, d24[0] - vdup.8 d25, d24[0] + ldr r4, [r2] + vdup.s8 q12, r4 b L4Compute L4Compute: - sub r6, r6, #4 + sub r8, r8, #4 vmovl.s8 q4, d22 vmovl.s8 q5, d23 - - vmovl.s16 q0, d8 - vmovl.s16 q1, d9 - vmovl.s16 q2, d10 - vmovl.s16 q3, d11 - - vcvtq.f32.s32 q0, q0 - vcvtq.f32.s32 q1, q1 - vcvtq.f32.s32 q2, q2 - vcvtq.f32.s32 q3, q3 - vmovl.s8 q6, d24 vmovl.s8 q7, d25 - vmulq.f32 q0, q0, q13 - vmulq.f32 q1, q1, q13 - vmulq.f32 q2, q2, q13 - vmulq.f32 q3, q3, q13 - - vmovl.s16 q8, d12 - vmovl.s16 q9, d13 - vmovl.s16 q10, d14 - vmovl.s16 q11, d15 - vcvtq.f32.s32 q8, q8 - vcvtq.f32.s32 q9, q9 - vcvtq.f32.s32 q10, q10 - vcvtq.f32.s32 q11, q11 - vmulq.f32 q8, q8, q14 - vmulq.f32 q9, q9, q14 - vmulq.f32 q10, q10, q14 - vmulq.f32 q11, q11, q14 - - vsubq.f32 q0, q0, q8 - vsubq.f32 q1, q1, q9 - vsubq.f32 q2, q2, q10 - vsubq.f32 q3, q3, q11 - - vmulq.f32 q0, q0, q15 - vmulq.f32 q1, q1, q15 - vmulq.f32 q2, q2, q15 - vmulq.f32 q3, q3, q15 - - vcvtq.s32.f32 q0, q0 - vcvtq.s32.f32 q1, q1 - vcvtq.s32.f32 q2, q2 - vcvtq.s32.f32 q3, q3 - - vqmovn.s32 d8, q0 - vqmovn.s32 d9, q1 - vqmovn.s32 d10, q2 - vqmovn.s32 d11, q3 + vsubw.s8 q4, q4, d0 + vsubw.s8 q5, q5, d0 + vsubw.s8 q6, q6, d1 + vsubw.s8 q7, q7, d1 + + vmovl.s16 q8, d8 + vmovl.s16 q9, d9 + vmovl.s16 q10, d10 + vmovl.s16 q11, d11 + + vmovl.s16 q3, d12 + vmovl.s16 q12, d13 + vmovl.s16 q15, d14 + vmovl.s16 q4, d15 + + vmulq.s32 q8, q8, q13 + vmulq.s32 q9, q9, q13 + vmulq.s32 q10, q10, q13 + vmulq.s32 q11, q11, q13 + + vmulq.s32 q3, q3, q14 + vmulq.s32 q12, q12, q14 + vmulq.s32 q15, q15, q14 + vmulq.s32 q4, q4, q14 + + vsub.s32 q8, q8, q3 + vsub.s32 q9, q9, q12 + vsub.s32 q10, q10, q15 + vsub.s32 q11, q11, q4 + + vqshrn.s32 d6, q8, #16 + vqshrn.s32 d7, q9, #16 + vqshrn.s32 d8, q10, #16 + vqshrn.s32 d9, q11, #16 + + vaddw.s8 q3, q3, d4 + vaddw.s8 q4, q4, d4 - vqmovn.s16 d0, q4 - vqmovn.s16 d1, q5 - cmp r6, #4 - vst1.32 {q0}, [r0]! + vqmovn.s16 d12, q3 + vqmovn.s16 d13, q4 + cmp r8, #4 + vst1.32 {q6}, [r0]! bge L4Loop L1: -cmp r6, #0 +cmp r8, #0 beq End L1Loop: - cmp r7, #0 + cmp r9, #0 beq L1NeedBroadcast0 - cmp r7, #1 + cmp r9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - vld1.32 {d0[0]}, [r1]! + vld1.32 {d6[0]}, [r1]! vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast0: - vld1.8 {d0[0]}, [r1] - vdup.8 d0, d0[0] + ldr r4, [r1] + vdup.s8 d6, r4 vld1.32 {d8[0]}, [r2]! b L1Compute L1NeedBroadcast1: - vld1.32 {d0[0]}, [r1]! - vld1.8 {d8[0]}, [r2] - vdup.8 d8, d8[0] + vld1.32 {d6[0]}, [r1]! + ldr r4, [r2] + vdup.s8 d8, r4 b L1Compute L1Compute: - subs r6, r6, #1 - vmovl.s8 q1, d0 - vmovl.s16 q2, d2 - vcvtq.f32.s32 q3, q2 - vmulq.f32 q3, q3, q13 + subs r8, r8, #1 + vmovl.s8 q3, d6 + vsubw.s8 q3, q3, d0 + vmovl.s16 q3, d6 + vmulq.s32 q3, q3, q13 vmovl.s8 q5, d8 + vsubw.s8 q5, q5, d1 vmovl.s16 q6, d10 - vcvtq.f32.s32 q7, q6 - vmulq.f32 q7, q7, q14 - - vsubq.f32 q3, q3, q7 + vmulq.s32 q6, q6, q14 - vmulq.f32 q3, q3, q15 - vcvtq.s32.f32 q0, q3 - vqmovn.s32 d2, q0 - vqmovn.s16 d6, q1 + vsub.s32 q3, q3, q6 + vqshrn.s32 d6, q3, #16 + vaddw.s8 q3, q3, d4 + vqmovn.s16 d6, q3 vst1.32 {d6[0]}, [r0]! bne L1Loop End: vpop {q4-q7} -pop {r4, r5, r6, r7, r8, pc} +pop {r4, r5, r6, r7, r8, r9, pc} #endif #endif diff --git a/source/backend/cpu/arm/arm32/MNNCubicLineC16.S b/source/backend/cpu/arm/arm32/MNNCubicLineC16.S index 74a8be5b0..c294ab34d 100644 --- a/source/backend/cpu/arm/arm32/MNNCubicLineC16.S +++ b/source/backend/cpu/arm/arm32/MNNCubicLineC16.S @@ -21,21 +21,24 @@ vcvt.s32.f32 \x, q13 .endm asm_function MNNCubicLineC16 -// void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, -// size_t number); +// void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, int8_t* zeroPoint, +// size_t number, ssize_t minValue, ssize_t maxValue); // Auto load: r0: dst, r1: A, r2: B, r3: C -// r4: D, r11: t, lr: number +// r4: D, r11: t, r5: zeroPoint, lr: number, r6:minValue, r7:maxValue push {r4-r8, r10-r11, lr} ldr r4, [sp, #32] ldr r11, [sp, #36] - -ldr lr, [sp, #40] +ldr r5, [sp, #40] +ldr lr, [sp, #44] +ldr r6, [sp, #48] +ldr r7, [sp, #52] vpush {q4-q7} cmp lr, #0 beq END ldr r10, [r11, #0] + L1Loop: //B vld1.32 {q3, q4}, [r2]! @@ -119,10 +122,10 @@ vmla.f32 q13, q6, d4[1] vmov.f32 q1, #0.5 vmov.f32 q2, #-0.5 -vmov.s8 d14, #127 -vmov.s8 d15, #0 -vsub.s8 d15, d15, d14 - +vdup.s8 q7, r7 +vdup.s8 q8, r6 +vld1.8 {d0[0]}, [r5] +vdup.8 d0, d0[0] _vroundq_f32 q1, q2, q10 _vroundq_f32 q1, q2, q11 @@ -133,13 +136,15 @@ vqmovn.s32 d20, q10 vqmovn.s32 d21, q11 vqmovn.s32 d22, q12 vqmovn.s32 d23, q13 -vqmovn.s16 d20, q10 // Store in q15. + +vaddw.s8 q10, q10, d0 +vaddw.s8 q11, q11, d0 + +vqmovn.s16 d20, q10 vqmovn.s16 d21, q11 -vmax.s8 d20, d20, d15 -vmin.s8 d20, d20, d14 -vmax.s8 d21, d21, d15 -vmin.s8 d21, d21, d14 +vmax.s8 q10, q10, q8 +vmin.s8 q10, q10, q7 vst1.8 {q10}, [r0]! diff --git a/source/backend/cpu/arm/arm32/MNNCubicSampleC16.S b/source/backend/cpu/arm/arm32/MNNCubicSampleC16.S index fa1ae962f..ddb064f6f 100644 --- a/source/backend/cpu/arm/arm32/MNNCubicSampleC16.S +++ b/source/backend/cpu/arm/arm32/MNNCubicSampleC16.S @@ -14,18 +14,22 @@ .align 5 asm_function MNNCubicSampleC16 -// void MNNCubicSampleC16(const int8_t* src, float* dst, const int32_t* position, const float* factor, size_t number); +// void MNNCubicSampleC16(const int8_t* src, float* dst, const int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); // Auto load: r0: src, r1: dst, r2: position, r3: factor -// r4: number +// r10, zeroPoint, r4: number push {r4-r8, r10, lr} -ldr r4, [sp, #28] +ldr r10, [sp, #28] +ldr r4, [sp, #32] mov lr, #16 vpush {q4-q7} cmp r4, #0 beq END +vld1.8 {d30[0]}, [r10] +vdup.8 d30, d30[0] // zeroPoint + L1Loop: ldr r5, [r2, #0] ldr r6, [r2, #4] @@ -46,6 +50,8 @@ add r8, r8, r0 vld1.8 {q0}, [r6] vmovl.s8 q1, d0 vmovl.s8 q2, d1 +vsubw.s8 q1, q1, d30 +vsubw.s8 q2, q2, d30 vmovl.s16 q3, d2 vmovl.s16 q4, d3 vmovl.s16 q5, d4 @@ -54,6 +60,8 @@ vmovl.s16 q6, d5 vld1.8 {q7}, [r7] vmovl.s8 q8, d14 vmovl.s8 q9, d15 +vsubw.s8 q8, q8, d30 +vsubw.s8 q9, q9, d30 vmovl.s16 q10, d16 vmovl.s16 q11, d17 vmovl.s16 q12, d18 @@ -101,6 +109,8 @@ vmla.f32 q13, q6, d1[1] vld1.8 {q0}, [r5] vmovl.s8 q1, d0 vmovl.s8 q2, d1 +vsubw.s8 q1, q1, d30 +vsubw.s8 q2, q2, d30 vmovl.s16 q3, d2 vmovl.s16 q4, d3 vmovl.s16 q5, d4 @@ -145,6 +155,8 @@ vmla.f32 q13, q6, d3[1] vld1.8 {q7}, [r8] vmovl.s8 q8, d14 vmovl.s8 q9, d15 +vsubw.s8 q8, q8, d30 +vsubw.s8 q9, q9, d30 vmovl.s16 q3, d16 vmovl.s16 q4, d17 vmovl.s16 q5, d18 diff --git a/source/backend/cpu/arm/arm32/MNNScaleAndAddBiasInt8.S b/source/backend/cpu/arm/arm32/MNNScaleAndAddBiasInt8.S index 685cdf1f2..de20ab463 100644 --- a/source/backend/cpu/arm/arm32/MNNScaleAndAddBiasInt8.S +++ b/source/backend/cpu/arm/arm32/MNNScaleAndAddBiasInt8.S @@ -15,23 +15,29 @@ asm_function MNNScaleAndAddBiasInt8 // MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, const int32_t* alpha, int32_t mShiftBits, -// ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack) +// ssize_t minValue, ssize_t maxValue, int8_t* inputZeroPoint, int8_t* outputZeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack) //Auto: r0:dst, r1:src, r2:bias, r3:alpha -//Load from sp: r4:mShiftBits, r5:minValue, r6:maxValue, r7:zeroPoint, r8:planeNumber, r10:biasNumber +//Load from sp: r4:mShiftBits, r5:minValue, r6:maxValue, r7:inputZeroPoint, r12:outputZeroPoint, r8:planeNumber, r10:biasNumber push {r4-r8, r10-r12, lr} ldr r4, [sp, #36] ldr r5, [sp, #40] ldr r6, [sp, #44] ldr r7, [sp, #48] -ldr r8, [sp, #52] -ldr r10, [sp, #56] +ldr r12, [sp, #52] +ldr r8, [sp, #56] +ldr r10, [sp, #60] vpush{q4-q7} vdup.s8 q7, r5 vdup.s8 q8, r6 +vld1.8 {d24[0]}, [r7] +vld1.8 {d26[0]}, [r12] +vdup.8 d24, d24[0] +vdup.8 d26, d26[0] + cmp r8, #0 beq BSEnd @@ -52,6 +58,8 @@ BSLoopZ: vld1.8 {q0}, [r1]! // q0: 4x(4xint8_t) vmovl.s8 q1, d0 vmovl.s8 q2, d1 + vsubw.s8 q1, q1, d24 + vsubw.s8 q2, q2, d24 vmovl.s16 q3, d2 vmovl.s16 q4, d3 vmovl.s16 q5, d4 @@ -72,6 +80,9 @@ BSLoopZ: vrshrn.s32 d10, q5, #15 vrshrn.s32 d11, q6, #15 + vaddw.s8 q3, q3, d26 + vaddw.s8 q5, q5, d26 + vqmovn.s16 d6, q3 vqmovn.s16 d7, q5 @@ -91,7 +102,9 @@ BSLoopZ: BSLoopP2: vld1.8 {d0}, [r1]! // q0: 2x(4xint8_t) + //vsub.s8 d0, d0, d24 vmovl.s8 q1, d0 + vsubw.s8 q1, q1, d24 vmovl.s16 q3, d2 vmovl.s16 q4, d3 @@ -103,6 +116,7 @@ BSLoopZ: vrshrn.s32 d6, q3, #15 vrshrn.s32 d7, q4, #15 + vaddw.s8 q3, q3, d26 vqmovn.s16 d6, q3 @@ -123,6 +137,7 @@ BSLoopZ: vdup.32 d0, lr vmovl.s8 q1, d0 + vsubw.s8 q1, q1, d24 vmovl.s16 q3, d2 vmul.s32 q3, q3, q14 @@ -130,7 +145,7 @@ BSLoopZ: vrshrn.s32 d6, q3, #15 vmov.32 d7, d6 - + vaddw.s8 q3, q3, d26 vqmovn.s16 d6, q3 vmax.s8 d6, d6, d14 diff --git a/source/backend/cpu/arm/arm64/MNNBilinearLineC8.S b/source/backend/cpu/arm/arm64/MNNBilinearLineC8.S index d8e87bf34..0e31ad489 100644 --- a/source/backend/cpu/arm/arm64/MNNBilinearLineC8.S +++ b/source/backend/cpu/arm/arm64/MNNBilinearLineC8.S @@ -11,20 +11,20 @@ .text .align 5 asm_function MNNBilinearLineC8 -// void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, size_t number) +// void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, int8_t* zeroPoint, size_t number) // Auto load: -// x0: dst, x1: src0, x2: src1, x3: factor, x4: number +// x0: dst, x1: src0, x2: src1, x3: factor, x4: zeroPoint, x5: number stp d14, d15, [sp, #-64]! stp d12, d13, [sp, #16] stp d10, d11, [sp, #32] stp d8, d9, [sp, #48] -cmp x4, #0 +cmp x5, #0 beq END -ldr w5, [x3, #0] // factor -dup v31.4s, w5 // v31: df +ld1 {v31.s}[0], [x3] +dup v31.4s, v31.s[0] // v31: df fmov s30, #1.0 // v30: sf=1-df fsub s30, s30, s31 movi v1.4s, #128 // s1=128 @@ -33,13 +33,13 @@ fmul s30, s30, s1 dup v31.8h, v31.h[0] dup v30.8h, v30.h[0] -cmp x4, #0 +cmp x5, #0 beq END -cmp x4, #2 +cmp x5, #2 blt L1Loop -cmp x4, #4 +cmp x5, #4 blt L2Loop -cmp x4, #8 +cmp x5, #8 blt L4Loop L8Loop: @@ -93,7 +93,6 @@ smull2 v1.4s, v19.8h, v30.8h smlal v0.4s, v23.4h, v31.4h smlal2 v1.4s, v23.8h, v31.8h - shrn v8.4h, v8.4s, #14 shrn2 v8.8h, v9.4s, #14 @@ -118,26 +117,26 @@ shrn2 v28.8h, v29.4s, #14 shrn v0.4h, v0.4s, #14 shrn2 v0.8h, v1.4s, #14 -sqxtn v8.8b, v8.8h -sqxtn2 v8.16b, v10.8h -sqxtn v9.8b, v12.8h -sqxtn2 v9.16b, v14.8h +sqxtn v2.8b, v8.8h +sqxtn2 v2.16b, v10.8h +sqxtn v3.8b, v12.8h +sqxtn2 v3.16b, v14.8h -sqxtn v10.8b, v24.8h -sqxtn2 v10.16b, v26.8h -sqxtn v11.8b, v28.8h -sqxtn2 v11.16b, v0.8h +sqxtn v4.8b, v24.8h +sqxtn2 v4.16b, v26.8h +sqxtn v5.8b, v28.8h +sqxtn2 v5.16b, v0.8h -st1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x0], #64 +st1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x0], #64 -sub x4, x4, #8 -cmp x4, #8 +sub x5, x5, #8 +cmp x5, #8 bge L8Loop -cmp x4, #0 +cmp x5, #0 beq END -cmp x4, #2 +cmp x5, #2 blt L1Loop -cmp x4, #4 +cmp x5, #4 blt L2Loop L4Loop: @@ -145,7 +144,6 @@ L4Loop: ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x1], #64 ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x2], #64 - smull v8.4s, v0.4h, v30.4h smull2 v9.4s, v0.8h, v30.8h smlal v8.4s, v4.4h, v31.4h @@ -185,12 +183,12 @@ sqxtn2 v9.16b, v14.8h st1 {v8.16b, v9.16b}, [x0], #32 -sub x4, x4, #4 -cmp x4, #4 +sub x5, x5, #4 +cmp x5, #4 bge L4Loop -cmp x4, #0 +cmp x5, #0 beq END -cmp x4, #2 +cmp x5, #2 blt L1Loop L2Loop: @@ -216,17 +214,15 @@ shrn2 v10.8h, v11.4s, #14 sqxtn v8.8b, v8.8h sqxtn2 v8.16b, v10.8h - st1 {v8.16b}, [x0], #16 -sub x4, x4, #2 -cmp x4, #2 +sub x5, x5, #2 +cmp x5, #2 bge L2Loop -cmp x4, #0 +cmp x5, #0 beq END L1Loop: - ld1 {v0.8h}, [x1], #16 ld1 {v1.8h}, [x2], #16 @@ -237,13 +233,12 @@ smlal2 v9.4s, v1.8h, v31.8h shrn v8.4h, v8.4s, #14 shrn2 v8.8h, v9.4s, #14 - sqxtn v8.8b, v8.8h st1 {v8.8b}, [x0], #8 -sub x4, x4, #1 -cmp x4, #1 +sub x5, x5, #1 +cmp x5, #1 bge L1Loop END: diff --git a/source/backend/cpu/arm/arm64/MNNBilinearSampleC8.S b/source/backend/cpu/arm/arm64/MNNBilinearSampleC8.S index f58fe1af3..85cf65bbf 100644 --- a/source/backend/cpu/arm/arm64/MNNBilinearSampleC8.S +++ b/source/backend/cpu/arm/arm64/MNNBilinearSampleC8.S @@ -11,10 +11,10 @@ .text .align 5 asm_function MNNBilinearSampleC8 -// void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, size_t number); +// void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); // Auto load: -// x0: src, x1: dst, x2: position, x3: factor, x4: number +// x0: src, x1: dst, x2: position, x3: factor, x4:zeroPoint, x5: number stp d14, d15, [sp, #(-16 * 7)]! stp d12, d13, [sp, #16] @@ -28,11 +28,11 @@ mov w15, #8 // w15: pack uxtw x15, w15 movi v14.4s, #128 -cmp x4, #0 +cmp x5, #0 beq END -cmp x4, #2 +cmp x5, #2 blt L1Loop -cmp x4, #4 +cmp x5, #4 blt L2Loop @@ -104,6 +104,10 @@ ld1 {v5.8b}, [x10] ld1 {v6.8b}, [x13] ld1 {v7.8b}, [x14] +cmp w4, #0 +beq L4COMPUTE + +L4COMPUTE: smull v8.8h, v0.8b, v30.8b smlal v8.8h, v1.8b, v31.8b smull v9.8h, v2.8b, v28.8b @@ -115,12 +119,12 @@ smlal v11.8h, v7.8b, v25.8b st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x1], #64 -sub x4, x4, #4 -cmp x4, #4 +sub x5, x5, #4 +cmp x5, #4 bge L4Loop -cmp x4, #0 +cmp x5, #0 beq END -cmp x4, #2 +cmp x5, #2 blt L1Loop L2Loop: @@ -160,6 +164,9 @@ ld1 {v1.8b}, [x8] ld1 {v2.8b}, [x11] ld1 {v3.8b}, [x12] +cmp w4, #0 +beq L2COMPUTE +L2COMPUTE: smull v4.8h, v0.8b, v30.8b smlal v4.8h, v1.8b, v31.8b @@ -168,17 +175,15 @@ smlal v5.8h, v3.8b, v29.8b st1 {v4.8h, v5.8h}, [x1], #32 -sub x4, x4, #2 -cmp x4, #2 +sub x5, x5, #2 +cmp x5, #2 bge L2Loop -cmp x4, #0 +cmp x5, #0 beq END L1Loop: -ldr w5, [x3, #0] -add x3, x3, #4 - -dup v31.4s, w5 +ld1 {v31.s}[0], [x3], #4 +dup v31.4s, v31.s[0] fmov s30, #1.0 fsub s30, s30, s31 fmul s30, s30, s14 // (float)t -> (int16)t @@ -200,14 +205,13 @@ add x10, x0, x8 ld1 {v0.8b}, [x9] ld1 {v8.8b}, [x10] - smull v1.8h, v0.8b, v30.8b smlal v1.8h, v8.8b, v31.8b st1 {v1.8h}, [x1], #16 -sub x4, x4, #1 -cmp x4, #1 +sub x5, x5, #1 +cmp x5, #1 bge L1Loop END: diff --git a/source/backend/cpu/arm/arm64/MNNBinaryAddInt8.S b/source/backend/cpu/arm/arm64/MNNBinaryAddInt8.S index 2716544fe..5df863adf 100644 --- a/source/backend/cpu/arm/arm64/MNNBinaryAddInt8.S +++ b/source/backend/cpu/arm/arm64/MNNBinaryAddInt8.S @@ -13,307 +13,305 @@ .align 5 asm_function MNNBinaryAddInt8 -// MNNBinaryAddInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// x0: dst, x1:src0, x2:src1, x3:scale0, x4:scale1, x5:outputScale, x6:size, x7: needBroadcast - -cmp x6, #0 +// MNNBinaryAddInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// x0: dst, x1:src0, x2:src1, x3:quantScalesInt32, x4:quantScalesFp32, x5: offset0, x6: offset1, x7: outputoffset +// Load from sp: +//x8: size, x9: needBroadcast + +ldr x8, [sp, #0] +ldr x9, [sp, #8] + +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +cmp x8, #0 beq End -ld1 {v29.4s}, [x3] // scale0 -ld1 {v30.4s}, [x4] // scale1 -ld1 {v31.4s}, [x5] // outputScale +ldr w4, [x3, #8] +ldr w3, [x3] +mov v0.s[0], w3 +mov v0.s[1], w4 -cmp x6, #8 +cmp x8, #8 bge L8Loop -cmp x6, #4 +cmp x8, #4 bge L4Loop blt L1 L8Loop: - cmp x7, #0 + cmp x9, #0 beq L8NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L8NeedBroadcast1 L8NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 // input0[0]: int8x16 - ld1 {v28.16b}, [x2], #16 // input1[0]: int8x16 - ld1 {v14.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 // input00, input01 + ld1 {v5.16b, v6.16b}, [x2], #32 // input10, input11 b L8Compute L8NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v14.b}[0], [x1] - dup v14.16b, v14.b[0] - ld1 {v28.16b}, [x2], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1r {v3.16b}, [x1] + ld1r {v4.16b}, [x1] + ld1 {v5.16b, v6.16b}, [x2], #32 b L8Compute L8NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v15.b}[0], [x2] - dup v15.16b, v15.b[0] - ld1 {v27.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v14.16b}, [x1], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 + ld1r {v5.16b}, [x2] + ld1r {v6.16b}, [x2] b L8Compute L8Compute: - sxtl v16.8h, v27.8b // v16: int16x8 - sxtl2 v17.8h, v27.16b // v17: int16x8 - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v10.8h, v14.8b // v10: int16x8 - sxtl2 v11.8h, v14.16b // v11: int16x8 - sxtl v12.8h, v15.8b - sxtl2 v13.8h, v15.16b + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v9.8h, v4.8b + sxtl2 v10.8h, v4.16b + + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + sxtl v13.8h, v6.8b + sxtl2 v14.8h, v6.16b + + INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + ssubw v9.8h, v9.8h, v2.8b + ssubw v10.8h, v10.8h, v2.8b + + INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + ssubw v13.8h, v13.8h, v1.8b + ssubw v14.8h, v14.8h, v1.8b - sxtl v18.4s, v16.4h // v18: int32x4 - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - sxtl v2.4s, v10.4h // v18: int32x4 - sxtl2 v3.4s, v10.8h - sxtl v4.4s, v11.4h - sxtl2 v5.4s, v11.8h - sxtl v6.4s, v12.4h - sxtl2 v7.4s, v12.8h - sxtl v8.4s, v13.4h - sxtl2 v9.4s, v13.8h - - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s - scvtf v24.4s, v24.4s - scvtf v25.4s, v25.4s - scvtf v26.4s, v26.4s - scvtf v27.4s, v27.4s - - scvtf v2.4s, v2.4s - scvtf v3.4s, v3.4s - scvtf v4.4s, v4.4s - scvtf v5.4s, v5.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v29.4s - fmul v5.4s, v5.4s, v29.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - fmul v8.4s, v8.4s, v30.4s - fmul v9.4s, v9.4s, v30.4s - - fmul v18.4s, v18.4s, v29.4s - fmul v19.4s, v19.4s, v29.4s - fmul v20.4s, v20.4s, v29.4s - fmul v21.4s, v21.4s, v29.4s - fmul v24.4s, v24.4s, v30.4s - fmul v25.4s, v25.4s, v30.4s - fmul v26.4s, v26.4s, v30.4s - fmul v27.4s, v27.4s, v30.4s - - fadd v2.4s, v2.4s, v6.4s - fadd v3.4s, v3.4s, v7.4s - fadd v4.4s, v4.4s, v8.4s - fadd v5.4s, v5.4s, v9.4s - - fadd v18.4s, v18.4s, v24.4s - fadd v19.4s, v19.4s, v25.4s - fadd v20.4s, v20.4s, v26.4s - fadd v21.4s, v21.4s, v27.4s - - fmul v2.4s, v2.4s, v31.4s - fmul v3.4s, v3.4s, v31.4s - fmul v4.4s, v4.4s, v31.4s - fmul v5.4s, v5.4s, v31.4s - fmul v18.4s, v18.4s, v31.4s - fmul v19.4s, v19.4s, v31.4s - fmul v20.4s, v20.4s, v31.4s - fmul v21.4s, v21.4s, v31.4s - - fcvtzs v18.4s, v18.4s - fcvtzs v19.4s, v19.4s - fcvtzs v20.4s, v20.4s - fcvtzs v21.4s, v21.4s - fcvtzs v2.4s, v2.4s - fcvtzs v3.4s, v3.4s - fcvtzs v4.4s, v4.4s - fcvtzs v5.4s, v5.4s - - sqxtn v18.4h, v18.4s - sqxtn2 v18.8h, v19.4s - sqxtn v19.4h, v20.4s - sqxtn2 v19.8h, v21.4s - - sqxtn v2.4h, v2.4s - sqxtn2 v2.8h, v3.4s - sqxtn v3.4h, v4.4s - sqxtn2 v3.8h, v5.4s - - sqxtn v18.8b, v18.8h - sqxtn2 v18.16b, v19.8h - sqxtn v2.8b, v2.8h - sqxtn2 v2.16b, v3.8h - - st1 {v18.16b}, [x0], #16 - st1 {v2.16b}, [x0], #16 - - sub x6, x6, #8 - cmp x6, #8 + + L8SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + sxtl v19.4s, v9.4h + sxtl2 v20.4s, v9.8h + sxtl v21.4s, v10.4h + sxtl2 v22.4s, v10.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + sxtl v27.4s, v13.4h + sxtl2 v28.4s, v13.8h + sxtl v29.4s, v14.4h + sxtl2 v30.4s, v14.8h + + mul v15.4s, v15.4s, v0.s[0] + mul v16.4s, v16.4s, v0.s[0] + mul v17.4s, v17.4s, v0.s[0] + mul v18.4s, v18.4s, v0.s[0] + mul v19.4s, v19.4s, v0.s[0] + mul v20.4s, v20.4s, v0.s[0] + mul v21.4s, v21.4s, v0.s[0] + mul v22.4s, v22.4s, v0.s[0] + + mul v23.4s, v23.4s, v0.s[1] + mul v24.4s, v24.4s, v0.s[1] + mul v25.4s, v25.4s, v0.s[1] + mul v26.4s, v26.4s, v0.s[1] + mul v27.4s, v27.4s, v0.s[1] + mul v28.4s, v28.4s, v0.s[1] + mul v29.4s, v29.4s, v0.s[1] + mul v30.4s, v30.4s, v0.s[1] + + add v15.4s, v15.4s, v23.4s + add v16.4s, v16.4s, v24.4s + add v17.4s, v17.4s, v25.4s + add v18.4s, v18.4s, v26.4s + add v19.4s, v19.4s, v27.4s + add v20.4s, v20.4s, v28.4s + add v21.4s, v21.4s, v29.4s + add v22.4s, v22.4s, v30.4s + + sqrshrn v1.4h, v15.4s, #16 + sqrshrn2 v1.8h, v16.4s, #16 + sqrshrn v2.4h, v17.4s, #16 + sqrshrn2 v2.8h, v18.4s, #16 + sqrshrn v3.4h, v19.4s, #16 + sqrshrn2 v3.8h, v20.4s, #16 + sqrshrn v4.4h, v21.4s, #16 + sqrshrn2 v4.8h, v22.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + saddw v3.8h, v3.8h, v14.8b + saddw v4.8h, v4.8h, v14.8b + + SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + sqxtn v6.8b, v3.8h + sqxtn2 v6.16b, v4.8h + + st1 {v5.16b, v6.16b}, [x0], #32 + + sub x8, x8, #8 + cmp x8, #8 bge L8Loop - cmp x6, #0 + cmp x8, #0 ble End - cmp x6, #4 + cmp x8, #4 blt L1Loop L4Loop: - cmp x7, #0 + cmp x9, #0 beq L4NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 - ld1 {v28.16b}, [x2], #16 + ld1 {v3.16b}, [x1], #16 // input00, input01 + ld1 {v5.16b}, [x2], #16 // input10, input11 b L4Compute L4NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v28.16b}, [x2], #16 + ld1r {v3.16b}, [x1] + ld1 {v5.16b}, [x2], #16 b L4Compute L4NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v27.16b}, [x1], #16 + ld1 {v3.16b}, [x1], #16 + ld1r {v5.16b}, [x2] b L4Compute L4Compute: - sxtl v16.8h, v27.8b - sxtl2 v17.8h, v27.16b - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v18.4s, v16.4h - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - scvtf v0.4s, v18.4s - scvtf v1.4s, v19.4s - scvtf v2.4s, v20.4s - scvtf v3.4s, v21.4s - scvtf v4.4s, v24.4s - scvtf v5.4s, v25.4s - scvtf v6.4s, v26.4s - scvtf v7.4s, v27.4s - - fmul v0.4s, v0.4s, v29.4s - fmul v1.4s, v1.4s, v29.4s - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v30.4s - fmul v5.4s, v5.4s, v30.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - - fadd v0.4s, v0.4s, v4.4s - fadd v1.4s, v1.4s, v5.4s - fadd v2.4s, v2.4s, v6.4s - fadd v3.4s, v3.4s, v7.4s - - fmul v16.4s, v0.4s, v31.4s - fmul v17.4s, v1.4s, v31.4s - fmul v18.4s, v2.4s, v31.4s - fmul v19.4s, v3.4s, v31.4s - - fcvtzs v20.4s, v16.4s - fcvtzs v21.4s, v17.4s - fcvtzs v22.4s, v18.4s - fcvtzs v23.4s, v19.4s - - sqxtn v0.4h, v20.4s - sqxtn2 v0.8h, v21.4s - sqxtn v1.4h, v22.4s - sqxtn2 v1.8h, v23.4s - - sqxtn v2.8b, v0.8h - sqxtn v3.8b, v1.8h - - st1 {v2.8b, v3.8b}, [x0], #16 - sub x6, x6, #4 - cmp x6, #4 + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + + L4_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + + L4_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + + L4SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + + mul v15.4s, v15.4s, v0.s[0] + mul v16.4s, v16.4s, v0.s[0] + mul v17.4s, v17.4s, v0.s[0] + mul v18.4s, v18.4s, v0.s[0] + + mul v23.4s, v23.4s, v0.s[1] + mul v24.4s, v24.4s, v0.s[1] + mul v25.4s, v25.4s, v0.s[1] + mul v26.4s, v26.4s, v0.s[1] + + add v15.4s, v15.4s, v23.4s + add v16.4s, v16.4s, v24.4s + add v17.4s, v17.4s, v25.4s + add v18.4s, v18.4s, v26.4s + + sqrshrn v1.4h, v15.4s, #16 + sqrshrn2 v1.8h, v16.4s, #16 + sqrshrn v2.4h, v17.4s, #16 + sqrshrn2 v2.8h, v18.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + + L4_SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + st1 {v5.16b}, [x0], #16 + sub x8, x8, #4 + cmp x8, #4 bge L4Loop L1: -cmp x6, #0 +cmp x8, #0 beq End L1Loop: - cmp x7, #0 + cmp x9, #0 beq L1NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - ld1 {v27.s}[0], [x1], #4 - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.s}[0], [x1], #4 // input00, input01 + ld1 {v5.s}[0], [x2], #4 // input10, input11 b L1Compute L1NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.8b, v27.b[0] - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.b}[0], [x1] + dup v3.8b, v3.b[0] + ld1 {v5.s}[0], [x2], #4 b L1Compute L1NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.8b, v28.b[0] - ld1 {v27.s}[0], [x1], #4 + ld1 {v3.s}[0], [x1], #4 + ld1r {v5.8b}, [x2] b L1Compute L1Compute: - sxtl v16.8h, v27.8b - sxtl v18.8h, v28.8b - sxtl v17.4s, v16.4h - sxtl v19.4s, v18.4h + sxtl v7.8h, v3.8b + sxtl v11.8h, v5.8b - scvtf v0.4s, v17.4s - scvtf v2.4s, v19.4s - fmul v1.4s, v0.4s, v29.4s - fmul v3.4s, v2.4s, v30.4s + L1_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + L1_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b - fadd v4.4s, v1.4s, v3.4s - fmul v0.4s, v4.4s, v31.4s + L1SXTL_S32: + sxtl v15.4s, v7.4h + sxtl v23.4s, v11.4h - fcvtzs v5.4s, v0.4s - sqxtn v6.4h, v5.4s - sqxtn v7.8b, v6.8h - st1 {v7.s}[0], [x0], #4 + mul v15.4s, v15.4s, v0.s[0] + mul v23.4s, v23.4s, v0.s[1] - subs x6, x6, #1 + add v15.4s, v15.4s, v23.4s + + sqrshrn v1.4h, v15.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + + L1_SQXTN_S8: + sqxtn v5.8b, v1.8h + st1 {v5.s}[0], [x0], #4 + + subs x8, x8, #1 bne L1Loop End: - +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 ret #endif diff --git a/source/backend/cpu/arm/arm64/MNNBinaryMaxInt8.S b/source/backend/cpu/arm/arm64/MNNBinaryMaxInt8.S index eeb8a468b..11ddec881 100644 --- a/source/backend/cpu/arm/arm64/MNNBinaryMaxInt8.S +++ b/source/backend/cpu/arm/arm64/MNNBinaryMaxInt8.S @@ -1,6 +1,5 @@ // // MNNBinaryMaxInt8.S -// MNN // // Created by MNN on 2019/08/14. // Copyright © 2018, Alibaba Group Holding Limited @@ -13,308 +12,305 @@ .align 5 asm_function MNNBinaryMaxInt8 -// MNNBinaryMaxInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// x0: dst, x1:src0, x2:src1, x3:scale0, x4:scale1, x5:outputScale, x6:size, x7: needBroadcast - -cmp x6, #0 +// MNNBinaryMaxInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// x0: dst, x1:src0, x2:src1, x3:quantScalesInt32, x4:quantScalesFp32, x5: offset0, x6: offset1, x7: outputoffset +// Load from sp: +//x8: size, x9: needBroadcast + +ldr x8, [sp, #0] +ldr x9, [sp, #8] + +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +cmp x8, #0 beq End -ld1 {v29.4s}, [x3] // scale0 -ld1 {v30.4s}, [x4] // scale1 -ld1 {v31.4s}, [x5] // outputScale +ldr w4, [x3, #8] +ldr w3, [x3] +mov v0.s[0], w3 +mov v0.s[1], w4 -cmp x6, #8 +cmp x8, #8 bge L8Loop -cmp x6, #4 +cmp x8, #4 bge L4Loop blt L1 L8Loop: - cmp x7, #0 + cmp x9, #0 beq L8NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L8NeedBroadcast1 L8NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 // input0[0]: int8x16 - ld1 {v28.16b}, [x2], #16 // input1[0]: int8x16 - ld1 {v14.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 // input00, input01 + ld1 {v5.16b, v6.16b}, [x2], #32 // input10, input11 b L8Compute L8NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v14.b}[0], [x1] - dup v14.16b, v14.b[0] - ld1 {v28.16b}, [x2], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1r {v3.16b}, [x1] + ld1r {v4.16b}, [x1] + ld1 {v5.16b, v6.16b}, [x2], #32 b L8Compute L8NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v15.b}[0], [x2] - dup v15.16b, v15.b[0] - ld1 {v27.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v14.16b}, [x1], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 + ld1r {v5.16b}, [x2] + ld1r {v6.16b}, [x2] b L8Compute L8Compute: - sxtl v16.8h, v27.8b // v16: int16x8 - sxtl2 v17.8h, v27.16b // v17: int16x8 - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v10.8h, v14.8b // v10: int16x8 - sxtl2 v11.8h, v14.16b // v11: int16x8 - sxtl v12.8h, v15.8b - sxtl2 v13.8h, v15.16b + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v9.8h, v4.8b + sxtl2 v10.8h, v4.16b + + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + sxtl v13.8h, v6.8b + sxtl2 v14.8h, v6.16b + - sxtl v18.4s, v16.4h // v18: int32x4 - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - sxtl v2.4s, v10.4h // v18: int32x4 - sxtl2 v3.4s, v10.8h - sxtl v4.4s, v11.4h - sxtl2 v5.4s, v11.8h - sxtl v6.4s, v12.4h - sxtl2 v7.4s, v12.8h - sxtl v8.4s, v13.4h - sxtl2 v9.4s, v13.8h - - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s - scvtf v24.4s, v24.4s - scvtf v25.4s, v25.4s - scvtf v26.4s, v26.4s - scvtf v27.4s, v27.4s - - scvtf v2.4s, v2.4s - scvtf v3.4s, v3.4s - scvtf v4.4s, v4.4s - scvtf v5.4s, v5.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v29.4s - fmul v5.4s, v5.4s, v29.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - fmul v8.4s, v8.4s, v30.4s - fmul v9.4s, v9.4s, v30.4s - - fmul v18.4s, v18.4s, v29.4s - fmul v19.4s, v19.4s, v29.4s - fmul v20.4s, v20.4s, v29.4s - fmul v21.4s, v21.4s, v29.4s - fmul v24.4s, v24.4s, v30.4s - fmul v25.4s, v25.4s, v30.4s - fmul v26.4s, v26.4s, v30.4s - fmul v27.4s, v27.4s, v30.4s - - fmax v2.4s, v2.4s, v6.4s - fmax v3.4s, v3.4s, v7.4s - fmax v4.4s, v4.4s, v8.4s - fmax v5.4s, v5.4s, v9.4s - - fmax v18.4s, v18.4s, v24.4s - fmax v19.4s, v19.4s, v25.4s - fmax v20.4s, v20.4s, v26.4s - fmax v21.4s, v21.4s, v27.4s - - fmul v2.4s, v2.4s, v31.4s - fmul v3.4s, v3.4s, v31.4s - fmul v4.4s, v4.4s, v31.4s - fmul v5.4s, v5.4s, v31.4s - fmul v18.4s, v18.4s, v31.4s - fmul v19.4s, v19.4s, v31.4s - fmul v20.4s, v20.4s, v31.4s - fmul v21.4s, v21.4s, v31.4s - - fcvtzs v18.4s, v18.4s - fcvtzs v19.4s, v19.4s - fcvtzs v20.4s, v20.4s - fcvtzs v21.4s, v21.4s - fcvtzs v2.4s, v2.4s - fcvtzs v3.4s, v3.4s - fcvtzs v4.4s, v4.4s - fcvtzs v5.4s, v5.4s - - sqxtn v18.4h, v18.4s - sqxtn2 v18.8h, v19.4s - sqxtn v19.4h, v20.4s - sqxtn2 v19.8h, v21.4s - - sqxtn v2.4h, v2.4s - sqxtn2 v2.8h, v3.4s - sqxtn v3.4h, v4.4s - sqxtn2 v3.8h, v5.4s - - sqxtn v18.8b, v18.8h - sqxtn2 v18.16b, v19.8h - sqxtn v2.8b, v2.8h - sqxtn2 v2.16b, v3.8h - - st1 {v18.16b}, [x0], #16 - st1 {v2.16b}, [x0], #16 - - sub x6, x6, #8 - cmp x6, #8 + INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + ssubw v9.8h, v9.8h, v2.8b + ssubw v10.8h, v10.8h, v2.8b + + INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + ssubw v13.8h, v13.8h, v1.8b + ssubw v14.8h, v14.8h, v1.8b + + + L8SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + sxtl v19.4s, v9.4h + sxtl2 v20.4s, v9.8h + sxtl v21.4s, v10.4h + sxtl2 v22.4s, v10.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + sxtl v27.4s, v13.4h + sxtl2 v28.4s, v13.8h + sxtl v29.4s, v14.4h + sxtl2 v30.4s, v14.8h + + mul v15.4s, v15.4s, v0.s[0] + mul v16.4s, v16.4s, v0.s[0] + mul v17.4s, v17.4s, v0.s[0] + mul v18.4s, v18.4s, v0.s[0] + mul v19.4s, v19.4s, v0.s[0] + mul v20.4s, v20.4s, v0.s[0] + mul v21.4s, v21.4s, v0.s[0] + mul v22.4s, v22.4s, v0.s[0] + + mul v23.4s, v23.4s, v0.s[1] + mul v24.4s, v24.4s, v0.s[1] + mul v25.4s, v25.4s, v0.s[1] + mul v26.4s, v26.4s, v0.s[1] + mul v27.4s, v27.4s, v0.s[1] + mul v28.4s, v28.4s, v0.s[1] + mul v29.4s, v29.4s, v0.s[1] + mul v30.4s, v30.4s, v0.s[1] + + smax v15.4s, v15.4s, v23.4s + smax v16.4s, v16.4s, v24.4s + smax v17.4s, v17.4s, v25.4s + smax v18.4s, v18.4s, v26.4s + smax v19.4s, v19.4s, v27.4s + smax v20.4s, v20.4s, v28.4s + smax v21.4s, v21.4s, v29.4s + smax v22.4s, v22.4s, v30.4s + + sqrshrn v1.4h, v15.4s, #16 + sqrshrn2 v1.8h, v16.4s, #16 + sqrshrn v2.4h, v17.4s, #16 + sqrshrn2 v2.8h, v18.4s, #16 + sqrshrn v3.4h, v19.4s, #16 + sqrshrn2 v3.8h, v20.4s, #16 + sqrshrn v4.4h, v21.4s, #16 + sqrshrn2 v4.8h, v22.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + saddw v3.8h, v3.8h, v14.8b + saddw v4.8h, v4.8h, v14.8b + + SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + sqxtn v6.8b, v3.8h + sqxtn2 v6.16b, v4.8h + + st1 {v5.16b, v6.16b}, [x0], #32 + + sub x8, x8, #8 + cmp x8, #8 bge L8Loop - cmp x6, #0 + cmp x8, #0 ble End - cmp x6, #4 + cmp x8, #4 blt L1Loop L4Loop: - cmp x7, #0 + cmp x9, #0 beq L4NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 - ld1 {v28.16b}, [x2], #16 + ld1 {v3.16b}, [x1], #16 // input00, input01 + ld1 {v5.16b}, [x2], #16 // input10, input11 b L4Compute L4NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v28.16b}, [x2], #16 + ld1r {v3.16b}, [x1] + ld1 {v5.16b}, [x2], #16 b L4Compute L4NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v27.16b}, [x1], #16 + ld1 {v3.16b}, [x1], #16 + ld1r {v5.16b}, [x2] b L4Compute L4Compute: - sxtl v16.8h, v27.8b - sxtl2 v17.8h, v27.16b - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v18.4s, v16.4h - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - scvtf v0.4s, v18.4s - scvtf v1.4s, v19.4s - scvtf v2.4s, v20.4s - scvtf v3.4s, v21.4s - scvtf v4.4s, v24.4s - scvtf v5.4s, v25.4s - scvtf v6.4s, v26.4s - scvtf v7.4s, v27.4s - - fmul v0.4s, v0.4s, v29.4s - fmul v1.4s, v1.4s, v29.4s - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v30.4s - fmul v5.4s, v5.4s, v30.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - - fmax v0.4s, v0.4s, v4.4s - fmax v1.4s, v1.4s, v5.4s - fmax v2.4s, v2.4s, v6.4s - fmax v3.4s, v3.4s, v7.4s - - fmul v16.4s, v0.4s, v31.4s - fmul v17.4s, v1.4s, v31.4s - fmul v18.4s, v2.4s, v31.4s - fmul v19.4s, v3.4s, v31.4s - - fcvtzs v20.4s, v16.4s - fcvtzs v21.4s, v17.4s - fcvtzs v22.4s, v18.4s - fcvtzs v23.4s, v19.4s - - sqxtn v0.4h, v20.4s - sqxtn2 v0.8h, v21.4s - sqxtn v1.4h, v22.4s - sqxtn2 v1.8h, v23.4s - - sqxtn v2.8b, v0.8h - sqxtn v3.8b, v1.8h - - st1 {v2.8b, v3.8b}, [x0], #16 - sub x6, x6, #4 - cmp x6, #4 + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + + L4_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + + L4_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + + L4SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + + mul v15.4s, v15.4s, v0.s[0] + mul v16.4s, v16.4s, v0.s[0] + mul v17.4s, v17.4s, v0.s[0] + mul v18.4s, v18.4s, v0.s[0] + + mul v23.4s, v23.4s, v0.s[1] + mul v24.4s, v24.4s, v0.s[1] + mul v25.4s, v25.4s, v0.s[1] + mul v26.4s, v26.4s, v0.s[1] + + smax v15.4s, v15.4s, v23.4s + smax v16.4s, v16.4s, v24.4s + smax v17.4s, v17.4s, v25.4s + smax v18.4s, v18.4s, v26.4s + + sqrshrn v1.4h, v15.4s, #16 + sqrshrn2 v1.8h, v16.4s, #16 + sqrshrn v2.4h, v17.4s, #16 + sqrshrn2 v2.8h, v18.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + + L4_SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + st1 {v5.16b}, [x0], #16 + sub x8, x8, #4 + cmp x8, #4 bge L4Loop L1: -cmp x6, #0 +cmp x8, #0 beq End L1Loop: - cmp x7, #0 + cmp x9, #0 beq L1NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - ld1 {v27.s}[0], [x1], #4 - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.s}[0], [x1], #4 // input00, input01 + ld1 {v5.s}[0], [x2], #4 // input10, input11 b L1Compute L1NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.8b, v27.b[0] - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.b}[0], [x1] + dup v3.8b, v3.b[0] + ld1 {v5.s}[0], [x2], #4 b L1Compute L1NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.8b, v28.b[0] - ld1 {v27.s}[0], [x1], #4 + ld1 {v3.s}[0], [x1], #4 + ld1r {v5.8b}, [x2] b L1Compute L1Compute: - sxtl v16.8h, v27.8b - sxtl v18.8h, v28.8b - sxtl v17.4s, v16.4h - sxtl v19.4s, v18.4h + sxtl v7.8h, v3.8b + sxtl v11.8h, v5.8b - scvtf v0.4s, v17.4s - scvtf v2.4s, v19.4s - fmul v1.4s, v0.4s, v29.4s - fmul v3.4s, v2.4s, v30.4s + L1_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + L1_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b - fmax v4.4s, v1.4s, v3.4s - fmul v0.4s, v4.4s, v31.4s + L1SXTL_S32: + sxtl v15.4s, v7.4h + sxtl v23.4s, v11.4h - fcvtzs v5.4s, v0.4s - sqxtn v6.4h, v5.4s - sqxtn v7.8b, v6.8h - st1 {v7.s}[0], [x0], #4 + mul v15.4s, v15.4s, v0.s[0] + mul v23.4s, v23.4s, v0.s[1] - subs x6, x6, #1 + smax v15.4s, v15.4s, v23.4s + + sqrshrn v1.4h, v15.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + + L1_SQXTN_S8: + sqxtn v5.8b, v1.8h + st1 {v5.s}[0], [x0], #4 + + subs x8, x8, #1 bne L1Loop End: - +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 ret #endif - diff --git a/source/backend/cpu/arm/arm64/MNNBinaryMinInt8.S b/source/backend/cpu/arm/arm64/MNNBinaryMinInt8.S index 2ee771ec4..6de3faed2 100644 --- a/source/backend/cpu/arm/arm64/MNNBinaryMinInt8.S +++ b/source/backend/cpu/arm/arm64/MNNBinaryMinInt8.S @@ -1,6 +1,5 @@ // // MNNBinaryMinInt8.S -// MNN // // Created by MNN on 2019/08/14. // Copyright © 2018, Alibaba Group Holding Limited @@ -13,309 +12,305 @@ .align 5 asm_function MNNBinaryMinInt8 -// MNNBinaryMinInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// x0: dst, x1:src0, x2:src1, x3:scale0, x4:scale1, x5:outputScale, x6:size, x7: needBroadcast - -cmp x6, #0 +// MNNBinaryMinInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// x0: dst, x1:src0, x2:src1, x3:quantScalesInt32, x4:quantScalesFp32, x5: offset0, x6: offset1, x7: outputoffset +// Load from sp: +//x8: size, x9: needBroadcast + +ldr x8, [sp, #0] +ldr x9, [sp, #8] + +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +cmp x8, #0 beq End -ld1 {v29.4s}, [x3] // scale0 -ld1 {v30.4s}, [x4] // scale1 -ld1 {v31.4s}, [x5] // outputScale +ldr w4, [x3, #8] +ldr w3, [x3] +mov v0.s[0], w3 +mov v0.s[1], w4 -cmp x6, #8 +cmp x8, #8 bge L8Loop -cmp x6, #4 +cmp x8, #4 bge L4Loop blt L1 L8Loop: - cmp x7, #0 + cmp x9, #0 beq L8NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L8NeedBroadcast1 L8NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 // input0[0]: int8x16 - ld1 {v28.16b}, [x2], #16 // input1[0]: int8x16 - ld1 {v14.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 // input00, input01 + ld1 {v5.16b, v6.16b}, [x2], #32 // input10, input11 b L8Compute L8NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v14.b}[0], [x1] - dup v14.16b, v14.b[0] - ld1 {v28.16b}, [x2], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1r {v3.16b}, [x1] + ld1r {v4.16b}, [x1] + ld1 {v5.16b, v6.16b}, [x2], #32 b L8Compute L8NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v15.b}[0], [x2] - dup v15.16b, v15.b[0] - ld1 {v27.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v14.16b}, [x1], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 + ld1r {v5.16b}, [x2] + ld1r {v6.16b}, [x2] b L8Compute L8Compute: - sxtl v16.8h, v27.8b // v16: int16x8 - sxtl2 v17.8h, v27.16b // v17: int16x8 - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v10.8h, v14.8b // v10: int16x8 - sxtl2 v11.8h, v14.16b // v11: int16x8 - sxtl v12.8h, v15.8b - sxtl2 v13.8h, v15.16b + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v9.8h, v4.8b + sxtl2 v10.8h, v4.16b + + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + sxtl v13.8h, v6.8b + sxtl2 v14.8h, v6.16b + + INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + ssubw v9.8h, v9.8h, v2.8b + ssubw v10.8h, v10.8h, v2.8b + + INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + ssubw v13.8h, v13.8h, v1.8b + ssubw v14.8h, v14.8h, v1.8b - sxtl v18.4s, v16.4h // v18: int32x4 - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - sxtl v2.4s, v10.4h // v18: int32x4 - sxtl2 v3.4s, v10.8h - sxtl v4.4s, v11.4h - sxtl2 v5.4s, v11.8h - sxtl v6.4s, v12.4h - sxtl2 v7.4s, v12.8h - sxtl v8.4s, v13.4h - sxtl2 v9.4s, v13.8h - - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s - scvtf v24.4s, v24.4s - scvtf v25.4s, v25.4s - scvtf v26.4s, v26.4s - scvtf v27.4s, v27.4s - - scvtf v2.4s, v2.4s - scvtf v3.4s, v3.4s - scvtf v4.4s, v4.4s - scvtf v5.4s, v5.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v29.4s - fmul v5.4s, v5.4s, v29.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - fmul v8.4s, v8.4s, v30.4s - fmul v9.4s, v9.4s, v30.4s - - fmul v18.4s, v18.4s, v29.4s - fmul v19.4s, v19.4s, v29.4s - fmul v20.4s, v20.4s, v29.4s - fmul v21.4s, v21.4s, v29.4s - fmul v24.4s, v24.4s, v30.4s - fmul v25.4s, v25.4s, v30.4s - fmul v26.4s, v26.4s, v30.4s - fmul v27.4s, v27.4s, v30.4s - - fmin v2.4s, v2.4s, v6.4s - fmin v3.4s, v3.4s, v7.4s - fmin v4.4s, v4.4s, v8.4s - fmin v5.4s, v5.4s, v9.4s - - fmin v18.4s, v18.4s, v24.4s - fmin v19.4s, v19.4s, v25.4s - fmin v20.4s, v20.4s, v26.4s - fmin v21.4s, v21.4s, v27.4s - - fmul v2.4s, v2.4s, v31.4s - fmul v3.4s, v3.4s, v31.4s - fmul v4.4s, v4.4s, v31.4s - fmul v5.4s, v5.4s, v31.4s - fmul v18.4s, v18.4s, v31.4s - fmul v19.4s, v19.4s, v31.4s - fmul v20.4s, v20.4s, v31.4s - fmul v21.4s, v21.4s, v31.4s - - fcvtzs v18.4s, v18.4s - fcvtzs v19.4s, v19.4s - fcvtzs v20.4s, v20.4s - fcvtzs v21.4s, v21.4s - fcvtzs v2.4s, v2.4s - fcvtzs v3.4s, v3.4s - fcvtzs v4.4s, v4.4s - fcvtzs v5.4s, v5.4s - - sqxtn v18.4h, v18.4s - sqxtn2 v18.8h, v19.4s - sqxtn v19.4h, v20.4s - sqxtn2 v19.8h, v21.4s - - sqxtn v2.4h, v2.4s - sqxtn2 v2.8h, v3.4s - sqxtn v3.4h, v4.4s - sqxtn2 v3.8h, v5.4s - - sqxtn v18.8b, v18.8h - sqxtn2 v18.16b, v19.8h - sqxtn v2.8b, v2.8h - sqxtn2 v2.16b, v3.8h - - st1 {v18.16b}, [x0], #16 - st1 {v2.16b}, [x0], #16 - - sub x6, x6, #8 - cmp x6, #8 + + L8SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + sxtl v19.4s, v9.4h + sxtl2 v20.4s, v9.8h + sxtl v21.4s, v10.4h + sxtl2 v22.4s, v10.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + sxtl v27.4s, v13.4h + sxtl2 v28.4s, v13.8h + sxtl v29.4s, v14.4h + sxtl2 v30.4s, v14.8h + + mul v15.4s, v15.4s, v0.s[0] + mul v16.4s, v16.4s, v0.s[0] + mul v17.4s, v17.4s, v0.s[0] + mul v18.4s, v18.4s, v0.s[0] + mul v19.4s, v19.4s, v0.s[0] + mul v20.4s, v20.4s, v0.s[0] + mul v21.4s, v21.4s, v0.s[0] + mul v22.4s, v22.4s, v0.s[0] + + mul v23.4s, v23.4s, v0.s[1] + mul v24.4s, v24.4s, v0.s[1] + mul v25.4s, v25.4s, v0.s[1] + mul v26.4s, v26.4s, v0.s[1] + mul v27.4s, v27.4s, v0.s[1] + mul v28.4s, v28.4s, v0.s[1] + mul v29.4s, v29.4s, v0.s[1] + mul v30.4s, v30.4s, v0.s[1] + + smin v15.4s, v15.4s, v23.4s + smin v16.4s, v16.4s, v24.4s + smin v17.4s, v17.4s, v25.4s + smin v18.4s, v18.4s, v26.4s + smin v19.4s, v19.4s, v27.4s + smin v20.4s, v20.4s, v28.4s + smin v21.4s, v21.4s, v29.4s + smin v22.4s, v22.4s, v30.4s + + sqrshrn v1.4h, v15.4s, #16 + sqrshrn2 v1.8h, v16.4s, #16 + sqrshrn v2.4h, v17.4s, #16 + sqrshrn2 v2.8h, v18.4s, #16 + sqrshrn v3.4h, v19.4s, #16 + sqrshrn2 v3.8h, v20.4s, #16 + sqrshrn v4.4h, v21.4s, #16 + sqrshrn2 v4.8h, v22.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + saddw v3.8h, v3.8h, v14.8b + saddw v4.8h, v4.8h, v14.8b + + SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + sqxtn v6.8b, v3.8h + sqxtn2 v6.16b, v4.8h + + st1 {v5.16b, v6.16b}, [x0], #32 + + sub x8, x8, #8 + cmp x8, #8 bge L8Loop - cmp x6, #0 + cmp x8, #0 ble End - cmp x6, #4 + cmp x8, #4 blt L1Loop L4Loop: - cmp x7, #0 + cmp x9, #0 beq L4NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 - ld1 {v28.16b}, [x2], #16 + ld1 {v3.16b}, [x1], #16 // input00, input01 + ld1 {v5.16b}, [x2], #16 // input10, input11 b L4Compute L4NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v28.16b}, [x2], #16 + ld1r {v3.16b}, [x1] + ld1 {v5.16b}, [x2], #16 b L4Compute L4NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v27.16b}, [x1], #16 + ld1 {v3.16b}, [x1], #16 + ld1r {v5.16b}, [x2] b L4Compute L4Compute: - sxtl v16.8h, v27.8b - sxtl2 v17.8h, v27.16b - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + + L4_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] - sxtl v18.4s, v16.4h - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - scvtf v0.4s, v18.4s - scvtf v1.4s, v19.4s - scvtf v2.4s, v20.4s - scvtf v3.4s, v21.4s - scvtf v4.4s, v24.4s - scvtf v5.4s, v25.4s - scvtf v6.4s, v26.4s - scvtf v7.4s, v27.4s - - fmul v0.4s, v0.4s, v29.4s - fmul v1.4s, v1.4s, v29.4s - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v30.4s - fmul v5.4s, v5.4s, v30.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - - fmin v0.4s, v0.4s, v4.4s - fmin v1.4s, v1.4s, v5.4s - fmin v2.4s, v2.4s, v6.4s - fmin v3.4s, v3.4s, v7.4s - - fmul v16.4s, v0.4s, v31.4s - fmul v17.4s, v1.4s, v31.4s - fmul v18.4s, v2.4s, v31.4s - fmul v19.4s, v3.4s, v31.4s - - fcvtzs v20.4s, v16.4s - fcvtzs v21.4s, v17.4s - fcvtzs v22.4s, v18.4s - fcvtzs v23.4s, v19.4s - - sqxtn v0.4h, v20.4s - sqxtn2 v0.8h, v21.4s - sqxtn v1.4h, v22.4s - sqxtn2 v1.8h, v23.4s - - sqxtn v2.8b, v0.8h - sqxtn v3.8b, v1.8h - - st1 {v2.8b, v3.8b}, [x0], #16 - sub x6, x6, #4 - cmp x6, #4 + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + + L4_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + + L4SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + + mul v15.4s, v15.4s, v0.s[0] + mul v16.4s, v16.4s, v0.s[0] + mul v17.4s, v17.4s, v0.s[0] + mul v18.4s, v18.4s, v0.s[0] + + mul v23.4s, v23.4s, v0.s[1] + mul v24.4s, v24.4s, v0.s[1] + mul v25.4s, v25.4s, v0.s[1] + mul v26.4s, v26.4s, v0.s[1] + + smin v15.4s, v15.4s, v23.4s + smin v16.4s, v16.4s, v24.4s + smin v17.4s, v17.4s, v25.4s + smin v18.4s, v18.4s, v26.4s + + sqrshrn v1.4h, v15.4s, #16 + sqrshrn2 v1.8h, v16.4s, #16 + sqrshrn v2.4h, v17.4s, #16 + sqrshrn2 v2.8h, v18.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + + L4_SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + st1 {v5.16b}, [x0], #16 + sub x8, x8, #4 + cmp x8, #4 bge L4Loop L1: -cmp x6, #0 +cmp x8, #0 beq End L1Loop: - cmp x7, #0 + cmp x9, #0 beq L1NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - ld1 {v27.s}[0], [x1], #4 - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.s}[0], [x1], #4 // input00, input01 + ld1 {v5.s}[0], [x2], #4 // input10, input11 b L1Compute L1NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.8b, v27.b[0] - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.b}[0], [x1] + dup v3.8b, v3.b[0] + ld1 {v5.s}[0], [x2], #4 b L1Compute L1NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.8b, v28.b[0] - ld1 {v27.s}[0], [x1], #4 + ld1 {v3.s}[0], [x1], #4 + ld1r {v5.8b}, [x2] b L1Compute L1Compute: - sxtl v16.8h, v27.8b - sxtl v18.8h, v28.8b - sxtl v17.4s, v16.4h - sxtl v19.4s, v18.4h + sxtl v7.8h, v3.8b + sxtl v11.8h, v5.8b + + L1_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + L1_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + + L1SXTL_S32: + sxtl v15.4s, v7.4h + sxtl v23.4s, v11.4h - scvtf v0.4s, v17.4s - scvtf v2.4s, v19.4s - fmul v1.4s, v0.4s, v29.4s - fmul v3.4s, v2.4s, v30.4s + mul v15.4s, v15.4s, v0.s[0] + mul v23.4s, v23.4s, v0.s[1] - fmin v4.4s, v1.4s, v3.4s - fmul v0.4s, v4.4s, v31.4s + smin v15.4s, v15.4s, v23.4s - fcvtzs v5.4s, v0.4s - sqxtn v6.4h, v5.4s - sqxtn v7.8b, v6.8h - st1 {v7.s}[0], [x0], #4 + sqrshrn v1.4h, v15.4s, #16 - subs x6, x6, #1 + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + + L1_SQXTN_S8: + sqxtn v5.8b, v1.8h + st1 {v5.s}[0], [x0], #4 + + subs x8, x8, #1 bne L1Loop End: - +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 ret #endif - - diff --git a/source/backend/cpu/arm/arm64/MNNBinaryMulInt8.S b/source/backend/cpu/arm/arm64/MNNBinaryMulInt8.S index 87ed67e18..5ca6edb7c 100644 --- a/source/backend/cpu/arm/arm64/MNNBinaryMulInt8.S +++ b/source/backend/cpu/arm/arm64/MNNBinaryMulInt8.S @@ -13,307 +13,371 @@ .align 5 asm_function MNNBinaryMulInt8 -// MNNBinaryMulInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// x0: dst, x1:src0, x2:src1, x3:scale0, x4:scale1, x5:outputScale, x6:size, x7: needBroadcast - -cmp x6, #0 +// MNNBinaryMulInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// x0: dst, x1:src0, x2:src1, x3:quantScalesInt32, x4:quantScalesFp32, x5: offset0, x6: offset1, x7: outputoffset +// Load from sp: +// x8: size, x9: needBroadcast + +ldr x8, [sp, #0] +ldr x9, [sp, #8] + +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +cmp x8, #0 beq End -ld1 {v29.4s}, [x3] // scale0 -ld1 {v30.4s}, [x4] // scale1 -ld1 {v31.4s}, [x5] // outputScale +ldr w3, [x4] +ldr w10, [x4, #8] +ldr w4, [x4, #4] +mov v0.s[0], w3 +mov v0.s[1], w4 +mov v0.s[2], w10 + -cmp x6, #8 +cmp x8, #8 bge L8Loop -cmp x6, #4 +cmp x8, #4 bge L4Loop blt L1 L8Loop: - cmp x7, #0 + cmp x9, #0 beq L8NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L8NeedBroadcast1 L8NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 // input0[0]: int8x16 - ld1 {v28.16b}, [x2], #16 // input1[0]: int8x16 - ld1 {v14.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 // input00, input01 + ld1 {v5.16b, v6.16b}, [x2], #32 // input10, input11 b L8Compute L8NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v14.b}[0], [x1] - dup v14.16b, v14.b[0] - ld1 {v28.16b}, [x2], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1r {v3.16b}, [x1] + ld1r {v4.16b}, [x1] + ld1 {v5.16b, v6.16b}, [x2], #32 b L8Compute L8NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v15.b}[0], [x2] - dup v15.16b, v15.b[0] - ld1 {v27.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v14.16b}, [x1], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 + ld1r {v5.16b}, [x2] + ld1r {v6.16b}, [x2] b L8Compute L8Compute: - sxtl v16.8h, v27.8b // v16: int16x8 - sxtl2 v17.8h, v27.16b // v17: int16x8 - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v10.8h, v14.8b // v10: int16x8 - sxtl2 v11.8h, v14.16b // v11: int16x8 - sxtl v12.8h, v15.8b - sxtl2 v13.8h, v15.16b + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v9.8h, v4.8b + sxtl2 v10.8h, v4.16b + + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + sxtl v13.8h, v6.8b + sxtl2 v14.8h, v6.16b + + + INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + ssubw v9.8h, v9.8h, v2.8b + ssubw v10.8h, v10.8h, v2.8b + + INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + ssubw v13.8h, v13.8h, v1.8b + ssubw v14.8h, v14.8h, v1.8b - sxtl v18.4s, v16.4h // v18: int32x4 - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - sxtl v2.4s, v10.4h // v18: int32x4 - sxtl2 v3.4s, v10.8h - sxtl v4.4s, v11.4h - sxtl2 v5.4s, v11.8h - sxtl v6.4s, v12.4h - sxtl2 v7.4s, v12.8h - sxtl v8.4s, v13.4h - sxtl2 v9.4s, v13.8h - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s + L8SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + sxtl v19.4s, v9.4h + sxtl2 v20.4s, v9.8h + sxtl v21.4s, v10.4h + sxtl2 v22.4s, v10.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + sxtl v27.4s, v13.4h + sxtl2 v28.4s, v13.8h + sxtl v29.4s, v14.4h + sxtl2 v30.4s, v14.8h + + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + + scvtf v23.4s, v23.4s scvtf v24.4s, v24.4s scvtf v25.4s, v25.4s scvtf v26.4s, v26.4s scvtf v27.4s, v27.4s - - scvtf v2.4s, v2.4s - scvtf v3.4s, v3.4s - scvtf v4.4s, v4.4s - scvtf v5.4s, v5.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v29.4s - fmul v5.4s, v5.4s, v29.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - fmul v8.4s, v8.4s, v30.4s - fmul v9.4s, v9.4s, v30.4s - - fmul v18.4s, v18.4s, v29.4s - fmul v19.4s, v19.4s, v29.4s - fmul v20.4s, v20.4s, v29.4s - fmul v21.4s, v21.4s, v29.4s - fmul v24.4s, v24.4s, v30.4s - fmul v25.4s, v25.4s, v30.4s - fmul v26.4s, v26.4s, v30.4s - fmul v27.4s, v27.4s, v30.4s - - fmul v2.4s, v2.4s, v6.4s - fmul v3.4s, v3.4s, v7.4s - fmul v4.4s, v4.4s, v8.4s - fmul v5.4s, v5.4s, v9.4s - - fmul v18.4s, v18.4s, v24.4s - fmul v19.4s, v19.4s, v25.4s - fmul v20.4s, v20.4s, v26.4s - fmul v21.4s, v21.4s, v27.4s - - fmul v2.4s, v2.4s, v31.4s - fmul v3.4s, v3.4s, v31.4s - fmul v4.4s, v4.4s, v31.4s - fmul v5.4s, v5.4s, v31.4s - fmul v18.4s, v18.4s, v31.4s - fmul v19.4s, v19.4s, v31.4s - fmul v20.4s, v20.4s, v31.4s - fmul v21.4s, v21.4s, v31.4s - - fcvtzs v18.4s, v18.4s - fcvtzs v19.4s, v19.4s - fcvtzs v20.4s, v20.4s - fcvtzs v21.4s, v21.4s - fcvtzs v2.4s, v2.4s - fcvtzs v3.4s, v3.4s - fcvtzs v4.4s, v4.4s - fcvtzs v5.4s, v5.4s - - sqxtn v18.4h, v18.4s - sqxtn2 v18.8h, v19.4s - sqxtn v19.4h, v20.4s - sqxtn2 v19.8h, v21.4s - - sqxtn v2.4h, v2.4s - sqxtn2 v2.8h, v3.4s - sqxtn v3.4h, v4.4s - sqxtn2 v3.8h, v5.4s - - sqxtn v18.8b, v18.8h - sqxtn2 v18.16b, v19.8h - sqxtn v2.8b, v2.8h - sqxtn2 v2.16b, v3.8h - - st1 {v18.16b}, [x0], #16 - st1 {v2.16b}, [x0], #16 - - sub x6, x6, #8 - cmp x6, #8 + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + + fmul v15.4s, v15.4s, v0.s[0] + fmul v16.4s, v16.4s, v0.s[0] + fmul v17.4s, v17.4s, v0.s[0] + fmul v18.4s, v18.4s, v0.s[0] + fmul v19.4s, v19.4s, v0.s[0] + fmul v20.4s, v20.4s, v0.s[0] + fmul v21.4s, v21.4s, v0.s[0] + fmul v22.4s, v22.4s, v0.s[0] + + fmul v23.4s, v23.4s, v0.s[1] + fmul v24.4s, v24.4s, v0.s[1] + fmul v25.4s, v25.4s, v0.s[1] + fmul v26.4s, v26.4s, v0.s[1] + fmul v27.4s, v27.4s, v0.s[1] + fmul v28.4s, v28.4s, v0.s[1] + fmul v29.4s, v29.4s, v0.s[1] + fmul v30.4s, v30.4s, v0.s[1] + + fmul v15.4s, v15.4s, v23.4s + fmul v16.4s, v16.4s, v24.4s + fmul v17.4s, v17.4s, v25.4s + fmul v18.4s, v18.4s, v26.4s + fmul v19.4s, v19.4s, v27.4s + fmul v20.4s, v20.4s, v28.4s + fmul v21.4s, v21.4s, v29.4s + fmul v22.4s, v22.4s, v30.4s + + fmul v15.4s, v15.4s, v0.s[2] + fmul v16.4s, v16.4s, v0.s[2] + fmul v17.4s, v17.4s, v0.s[2] + fmul v18.4s, v18.4s, v0.s[2] + fmul v19.4s, v19.4s, v0.s[2] + fmul v20.4s, v20.4s, v0.s[2] + fmul v21.4s, v21.4s, v0.s[2] + fmul v22.4s, v22.4s, v0.s[2] + + fcvtas v15.4s, v15.4s + fcvtas v16.4s, v16.4s + fcvtas v17.4s, v17.4s + fcvtas v18.4s, v18.4s + fcvtas v19.4s, v19.4s + fcvtas v20.4s, v20.4s + fcvtas v21.4s, v21.4s + fcvtas v22.4s, v22.4s + + sqxtn v1.4h, v15.4s + sqxtn2 v1.8h, v16.4s + sqxtn v2.4h, v17.4s + sqxtn2 v2.8h, v18.4s + sqxtn v3.4h, v19.4s + sqxtn2 v3.8h, v20.4s + sqxtn v4.4h, v21.4s + sqxtn2 v4.8h, v22.4s + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + saddw v3.8h, v3.8h, v14.8b + saddw v4.8h, v4.8h, v14.8b + + SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + sqxtn v6.8b, v3.8h + sqxtn2 v6.16b, v4.8h + + st1 {v5.16b, v6.16b}, [x0], #32 + + sub x8, x8, #8 + cmp x8, #8 bge L8Loop - cmp x6, #0 + cmp x8, #0 ble End - cmp x6, #4 + cmp x8, #4 blt L1Loop L4Loop: - cmp x7, #0 + cmp x9, #0 beq L4NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 - ld1 {v28.16b}, [x2], #16 + ld1 {v3.16b}, [x1], #16 // input00, input01 + ld1 {v5.16b}, [x2], #16 // input10, input11 b L4Compute L4NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v28.16b}, [x2], #16 + ld1r {v3.16b}, [x1] + ld1 {v5.16b}, [x2], #16 b L4Compute L4NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v27.16b}, [x1], #16 + ld1 {v3.16b}, [x1], #16 + ld1r {v5.16b}, [x2] b L4Compute L4Compute: - sxtl v16.8h, v27.8b - sxtl2 v17.8h, v27.16b - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v18.4s, v16.4h - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - scvtf v0.4s, v18.4s - scvtf v1.4s, v19.4s - scvtf v2.4s, v20.4s - scvtf v3.4s, v21.4s - scvtf v4.4s, v24.4s - scvtf v5.4s, v25.4s - scvtf v6.4s, v26.4s - scvtf v7.4s, v27.4s - - fmul v0.4s, v0.4s, v29.4s - fmul v1.4s, v1.4s, v29.4s - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v30.4s - fmul v5.4s, v5.4s, v30.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - - fmul v0.4s, v0.4s, v4.4s - fmul v1.4s, v1.4s, v5.4s - fmul v2.4s, v2.4s, v6.4s - fmul v3.4s, v3.4s, v7.4s - - fmul v16.4s, v0.4s, v31.4s - fmul v17.4s, v1.4s, v31.4s - fmul v18.4s, v2.4s, v31.4s - fmul v19.4s, v3.4s, v31.4s - - fcvtzs v20.4s, v16.4s - fcvtzs v21.4s, v17.4s - fcvtzs v22.4s, v18.4s - fcvtzs v23.4s, v19.4s - - sqxtn v0.4h, v20.4s - sqxtn2 v0.8h, v21.4s - sqxtn v1.4h, v22.4s - sqxtn2 v1.8h, v23.4s - - sqxtn v2.8b, v0.8h - sqxtn v3.8b, v1.8h - - st1 {v2.8b, v3.8b}, [x0], #16 - sub x6, x6, #4 - cmp x6, #4 + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + + L4_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + + L4_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + + L4SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + + fmul v15.4s, v15.4s, v0.s[0] + fmul v16.4s, v16.4s, v0.s[0] + fmul v17.4s, v17.4s, v0.s[0] + fmul v18.4s, v18.4s, v0.s[0] + + fmul v23.4s, v23.4s, v0.s[1] + fmul v24.4s, v24.4s, v0.s[1] + fmul v25.4s, v25.4s, v0.s[1] + fmul v26.4s, v26.4s, v0.s[1] + + fmul v15.4s, v15.4s, v23.4s + fmul v16.4s, v16.4s, v24.4s + fmul v17.4s, v17.4s, v25.4s + fmul v18.4s, v18.4s, v26.4s + + fmul v15.4s, v15.4s, v0.s[2] + fmul v16.4s, v16.4s, v0.s[2] + fmul v17.4s, v17.4s, v0.s[2] + fmul v18.4s, v18.4s, v0.s[2] + + fcvtas v15.4s, v15.4s + fcvtas v16.4s, v16.4s + fcvtas v17.4s, v17.4s + fcvtas v18.4s, v18.4s + + sqxtn v1.4h, v15.4s + sqxtn2 v1.8h, v16.4s + sqxtn v2.4h, v17.4s + sqxtn2 v2.8h, v18.4s + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + + L4_SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + st1 {v5.16b}, [x0], #16 + sub x8, x8, #4 + cmp x8, #4 bge L4Loop L1: -cmp x6, #0 +cmp x8, #0 beq End L1Loop: - cmp x7, #0 + cmp x9, #0 beq L1NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - ld1 {v27.s}[0], [x1], #4 - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.s}[0], [x1], #4 // input00, input01 + ld1 {v5.s}[0], [x2], #4 // input10, input11 b L1Compute L1NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.8b, v27.b[0] - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.b}[0], [x1] + dup v3.8b, v3.b[0] + ld1 {v5.s}[0], [x2], #4 b L1Compute L1NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.8b, v28.b[0] - ld1 {v27.s}[0], [x1], #4 + ld1 {v3.s}[0], [x1], #4 + ld1r {v5.8b}, [x2] b L1Compute L1Compute: - sxtl v16.8h, v27.8b - sxtl v18.8h, v28.8b - sxtl v17.4s, v16.4h - sxtl v19.4s, v18.4h + sxtl v7.8h, v3.8b + sxtl v11.8h, v5.8b + + L1_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + L1_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + + L1SXTL_S32: + sxtl v15.4s, v7.4h + sxtl v23.4s, v11.4h + + scvtf v15.4s, v15.4s + scvtf v23.4s, v23.4s - scvtf v0.4s, v17.4s - scvtf v2.4s, v19.4s - fmul v1.4s, v0.4s, v29.4s - fmul v3.4s, v2.4s, v30.4s + fmul v15.4s, v15.4s, v0.s[0] + fmul v23.4s, v23.4s, v0.s[1] + fmul v15.4s, v15.4s, v23.4s - fmul v4.4s, v1.4s, v3.4s - fmul v0.4s, v4.4s, v31.4s + fmul v15.4s, v15.4s, v0.s[2] - fcvtzs v5.4s, v0.4s - sqxtn v6.4h, v5.4s - sqxtn v7.8b, v6.8h - st1 {v7.s}[0], [x0], #4 + fcvtas v15.4s, v15.4s - subs x6, x6, #1 + sqxtn v1.4h, v15.4s + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + + L1_SQXTN_S8: + sqxtn v5.8b, v1.8h + st1 {v5.s}[0], [x0], #4 + + subs x8, x8, #1 bne L1Loop End: - +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 ret #endif + diff --git a/source/backend/cpu/arm/arm64/MNNBinarySqdInt8.S b/source/backend/cpu/arm/arm64/MNNBinarySqdInt8.S index fd28286e2..96329a63e 100644 --- a/source/backend/cpu/arm/arm64/MNNBinarySqdInt8.S +++ b/source/backend/cpu/arm/arm64/MNNBinarySqdInt8.S @@ -13,322 +13,385 @@ .align 5 asm_function MNNBinarySqdInt8 -// MNNBinarySqdInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// x0: dst, x1:src0, x2:src1, x3:scale0, x4:scale1, x5:outputScale, x6:size, x7: needBroadcast - -cmp x6, #0 +// MNNBinarySqdInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// x0: dst, x1:src0, x2:src1, x3:quantScalesInt32, x4:quantScalesFp32, x5: offset0, x6: offset1, x7: outputoffset +// Load from sp: +// x8: size, x9: needBroadcast + +ldr x8, [sp, #0] +ldr x9, [sp, #8] + +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +cmp x8, #0 beq End -ld1 {v29.4s}, [x3] // scale0 -ld1 {v30.4s}, [x4] // scale1 -ld1 {v31.4s}, [x5] // outputScale +ldr w3, [x4] +ldr w10, [x4, #8] +ldr w4, [x4, #4] +mov v0.s[0], w3 +mov v0.s[1], w4 +mov v0.s[2], w10 -cmp x6, #8 +cmp x8, #8 bge L8Loop -cmp x6, #4 +cmp x8, #4 bge L4Loop blt L1 L8Loop: - cmp x7, #0 + cmp x9, #0 beq L8NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L8NeedBroadcast1 L8NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 // input0[0]: int8x16 - ld1 {v28.16b}, [x2], #16 // input1[0]: int8x16 - ld1 {v14.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 // input00, input01 + ld1 {v5.16b, v6.16b}, [x2], #32 // input10, input11 b L8Compute L8NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v14.b}[0], [x1] - dup v14.16b, v14.b[0] - ld1 {v28.16b}, [x2], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1r {v3.16b}, [x1] + ld1r {v4.16b}, [x1] + ld1 {v5.16b, v6.16b}, [x2], #32 b L8Compute L8NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v15.b}[0], [x2] - dup v15.16b, v15.b[0] - ld1 {v27.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v14.16b}, [x1], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 + ld1r {v5.16b}, [x2] + ld1r {v6.16b}, [x2] b L8Compute L8Compute: - sxtl v16.8h, v27.8b // v16: int16x8 - sxtl2 v17.8h, v27.16b // v17: int16x8 - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v10.8h, v14.8b // v10: int16x8 - sxtl2 v11.8h, v14.16b // v11: int16x8 - sxtl v12.8h, v15.8b - sxtl2 v13.8h, v15.16b + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v9.8h, v4.8b + sxtl2 v10.8h, v4.16b + + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + sxtl v13.8h, v6.8b + sxtl2 v14.8h, v6.16b + + + INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + ssubw v9.8h, v9.8h, v2.8b + ssubw v10.8h, v10.8h, v2.8b + + INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + ssubw v13.8h, v13.8h, v1.8b + ssubw v14.8h, v14.8h, v1.8b - sxtl v18.4s, v16.4h // v18: int32x4 - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - sxtl v2.4s, v10.4h // v18: int32x4 - sxtl2 v3.4s, v10.8h - sxtl v4.4s, v11.4h - sxtl2 v5.4s, v11.8h - sxtl v6.4s, v12.4h - sxtl2 v7.4s, v12.8h - sxtl v8.4s, v13.4h - sxtl2 v9.4s, v13.8h - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s + L8SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + sxtl v19.4s, v9.4h + sxtl2 v20.4s, v9.8h + sxtl v21.4s, v10.4h + sxtl2 v22.4s, v10.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + sxtl v27.4s, v13.4h + sxtl2 v28.4s, v13.8h + sxtl v29.4s, v14.4h + sxtl2 v30.4s, v14.8h + + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + scvtf v19.4s, v19.4s + scvtf v20.4s, v20.4s + scvtf v21.4s, v21.4s + scvtf v22.4s, v22.4s + + scvtf v23.4s, v23.4s scvtf v24.4s, v24.4s scvtf v25.4s, v25.4s scvtf v26.4s, v26.4s scvtf v27.4s, v27.4s - - scvtf v2.4s, v2.4s - scvtf v3.4s, v3.4s - scvtf v4.4s, v4.4s - scvtf v5.4s, v5.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v29.4s - fmul v5.4s, v5.4s, v29.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - fmul v8.4s, v8.4s, v30.4s - fmul v9.4s, v9.4s, v30.4s - - fmul v18.4s, v18.4s, v29.4s - fmul v19.4s, v19.4s, v29.4s - fmul v20.4s, v20.4s, v29.4s - fmul v21.4s, v21.4s, v29.4s - fmul v24.4s, v24.4s, v30.4s - fmul v25.4s, v25.4s, v30.4s - fmul v26.4s, v26.4s, v30.4s - fmul v27.4s, v27.4s, v30.4s - - fsub v2.4s, v2.4s, v6.4s // (x-y) - fsub v3.4s, v3.4s, v7.4s - fsub v4.4s, v4.4s, v8.4s - fsub v5.4s, v5.4s, v9.4s - - fsub v18.4s, v18.4s, v24.4s - fsub v19.4s, v19.4s, v25.4s - fsub v20.4s, v20.4s, v26.4s - fsub v21.4s, v21.4s, v27.4s - - fmul v2.4s, v2.4s, v2.4s // (x-y)*(x-y) - fmul v3.4s, v3.4s, v3.4s - fmul v4.4s, v4.4s, v4.4s - fmul v5.4s, v5.4s, v5.4s - fmul v18.4s, v18.4s, v18.4s - fmul v19.4s, v19.4s, v19.4s - fmul v20.4s, v20.4s, v20.4s - fmul v21.4s, v21.4s, v21.4s - - fmul v2.4s, v2.4s, v31.4s - fmul v3.4s, v3.4s, v31.4s - fmul v4.4s, v4.4s, v31.4s - fmul v5.4s, v5.4s, v31.4s - fmul v18.4s, v18.4s, v31.4s - fmul v19.4s, v19.4s, v31.4s - fmul v20.4s, v20.4s, v31.4s - fmul v21.4s, v21.4s, v31.4s - - fcvtzs v18.4s, v18.4s - fcvtzs v19.4s, v19.4s - fcvtzs v20.4s, v20.4s - fcvtzs v21.4s, v21.4s - fcvtzs v2.4s, v2.4s - fcvtzs v3.4s, v3.4s - fcvtzs v4.4s, v4.4s - fcvtzs v5.4s, v5.4s - - sqxtn v18.4h, v18.4s - sqxtn2 v18.8h, v19.4s - sqxtn v19.4h, v20.4s - sqxtn2 v19.8h, v21.4s - - sqxtn v2.4h, v2.4s - sqxtn2 v2.8h, v3.4s - sqxtn v3.4h, v4.4s - sqxtn2 v3.8h, v5.4s - - sqxtn v18.8b, v18.8h - sqxtn2 v18.16b, v19.8h - sqxtn v2.8b, v2.8h - sqxtn2 v2.16b, v3.8h - - st1 {v18.16b}, [x0], #16 - st1 {v2.16b}, [x0], #16 - - sub x6, x6, #8 - cmp x6, #8 + scvtf v28.4s, v28.4s + scvtf v29.4s, v29.4s + scvtf v30.4s, v30.4s + + fmul v15.4s, v15.4s, v0.s[0] + fmul v16.4s, v16.4s, v0.s[0] + fmul v17.4s, v17.4s, v0.s[0] + fmul v18.4s, v18.4s, v0.s[0] + fmul v19.4s, v19.4s, v0.s[0] + fmul v20.4s, v20.4s, v0.s[0] + fmul v21.4s, v21.4s, v0.s[0] + fmul v22.4s, v22.4s, v0.s[0] + + fmul v23.4s, v23.4s, v0.s[1] + fmul v24.4s, v24.4s, v0.s[1] + fmul v25.4s, v25.4s, v0.s[1] + fmul v26.4s, v26.4s, v0.s[1] + fmul v27.4s, v27.4s, v0.s[1] + fmul v28.4s, v28.4s, v0.s[1] + fmul v29.4s, v29.4s, v0.s[1] + fmul v30.4s, v30.4s, v0.s[1] + + fsub v15.4s, v15.4s, v23.4s + fsub v16.4s, v16.4s, v24.4s + fsub v17.4s, v17.4s, v25.4s + fsub v18.4s, v18.4s, v26.4s + fsub v19.4s, v19.4s, v27.4s + fsub v20.4s, v20.4s, v28.4s + fsub v21.4s, v21.4s, v29.4s + fsub v22.4s, v22.4s, v30.4s + + fmul v15.4s, v15.4s, v15.4s + fmul v16.4s, v16.4s, v16.4s + fmul v17.4s, v17.4s, v17.4s + fmul v18.4s, v18.4s, v18.4s + fmul v19.4s, v19.4s, v19.4s + fmul v20.4s, v20.4s, v20.4s + fmul v21.4s, v21.4s, v21.4s + fmul v22.4s, v22.4s, v22.4s + + fmul v15.4s, v15.4s, v0.s[2] + fmul v16.4s, v16.4s, v0.s[2] + fmul v17.4s, v17.4s, v0.s[2] + fmul v18.4s, v18.4s, v0.s[2] + fmul v19.4s, v19.4s, v0.s[2] + fmul v20.4s, v20.4s, v0.s[2] + fmul v21.4s, v21.4s, v0.s[2] + fmul v22.4s, v22.4s, v0.s[2] + + fcvtas v15.4s, v15.4s + fcvtas v16.4s, v16.4s + fcvtas v17.4s, v17.4s + fcvtas v18.4s, v18.4s + fcvtas v19.4s, v19.4s + fcvtas v20.4s, v20.4s + fcvtas v21.4s, v21.4s + fcvtas v22.4s, v22.4s + + sqxtn v1.4h, v15.4s + sqxtn2 v1.8h, v16.4s + sqxtn v2.4h, v17.4s + sqxtn2 v2.8h, v18.4s + sqxtn v3.4h, v19.4s + sqxtn2 v3.8h, v20.4s + sqxtn v4.4h, v21.4s + sqxtn2 v4.8h, v22.4s + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + saddw v3.8h, v3.8h, v14.8b + saddw v4.8h, v4.8h, v14.8b + + SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + sqxtn v6.8b, v3.8h + sqxtn2 v6.16b, v4.8h + + st1 {v5.16b, v6.16b}, [x0], #32 + + sub x8, x8, #8 + cmp x8, #8 bge L8Loop - cmp x6, #0 + cmp x8, #0 ble End - cmp x6, #4 + cmp x8, #4 blt L1Loop L4Loop: - cmp x7, #0 + cmp x9, #0 beq L4NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 - ld1 {v28.16b}, [x2], #16 + ld1 {v3.16b}, [x1], #16 // input00, input01 + ld1 {v5.16b}, [x2], #16 // input10, input11 b L4Compute L4NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v28.16b}, [x2], #16 + ld1r {v3.16b}, [x1] + ld1 {v5.16b}, [x2], #16 b L4Compute L4NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v27.16b}, [x1], #16 + ld1 {v3.16b}, [x1], #16 + ld1r {v5.16b}, [x2] b L4Compute L4Compute: - sxtl v16.8h, v27.8b - sxtl2 v17.8h, v27.16b - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v18.4s, v16.4h - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - scvtf v0.4s, v18.4s - scvtf v1.4s, v19.4s - scvtf v2.4s, v20.4s - scvtf v3.4s, v21.4s - scvtf v4.4s, v24.4s - scvtf v5.4s, v25.4s - scvtf v6.4s, v26.4s - scvtf v7.4s, v27.4s - - fmul v0.4s, v0.4s, v29.4s - fmul v1.4s, v1.4s, v29.4s - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v30.4s - fmul v5.4s, v5.4s, v30.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - - fsub v0.4s, v0.4s, v4.4s - fsub v1.4s, v1.4s, v5.4s - fsub v2.4s, v2.4s, v6.4s - fsub v3.4s, v3.4s, v7.4s - - fmul v0.4s, v0.4s, v0.4s - fmul v1.4s, v1.4s, v1.4s - fmul v2.4s, v2.4s, v2.4s - fmul v3.4s, v3.4s, v3.4s - - fmul v16.4s, v0.4s, v31.4s - fmul v17.4s, v1.4s, v31.4s - fmul v18.4s, v2.4s, v31.4s - fmul v19.4s, v3.4s, v31.4s - - fcvtzs v20.4s, v16.4s - fcvtzs v21.4s, v17.4s - fcvtzs v22.4s, v18.4s - fcvtzs v23.4s, v19.4s - - sqxtn v0.4h, v20.4s - sqxtn2 v0.8h, v21.4s - sqxtn v1.4h, v22.4s - sqxtn2 v1.8h, v23.4s - - sqxtn v2.8b, v0.8h - sqxtn v3.8b, v1.8h - - st1 {v2.8b, v3.8b}, [x0], #16 - sub x6, x6, #4 - cmp x6, #4 + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + + L4_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + + L4_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + + L4SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + + scvtf v15.4s, v15.4s + scvtf v16.4s, v16.4s + scvtf v17.4s, v17.4s + scvtf v18.4s, v18.4s + + scvtf v23.4s, v23.4s + scvtf v24.4s, v24.4s + scvtf v25.4s, v25.4s + scvtf v26.4s, v26.4s + + fmul v15.4s, v15.4s, v0.s[0] + fmul v16.4s, v16.4s, v0.s[0] + fmul v17.4s, v17.4s, v0.s[0] + fmul v18.4s, v18.4s, v0.s[0] + + fmul v23.4s, v23.4s, v0.s[1] + fmul v24.4s, v24.4s, v0.s[1] + fmul v25.4s, v25.4s, v0.s[1] + fmul v26.4s, v26.4s, v0.s[1] + + fsub v15.4s, v15.4s, v23.4s + fsub v16.4s, v16.4s, v24.4s + fsub v17.4s, v17.4s, v25.4s + fsub v18.4s, v18.4s, v26.4s + + fmul v15.4s, v15.4s, v15.4s + fmul v16.4s, v16.4s, v16.4s + fmul v17.4s, v17.4s, v17.4s + fmul v18.4s, v18.4s, v18.4s + + fmul v15.4s, v15.4s, v0.s[2] + fmul v16.4s, v16.4s, v0.s[2] + fmul v17.4s, v17.4s, v0.s[2] + fmul v18.4s, v18.4s, v0.s[2] + + fcvtas v15.4s, v15.4s + fcvtas v16.4s, v16.4s + fcvtas v17.4s, v17.4s + fcvtas v18.4s, v18.4s + + sqxtn v1.4h, v15.4s + sqxtn2 v1.8h, v16.4s + sqxtn v2.4h, v17.4s + sqxtn2 v2.8h, v18.4s + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + + L4_SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + st1 {v5.16b}, [x0], #16 + sub x8, x8, #4 + cmp x8, #4 bge L4Loop L1: -cmp x6, #0 +cmp x8, #0 beq End L1Loop: - cmp x7, #0 + cmp x9, #0 beq L1NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - ld1 {v27.s}[0], [x1], #4 - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.s}[0], [x1], #4 // input00, input01 + ld1 {v5.s}[0], [x2], #4 // input10, input11 b L1Compute L1NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.8b, v27.b[0] - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.b}[0], [x1] + dup v3.8b, v3.b[0] + ld1 {v5.s}[0], [x2], #4 b L1Compute L1NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.8b, v28.b[0] - ld1 {v27.s}[0], [x1], #4 + ld1 {v3.s}[0], [x1], #4 + ld1r {v5.8b}, [x2] b L1Compute L1Compute: - sxtl v16.8h, v27.8b - sxtl v18.8h, v28.8b - sxtl v17.4s, v16.4h - sxtl v19.4s, v18.4h - - scvtf v0.4s, v17.4s - scvtf v2.4s, v19.4s - fmul v1.4s, v0.4s, v29.4s - fmul v3.4s, v2.4s, v30.4s - - fsub v4.4s, v1.4s, v3.4s - fmul v4.4s, v4.4s, v4.4s - fmul v0.4s, v4.4s, v31.4s - - fcvtzs v5.4s, v0.4s - sqxtn v6.4h, v5.4s - sqxtn v7.8b, v6.8h - st1 {v7.s}[0], [x0], #4 - - subs x6, x6, #1 + sxtl v7.8h, v3.8b + sxtl v11.8h, v5.8b + + L1_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + L1_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + + L1SXTL_S32: + sxtl v15.4s, v7.4h + sxtl v23.4s, v11.4h + + scvtf v15.4s, v15.4s + scvtf v23.4s, v23.4s + + fmul v15.4s, v15.4s, v0.s[0] + fmul v23.4s, v23.4s, v0.s[1] + fsub v15.4s, v15.4s, v23.4s + fmul v15.4s, v15.4s, v15.4s + + fmul v15.4s, v15.4s, v0.s[2] + + fcvtas v15.4s, v15.4s + + sqxtn v1.4h, v15.4s + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + + L1_SQXTN_S8: + sqxtn v5.8b, v1.8h + st1 {v5.s}[0], [x0], #4 + + subs x8, x8, #1 bne L1Loop End: - +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 ret #endif diff --git a/source/backend/cpu/arm/arm64/MNNBinarySubInt8.S b/source/backend/cpu/arm/arm64/MNNBinarySubInt8.S index 9fb91c868..96c9f518d 100644 --- a/source/backend/cpu/arm/arm64/MNNBinarySubInt8.S +++ b/source/backend/cpu/arm/arm64/MNNBinarySubInt8.S @@ -1,6 +1,5 @@ // // MNNBinarySubInt8.S -// MNN // // Created by MNN on 2019/08/14. // Copyright © 2018, Alibaba Group Holding Limited @@ -13,307 +12,305 @@ .align 5 asm_function MNNBinarySubInt8 -// MNNBinarySubInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, -// const float* scale0, const float* scale1, const float* outputScale, const size_t size, size_t needBroadcast) -// x0: dst, x1:src0, x2:src1, x3:scale0, x4:scale1, x5:outputScale, x6:size, x7: needBroadcast - -cmp x6, #0 +// MNNBinarySubInt8(int8_t* dst, const int8_t* src0, const int8_t* src1, ssize_t* quantScalesInt32, +// float* quantScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset,, +// const size_t size, size_t needBroadcast) +// Auto load: +// x0: dst, x1:src0, x2:src1, x3:quantScalesInt32, x4:quantScalesFp32, x5: offset0, x6: offset1, x7: outputoffset +// Load from sp: +//x8: size, x9: needBroadcast + +ldr x8, [sp, #0] +ldr x9, [sp, #8] + +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +cmp x8, #0 beq End -ld1 {v29.4s}, [x3] // scale0 -ld1 {v30.4s}, [x4] // scale1 -ld1 {v31.4s}, [x5] // outputScale +ldr w4, [x3, #8] +ldr w3, [x3] +mov v0.s[0], w3 +mov v0.s[1], w4 -cmp x6, #8 +cmp x8, #8 bge L8Loop -cmp x6, #4 +cmp x8, #4 bge L4Loop blt L1 L8Loop: - cmp x7, #0 + cmp x9, #0 beq L8NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L8NeedBroadcast1 L8NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 // input0[0]: int8x16 - ld1 {v28.16b}, [x2], #16 // input1[0]: int8x16 - ld1 {v14.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 // input00, input01 + ld1 {v5.16b, v6.16b}, [x2], #32 // input10, input11 b L8Compute L8NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v14.b}[0], [x1] - dup v14.16b, v14.b[0] - ld1 {v28.16b}, [x2], #16 // input0[1]: int8x16 - ld1 {v15.16b}, [x2], #16 // input1[1]: int8x16 + ld1r {v3.16b}, [x1] + ld1r {v4.16b}, [x1] + ld1 {v5.16b, v6.16b}, [x2], #32 b L8Compute L8NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v15.b}[0], [x2] - dup v15.16b, v15.b[0] - ld1 {v27.16b}, [x1], #16 // input0[1]: int8x16 - ld1 {v14.16b}, [x1], #16 // input1[1]: int8x16 + ld1 {v3.16b, v4.16b}, [x1], #32 + ld1r {v5.16b}, [x2] + ld1r {v6.16b}, [x2] b L8Compute L8Compute: - sxtl v16.8h, v27.8b // v16: int16x8 - sxtl2 v17.8h, v27.16b // v17: int16x8 - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v10.8h, v14.8b // v10: int16x8 - sxtl2 v11.8h, v14.16b // v11: int16x8 - sxtl v12.8h, v15.8b - sxtl2 v13.8h, v15.16b + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v9.8h, v4.8b + sxtl2 v10.8h, v4.16b + + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + sxtl v13.8h, v6.8b + sxtl2 v14.8h, v6.16b + + INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + ssubw v9.8h, v9.8h, v2.8b + ssubw v10.8h, v10.8h, v2.8b + + INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + ssubw v13.8h, v13.8h, v1.8b + ssubw v14.8h, v14.8h, v1.8b - sxtl v18.4s, v16.4h // v18: int32x4 - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - sxtl v2.4s, v10.4h // v18: int32x4 - sxtl2 v3.4s, v10.8h - sxtl v4.4s, v11.4h - sxtl2 v5.4s, v11.8h - sxtl v6.4s, v12.4h - sxtl2 v7.4s, v12.8h - sxtl v8.4s, v13.4h - sxtl2 v9.4s, v13.8h - - scvtf v18.4s, v18.4s - scvtf v19.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s - scvtf v24.4s, v24.4s - scvtf v25.4s, v25.4s - scvtf v26.4s, v26.4s - scvtf v27.4s, v27.4s - - scvtf v2.4s, v2.4s - scvtf v3.4s, v3.4s - scvtf v4.4s, v4.4s - scvtf v5.4s, v5.4s - scvtf v6.4s, v6.4s - scvtf v7.4s, v7.4s - scvtf v8.4s, v8.4s - scvtf v9.4s, v9.4s - - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v29.4s - fmul v5.4s, v5.4s, v29.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - fmul v8.4s, v8.4s, v30.4s - fmul v9.4s, v9.4s, v30.4s - - fmul v18.4s, v18.4s, v29.4s - fmul v19.4s, v19.4s, v29.4s - fmul v20.4s, v20.4s, v29.4s - fmul v21.4s, v21.4s, v29.4s - fmul v24.4s, v24.4s, v30.4s - fmul v25.4s, v25.4s, v30.4s - fmul v26.4s, v26.4s, v30.4s - fmul v27.4s, v27.4s, v30.4s - - fsub v2.4s, v2.4s, v6.4s - fsub v3.4s, v3.4s, v7.4s - fsub v4.4s, v4.4s, v8.4s - fsub v5.4s, v5.4s, v9.4s - - fsub v18.4s, v18.4s, v24.4s - fsub v19.4s, v19.4s, v25.4s - fsub v20.4s, v20.4s, v26.4s - fsub v21.4s, v21.4s, v27.4s - - fmul v2.4s, v2.4s, v31.4s - fmul v3.4s, v3.4s, v31.4s - fmul v4.4s, v4.4s, v31.4s - fmul v5.4s, v5.4s, v31.4s - fmul v18.4s, v18.4s, v31.4s - fmul v19.4s, v19.4s, v31.4s - fmul v20.4s, v20.4s, v31.4s - fmul v21.4s, v21.4s, v31.4s - - fcvtzs v18.4s, v18.4s - fcvtzs v19.4s, v19.4s - fcvtzs v20.4s, v20.4s - fcvtzs v21.4s, v21.4s - fcvtzs v2.4s, v2.4s - fcvtzs v3.4s, v3.4s - fcvtzs v4.4s, v4.4s - fcvtzs v5.4s, v5.4s - - sqxtn v18.4h, v18.4s - sqxtn2 v18.8h, v19.4s - sqxtn v19.4h, v20.4s - sqxtn2 v19.8h, v21.4s - - sqxtn v2.4h, v2.4s - sqxtn2 v2.8h, v3.4s - sqxtn v3.4h, v4.4s - sqxtn2 v3.8h, v5.4s - - sqxtn v18.8b, v18.8h - sqxtn2 v18.16b, v19.8h - sqxtn v2.8b, v2.8h - sqxtn2 v2.16b, v3.8h - - st1 {v18.16b}, [x0], #16 - st1 {v2.16b}, [x0], #16 - - sub x6, x6, #8 - cmp x6, #8 + + L8SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + sxtl v19.4s, v9.4h + sxtl2 v20.4s, v9.8h + sxtl v21.4s, v10.4h + sxtl2 v22.4s, v10.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + sxtl v27.4s, v13.4h + sxtl2 v28.4s, v13.8h + sxtl v29.4s, v14.4h + sxtl2 v30.4s, v14.8h + + mul v15.4s, v15.4s, v0.s[0] + mul v16.4s, v16.4s, v0.s[0] + mul v17.4s, v17.4s, v0.s[0] + mul v18.4s, v18.4s, v0.s[0] + mul v19.4s, v19.4s, v0.s[0] + mul v20.4s, v20.4s, v0.s[0] + mul v21.4s, v21.4s, v0.s[0] + mul v22.4s, v22.4s, v0.s[0] + + mul v23.4s, v23.4s, v0.s[1] + mul v24.4s, v24.4s, v0.s[1] + mul v25.4s, v25.4s, v0.s[1] + mul v26.4s, v26.4s, v0.s[1] + mul v27.4s, v27.4s, v0.s[1] + mul v28.4s, v28.4s, v0.s[1] + mul v29.4s, v29.4s, v0.s[1] + mul v30.4s, v30.4s, v0.s[1] + + sqsub v15.4s, v15.4s, v23.4s + sqsub v16.4s, v16.4s, v24.4s + sqsub v17.4s, v17.4s, v25.4s + sqsub v18.4s, v18.4s, v26.4s + sqsub v19.4s, v19.4s, v27.4s + sqsub v20.4s, v20.4s, v28.4s + sqsub v21.4s, v21.4s, v29.4s + sqsub v22.4s, v22.4s, v30.4s + + sqrshrn v1.4h, v15.4s, #16 + sqrshrn2 v1.8h, v16.4s, #16 + sqrshrn v2.4h, v17.4s, #16 + sqrshrn2 v2.8h, v18.4s, #16 + sqrshrn v3.4h, v19.4s, #16 + sqrshrn2 v3.8h, v20.4s, #16 + sqrshrn v4.4h, v21.4s, #16 + sqrshrn2 v4.8h, v22.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + saddw v3.8h, v3.8h, v14.8b + saddw v4.8h, v4.8h, v14.8b + + SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + sqxtn v6.8b, v3.8h + sqxtn2 v6.16b, v4.8h + + st1 {v5.16b, v6.16b}, [x0], #32 + + sub x8, x8, #8 + cmp x8, #8 bge L8Loop - cmp x6, #0 + cmp x8, #0 ble End - cmp x6, #4 + cmp x8, #4 blt L1Loop L4Loop: - cmp x7, #0 + cmp x9, #0 beq L4NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L4NeedBroadcast1 L4NotNeedBroadcast: - ld1 {v27.16b}, [x1], #16 - ld1 {v28.16b}, [x2], #16 + ld1 {v3.16b}, [x1], #16 // input00, input01 + ld1 {v5.16b}, [x2], #16 // input10, input11 b L4Compute L4NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.16b, v27.b[0] - ld1 {v28.16b}, [x2], #16 + ld1r {v3.16b}, [x1] + ld1 {v5.16b}, [x2], #16 b L4Compute L4NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.16b, v28.b[0] - ld1 {v27.16b}, [x1], #16 + ld1 {v3.16b}, [x1], #16 + ld1r {v5.16b}, [x2] b L4Compute L4Compute: - sxtl v16.8h, v27.8b - sxtl2 v17.8h, v27.16b - sxtl v22.8h, v28.8b - sxtl2 v23.8h, v28.16b - - sxtl v18.4s, v16.4h - sxtl2 v19.4s, v16.8h - sxtl v20.4s, v17.4h - sxtl2 v21.4s, v17.8h - sxtl v24.4s, v22.4h - sxtl2 v25.4s, v22.8h - sxtl v26.4s, v23.4h - sxtl2 v27.4s, v23.8h - - scvtf v0.4s, v18.4s - scvtf v1.4s, v19.4s - scvtf v2.4s, v20.4s - scvtf v3.4s, v21.4s - scvtf v4.4s, v24.4s - scvtf v5.4s, v25.4s - scvtf v6.4s, v26.4s - scvtf v7.4s, v27.4s - - fmul v0.4s, v0.4s, v29.4s - fmul v1.4s, v1.4s, v29.4s - fmul v2.4s, v2.4s, v29.4s - fmul v3.4s, v3.4s, v29.4s - fmul v4.4s, v4.4s, v30.4s - fmul v5.4s, v5.4s, v30.4s - fmul v6.4s, v6.4s, v30.4s - fmul v7.4s, v7.4s, v30.4s - - fsub v0.4s, v0.4s, v4.4s - fsub v1.4s, v1.4s, v5.4s - fsub v2.4s, v2.4s, v6.4s - fsub v3.4s, v3.4s, v7.4s - - fmul v16.4s, v0.4s, v31.4s - fmul v17.4s, v1.4s, v31.4s - fmul v18.4s, v2.4s, v31.4s - fmul v19.4s, v3.4s, v31.4s - - fcvtzs v20.4s, v16.4s - fcvtzs v21.4s, v17.4s - fcvtzs v22.4s, v18.4s - fcvtzs v23.4s, v19.4s - - sqxtn v0.4h, v20.4s - sqxtn2 v0.8h, v21.4s - sqxtn v1.4h, v22.4s - sqxtn2 v1.8h, v23.4s - - sqxtn v2.8b, v0.8h - sqxtn v3.8b, v1.8h - - st1 {v2.8b, v3.8b}, [x0], #16 - sub x6, x6, #4 - cmp x6, #4 + sxtl v7.8h, v3.8b + sxtl2 v8.8h, v3.16b + sxtl v11.8h, v5.8b + sxtl2 v12.8h, v5.16b + + L4_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + ssubw v8.8h, v8.8h, v2.8b + + L4_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b + ssubw v12.8h, v12.8h, v1.8b + + L4SXTL_S32: + sxtl v15.4s, v7.4h + sxtl2 v16.4s, v7.8h + sxtl v17.4s, v8.4h + sxtl2 v18.4s, v8.8h + + sxtl v23.4s,v11.4h + sxtl2 v24.4s, v11.8h + sxtl v25.4s, v12.4h + sxtl2 v26.4s, v12.8h + + mul v15.4s, v15.4s, v0.s[0] + mul v16.4s, v16.4s, v0.s[0] + mul v17.4s, v17.4s, v0.s[0] + mul v18.4s, v18.4s, v0.s[0] + + mul v23.4s, v23.4s, v0.s[1] + mul v24.4s, v24.4s, v0.s[1] + mul v25.4s, v25.4s, v0.s[1] + mul v26.4s, v26.4s, v0.s[1] + + sqsub v15.4s, v15.4s, v23.4s + sqsub v16.4s, v16.4s, v24.4s + sqsub v17.4s, v17.4s, v25.4s + sqsub v18.4s, v18.4s, v26.4s + + sqrshrn v1.4h, v15.4s, #16 + sqrshrn2 v1.8h, v16.4s, #16 + sqrshrn v2.4h, v17.4s, #16 + sqrshrn2 v2.8h, v18.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + saddw v2.8h, v2.8h, v14.8b + + L4_SQXTN_S8: + sqxtn v5.8b, v1.8h + sqxtn2 v5.16b, v2.8h + st1 {v5.16b}, [x0], #16 + sub x8, x8, #4 + cmp x8, #4 bge L4Loop L1: -cmp x6, #0 +cmp x8, #0 beq End L1Loop: - cmp x7, #0 + cmp x9, #0 beq L1NeedBroadcast0 - cmp x7, #1 + cmp x9, #1 beq L1NeedBroadcast1 L1NotNeedBroadcast: - ld1 {v27.s}[0], [x1], #4 - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.s}[0], [x1], #4 // input00, input01 + ld1 {v5.s}[0], [x2], #4 // input10, input11 b L1Compute L1NeedBroadcast0: - ld1 {v27.b}[0], [x1] - dup v27.8b, v27.b[0] - ld1 {v28.s}[0], [x2], #4 + ld1 {v3.b}[0], [x1] + dup v3.8b, v3.b[0] + ld1 {v5.s}[0], [x2], #4 b L1Compute L1NeedBroadcast1: - ld1 {v28.b}[0], [x2] - dup v28.8b, v28.b[0] - ld1 {v27.s}[0], [x1], #4 + ld1 {v3.s}[0], [x1], #4 + ld1r {v5.8b}, [x2] b L1Compute L1Compute: - sxtl v16.8h, v27.8b - sxtl v18.8h, v28.8b - sxtl v17.4s, v16.4h - sxtl v19.4s, v18.4h + sxtl v7.8h, v3.8b + sxtl v11.8h, v5.8b - scvtf v0.4s, v17.4s - scvtf v2.4s, v19.4s - fmul v1.4s, v0.4s, v29.4s - fmul v3.4s, v2.4s, v30.4s + L1_INPUT0_SUB_ZERO: + ld1r {v2.8b}, [x5] + ssubw v7.8h, v7.8h, v2.8b + L1_INPUT1_SUB_ZERO: + ld1r {v1.8b}, [x6] + ssubw v11.8h, v11.8h, v1.8b - fsub v4.4s, v1.4s, v3.4s - fmul v0.4s, v4.4s, v31.4s + L1SXTL_S32: + sxtl v15.4s, v7.4h + sxtl v23.4s, v11.4h - fcvtzs v5.4s, v0.4s - sqxtn v6.4h, v5.4s - sqxtn v7.8b, v6.8h - st1 {v7.s}[0], [x0], #4 + mul v15.4s, v15.4s, v0.s[0] + mul v23.4s, v23.4s, v0.s[1] - subs x6, x6, #1 + sqsub v15.4s, v15.4s, v23.4s + + sqrshrn v1.4h, v15.4s, #16 + + ld1r {v14.8b}, [x7] + saddw v1.8h, v1.8h, v14.8b + + L1_SQXTN_S8: + sqxtn v5.8b, v1.8h + st1 {v5.s}[0], [x0], #4 + + subs x8, x8, #1 bne L1Loop End: - +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 ret #endif diff --git a/source/backend/cpu/arm/arm64/MNNCubicLineC16.S b/source/backend/cpu/arm/arm64/MNNCubicLineC16.S index 2985f4813..5e59ba6a6 100644 --- a/source/backend/cpu/arm/arm64/MNNCubicLineC16.S +++ b/source/backend/cpu/arm/arm64/MNNCubicLineC16.S @@ -11,22 +11,25 @@ .text .align 5 asm_function MNNCubicLineC16 -// void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, -// size_t number); +// void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, int8_t* zeroPoint, +// size_t number, ssize_t minValue, ssize_t maxValue); // Auto load: -// x0: dst, x1: A, x2: B, x3: C, x4: D, x5: t, x6: number - +// x0: dst, x1: A, x2: B, x3: C, x4: D, x5: t, x6: zeroPoint, x7: number +// Load from sp: x8: minValue, x9: maxValue +ldr x8, [sp, #0] +ldr x9, [sp, #8] stp d14, d15, [sp, #-64]! stp d12, d13, [sp, #16] stp d10, d11, [sp, #32] stp d8, d9, [sp, #48] -cmp x6, #0 +cmp x7, #0 beq END ldr w5, [x5, #0] fmov s1, #1.0 +ld1r {v0.8b}, [x6] dup v31.4s, w5 // v31: t fmov s30, #1.0 @@ -100,25 +103,27 @@ fcvtas v5.4s, v5.4s fcvtas v6.4s, v6.4s fcvtas v7.4s, v7.4s -movi v18.16b, #0 -movi v19.16b, #127 -sub v18.16b, v18.16b, v19.16b +dup v18.16b, w8 +dup v19.16b, w9 sqxtn v4.4h, v4.4s sqxtn2 v4.8h, v5.4s sqxtn v6.4h, v6.4s sqxtn2 v6.8h, v7.4s -sqxtn v4.8b, v4.8h -sqxtn2 v4.16b, v6.8h +saddw v4.8h, v4.8h, v0.8b +saddw v6.8h, v6.8h, v0.8b + +sqxtn v10.8b, v4.8h +sqxtn2 v10.16b, v6.8h -smin v4.16b, v4.16b, v19.16b -smax v4.16b, v4.16b, v18.16b +smin v10.16b, v10.16b, v19.16b +smax v10.16b, v10.16b, v18.16b -st1 {v4.16b}, [x0], #16 +st1 {v10.16b}, [x0], #16 -sub x6, x6, #1 -cmp x6, #1 +sub x7, x7, #1 +cmp x7, #1 bge L1Loop END: diff --git a/source/backend/cpu/arm/arm64/MNNCubicSampleC16.S b/source/backend/cpu/arm/arm64/MNNCubicSampleC16.S index 5f9cc9915..a023d83bf 100644 --- a/source/backend/cpu/arm/arm64/MNNCubicSampleC16.S +++ b/source/backend/cpu/arm/arm64/MNNCubicSampleC16.S @@ -11,29 +11,28 @@ .text .align 5 asm_function MNNCubicSampleC16 -// void MNNCubicSampleC16(const int8_t* src, float* dst, int32_t* position, const float* factor, size_t number) +// void MNNCubicSampleC16(const int8_t* src, float* dst, int32_t* position, const float* factor, int8_t* zeroPoint, size_t number) // Auto load: -// x0: src, x1: dst, x2: position, x3: factor, x4: number +// x0: src, x1: dst, x2: position, x3: factor, x4: zeroPoint, x5: number stp d14, d15, [sp, #-64]! stp d12, d13, [sp, #16] stp d10, d11, [sp, #32] stp d8, d9, [sp, #48] -cmp x4, #0 +cmp x5, #0 beq END mov w15, #16 uxtw x15, w15 L1Loop: -ldr w5, [x3, #0] -add x3, x3, #4 +ld1 {v31.s}[0], [x3], #4 fmov s1, #1.0 -dup v31.4s, w5 // v31: t +dup v31.4s, v31.s[0] // v31: t fmov s30, #1.0 fsub s30, s30, s31 // 1-t @@ -86,6 +85,8 @@ uxtw x8, w8 uxtw x9, w9 uxtw x10, w10 +ld1r {v11.8b}, [x4] + mul x7, x7, x15 mul x8, x8, x15 mul x9, x9, x15 @@ -109,6 +110,15 @@ sxtl2 v17.8h, v15.16b sxtl v23.8h, v22.8b sxtl2 v24.8h, v22.16b +ssubw v1.8h, v1.8h, v11.8b +ssubw v2.8h, v2.8h, v11.8b +ssubw v9.8h, v9.8h, v11.8b +ssubw v10.8h, v10.8h, v11.8b +ssubw v16.8h, v16.8h, v11.8b +ssubw v17.8h, v17.8h, v11.8b +ssubw v23.8h, v23.8h, v11.8b +ssubw v24.8h, v24.8h, v11.8b + sxtl v4.4s, v1.4h sxtl2 v5.4s, v1.8h sxtl v6.4s, v2.4h @@ -162,8 +172,8 @@ fmla v6.4s, v27.4s, v31.s[0] fmla v7.4s, v28.4s, v31.s[0] st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 -sub x4, x4, #1 -cmp x4, #1 +sub x5, x5, #1 +cmp x5, #1 bge L1Loop END: diff --git a/source/backend/cpu/arm/arm64/MNNScaleAndAddBiasInt8.S b/source/backend/cpu/arm/arm64/MNNScaleAndAddBiasInt8.S index acbd529d5..e0b37b611 100644 --- a/source/backend/cpu/arm/arm64/MNNScaleAndAddBiasInt8.S +++ b/source/backend/cpu/arm/arm64/MNNScaleAndAddBiasInt8.S @@ -14,15 +14,15 @@ asm_function MNNScaleAndAddBiasInt8 // MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, const int32_t* alpha, int32_t mShiftBits, -// ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack) +// ssize_t minValue, ssize_t maxValue, int8_t* inputZeroPoint, int8_t* outputZeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack) -//Auto: x0:dst, x1:src, x2:bias, x3:alpha, x4:mShiftBits, x5:minValue, x6:maxValue, x7:zeroPoint -//Load from sp: x8:planeNumber, x9:biasNumber +//Auto: x0:dst, x1:src, x2:bias, x3:alpha, x4:mShiftBits, x5:minValue, x6:maxValue, x7:inputZeroPoint +//Load from sp: x11:outputZeroPoint, x8:planeNumber, x9:biasNumber //avoid to touch platform-register x-18 - -ldr x8, [sp, #0] -ldr x9, [sp, #8] +ldr x11, [sp, #0] +ldr x8, [sp, #8] +ldr x9, [sp, #16] stp d14, d15, [sp, #-64]! stp d12, d13, [sp, #16] @@ -38,9 +38,7 @@ beq BSEnd dup v27.16b, w5 // min dup v28.16b, w6 // max -dup v29.4s, w4 -neg v29.4s, v29.4s - +ld1r {v29.8b}, [x7] // inputZeroPoint BSLoopZ: mov x10, x8 @@ -66,6 +64,15 @@ BSLoopP16: sxtl v10.8h, v3.8b sxtl2 v11.8h, v3.16b + ssubw v4.8h, v4.8h, v29.8b + ssubw v5.8h, v5.8h, v29.8b + ssubw v6.8h, v6.8h, v29.8b + ssubw v7.8h, v7.8h, v29.8b + ssubw v8.8h, v8.8h, v29.8b + ssubw v9.8h, v9.8h, v29.8b + ssubw v10.8h, v10.8h, v29.8b + ssubw v11.8h, v11.8h, v29.8b + sxtl v12.4s, v4.4h sxtl2 v13.4s, v4.8h sxtl v14.4s, v5.4h @@ -83,6 +90,8 @@ BSLoopP16: sxtl v26.4s, v11.4h sxtl2 v11.4s, v11.8h + ld1r {v0.8b}, [x11] + mul v12.4s, v12.4s, v30.4s mul v13.4s, v13.4s, v30.4s mul v14.4s, v14.4s, v30.4s @@ -134,6 +143,15 @@ BSLoopP16: sqrshrn v26.4h, v26.4s, #15 sqrshrn2 v26.8h, v11.4s, #15 + saddw v12.8h, v12.8h, v0.8b + saddw v14.8h, v14.8h, v0.8b + saddw v16.8h, v16.8h, v0.8b + saddw v18.8h, v18.8h, v0.8b + saddw v20.8h, v20.8h, v0.8b + saddw v22.8h, v22.8h, v0.8b + saddw v24.8h, v24.8h, v0.8b + saddw v26.8h, v26.8h, v0.8b + sqxtn v12.8b, v12.8h sqxtn2 v12.16b, v14.8h sqxtn v13.8b, v16.8h @@ -172,6 +190,11 @@ BSLoopP16: sxtl v4.8h, v1.8b sxtl2 v5.8h, v1.16b + ssubw v2.8h, v2.8h, v29.8b + ssubw v3.8h, v3.8h, v29.8b + ssubw v4.8h, v4.8h, v29.8b + ssubw v5.8h, v5.8h, v29.8b + sxtl v16.4s, v2.4h sxtl2 v17.4s, v2.8h sxtl v18.4s, v3.4h @@ -180,6 +203,7 @@ BSLoopP16: sxtl2 v21.4s, v4.8h sxtl v22.4s, v5.4h sxtl2 v23.4s, v5.8h + ld1r {v24.8b}, [x11] mul v16.4s, v16.4s, v30.4s mul v17.4s, v17.4s, v30.4s @@ -208,6 +232,11 @@ BSLoopP16: sqrshrn v22.4h, v22.4s, #15 sqrshrn2 v22.8h, v23.4s, #15 + saddw v16.8h, v16.8h, v24.8b + saddw v18.8h, v18.8h, v24.8b + saddw v20.8h, v20.8h, v24.8b + saddw v22.8h, v22.8h, v24.8b + sqxtn v0.8b, v16.8h sqxtn2 v0.16b, v18.8h sqxtn v1.8b, v20.8h @@ -233,6 +262,9 @@ BSLoopP16: sxtl v2.8h, v0.8b sxtl2 v3.8h, v0.16b + + ssubw v2.8h, v2.8h, v29.8b + ssubw v3.8h, v2.8h, v29.8b sxtl v16.4s, v2.4h sxtl2 v17.4s, v2.8h sxtl v18.4s, v3.4h @@ -242,6 +274,7 @@ BSLoopP16: mul v17.4s, v17.4s, v30.4s mul v18.4s, v18.4s, v30.4s mul v19.4s, v19.4s, v30.4s + ld1r {v20.8b}, [x11] add v16.4s, v16.4s, v31.4s add v17.4s, v17.4s, v31.4s @@ -253,6 +286,8 @@ BSLoopP16: sqrshrn v18.4h, v18.4s, #15 sqrshrn2 v18.8h, v19.4s, #15 + saddw v16.8h, v16.8h, v20.8b + saddw v18.8h, v18.8h, v20.8b sqxtn v0.8b, v16.8h sqxtn2 v0.16b, v18.8h @@ -271,8 +306,10 @@ BSLoopP16: BSLoopP1: ld1 {v0.s}[0], [x1], #4 dup v0.4s, v0.s[0] + ld1r {v20.8b}, [x11] sxtl v2.8h, v0.8b + ssubw v2.8h, v2.8h, v29.8b sxtl v1.4s, v2.4h mul v1.4s, v1.4s, v30.4s @@ -280,6 +317,7 @@ BSLoopP16: sqrshrn v1.4h, v1.4s, #15 dup v1.2d, v1.d[0] + saddw v1.8h, v1.8h, v20.8b sqxtn v1.8b, v1.8h smax v1.8b, v1.8b, v27.8b diff --git a/source/backend/cpu/compute/CommonOptFunction.h b/source/backend/cpu/compute/CommonOptFunction.h index 10ef40ec3..5e23bb32e 100644 --- a/source/backend/cpu/compute/CommonOptFunction.h +++ b/source/backend/cpu/compute/CommonOptFunction.h @@ -117,12 +117,10 @@ void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bo void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNFunctionInit(); void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); -#ifdef MNN_LOW_MEMORY void MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); -#endif void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose); struct SparseMatMulParas @@ -178,7 +176,7 @@ void MNNRoiAlignAvg(float* dst, const float* src, const std::vector maxValue) { value = maxValue; } @@ -1568,7 +1571,7 @@ void MNNBinaryAddInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* } } -void MNNBinarySubInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) { +void MNNBinarySubInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast) { float res = 0; #ifdef MNN_USE_SSE const int zeroPoint = 128; @@ -1587,19 +1590,22 @@ void MNNBinarySubInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* #endif for (int i = 0; i < elementSize; ++i) { if (needBroadcast == 0) { - float inp0 = (inputData0[0] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; + int inp0 = (inputData0[0] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; res = inp0 - inp1; } else if (needBroadcast == 1) { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[0] - zeroPoint) * inputScale1[0]; + int inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[0] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; res = inp0 - inp1; } else { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; + int inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; res = inp0 - inp1; } - int value = (int)roundf(res * outputScale[0]) + zeroPoint; + int value = (res + (1<<15)) / (1 << 16) + zeroPoint + outputOffset[0]; + if (res < 0) { + value = (res - (1<<15)) / (1 << 16) + zeroPoint + outputOffset[0]; + } if (value > maxValue) { value = maxValue; } @@ -1610,7 +1616,7 @@ void MNNBinarySubInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* } } -void MNNBinaryMulInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) { +void MNNBinaryMulInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast) { float res = 0; #ifdef MNN_USE_SSE const int zeroPoint = 128; @@ -1627,33 +1633,33 @@ void MNNBinaryMulInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif - for (int i = 0; i < elementSize; ++i) { - if (needBroadcast == 0) { - float inp0 = (inputData0[0] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; - res = inp0 * inp1; - } else if (needBroadcast == 1) { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[0] - zeroPoint) * inputScale1[0]; - res = inp0 * inp1; - } else { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; - res = inp0 * inp1; - } - int value = (int)roundf(res * outputScale[0]) + zeroPoint; - if (value > maxValue) { - value = maxValue; - } - if (value < minValue) { - value = minValue; - } - outputData[i] = value; + for (int i = 0; i < elementSize; ++i) { + if (needBroadcast == 0) { + float inp0 = (inputData0[0] - zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + float inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; + res = inp0 * inp1; + } else if (needBroadcast == 1) { + float inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + float inp1 = (inputData1[0] - zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; + res = inp0 * inp1; + } else { + float inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + float inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; + res = inp0 * inp1; } + int value = (int)roundf(res * inputScalesFp32[2]) + zeroPoint + outputOffset[0]; + if (value > maxValue) { + value = maxValue; + } + if (value < minValue) { + value = minValue; + } + outputData[i] = value; + } } -void MNNBinaryMinInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) { - float res = 0; +void MNNBinaryMinInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast) { + int res = 0; #ifdef MNN_USE_SSE const int zeroPoint = 128; const int maxValue = 255; @@ -1669,33 +1675,36 @@ void MNNBinaryMinInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif - for (int i = 0; i < elementSize; ++i) { - if (needBroadcast == 0) { - float inp0 = (inputData0[0] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; - res = std::min(inp0, inp1); - } else if (needBroadcast == 1) { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[0] - zeroPoint) * inputScale1[0]; - res = std::min(inp0, inp1); - } else { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; - res = std::min(inp0, inp1); - } - int value = (int)roundf(res * outputScale[0]) + zeroPoint; - if (value > maxValue) { - value = maxValue; - } - if (value < minValue) { - value = minValue; - } - outputData[i] = value; + for (int i = 0; i < elementSize; ++i) { + if (needBroadcast == 0) { + int inp0 = (inputData0[0] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; + res = std::min(inp0, inp1); + } else if (needBroadcast == 1) { + int inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[0] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; + res = std::min(inp0, inp1); + } else { + int inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; + res = std::min(inp0, inp1); } + int value = (res + (1<<15)) / (1 << 16) + zeroPoint + outputOffset[0]; + if (res < 0) { + value = (res - (1<<15)) / (1 << 16) + zeroPoint + outputOffset[0]; + } + if (value > maxValue) { + value = maxValue; + } + if (value < minValue) { + value = minValue; + } + outputData[i] = value; + } } -void MNNBinaryMaxInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) { - float res = 0; +void MNNBinaryMaxInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast) { + int32_t res = 0; #ifdef MNN_USE_SSE const int zeroPoint = 128; const int maxValue = 255; @@ -1711,31 +1720,34 @@ void MNNBinaryMaxInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* const int8_t* inputData1 = inputRaw1; int8_t* outputData = outputRaw; #endif - for (int i = 0; i < elementSize; ++i) { - if (needBroadcast == 0) { - float inp0 = (inputData0[0] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; - res = std::max(inp0, inp1); - } else if (needBroadcast == 1) { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[0] - zeroPoint) * inputScale1[0]; - res = std::max(inp0, inp1); - } else { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; - res = std::max(inp0, inp1); - } - int value = (int)roundf(res * outputScale[0]) + zeroPoint; - if (value > maxValue) { - value = maxValue; - } - if (value < minValue) { - value = minValue; - } - outputData[i] = value; + for (int i = 0; i < elementSize; ++i) { + if (needBroadcast == 0) { + int inp0 = (inputData0[0] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; + res = std::max(inp0, inp1); + } else if (needBroadcast == 1) { + int inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[0] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; + res = std::max(inp0, inp1); + } else { + int inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesInt32[0]; + int inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesInt32[1]; + res = std::max(inp0, inp1); } + int value = (res + (1<<15)) / (1 << 16) + zeroPoint + outputOffset[0]; + if (res < 0) { + value = (res - (1<<15)) / (1 << 16) + zeroPoint + outputOffset[0]; + } + if (value > maxValue) { + value = maxValue; + } + if (value < minValue) { + value = minValue; + } + outputData[i] = value; + } } -void MNNBinarySqdInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast) { +void MNNBinarySqdInt8 (int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast) { float res = 0; #ifdef MNN_USE_SSE const int zeroPoint = 128; @@ -1754,19 +1766,19 @@ void MNNBinarySqdInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* #endif for (int i = 0; i < elementSize; ++i) { if (needBroadcast == 0) { - float inp0 = (inputData0[0] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; + float inp0 = (inputData0[0] - zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + float inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; res = (inp0 - inp1) * (inp0 - inp1); } else if (needBroadcast == 1) { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[0] - zeroPoint) * inputScale1[0]; + float inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + float inp1 = (inputData1[0] - zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; res = (inp0 - inp1) * (inp0 - inp1); } else { - float inp0 = (inputData0[i] - zeroPoint) * inputScale0[0]; - float inp1 = (inputData1[i] - zeroPoint) * inputScale1[0]; + float inp0 = (inputData0[i] - zeroPoint - inputOffset0[0]) * inputScalesFp32[0]; + float inp1 = (inputData1[i] - zeroPoint - inputOffset1[0]) * inputScalesFp32[1]; res = (inp0 - inp1) * (inp0 - inp1); } - int value = (int)roundf(res * outputScale[0]) + zeroPoint; + int value = (int)roundf(res * inputScalesFp32[2]) + zeroPoint + outputOffset[0]; if (value > maxValue) { value = maxValue; } @@ -1777,7 +1789,7 @@ void MNNBinarySqdInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* } } -void MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, const int32_t* alpha, int32_t mShiftBits, ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack) { +void MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, const int32_t* alpha, int32_t mShiftBits, ssize_t minValue, ssize_t maxValue, int8_t* inputZeroPoint, int8_t* outputZeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack) { #ifdef MNN_USE_SSE const uint8_t* srcPtr = (uint8_t*)src; uint8_t* dstPtr = (uint8_t*)dst; @@ -1787,7 +1799,8 @@ void MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, int8_t* dstPtr = dst; int offset = 0; #endif - ssize_t zeroPointValue = zeroPoint + offset; + int intputZeroPointValue = *inputZeroPoint + offset; + int outputZeroPointValue = *outputZeroPoint + offset; int d = mShiftBits - 1; for (int z = 0; z < biasNumber; ++z) { @@ -1803,11 +1816,11 @@ void MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, const auto srcX = srcZ + pack * p; for (int i = 0; i < pack; ++i) { - int32_t val = static_cast(srcX[i] - zeroPointValue) * alphaZ[i] + biasZ[i]; + int32_t val = static_cast(srcX[i] - intputZeroPointValue) * alphaZ[i] + biasZ[i]; - int valOut = (val + (1< maxValue + offset) { diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.h b/source/backend/cpu/compute/Int8FunctionsOpt.h index 83fa5c78b..4b7f6bc0a 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.h +++ b/source/backend/cpu/compute/Int8FunctionsOpt.h @@ -52,13 +52,13 @@ void MNNInt8ScaleToFloat(float* dst, const int8_t* src, const float* scale, size void MNNInt8FunctionInit(); void MNNPackedSparseQuantMatMulEpx1(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap); void MNNPackedSparseQuantMatMulEpx4(int8_t* C, const int8_t* A, const int8_t* B, const size_t* sparseQuantParam, const QuanPostTreatParameters* post, unsigned int* NNZMap, int* dataOffsetMap); -void MNNBinaryAddInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast); -void MNNBinarySubInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast); -void MNNBinaryMulInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast); -void MNNBinarySqdInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast); -void MNNBinaryMaxInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast); -void MNNBinaryMinInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, const float* inputScale0, const float* inputScale1, const float* outputScale, int elementSize, int needBroadcast); -void MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, const int32_t* alpha, int32_t mShiftBits, ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack = 4); +void MNNBinaryAddInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast); +void MNNBinarySubInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast); +void MNNBinaryMulInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast); +void MNNBinarySqdInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast); +void MNNBinaryMaxInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast); +void MNNBinaryMinInt8(int8_t* outputRaw, const int8_t* inputRaw0, const int8_t* inputRaw1, ssize_t* inputScalesInt32, float* inputScalesFp32, const int8_t* inputOffset0, const int8_t* inputOffset1, const int8_t* outputOffset, size_t elementSize, size_t needBroadcast); +void MNNScaleAndAddBiasInt8(int8_t* dst, const int8_t* src, const int32_t* bias, const int32_t* alpha, int32_t mShiftBits, ssize_t minValue, ssize_t maxValue, int8_t* inputZeroPoint, int8_t* outputZeroPoint, ssize_t planeNumber, ssize_t biasNumber, ssize_t pack = 4); #ifdef __cplusplus } #endif diff --git a/source/backend/cpu/compute/ResizeFunction.cpp b/source/backend/cpu/compute/ResizeFunction.cpp index 6efa27c0e..efa89167b 100644 --- a/source/backend/cpu/compute/ResizeFunction.cpp +++ b/source/backend/cpu/compute/ResizeFunction.cpp @@ -39,7 +39,7 @@ static Vec CubicInterpolation2(Vec& A, Vec& B, Vec; #ifdef MNN_USE_SSE - Vec16 zeroPointV(128); + Vec16 zeroPointV(128 + (*zeroPoint)); const uint8_t* srcPtr = (uint8_t*)src; #else - Vec16 zeroPointV(0); + Vec16 zeroPointV(*zeroPoint); const int8_t* srcPtr = src; #endif for (int i = 0; i < number; ++i) { @@ -108,20 +108,20 @@ void MNNCubicSampleC16(const int8_t* src, float* dst, int32_t* position, const f } } -void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, - size_t number) { +void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, int8_t* zeroPoint, + size_t number, ssize_t minValue, ssize_t maxValue) { int pack = 16; using Vec16 = Vec; #ifdef MNN_USE_SSE uint8_t* dstPtr = (uint8_t*)dst; - int offset = 128; - int minValue = 0; - int maxValue = 255; + int offset = 128 + (*zeroPoint); + int minVal = 128 + minValue; + int maxVal = 128 + maxValue; #else int8_t* dstPtr = dst; - int offset = 0; - int minValue = -128; - int maxValue = 127; + int offset = *zeroPoint; + int minVal = (int)minValue; + int maxVal = (int)maxValue; #endif float f = *t; for (int i = 0; i < number; ++i) { @@ -132,24 +132,24 @@ void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C auto val16 = CubicInterpolation2(a, b, c, d, f); for (int j = 0; j < pack; ++j) { int val = (int)roundf(val16[j]) + offset; - if (val > maxValue) { - val = maxValue; + if (val > maxVal) { + val = maxVal; } - if (val < minValue) { - val = minValue; + if (val < minVal) { + val = minVal; } *(dstPtr + pack * i + j) = val; } } } -void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, +void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, int8_t* zeroPoint, size_t number) { #ifdef MNN_USE_SSE - int offset = 128; + int offset = 128 + *zeroPoint; const uint8_t* srcPtr = (uint8_t*)src; #else - int offset = 0; + int offset = *zeroPoint; const int8_t* srcPtr = src; #endif int pack = 8; @@ -167,12 +167,12 @@ void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* positio } } -void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, size_t number) { +void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, int8_t* zeroPoint, size_t number) { #ifdef MNN_USE_SSE - int offset = 128; + int offset = 128 + (*zeroPoint); uint8_t* dstPtr = (uint8_t*)dst; #else - int offset = 0; + int offset = *zeroPoint; int8_t* dstPtr = dst; #endif int pack = 8; diff --git a/source/backend/cpu/compute/ResizeFunction.h b/source/backend/cpu/compute/ResizeFunction.h index a8be4a655..69e3ecd57 100644 --- a/source/backend/cpu/compute/ResizeFunction.h +++ b/source/backend/cpu/compute/ResizeFunction.h @@ -11,20 +11,21 @@ #include #include +#include "core/Macro.h" #ifdef __cplusplus extern "C" { #endif -void MNNCubicSampleC4(const float* src, float* dst, int32_t* position, const float* factor, size_t number); -void MNNCubicLineC4(float* dst, const float* A, const float* B, const float* C, const float* D, float* t, - size_t number); -void CPUBilinearSampleC4(const float* src, float* dst, const int32_t* position, const float* factor, size_t number); -void CPUBilinearLineC4(float* dst, const float* A, const float* B, const float* t, size_t number); -void MNNCubicSampleC16(const int8_t* src, float* dst, int32_t* position, const float* factor, size_t number); -void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, - size_t number); -void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, size_t number); -void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, size_t number); +void MNNCubicSampleC4(const float* src, float* dst, int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); +void MNNCubicLineC4(float* dst, const float* A, const float* B, const float* C, const float* D, float* t, int8_t* zeroPoint, + size_t number, ssize_t minValue, ssize_t maxValue); +void CPUBilinearSampleC4(const float* src, float* dst, const int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); +void CPUBilinearLineC4(float* dst, const float* A, const float* B, const float* t, int8_t* zeroPoint, size_t number); +void MNNCubicSampleC16(const int8_t* src, float* dst, int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); +void MNNCubicLineC16(int8_t* dst, const float* A, const float* B, const float* C, const float* D, float* t, int8_t* zeroPoint, + size_t number, ssize_t minValue, ssize_t maxValue); +void MNNBilinearSampleC8(const int8_t* src, int16_t* dst, const int32_t* position, const float* factor, int8_t* zeroPoint, size_t number); +void MNNBilinearLineC8(int8_t* dst, const int16_t* A, const int16_t* B, const float* t, int8_t* zeroPoint, size_t number); #ifdef __cplusplus } #endif diff --git a/source/backend/cuda/CMakeLists.txt b/source/backend/cuda/CMakeLists.txt index 348897db1..1fbf15766 100644 --- a/source/backend/cuda/CMakeLists.txt +++ b/source/backend/cuda/CMakeLists.txt @@ -79,11 +79,16 @@ else() endif() option(MNN_CUDA_QUANT "Enable MNN CUDA Quant File" OFF) +option(MNN_CUDA_BF16 "Enable MNN CUDA Bfloat16 File" OFF) IF (MNN_CUDA_QUANT) add_definitions(-DENABLE_CUDA_QUANT) ENDIF() +IF (MNN_CUDA_BF16) + add_definitions(-DENABLE_CUDA_BF16) +ENDIF() + file(GLOB_RECURSE MNN_CUDA_SRC ${CMAKE_CURRENT_LIST_DIR}/core/* ${CMAKE_CURRENT_SOURCE_DIR}/execution/*) message(STATUS "message ${CUDA_NVCC_FLAGS} !!!!!!!!!!! ${CUDA_INCLUDE_DIRS}") diff --git a/source/backend/cuda/execution/ConvBaseKernel.cu b/source/backend/cuda/execution/ConvBaseKernel.cu index 3a5f4c9db..7afd9899e 100644 --- a/source/backend/cuda/execution/ConvBaseKernel.cu +++ b/source/backend/cuda/execution/ConvBaseKernel.cu @@ -99,6 +99,7 @@ __global__ void Float22Half2(const float* param, } } +#ifdef ENABLE_CUDA_BF16 __global__ void Float22BFloat16(const float* param, __nv_bfloat16* output, const size_t maxCount @@ -112,7 +113,7 @@ __global__ void Float22BFloat16(const float* param, } #endif } - +#endif void callFloat2Half(const void* input, void* output, const int count, CUDARuntime* runtime) { int thread_count = count / 4; @@ -122,6 +123,7 @@ void callFloat2Half(const void* input, void* output, const int count, CUDARuntim checkKernelErrors; } +#ifdef ENABLE_CUDA_BF16 void callFloat2BFloat16(const void* input, void* output, const int count, CUDARuntime* runtime) { int thread_count = count / 4; int block_num = runtime->blocks_num(thread_count); @@ -129,7 +131,7 @@ void callFloat2BFloat16(const void* input, void* output, const int count, CUDARu Float22BFloat16<<>>((const float*)input, (__nv_bfloat16 *)output, thread_count); checkKernelErrors; } - +#endif void callWeightFill(const void* input, void* output, const int l, const int h, const int lp, const int hp, const int precision, CUDARuntime* runtime) { DivModFast lpD(lp); @@ -147,8 +149,10 @@ void callWeightFill(const void* input, void* output, const int l, const int h, c checkKernelErrors; } else { MNN_ASSERT(precision == 3); + #ifdef ENABLE_CUDA_BF16 WeightPackFill<<>>((const float*)input, (__nv_bfloat16*)output, lp*hp, l, h, lpD); checkKernelErrors; + #endif } } @@ -190,10 +194,12 @@ void callIm2ColPack(const void* input, void* output, const ConvolutionCommon::Im checkKernelErrors; } else { MNN_ASSERT(precision == 3); + #ifdef ENABLE_CUDA_BF16 Im2Col_packC<<>>(sw, sh, dw, dh, pw, ph, icDiv4, iw, ih, maxCount, PACK_NUMBER, e, l, (const __nv_bfloat16*)input, (__nv_bfloat16 *)output, \ lpD, owD, ohD, fxyD, fxD); checkKernelErrors; + #endif } } diff --git a/source/backend/cuda/execution/ConvBaseKernel.cuh b/source/backend/cuda/execution/ConvBaseKernel.cuh index 1fc53fe20..3972b1537 100644 --- a/source/backend/cuda/execution/ConvBaseKernel.cuh +++ b/source/backend/cuda/execution/ConvBaseKernel.cuh @@ -5,19 +5,22 @@ // Created by MNN on 2023/03/21. // Copyright © 2018, Alibaba Group Holding Limited // - #ifndef CONV_BASE_KERNEL_CUH #define CONV_BASE_KERNEL_CUH #include "core/Execution.hpp" #include "backend/cuda/core/CUDABackend.hpp" +#ifdef ENABLE_CUDA_BF16 #include "cuda_bf16.h" +#endif namespace MNN { namespace CUDA { void callFloat2Half(const void* input, void* output, const int count, CUDARuntime* runtime); +#ifdef ENABLE_CUDA_BF16 void callFloat2BFloat16(const void* input, void* output, const int count, CUDARuntime* runtime); +#endif void callWeightFill(const void* input, void* output, const int l, const int h, const int lp, const int hp, const int precision, CUDARuntime* runtime); void callIm2ColPack(const void* input, void* output, const ConvolutionCommon::Im2ColParameter* info, const int e, const int l, const int ep, const int lp, const int precision, CUDARuntime* runtime); @@ -25,8 +28,9 @@ ErrorCode callCutlassGemmCudaCoreFloat16(const std::vector &inputs, con ErrorCode callCutlassGemmCudaCoreFloat32(const std::vector &inputs, const std::vector &outputs); ErrorCode callCutlassGemmTensorCore884(const std::vector &inputs, const std::vector &outputs); ErrorCode callCutlassGemmTensorCore(const std::vector &inputs, const std::vector &outputs); +#ifdef ENABLE_CUDA_BF16 ErrorCode callCutlassGemmBf16TensorCore(const std::vector &inputs, const std::vector &outputs); - +#endif } //namespace CUDA } //namespace MNN #endif \ No newline at end of file diff --git a/source/backend/cuda/execution/ConvDepthWiseExecution.cu b/source/backend/cuda/execution/ConvDepthWiseExecution.cu index 15c25bc74..ec28060c7 100755 --- a/source/backend/cuda/execution/ConvDepthWiseExecution.cu +++ b/source/backend/cuda/execution/ConvDepthWiseExecution.cu @@ -520,6 +520,7 @@ static std::shared_ptr _makeResource(const Op* auto offsetGpuStorage = static_cast(bn)->getStaticBufferPool()->alloc(sizeof(offset)); auto offsetGpu = (uint8_t*)offsetGpuStorage.first + offsetGpuStorage.second; + #ifdef ENABLE_CUDA_BF16 if(static_cast(bn)->getPrecision() == 3) { // [Oc, Kh*Kw] -> [Kh*Kw, Oc(p)] DivModFast d_ocp(depthC * PACK_NUMBER); @@ -529,7 +530,10 @@ static std::shared_ptr _makeResource(const Op* WeightTransToBf16<<>>((const float*)tempWeight, (__nv_bfloat16*)res->mFilter, count,\ kernelY * kernelX, depth, d_ocp); checkKernelErrors; - } else { + } + else + #endif + { reg.size[0] = 1; reg.size[1] = kernelY * kernelX; reg.size[2] = depthC * PACK_NUMBER; @@ -565,13 +569,18 @@ static std::shared_ptr _makeResource(const Op* auto tempBias = (uint8_t*)tempBiasStorage.first + tempBiasStorage.second; cuda_check(cudaMemcpy(tempBias, conv->bias()->data(), conv->bias()->size()*sizeof(float), cudaMemcpyHostToDevice)); - if(static_cast(bn)->getPrecision() == 3) { + #ifdef ENABLE_CUDA_BF16 + if(static_cast(bn)->getPrecision() == 3) + { auto countBias = depthC * PACK_NUMBER; int block_num = runtime->blocks_num(countBias); int threads_num = runtime->threads_num(); BiasTransToBf16<<>>((const float*)tempBias, (__nv_bfloat16*)res->mBias, countBias, depth); checkKernelErrors; - } else { + } + else + #endif + { reg.size[0] = 1; reg.size[1] = 1; reg.size[2] = depthC * PACK_NUMBER; @@ -679,6 +688,7 @@ ErrorCode ConvDepthWiseExecution::onExecute(const std::vector &inputs, const int ph = parameters.pad[1]; const int total = parameters.total; + #ifdef ENABLE_CUDA_BF16 if (static_cast(backend())->getPrecision() == 3) { if(kw==3 && kh==3 && sw==1 && sh==1 && pw==1 && ph==1 && ow % 2 ==0) { DivModFast d_ow2(ow/2); @@ -715,6 +725,7 @@ ErrorCode ConvDepthWiseExecution::onExecute(const std::vector &inputs, return NO_ERROR; } + #endif if (static_cast(backend())->useFp16()) { if(parameters.kernelSize[0]==3 && parameters.kernelSize[1]==3 && parameters.stride[0]==1 && parameters.stride[1]==1 && parameters.pad[0]==1 && parameters.pad[1]==1 && parameters.outputSize[0] % 2 ==0) { diff --git a/source/backend/cuda/execution/ConvSingleInputExecution.cu b/source/backend/cuda/execution/ConvSingleInputExecution.cu index 3172d953c..3fd3d8f40 100644 --- a/source/backend/cuda/execution/ConvSingleInputExecution.cu +++ b/source/backend/cuda/execution/ConvSingleInputExecution.cu @@ -51,10 +51,12 @@ public: return new ConvWinogradExecution(backend, op, resource); } + #ifdef ENABLE_CUDA_BF16 if (static_cast(backend)->getPrecision() == 3) { std::shared_ptr resource(new ConvCutlassBf16Execution::Resource(backend, op)); return new ConvCutlassBf16Execution(backend, op, resource); } + #endif std::shared_ptr resource(new ConvCutlassExecution::Resource(backend, op)); return new ConvCutlassExecution(backend, op, resource); #endif diff --git a/source/backend/cuda/execution/PoolExecution.cu b/source/backend/cuda/execution/PoolExecution.cu index 24a3604b1..878dbdda3 100755 --- a/source/backend/cuda/execution/PoolExecution.cu +++ b/source/backend/cuda/execution/PoolExecution.cu @@ -166,6 +166,7 @@ ErrorCode PoolExecution::onExecute(const std::vector &inputs, const st int threads_num = prop.maxThreadsPerBlock; int block_num = prop.multiProcessorCount; + #ifdef ENABLE_CUDA_BF16 if (static_cast(backend())->getPrecision() == 3) { auto inputPtr = (const __nv_bfloat16*)inputs[0]->deviceId(); auto outputPtr = (__nv_bfloat16*)outputs[0]->deviceId(); @@ -193,6 +194,7 @@ ErrorCode PoolExecution::onExecute(const std::vector &inputs, const st } return NO_ERROR; } + #endif if (static_cast(backend())->useFp16()) { auto inputPtr = (const half*)inputs[0]->deviceId(); diff --git a/source/backend/cuda/execution/Transpose.cu b/source/backend/cuda/execution/Transpose.cu index e0178f0d6..984b9b050 100644 --- a/source/backend/cuda/execution/Transpose.cu +++ b/source/backend/cuda/execution/Transpose.cu @@ -569,9 +569,11 @@ void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN maxCount, batch, area, channel, UP_DIV(channel, 8) * 8); checkKernelErrors; } else if(isBf16) { + #ifdef ENABLE_CUDA_BF16 C4NHW4_2_NHWC8<<>>((float *)input, (__nv_bfloat16 *)output, maxCount, batch, area, channel, UP_DIV(channel, 8) * 8); checkKernelErrors; + #endif } else { C4NHW4_2_NHWC8<<>>((float *)input, (float *)output, maxCount, batch, area, channel, UP_DIV(channel, 8) * 8); @@ -589,9 +591,11 @@ void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN maxCount, batch, channel, area, UP_DIV(channel, 4) * 4); checkKernelErrors; } else if(isBf16) { + #ifdef ENABLE_CUDA_BF16 NHWC8_2_C4NHW4<<>>((__nv_bfloat16 *)input, (float *)output, maxCount, batch, channel, area, UP_DIV(channel, 4) * 4); checkKernelErrors; + #endif } else { NHWC8_2_C4NHW4<<>>((float *)input, (float *)output, maxCount, batch, channel, area, UP_DIV(channel, 4) * 4); @@ -619,7 +623,9 @@ void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN if(isFp16) { insideFormatConvert((float *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); } else if(isBf16) { + #ifdef ENABLE_CUDA_BF16 insideFormatConvert((float *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); + #endif } else { insideFormatConvert((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); } @@ -627,7 +633,9 @@ void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN if(isFp16) { insideFormatConvert((half *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); } else if(isBf16) { + #ifdef ENABLE_CUDA_BF16 insideFormatConvert<__nv_bfloat16, float>((__nv_bfloat16 *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); + #endif } else { insideFormatConvert((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); } @@ -635,7 +643,9 @@ void FormatConvert(void* output, void* input, MNN_DATA_FORMAT srcDataFormat, MNN if(isFp16) { insideFormatConvert((half *)input, (half *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); } else if(isBf16) { + #ifdef ENABLE_CUDA_BF16 insideFormatConvert<__nv_bfloat16, __nv_bfloat16>((__nv_bfloat16 *)input, (__nv_bfloat16 *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); + #endif } else { insideFormatConvert((float *)input, (float *)output, srcDataFormat, dstDataFormat, runtime, area, batch, channel); } diff --git a/source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.cu b/source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.cu index c541e772c..cf5f678aa 100644 --- a/source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.cu +++ b/source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.cu @@ -5,6 +5,7 @@ // Created by MNN on 2023/05/31. // Copyright © 2018, Alibaba Group Holding Limited // +#ifdef ENABLE_CUDA_BF16 #include "ConvCutlassBf16Execution.hpp" #include "../ConvBaseKernel.cuh" @@ -213,4 +214,5 @@ ErrorCode ConvCutlassBf16Execution::onExecute(const std::vector &inputs }// namespace CUDA -}// namespace MNN \ No newline at end of file +}// namespace MNN +#endif \ No newline at end of file diff --git a/source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.hpp b/source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.hpp index c625d5274..be17b8ab2 100644 --- a/source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.hpp +++ b/source/backend/cuda/execution/bf16/ConvCutlassBf16Execution.hpp @@ -5,6 +5,7 @@ // Created by MNN on 2023/05/29. // Copyright © 2018, Alibaba Group Holding Limited // +#ifdef ENABLE_CUDA_BF16 #ifndef ConvCutlassBf16Execution_hpp #define ConvCutlassBf16Execution_hpp @@ -43,4 +44,5 @@ class ConvCutlassBf16Execution : public CutlassConvCommonExecution { } // namespace CUDA } // namespace MNN -#endif /* ConvCutlassBf16Execution */ \ No newline at end of file +#endif /* ConvCutlassBf16Execution */ +#endif \ No newline at end of file diff --git a/source/backend/cuda/execution/bf16/ConvDepthWiseBf16.cuh b/source/backend/cuda/execution/bf16/ConvDepthWiseBf16.cuh index c9a0bb523..3a1e2889f 100644 --- a/source/backend/cuda/execution/bf16/ConvDepthWiseBf16.cuh +++ b/source/backend/cuda/execution/bf16/ConvDepthWiseBf16.cuh @@ -5,6 +5,7 @@ // Created by MNN on 2023/05/30. // Copyright © 2018, Alibaba Group Holding Limited // +#ifdef ENABLE_CUDA_BF16 #ifndef CONV_DEPTHWISE_BF16_CUH_ #define CONV_DEPTHWISE_BF16_CUH_ @@ -402,4 +403,5 @@ __global__ void BiasTransToBf16(const T0* param, } //namespace CUDA } //namespace MNN +#endif #endif \ No newline at end of file diff --git a/source/backend/cuda/execution/bf16/CutlassGemmBf16Param.hpp b/source/backend/cuda/execution/bf16/CutlassGemmBf16Param.hpp index a11e39712..cf1ca9c8f 100644 --- a/source/backend/cuda/execution/bf16/CutlassGemmBf16Param.hpp +++ b/source/backend/cuda/execution/bf16/CutlassGemmBf16Param.hpp @@ -1,3 +1,5 @@ +#ifdef ENABLE_CUDA_BF16 + #ifndef CutlassGemmBF16Param_hpp #define CutlassGemmBF16Param_hpp @@ -84,3 +86,4 @@ using GemmTensor_BF16_BF16_Relu6_AlignTensor_Sm80 = cutlass::gemm::device::Gemm< } } #endif +#endif diff --git a/source/backend/cuda/execution/bf16/PoolBf16.cuh b/source/backend/cuda/execution/bf16/PoolBf16.cuh index 7e4dbe531..ce09f91ba 100644 --- a/source/backend/cuda/execution/bf16/PoolBf16.cuh +++ b/source/backend/cuda/execution/bf16/PoolBf16.cuh @@ -5,6 +5,7 @@ // Created by MNN on 2023/05/30. // Copyright © 2018, Alibaba Group Holding Limited // +#ifdef ENABLE_CUDA_BF16 #ifndef CONV_DEPTHWISE_BF16_CUH_ #define CONV_DEPTHWISE_BF16_CUH_ @@ -120,4 +121,5 @@ __global__ void avgpool_C8_BF16(const T* uInput, T* uOutput, } } //namespace CUDA } //namespace MNN +#endif #endif \ No newline at end of file diff --git a/source/backend/cuda/execution/cutlass/CutlassConvCommonExecution.cu b/source/backend/cuda/execution/cutlass/CutlassConvCommonExecution.cu index 9e8540f17..57509ffd1 100644 --- a/source/backend/cuda/execution/cutlass/CutlassConvCommonExecution.cu +++ b/source/backend/cuda/execution/cutlass/CutlassConvCommonExecution.cu @@ -16,6 +16,7 @@ CutlassConvCommonExecution::CutlassConvCommonExecution(Backend *backend) : Execu } ErrorCode CutlassConvCommonExecution::runCutlassGemmFunc() { + #ifdef ENABLE_CUDA_BF16 if(mBf16Infer) { if(mActivationType == 1) { cutlass::Status status = mGemmBF16BF16ReluSm80(); @@ -29,7 +30,7 @@ ErrorCode CutlassConvCommonExecution::runCutlassGemmFunc() { } return NO_ERROR; } - + #endif if(mFp32Infer) { if(mActivationType == 1) { cutlass::Status status = mGemmCudaF32F32Relu(); diff --git a/source/backend/cuda/execution/cutlass/CutlassConvCommonExecution.hpp b/source/backend/cuda/execution/cutlass/CutlassConvCommonExecution.hpp index 2c85be3b9..ef8552c1a 100644 --- a/source/backend/cuda/execution/cutlass/CutlassConvCommonExecution.hpp +++ b/source/backend/cuda/execution/cutlass/CutlassConvCommonExecution.hpp @@ -79,10 +79,11 @@ class CutlassConvCommonExecution : public Execution { GemmCuda_F32_F32_Relu6_AlignCuda mGemmCudaF32F32Relu6; GemmCuda_F32_F32_Linear_AlignCuda mGemmCudaF32F32Ln; + #ifdef ENABLE_CUDA_BF16 GemmTensor_BF16_BF16_Linear_AlignTensor_Sm80 mGemmBF16BF16LnSm80; GemmTensor_BF16_BF16_Relu_AlignTensor_Sm80 mGemmBF16BF16ReluSm80; GemmTensor_BF16_BF16_Relu6_AlignTensor_Sm80 mGemmBF16BF16Relu6Sm80; - + #endif int mGpuComputeCap = 75; int mActivationType = 0; bool mFp16Infer = false; diff --git a/source/backend/cuda/execution/cutlass/CutlassGemmBf16TensorCore.cu b/source/backend/cuda/execution/cutlass/CutlassGemmBf16TensorCore.cu index abe1a95c1..c3b4b7291 100644 --- a/source/backend/cuda/execution/cutlass/CutlassGemmBf16TensorCore.cu +++ b/source/backend/cuda/execution/cutlass/CutlassGemmBf16TensorCore.cu @@ -5,7 +5,7 @@ // Created by MNN on 2023/05/29. // Copyright © 2018, Alibaba Group Holding Limited // - +#ifdef ENABLE_CUDA_BF16 #include "CutlassConvCommonExecution.hpp" namespace MNN { @@ -101,3 +101,4 @@ ErrorCode CutlassConvCommonExecution::callCutlassGemmBf16TensorCore(const std::v } } +#endif diff --git a/source/backend/opencl/core/runtime/OpenCLWrapper.hpp b/source/backend/opencl/core/runtime/OpenCLWrapper.hpp index d58a1f6c9..c5deb1ec5 100644 --- a/source/backend/opencl/core/runtime/OpenCLWrapper.hpp +++ b/source/backend/opencl/core/runtime/OpenCLWrapper.hpp @@ -149,23 +149,23 @@ class OpenCLSymbols { using clGetImageInfoFunc = cl_int (CL_API_CALL *)(cl_mem, cl_image_info, size_t, void *, size_t *); // opencl 2.0 get sub group info and wave size. - using clCreateCommandQueueWithPropertiesFunc = cl_command_queue (*)(cl_context, cl_device_id, + using clCreateCommandQueueWithPropertiesFunc = cl_command_queue (CL_API_CALL *)(cl_context, cl_device_id, const cl_queue_properties *, cl_int *); - using clSVMAllocFunc = void *(*)(cl_context, cl_mem_flags, size_t size, cl_uint); - using clSVMFreeFunc = void (*)(cl_context, void *); - using clEnqueueSVMMapFunc = cl_int (*)(cl_command_queue, cl_bool, cl_map_flags, + using clSVMAllocFunc = void *(CL_API_CALL *)(cl_context, cl_mem_flags, size_t size, cl_uint); + using clSVMFreeFunc = void (CL_API_CALL *)(cl_context, void *); + using clEnqueueSVMMapFunc = cl_int (CL_API_CALL *)(cl_command_queue, cl_bool, cl_map_flags, void *, size_t, cl_uint, const cl_event *, cl_event *); - using clEnqueueSVMUnmapFunc = cl_int (*)(cl_command_queue, void *, cl_uint, + using clEnqueueSVMUnmapFunc = cl_int (CL_API_CALL *)(cl_command_queue, void *, cl_uint, const cl_event *, cl_event *); - using clSetKernelArgSVMPointerFunc = cl_int (*)(cl_kernel, cl_uint, const void *); + using clSetKernelArgSVMPointerFunc = cl_int (CL_API_CALL *)(cl_kernel, cl_uint, const void *); - using clNewRecordingQCOMFunc = cl_recording_qcom(*)(cl_command_queue, cl_int *); - using clEndRecordingQCOMFunc = cl_int (*)(cl_recording_qcom); - using clReleaseRecordingQCOMFunc = cl_int (*)(cl_recording_qcom); - using clRetainRecordingQCOMFunc = cl_int (*)(cl_recording_qcom); - using clEnqueueRecordingQCOMFunc = cl_int (*)(cl_command_queue, cl_recording_qcom, size_t, const cl_array_arg_qcom*, size_t, const cl_offset_qcom*, + using clNewRecordingQCOMFunc = cl_recording_qcom(CL_API_CALL *)(cl_command_queue, cl_int *); + using clEndRecordingQCOMFunc = cl_int (CL_API_CALL *)(cl_recording_qcom); + using clReleaseRecordingQCOMFunc = cl_int (CL_API_CALL *)(cl_recording_qcom); + using clRetainRecordingQCOMFunc = cl_int (CL_API_CALL *)(cl_recording_qcom); + using clEnqueueRecordingQCOMFunc = cl_int (CL_API_CALL *)(cl_command_queue, cl_recording_qcom, size_t, const cl_array_arg_qcom*, size_t, const cl_offset_qcom*, size_t, const cl_workgroup_qcom*, size_t, const cl_workgroup_qcom*, cl_uint, const cl_event*, cl_event*); - using clEnqueueRecordingSVMQCOMFunc = cl_int (*)(cl_command_queue, cl_recording_qcom, size_t, const cl_array_arg_qcom*, size_t, const cl_array_arg_qcom*, + using clEnqueueRecordingSVMQCOMFunc = cl_int (CL_API_CALL *)(cl_command_queue, cl_recording_qcom, size_t, const cl_array_arg_qcom*, size_t, const cl_array_arg_qcom*, size_t, const cl_offset_qcom*, size_t, const cl_workgroup_qcom*, size_t, const cl_workgroup_qcom*, size_t, const cl_array_kernel_exec_info_qcom*, cl_uint, const cl_event*, cl_event*); diff --git a/source/core/Pipeline.cpp b/source/core/Pipeline.cpp index de9aaec5e..1b7411e48 100644 --- a/source/core/Pipeline.cpp +++ b/source/core/Pipeline.cpp @@ -70,7 +70,12 @@ static bool _supportQuant(const Op* op, const std::vector& inputs, cons return false; } case OpType_BinaryOp: - return true; + if (op->main_as_BinaryOp()->activationType() == 1) { + return false; + } else { + return true; + } + case OpType_Softmax: return true; case OpType_Scale: diff --git a/test/op/BinaryOPTest.cpp b/test/op/BinaryOPTest.cpp index 611412638..cc5ae81f4 100644 --- a/test/op/BinaryOPTest.cpp +++ b/test/op/BinaryOPTest.cpp @@ -101,11 +101,15 @@ class AddInt8Test : public BinaryTestCommon { public: virtual ~AddInt8Test() = default; virtual bool run(int precision) { - vector inp2 = {1.1, 2.2, 3.3, 4.6}, inp1 = {2}; - vector rightResult = {3.1, 4.2, 5.3, 6.6}; - - return test(MNN::Express::_Add, "AddInt8Test", 0.01, inp1, inp2, rightResult, {1}, {4}, {4}, {0.4, 0.4, 0.4}, + vector inp2 = {1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6}, inp1 = {2}; + vector rightResult = {3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6,3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6, 3.1, 4.2, 5.3, 6.6}; + printf("AddInt8 test zeropoint is zero\n"); + bool res = test(MNN::Express::_Add, "AddInt8Test", 0.01, inp1, inp2, rightResult, {1}, {32}, {32}, {0.4, 0.4, 1.0}, {0., 0., 0.}); + printf("AddInt8 test zeropoint is not zero\n"); + res = test(MNN::Express::_Add, "AddInt8Test", 0.01, inp1, inp2, rightResult, {1}, {32}, {32}, {0.4, 0.4, 1.0}, + {1., 2., 3.}); + return res; } }; @@ -129,9 +133,13 @@ class SubtractInt8Test : public BinaryTestCommon { vector inp1 = {7.0, 28.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6,1.1, 2.2, 3.3, 4.6,1.1, 2.2, 3.3, 4.6}, inp2 = {5.7}; vector rightResult = {1.3, 22.5, -2.4, -1.1, -4.6, -3.5, -2.4, -1.1, -4.6, -3.5, -2.4, -1.1, -4.6, -3.5, -2.4, -1.1}; - - return test(MNN::Express::_Subtract, "SubtractInt8Test", 0.01, inp1, inp2, rightResult, - {4, 4}, {1}, {4, 4}, {0.4, 0.4, 0.4}, {0., 0., 0.}); + printf("SubtractInt8 test zeropoint is zero\n"); + bool res = test(MNN::Express::_Subtract, "SubtractInt8Test", 0.01, inp1, inp2, rightResult, + {4, 4}, {1}, {4, 4}, {0.4, 0.4, 1.0}, {0., 0., 0.}); + printf("SubtractInt8 test zeropoint is not zero\n"); + res = test(MNN::Express::_Subtract, "SubtractInt8Test", 0.01, inp1, inp2, rightResult, + {4, 4}, {1}, {4, 4}, {0.4, 0.4, 1.0}, {1., 2., 3.}); + return res; } }; @@ -148,10 +156,15 @@ class MultiplyInt8Test : public BinaryTestCommon { public: virtual ~MultiplyInt8Test() = default; virtual bool run(int precision) { - vector inp1 = {1.1, 2.2, 3.3, 4.6}, inp2 = {5.7, 2.5, 0.25, 0.43}; - vector rightResult = {6.27 , 5.5 , 0.825, 1.978}; - return test(MNN::Express::_Multiply, "MultiplyInt8Test", 0.01, inp1, inp2, rightResult, - {4}, {4}, {4}, {0.4, 0.4, 0.16}, {0., 0., 0.}); + vector inp1 = {1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6, 1.1, 2.2, 3.3, 4.6,}, inp2 = {5.7, 2.5, 0.25, 0.43, 5.7, 2.5, 0.25, 0.43, 5.7, 2.5, 0.25, 0.43, 5.7, 2.5, 0.25, 0.43}; + vector rightResult = {6.27 , 5.5 , 0.825, 1.978, 6.27 , 5.5 , 0.825, 1.978, 6.27 , 5.5 , 0.825, 1.978, 6.27 , 5.5 , 0.825, 1.978}; + printf("MultiplyInt8 test zeropoint is zero\n"); + bool res = test(MNN::Express::_Multiply, "MultiplyInt8Test", 0.01, inp1, inp2, rightResult, + {16}, {16}, {16}, {0.4, 0.4, 0.16}, {0., 0., 0.}); + printf("MultiplyInt8 test zeropoint is not zero\n"); + res = test(MNN::Express::_Multiply, "MultiplyInt8Test", 0.01, inp1, inp2, rightResult, + {16}, {16}, {16}, {0.4, 0.4, 0.16}, {1., 2., 3.}); + return res; } }; @@ -170,8 +183,13 @@ class DivideInt8Test : public BinaryTestCommon { virtual bool run(int precision) { vector inp1 = {1.1, 2.2, 3.3, 4.6}, inp2 = {5.7, 2.5, 2.6, 1.88}; vector rightResult = {0.19298, 0.88, 1.269, 2.4468}; - return test(MNN::Express::_Divide, "DivideInt8Test", 0.01, inp1, inp2, rightResult, + printf("DivedeInt8 test zero point is zero\n"); + bool res = test(MNN::Express::_Divide, "DivideInt8Test", 0.01, inp1, inp2, rightResult, {4}, {4}, {4}, {0.4, 0.4, 1.0}, {0., 0., 0.}); + printf("DivedeInt8 test zero point is not zero\n"); + res = test(MNN::Express::_Divide, "DivideInt8Test", 0.01, inp1, inp2, rightResult, + {4}, {4}, {4}, {0.4, 0.4, 1.0}, {0., 2., 0.}); + return res; } }; @@ -215,8 +233,11 @@ class MinimumInt8Test : public BinaryTestCommon { virtual bool run(int precision) { vector inp1 = {-1.2, -5.0, 8, 10}, inp2 = {9.3, 3.1, 11.0, 2.9}; vector rightResult = {-1.2, -5.0, 8, 2.9}; - return test(MNN::Express::_Minimum, "MinimumInt8Test", 0.01, inp1, inp2, rightResult, - {4}, {4}, {4}, {0.4, 0.4, 0.4}, {0., 0., 0.}); + bool res = test(MNN::Express::_Minimum, "MinimumInt8Test", 0.01, inp1, inp2, rightResult, + {4}, {4}, {4}, {0.4, 0.4, 1.0}, {0., 0., 0.}); + res = test(MNN::Express::_Minimum, "MinimumInt8Test", 0.01, inp1, inp2, rightResult, + {4}, {4}, {4}, {0.4, 0.4, 1.0}, {1., 2., 3.}); + return res; } }; @@ -233,10 +254,15 @@ class MaximumInt8Test : public BinaryTestCommon { public: virtual ~MaximumInt8Test() = default; virtual bool run(int precision) { - vector inp1 = {-1, -5, 8, 10}, inp2 = {9}; - vector rightResult = {9, 9, 9, 10}; - return test(MNN::Express::_Maximum, "MaximumInt8Test", 0.01, inp1, inp2, rightResult, - {4}, {1}, {4}, {0.4, 0.4, 0.4}, {0., 0., 0.}); + vector inp1 = {-1, -5, 8, 10, -1, -5, 8, 10,-1, -5, 8, 10,-1, -5, 8, 10,-1, -5, 8, 10,-1, -5, 8, 10,-1, -5, 8, 10,-1, -5, 8, 10}, inp2 = {9}; + vector rightResult = {9, 9, 9, 10,9, 9, 9, 10,9, 9, 9, 10,9, 9, 9, 10,9, 9, 9, 10,9, 9, 9, 10,9, 9, 9, 10,9, 9, 9, 10,}; + printf("MaximumInt8 test zeropoint is zero\n"); + bool res = test(MNN::Express::_Maximum, "MaximumInt8Test", 0.01, inp1, inp2, rightResult, + {32}, {1}, {32}, {0.4, 0.4, 1.0}, {0., 0., 0.}); + printf("MaximumInt8 test zeropoint is not zero\n"); + res = test(MNN::Express::_Maximum, "MaximumInt8Test", 0.01, inp1, inp2, rightResult, + {32}, {1}, {32}, {0.4, 0.4, 1.0}, {1., 2., 5.}); + return res; } }; @@ -360,8 +386,13 @@ class SquaredDifferenceInt8Test : public BinaryTestCommon { virtual bool run(int precision) { vector inp1 = {-1, -2, -3, -4, 5, 6, 7, 8, -1, -2, -3, -4, 5, 6, 7, 8, -1, -2, -3, -4, 5, 6, 7, 8, -1, -2, -3, -4, 5, 6, 7, 8}, inp2 = {3}; vector rightResult = {16, 25, 36, 49, 4, 9, 16, 25, 16, 25, 36, 49, 4, 9, 16, 25, 16, 25, 36, 49, 4, 9, 16, 25, 16, 25, 36, 49, 4, 9, 16, 25}; - return test(MNN::Express::_SquaredDifference, "SquaredDifferenceInt8Test", 0.01, inp1, inp2, rightResult, + printf("SqdInt8 test zeropoint is zero\n"); + bool res = test(MNN::Express::_SquaredDifference, "SquaredDifferenceInt8Test", 0.01, inp1, inp2, rightResult, {8, 4}, {1}, {8, 4}, {1, 1, 1}, {0., 0., 0.}); + printf("SqdInt8 test zeropoint is not zero\n"); + res = test(MNN::Express::_SquaredDifference, "SquaredDifferenceInt8Test", 0.01, inp1, inp2, rightResult, + {8, 4}, {1}, {8, 4}, {1, 1, 1}, {2., 0., 1.}); + return res; } }; diff --git a/test/op/ResizeTest.cpp b/test/op/ResizeTest.cpp index 72e7c54c5..cc0380f4b 100644 --- a/test/op/ResizeTest.cpp +++ b/test/op/ResizeTest.cpp @@ -131,7 +131,7 @@ class InterpInt8Test : public MNNTestCase { virtual bool run(int precision) { auto input = _Input({1, 2, 2, 1}, NHWC); input->setName("input_tensor"); - input->writeScaleMap(0.05, 0.f); + input->writeScaleMap(0.032, 1.f); // set input data const float inpudata[] = {-1.0, -2.0, 3.0, 4.0}; auto inputPtr = input->writeMap(); @@ -148,9 +148,10 @@ class InterpInt8Test : public MNNTestCase { //Interp Type:1 { + printf("InterpInt8 test: Type=1\n"); auto output = _Interp({input, scaleVar}, wScale, hScale, outW, outH, 1, false); output = _Convert(output, NHWC); - output->writeScaleMap(0.032f, 0.f); + output->writeScaleMap(0.031372549, 0.f); const std::vector expectedOutput = {-1.0, -1.0, -2.0, -2.0, -1.0, -1.0, -2.0, -2.0, 3.0, 3.0, 4.0, 4.0, 3.0, 3.0, 4.0, 4.0}; auto gotOutput = output->readMap(); @@ -170,9 +171,10 @@ class InterpInt8Test : public MNNTestCase { //Interp Type:2 { + printf("InterpInt8 test: Type=2\n"); auto output = _Interp({input, scaleVar}, wScale, hScale, outW, outH, 2, false); output = _Convert(output, NHWC); - output->writeScaleMap(0.032, 0.); + output->writeScaleMap(0.031372549, 2.); const std::vector expectedOutput = { -1.0000, -1.2500, -1.7500, -2.0000, 0.0000, -0.1250, -0.3750, -0.5000, 2.0000, 2.1250, 2.3750, 2.5000, 3.0000, 3.2500, 3.7500, 4.0000}; auto gotOutput = output->readMap(); @@ -191,9 +193,19 @@ class InterpInt8Test : public MNNTestCase { // Interp Type:3 { - auto output = _Interp({input, scaleVar}, wScale, hScale, outW, outH, 3, false); + printf("InterpInt8 test: Type=3\n"); + auto input0 = _Input({1, 2, 2, 1}, NHWC); + input0->setName("input_tensor"); + input0->writeScaleMap(0.05, 1.f); // Fix quant scale. + // set input data + const float inpudata[] = {-1.0, -2.0, 3.0, 4.0}; + auto inputPtr = input0->writeMap(); + memcpy(inputPtr, inpudata, 4 * sizeof(float)); + input0->unMap(); + input0 = _Convert(input0, NC4HW4); + auto output = _Interp({input0, scaleVar}, wScale, hScale, outW, outH, 3, false); output = _Convert(output, NHWC); - output->writeScaleMap(0.03967, 0.); + output->writeScaleMap(0.03967, 1.); const std::vector expectedOutput = { 2.516724, 2.217651, 2.516724, 2.698303, -0.516724, -0.217651, -0.516724, -0.698303, 2.516724, 2.217651, 2.516724, 2.698303, 4.358459, 3.696228, 4.358459, 4.760529}; auto gotOutput = output->readMap(); if (!checkVector(gotOutput, expectedOutput.data(), 16, 0.02)) { diff --git a/test/op/ScaleTest.cpp b/test/op/ScaleTest.cpp index c3d3d7293..94115e0cb 100644 --- a/test/op/ScaleTest.cpp +++ b/test/op/ScaleTest.cpp @@ -39,7 +39,7 @@ class ScaleInt8Test : public MNNTestCase { virtual ~ScaleInt8Test() = default; virtual bool run(int precision) { auto input = _Input({1, 2, 2, 1}, NCHW); - input->writeScaleMap(0.0313725, 0.f); + input->writeScaleMap(0.0313725, 1.f); // set input data const float inpudata[] = {-1.0, -2.0, 3.0, 4.0}; auto inputPtr = input->writeMap(); @@ -47,10 +47,10 @@ class ScaleInt8Test : public MNNTestCase { input = _Convert(input, NC4HW4); auto output = _Scale(input, 2, {2.0, 1.0}, {3.0, 4.0}); output = _Convert(output, NCHW); - output->writeScaleMap(0.063, 0.f); + output->writeScaleMap(0.0628, 1.f); const std::vector expectedOutput = {1, -1, 7, 8}; auto gotOutput = output->readMap(); - if (!checkVector(gotOutput, expectedOutput.data(), 4, 1e-2)) { + if (!checkVector(gotOutput, expectedOutput.data(), 4, 1e-1)) { MNN_ERROR("ScaleTestInt8 test failed!\n"); return false; } diff --git a/test/op/SoftmaxTest.cpp b/test/op/SoftmaxTest.cpp index bcba2276c..a7cfc568c 100644 --- a/test/op/SoftmaxTest.cpp +++ b/test/op/SoftmaxTest.cpp @@ -262,7 +262,7 @@ class SoftmaxInt8Test: public MNNTestCase { 9.5616,5.1855,3.1104,3.7267,5.2157,9.8103,9.8155,6.3442,8.2376,3.6553,5.3901}; const float quantScales[] = {0.102, 0.00784}; - const float zeroPoints[] = {0., 0.}; + const float zeroPoints[] = {1., 2.}; input->writeScaleMap(quantScales[0], zeroPoints[0]); auto inputPtr = input->writeMap(); memcpy(inputPtr, inputData, 48 * sizeof(float)); @@ -289,7 +289,7 @@ class SoftmaxInt8Test: public MNNTestCase { // set input data const float inpudata[] = {1.0, 2.0, 3.0, 4.0, 5.0, -1.0, -2.0, -3.0, -4.0, -5.0}; const float quantScales[] = {1.0, 0.00784}; - const float zeroPoints[] = {0., 0.}; + const float zeroPoints[] = {1., 2.}; input->writeScaleMap(quantScales[0], zeroPoints[0]); auto inputPtr = input->writeMap(); memcpy(inputPtr, inpudata, 10 * sizeof(float)); @@ -312,7 +312,7 @@ class SoftmaxInt8Test: public MNNTestCase { // set input data const float inpudata[] = {-1.0, -2.0, 3.0, 4.0}; const float quantScales[] = {1.0, 0.00784}; - const float zeroPoints[] = {0., 0.}; + const float zeroPoints[] = {1., 2.}; input->writeScaleMap(quantScales[0], zeroPoints[0]); auto inputPtr = input->writeMap(); memcpy(inputPtr, inpudata, 4 * sizeof(float));