From 67eceb8abb481bd6f14347049b0905163bee17c2 Mon Sep 17 00:00:00 2001 From: "zhaode.wzd" Date: Tue, 27 Jun 2023 10:33:16 +0800 Subject: [PATCH] [MNN:Sync] Sync Internal code, support low_memory for conv. --- benchmark/benchmark.cpp | 8 +- docs/compile/cmake.md | 1 + source/backend/arm82/Arm82Backend.cpp | 2 +- source/backend/arm82/Arm82Backend.hpp | 2 +- source/backend/arm82/Arm82Functions.cpp | 15 +- source/backend/arm82/CMakeLists.txt | 6 + .../low_memory/MNNPackedMatMulFP16_int4.S | 519 ++++++ .../low_memory/MNNPackedMatMulFP16_int8.S | 480 +++++ .../MNNPackedMatMulRemainFP16_int4.S | 843 +++++++++ .../MNNPackedMatMulRemainFP16_int8.S | 750 ++++++++ source/backend/cpu/CMakeLists.txt | 6 + source/backend/cpu/CPUBackend.cpp | 2 +- source/backend/cpu/CPUSoftMaxInt8.cpp | 4 + source/backend/cpu/arm/CMakeLists.txt | 4 + source/backend/cpu/arm/FunctionSummary.hpp | 8 +- .../low_memory/MNNPackedMatMulRemain_int4.S | 1161 ++++++++++++ .../low_memory/MNNPackedMatMulRemain_int8.S | 1003 ++++++++++ .../arm64/low_memory/MNNPackedMatMul_int4.S | 630 +++++++ .../arm64/low_memory/MNNPackedMatMul_int8.S | 595 ++++++ source/backend/cpu/bf16/BF16Functions.cpp | 6 +- .../backend/cpu/compute/CommonOptFunction.cpp | 165 +- .../backend/cpu/compute/CommonOptFunction.h | 20 +- .../cpu/compute/ConvolutionFloatFactory.cpp | 10 +- .../compute/ConvolutionPackFreeWinograd.cpp | 4 +- .../cpu/compute/ConvolutionPackWinograd.cpp | 6 +- .../cpu/compute/ConvolutionTiledExecutor.hpp | 1 + .../cpu/compute/DeconvolutionWithStride.cpp | 4 +- .../compute/DenseConvolutionTiledExecutor.cpp | 209 +-- .../compute/DenseConvolutionTiledExecutor.hpp | 4 +- .../cpu/compute/StrassenMatmulComputor.cpp | 4 +- source/backend/cpu/x86_x64/AVX2Functions.cpp | 6 + source/backend/cpu/x86_x64/CMakeLists.txt | 6 + .../cpu/x86_x64/FunctionDispatcher.cpp | 6 + .../cpu/x86_x64/avx/FunctionSummary.hpp | 14 +- source/backend/cpu/x86_x64/avx/GemmAVX2.cpp | 27 +- .../backend/cpu/x86_x64/avx/GemmFunction.hpp | 1635 +++++++++++++++++ .../backend/cpu/x86_x64/avx/MathFunctions.cpp | 2 +- .../cpu/x86_x64/avx512/FunctionSummary.hpp | 4 +- .../backend/cpu/x86_x64/avx512/GemmCommon.cpp | 4 +- .../cpu/x86_x64/avxfma/FunctionSummary.hpp | 8 +- .../cpu/x86_x64/avxfma/GemmAVX2FMA.cpp | 4 +- .../cpu/x86_x64/avxfma/GemmAVX2FMABF16.cpp | 4 +- .../cpu/x86_x64/sse/FunctionSummary.hpp | 14 +- .../backend/cpu/x86_x64/sse/GemmFunction.hpp | 426 +++++ source/backend/cpu/x86_x64/sse/GemmSSE.cpp | 37 +- .../backend/cuda/execution/ArgMaxExecution.cu | 5 +- .../execution/buffer/BinaryBufExecution.cpp | 2 +- .../execution/image/EltwiseExecution.cpp | 2 +- source/core/ConvolutionCommon.cpp | 60 +- test/core/IDSTTest.cpp | 2 +- .../source/common/FullQuantAndCoding.cpp | 2 +- .../source/common/WeightQuantAndCoding.cpp | 4 +- tools/cpp/ConvertToFullQuant.hpp | 2 +- tools/cpp/IDSTEncoder.hpp | 6 +- tools/cpp/revertMNNModel.cpp | 11 +- tools/cpp/revertMNNModel.hpp | 2 +- tools/quantization/calibration.cpp | 25 +- 57 files changed, 8566 insertions(+), 226 deletions(-) create mode 100644 source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int4.S create mode 100644 source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int8.S create mode 100644 source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int4.S create mode 100644 source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int8.S create mode 100644 source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int4.S create mode 100644 source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int8.S create mode 100644 source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int4.S create mode 100644 source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int8.S diff --git a/benchmark/benchmark.cpp b/benchmark/benchmark.cpp index ed6465d97..3f5061c0b 100644 --- a/benchmark/benchmark.cpp +++ b/benchmark/benchmark.cpp @@ -119,10 +119,12 @@ std::vector doBench(Model& model, int loop, int warmup = 10, int forward int numberThread = 4, int precision = 2, float sparsity = 0.0f, int sparseBlockOC = 1, bool testQuantModel=false) { auto revertor = std::unique_ptr(new Revert(model.model_file.c_str())); if (testQuantModel) { - float scale = 0.003, offset = 0.f; - revertor->writeExtraDescribeTensor(&scale, &offset); + printf("Auto set sparsity=0 when test quantized model in benchmark...\n"); + revertor->initialize(0, sparseBlockOC, false, true); + } else { + revertor->initialize(sparsity, sparseBlockOC); } - revertor->initialize(sparsity, sparseBlockOC); + auto modelBuffer = revertor->getBuffer(); const auto bufferSize = revertor->getBufferSize(); auto net = std::shared_ptr(MNN::Interpreter::createFromBuffer(modelBuffer, bufferSize)); diff --git a/docs/compile/cmake.md b/docs/compile/cmake.md index 0f68963a9..b3272e13a 100644 --- a/docs/compile/cmake.md +++ b/docs/compile/cmake.md @@ -80,3 +80,4 @@ MNN使用CMake构建项目,CMake中的宏定义列表如下: | MNN_OPENCV_TEST | 构建MNN的OpenCV功能是否开启单元测试,默认为`OFF` | | MNN_OPENCV_BENCH | 构建MNN的OpenCV功能是否开启性能benchmark,默认为`OFF` | | MNN_VULKAN_IMAGE | 构建MNN的Vulkan后端时采用Image内存模式,以便支持FP16和部分移动端上GPU的加速,默认为`ON` | +| MNN_LOW_MEMORY | 是否支持低内存模式,支持低内存模式使用权值量化模型并设置`low_memory`则会使用计算时反量化,默认为`OFF` | diff --git a/source/backend/arm82/Arm82Backend.cpp b/source/backend/arm82/Arm82Backend.cpp index 0d7ea4890..a6dc8ff93 100644 --- a/source/backend/arm82/Arm82Backend.cpp +++ b/source/backend/arm82/Arm82Backend.cpp @@ -40,7 +40,7 @@ bool Arm82Backend::addArm82Creator(OpType t, Arm82Creator* ct) { return true; } -Arm82Backend::Arm82Backend(const CPURuntime* runtime) : CPUBackend(runtime, BackendConfig::Precision_Low, BackendConfig::Memory_Normal, MNN_FORWARD_CPU_EXTENSION) { +Arm82Backend::Arm82Backend(const CPURuntime* runtime, BackendConfig::MemoryMode memory) : CPUBackend(runtime, BackendConfig::Precision_Low, memory, MNN_FORWARD_CPU_EXTENSION) { mCoreFunctions = Arm82Functions::get(); } diff --git a/source/backend/arm82/Arm82Backend.hpp b/source/backend/arm82/Arm82Backend.hpp index 9a8d8ddcf..11511615c 100644 --- a/source/backend/arm82/Arm82Backend.hpp +++ b/source/backend/arm82/Arm82Backend.hpp @@ -29,7 +29,7 @@ namespace MNN { class Arm82Backend : public CPUBackend { public: virtual ~Arm82Backend(); - Arm82Backend(const CPURuntime* runtime); + Arm82Backend(const CPURuntime* runtime, BackendConfig::MemoryMode memory); virtual Execution* onCreate(const std::vector& inputs, const std::vector& outputs, const MNN::Op* op) override; virtual Backend::MemObj* onAcquire(const Tensor* nativeTensor, StorageType storageType) override; diff --git a/source/backend/arm82/Arm82Functions.cpp b/source/backend/arm82/Arm82Functions.cpp index be35d5539..ee1aab2f2 100644 --- a/source/backend/arm82/Arm82Functions.cpp +++ b/source/backend/arm82/Arm82Functions.cpp @@ -26,7 +26,14 @@ void MNNPackedMatMulFP16(float* C, const float* A, const float* B, const size_t* // C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, e), hP = 24, e >= 1 // parameter: [aStride, l, h, cStride, bExtraStride] -void MNNPackedMatMulRemainFP16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); +void MNNPackedMatMulRemainFP16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); + +#ifdef MNN_LOW_MEMORY +void MNNPackedMatMulFP16_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMulRemainFP16_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMulFP16_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMulRemainFP16_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#endif void MNNConvDwF23MulTransUnitFP16(FLOAT16 **cacheLine, const FLOAT16 *weight, FLOAT16 *dest, size_t ow); @@ -700,6 +707,12 @@ bool Arm82Functions::init() { // MatMul FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16); +#ifdef MNN_LOW_MEMORY + FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int4, MNNPackedMatMulFP16_int4); + FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int4, MNNPackedMatMulRemainFP16_int4); + FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8); + FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int8, MNNPackedMatMulRemainFP16_int8); +#endif FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A); FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode); FUNC_PTR_ASSIGN(gInstance->MNNPackForMatMul_B, Arm82MNNPackForMatMul_B); diff --git a/source/backend/arm82/CMakeLists.txt b/source/backend/arm82/CMakeLists.txt index 665731dde..cc9fc0ab7 100644 --- a/source/backend/arm82/CMakeLists.txt +++ b/source/backend/arm82/CMakeLists.txt @@ -7,7 +7,13 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?") target_compile_options(MNN_Arm82 PRIVATE -march=armv8.2-a+fp16 -mfpu=neon-fp-armv8 -mfloat-abi=softfp -DENABLE_ARMV82) elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR ARCHS STREQUAL "arm64") file(GLOB MNN_ARM82_SRCS_ASM "${CMAKE_CURRENT_LIST_DIR}/asm/arm64/*") + if (MNN_LOW_MEMORY) + file(GLOB MNN_ARM82_SRCS_ASM ${MNN_ARM82_SRCS_ASM} ${CMAKE_CURRENT_LIST_DIR}/asm/arm64/low_memory/*) + endif() add_library(MNN_Arm82 OBJECT ${MNN_ARM82_SRCS} ${MNN_ARM82_SRCS_ASM}) + if (MNN_LOW_MEMORY) + target_compile_options(MNN_Arm82 PRIVATE -DMNN_LOW_MEMORY) + endif() target_compile_options(MNN_Arm82 PRIVATE -march=armv8.2-a+fp16 -DENABLE_ARMV82) else() # Building fat binary requires multiple separate builds and lipo-by-hand under CMake's design diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int4.S b/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int4.S new file mode 100644 index 000000000..fa6532eea --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int4.S @@ -0,0 +1,519 @@ +// +// MNNPackedMatMulFP16_int4.S +// MNN +// +// Created by MNN on 2023/05/29. +// Copyright © 2018, Alibaba Group Holding Limited +// +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 8 * 24 MatMul +asm_function MNNPackedMatMulFP16_int4 +//void MNNPackedMatMulFP16(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, const size_t* parameter, const FLOAT16* postParameters, const FLOAT16* bias, const FLOAT16* k, const FLOAT16* b); +// x0: C, x1:A, x2:B, x3:parameter, x4: postParameters, x5:bias, x6: quant_alpha, x7: quant_bias +stp d14, d15, [sp, #-80]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] +stp x19, x20, [sp, #64] + +ldr x9, [x3, #8] // l +ldr x10, [x3, #16] // h + +ldr x13, [x3, #24] // cStride +ldr x11, [x3, #40] // bExtraStride + +// v0, v1, v2: A +// v3, v4: B +// v8 - v31: C +add x10, x10, #7 +lsr x10, x10, #3 + +Start: + +cmp x10, #2 +blt LH4 + +LH8: +sub x14, x13, #128 +mov x19, x6 +mov x20, x7 +LoopH: + mov x15, x1 + ld1 {v4.8h, v5.8h}, [x19], #32 // alpha + ld1 {v6.8h, v7.8h}, [x20], #32 // bias + subs x12, x9, #2 + // ld1 {v3.8h, v4.8h}, [x2], #32 + ld1 {v0.8h}, [x2], #16 + ushr v1.16b, v0.16b, #4 + mov w17, #0x0f + dup v3.16b, w17 + and v2.16b, v0.16b, v3.16b + mov w17, #7 + dup v0.16b, w17 + sub v1.16b, v1.16b, v0.16b + sub v2.16b, v2.16b, v0.16b + zip1 v0.16b, v1.16b, v2.16b + zip2 v3.16b, v1.16b, v2.16b + sxtl v1.8h, v0.8b + sxtl2 v2.8h, v0.16b + scvtf v0.8h, v1.8h + scvtf v1.8h, v2.8h + mov v2.8h, v7.8h + fmla v2.8h, v1.8h, v5.8h + mov v1.8h, v6.8h + fmla v1.8h, v0.8h, v4.8h + + ld1 {v0.8h}, [x15], #16 + fmul v8.8h, v1.8h, v0.h[0] + fmul v9.8h, v1.8h, v0.h[1] + fmul v10.8h, v1.8h, v0.h[2] + fmul v11.8h, v1.8h, v0.h[3] + fmul v12.8h, v1.8h, v0.h[4] + fmul v13.8h, v1.8h, v0.h[5] + fmul v14.8h, v1.8h, v0.h[6] + fmul v15.8h, v1.8h, v0.h[7] + + fmul v20.8h, v2.8h, v0.h[0] + fmul v21.8h, v2.8h, v0.h[1] + fmul v22.8h, v2.8h, v0.h[2] + fmul v23.8h, v2.8h, v0.h[3] + fmul v24.8h, v2.8h, v0.h[4] + fmul v25.8h, v2.8h, v0.h[5] + fmul v26.8h, v2.8h, v0.h[6] + fmul v27.8h, v2.8h, v0.h[7] + + ld1 {v1.4h}, [x15], #8 + fmul v16.8h, v1.8h, v1.h[0] + fmul v17.8h, v1.8h, v1.h[1] + fmul v18.8h, v1.8h, v1.h[2] + fmul v19.8h, v1.8h, v1.h[3] + fmul v28.8h, v2.8h, v1.h[0] + fmul v29.8h, v2.8h, v1.h[1] + fmul v30.8h, v2.8h, v1.h[2] + fmul v31.8h, v2.8h, v1.h[3] + + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + scvtf v0.8h, v0.8h + scvtf v1.8h, v1.8h + mov v2.8h, v7.8h + fmla v2.8h, v1.8h, v5.8h + mov v1.8h, v6.8h + fmla v1.8h, v0.8h, v4.8h + + ld1 {v0.8h}, [x15], #16 + fmla v8.8h, v1.8h, v0.h[0] + fmla v9.8h, v1.8h, v0.h[1] + fmla v10.8h, v1.8h, v0.h[2] + fmla v11.8h, v1.8h, v0.h[3] + fmla v12.8h, v1.8h, v0.h[4] + fmla v13.8h, v1.8h, v0.h[5] + fmla v14.8h, v1.8h, v0.h[6] + fmla v15.8h, v1.8h, v0.h[7] + + fmla v20.8h, v2.8h, v0.h[0] + fmla v21.8h, v2.8h, v0.h[1] + fmla v22.8h, v2.8h, v0.h[2] + fmla v23.8h, v2.8h, v0.h[3] + fmla v24.8h, v2.8h, v0.h[4] + fmla v25.8h, v2.8h, v0.h[5] + fmla v26.8h, v2.8h, v0.h[6] + fmla v27.8h, v2.8h, v0.h[7] + + ld1 {v0.4h}, [x15], #8 + fmla v16.8h, v1.8h, v0.h[0] + fmla v17.8h, v1.8h, v0.h[1] + fmla v18.8h, v1.8h, v0.h[2] + fmla v19.8h, v1.8h, v0.h[3] + fmla v28.8h, v2.8h, v0.h[0] + fmla v29.8h, v2.8h, v0.h[1] + fmla v30.8h, v2.8h, v0.h[2] + fmla v31.8h, v2.8h, v0.h[3] + + beq LoopLEnd + + LoopL2: + // ld1 {v3.8h, v4.8h}, [x2], #32 + subs x12, x12, #2 + ld1 {v0.8h}, [x2], #16 + ushr v1.16b, v0.16b, #4 + mov w17, #0x0f + dup v3.16b, w17 + and v2.16b, v0.16b, v3.16b + mov w17, #7 + dup v0.16b, w17 + sub v1.16b, v1.16b, v0.16b + sub v2.16b, v2.16b, v0.16b + zip1 v0.16b, v1.16b, v2.16b + zip2 v3.16b, v1.16b, v2.16b + sxtl v1.8h, v0.8b + sxtl2 v2.8h, v0.16b + scvtf v0.8h, v1.8h + scvtf v1.8h, v2.8h + mov v2.8h, v7.8h + fmla v2.8h, v1.8h, v5.8h + mov v1.8h, v6.8h + fmla v1.8h, v0.8h, v4.8h + + ld1 {v0.8h}, [x15], #16 + fmla v8.8h, v1.8h, v0.h[0] + fmla v9.8h, v1.8h, v0.h[1] + fmla v10.8h, v1.8h, v0.h[2] + fmla v11.8h, v1.8h, v0.h[3] + fmla v12.8h, v1.8h, v0.h[4] + fmla v13.8h, v1.8h, v0.h[5] + fmla v14.8h, v1.8h, v0.h[6] + fmla v15.8h, v1.8h, v0.h[7] + + fmla v20.8h, v2.8h, v0.h[0] + fmla v21.8h, v2.8h, v0.h[1] + fmla v22.8h, v2.8h, v0.h[2] + fmla v23.8h, v2.8h, v0.h[3] + fmla v24.8h, v2.8h, v0.h[4] + fmla v25.8h, v2.8h, v0.h[5] + fmla v26.8h, v2.8h, v0.h[6] + fmla v27.8h, v2.8h, v0.h[7] + + ld1 {v0.4h}, [x15], #8 + fmla v16.8h, v1.8h, v0.h[0] + fmla v17.8h, v1.8h, v0.h[1] + fmla v18.8h, v1.8h, v0.h[2] + fmla v19.8h, v1.8h, v0.h[3] + fmla v28.8h, v2.8h, v0.h[0] + fmla v29.8h, v2.8h, v0.h[1] + fmla v30.8h, v2.8h, v0.h[2] + fmla v31.8h, v2.8h, v0.h[3] + + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + scvtf v0.8h, v0.8h + scvtf v1.8h, v1.8h + mov v2.8h, v7.8h + fmla v2.8h, v1.8h, v5.8h + mov v1.8h, v6.8h + fmla v1.8h, v0.8h, v4.8h + + ld1 {v0.8h}, [x15], #16 + fmla v8.8h, v1.8h, v0.h[0] + fmla v9.8h, v1.8h, v0.h[1] + fmla v10.8h, v1.8h, v0.h[2] + fmla v11.8h, v1.8h, v0.h[3] + fmla v12.8h, v1.8h, v0.h[4] + fmla v13.8h, v1.8h, v0.h[5] + fmla v14.8h, v1.8h, v0.h[6] + fmla v15.8h, v1.8h, v0.h[7] + + fmla v20.8h, v2.8h, v0.h[0] + fmla v21.8h, v2.8h, v0.h[1] + fmla v22.8h, v2.8h, v0.h[2] + fmla v23.8h, v2.8h, v0.h[3] + fmla v24.8h, v2.8h, v0.h[4] + fmla v25.8h, v2.8h, v0.h[5] + fmla v26.8h, v2.8h, v0.h[6] + fmla v27.8h, v2.8h, v0.h[7] + + ld1 {v0.4h}, [x15], #8 + fmla v16.8h, v1.8h, v0.h[0] + fmla v17.8h, v1.8h, v0.h[1] + fmla v18.8h, v1.8h, v0.h[2] + fmla v19.8h, v1.8h, v0.h[3] + fmla v28.8h, v2.8h, v0.h[0] + fmla v29.8h, v2.8h, v0.h[1] + fmla v30.8h, v2.8h, v0.h[2] + fmla v31.8h, v2.8h, v0.h[3] + bne LoopL2 + + LoopLEnd: + + add x2, x2, x11 + sub x10, x10, #2 + cmp x10, #2 + + cbz x4, StoreLH8 + + AddBiasLH8: + ld1 {v5.8h}, [x4] + fcvtn v5.4h, v5.4s + dup v6.8h, v5.h[2] // Min Value + dup v7.8h, v5.h[3] // Max Value + ld1 {v0.8h, v1.8h}, [x5], #32 + + fmla v8.8h, v0.8h, v5.h[1] + fmla v9.8h, v0.8h, v5.h[1] + fmla v10.8h, v0.8h, v5.h[1] + fmla v11.8h, v0.8h, v5.h[1] + + fmla v12.8h, v0.8h, v5.h[1] + fmla v13.8h, v0.8h, v5.h[1] + fmla v14.8h, v0.8h, v5.h[1] + fmla v15.8h, v0.8h, v5.h[1] + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + fmla v20.8h, v1.8h, v5.h[1] + fmla v21.8h, v1.8h, v5.h[1] + fmla v22.8h, v1.8h, v5.h[1] + fmla v23.8h, v1.8h, v5.h[1] + + fmla v24.8h, v1.8h, v5.h[1] + fmla v25.8h, v1.8h, v5.h[1] + fmla v26.8h, v1.8h, v5.h[1] + fmla v27.8h, v1.8h, v5.h[1] + + fmla v28.8h, v1.8h, v5.h[1] + fmla v29.8h, v1.8h, v5.h[1] + fmla v30.8h, v1.8h, v5.h[1] + fmla v31.8h, v1.8h, v5.h[1] + + PostTreatLH8: + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v25.8h, v25.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v27.8h, v27.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v29.8h, v29.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + fmax v31.8h, v31.8h, v6.8h + + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v25.8h, v25.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v27.8h, v27.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v29.8h, v29.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + fmin v31.8h, v31.8h, v7.8h + + StoreLH8: + + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], x14 + + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x0], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x0], x14 + + bge LoopH + +LH4: +cbz x10, End +mov x19, x6 +mov x20, x7 +LoopHRemain: + mov x15, x1 + subs x12, x9, #2 + ld1 {v20.8h}, [x19], #16 // alpha + ld1 {v21.8h}, [x20], #16 // bias + mov w17, #0x0f + dup v22.16b, w17 + mov w17, #7 + dup v23.16b, w17 + // ld1 {v3.8h}, [x2] + ld1 {v3.8h}, [x2], #16 + // 01234567xxxxxxx89... => 0123456789... + uzp1 v0.4s, v3.4s, v3.4s + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v22.8b + sub v1.8b, v1.8b, v23.8b + sub v2.8b, v2.8b, v23.8b + zip1 v6.8b, v1.8b, v2.8b + zip2 v7.8b, v1.8b, v2.8b + sxtl v0.8h, v6.8b + sxtl v1.8h, v7.8b + scvtf v6.8h, v0.8h + scvtf v7.8h, v1.8h + mov v3.8h, v21.8h + mov v4.8h, v21.8h + fmla v3.8h, v6.8h, v20.8h + fmla v4.8h, v7.8h, v20.8h + + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 + fmul v8.8h, v3.8h, v0.h[0] + fmul v9.8h, v3.8h, v0.h[1] + fmul v10.8h, v3.8h, v0.h[2] + fmul v11.8h, v3.8h, v0.h[3] + fmul v12.8h, v3.8h, v1.h[0] + fmul v13.8h, v3.8h, v1.h[1] + fmul v14.8h, v3.8h, v1.h[2] + fmul v15.8h, v3.8h, v1.h[3] + fmul v16.8h, v3.8h, v2.h[0] + fmul v17.8h, v3.8h, v2.h[1] + fmul v18.8h, v3.8h, v2.h[2] + fmul v19.8h, v3.8h, v2.h[3] + + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v0.h[3] + fmla v12.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v1.h[1] + fmla v14.8h, v4.8h, v1.h[2] + fmla v15.8h, v4.8h, v1.h[3] + fmla v16.8h, v4.8h, v2.h[0] + fmla v17.8h, v4.8h, v2.h[1] + fmla v18.8h, v4.8h, v2.h[2] + fmla v19.8h, v4.8h, v2.h[3] + + beq LoopLREnd + + LoopLR: + subs x12, x12, #2 + // ld1 {v3.8h}, [x2] + ld1 {v3.8h}, [x2], #16 + // 01234567xxxxxxx89... => 0123456789... + uzp1 v0.4s, v3.4s, v3.4s + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v22.8b + sub v1.8b, v1.8b, v23.8b + sub v2.8b, v2.8b, v23.8b + zip1 v6.8b, v1.8b, v2.8b + zip2 v7.8b, v1.8b, v2.8b + sxtl v0.8h, v6.8b + sxtl v1.8h, v7.8b + scvtf v6.8h, v0.8h + scvtf v7.8h, v1.8h + mov v3.8h, v21.8h + mov v4.8h, v21.8h + fmla v3.8h, v6.8h, v20.8h + fmla v4.8h, v7.8h, v20.8h + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v3.8h, v0.h[1] + fmla v10.8h, v3.8h, v0.h[2] + fmla v11.8h, v3.8h, v0.h[3] + fmla v12.8h, v3.8h, v1.h[0] + fmla v13.8h, v3.8h, v1.h[1] + fmla v14.8h, v3.8h, v1.h[2] + fmla v15.8h, v3.8h, v1.h[3] + fmla v16.8h, v3.8h, v2.h[0] + fmla v17.8h, v3.8h, v2.h[1] + fmla v18.8h, v3.8h, v2.h[2] + fmla v19.8h, v3.8h, v2.h[3] + + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v0.h[3] + fmla v12.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v1.h[1] + fmla v14.8h, v4.8h, v1.h[2] + fmla v15.8h, v4.8h, v1.h[3] + fmla v16.8h, v4.8h, v2.h[0] + fmla v17.8h, v4.8h, v2.h[1] + fmla v18.8h, v4.8h, v2.h[2] + fmla v19.8h, v4.8h, v2.h[3] + + bne LoopLR + LoopLREnd: + + cbz x4, StoreLH4 + AddBiasLH4: + ld1 {v5.8h}, [x4] + fcvtn v5.4h, v5.4s + dup v6.8h, v5.h[2] // Min Value + dup v7.8h, v5.h[3] // Max Value + ld1 {v0.8h}, [x5], #16 + + fmla v8.8h, v0.8h, v5.h[1] + fmla v9.8h, v0.8h, v5.h[1] + fmla v10.8h, v0.8h, v5.h[1] + fmla v11.8h, v0.8h, v5.h[1] + + fmla v12.8h, v0.8h, v5.h[1] + fmla v13.8h, v0.8h, v5.h[1] + fmla v14.8h, v0.8h, v5.h[1] + fmla v15.8h, v0.8h, v5.h[1] + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + PostTreatLH4: + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + + StoreLH4: + + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0] + sub x10, x10, #1 + + +End: +ldp x19, x20, [sp, #64] +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #80 + +ret + +#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int8.S b/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int8.S new file mode 100644 index 000000000..790b02b78 --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulFP16_int8.S @@ -0,0 +1,480 @@ +// +// MNNPackedMatMulFP16_int8.S +// MNN +// +// Created by MNN on 2023/06/06. +// Copyright © 2018, Alibaba Group Holding Limited +// +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 8 * 24 MatMul +asm_function MNNPackedMatMulFP16_int8 +//void MNNPackedMatMulFP16_int8(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, const size_t* parameter, const FLOAT16* postParameters, const FLOAT16* bias, const FLOAT16* k, const FLOAT16* b); +// x0: C, x1:A, x2:B, x3:parameter, x4: postParameters, x5:bias, x6: quant_alpha, x7: quant_bias +stp d14, d15, [sp, #-80]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] +stp x19, x20, [sp, #64] + +ldr x9, [x3, #8] // l +ldr x10, [x3, #16] // h + +ldr x13, [x3, #24] // cStride +ldr x11, [x3, #40] // bExtraStride + +// v0, v1, v2: A +// v3, v4: B +// v8 - v31: C +add x10, x10, #7 +lsr x10, x10, #3 + +Start: + +cmp x10, #2 +blt LH4 + +LH8: +sub x14, x13, #128 +mov x19, x6 +mov x20, x7 +LoopH: + mov x15, x1 + ld1 {v4.8h, v5.8h}, [x19], #32 // alpha + ld1 {v6.8h, v7.8h}, [x20], #32 // bias + subs x12, x9, #2 + ld1 {v2.16b, v3.16b}, [x2], #32 + sxtl v0.8h, v2.8b + sxtl2 v1.8h, v2.16b + scvtf v0.8h, v0.8h + scvtf v1.8h, v1.8h + mov v2.8h, v7.8h + fmla v2.8h, v1.8h, v5.8h + mov v1.8h, v6.8h + fmla v1.8h, v0.8h, v4.8h + + ld1 {v0.8h}, [x15], #16 + fmul v8.8h, v1.8h, v0.h[0] + fmul v9.8h, v1.8h, v0.h[1] + fmul v10.8h, v1.8h, v0.h[2] + fmul v11.8h, v1.8h, v0.h[3] + fmul v12.8h, v1.8h, v0.h[4] + fmul v13.8h, v1.8h, v0.h[5] + fmul v14.8h, v1.8h, v0.h[6] + fmul v15.8h, v1.8h, v0.h[7] + + fmul v20.8h, v2.8h, v0.h[0] + fmul v21.8h, v2.8h, v0.h[1] + fmul v22.8h, v2.8h, v0.h[2] + fmul v23.8h, v2.8h, v0.h[3] + fmul v24.8h, v2.8h, v0.h[4] + fmul v25.8h, v2.8h, v0.h[5] + fmul v26.8h, v2.8h, v0.h[6] + fmul v27.8h, v2.8h, v0.h[7] + + ld1 {v1.4h}, [x15], #8 + fmul v16.8h, v1.8h, v1.h[0] + fmul v17.8h, v1.8h, v1.h[1] + fmul v18.8h, v1.8h, v1.h[2] + fmul v19.8h, v1.8h, v1.h[3] + fmul v28.8h, v2.8h, v1.h[0] + fmul v29.8h, v2.8h, v1.h[1] + fmul v30.8h, v2.8h, v1.h[2] + fmul v31.8h, v2.8h, v1.h[3] + + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + scvtf v0.8h, v0.8h + scvtf v1.8h, v1.8h + mov v2.8h, v7.8h + fmla v2.8h, v1.8h, v5.8h + mov v1.8h, v6.8h + fmla v1.8h, v0.8h, v4.8h + + ld1 {v0.8h}, [x15], #16 + fmla v8.8h, v1.8h, v0.h[0] + fmla v9.8h, v1.8h, v0.h[1] + fmla v10.8h, v1.8h, v0.h[2] + fmla v11.8h, v1.8h, v0.h[3] + fmla v12.8h, v1.8h, v0.h[4] + fmla v13.8h, v1.8h, v0.h[5] + fmla v14.8h, v1.8h, v0.h[6] + fmla v15.8h, v1.8h, v0.h[7] + + fmla v20.8h, v2.8h, v0.h[0] + fmla v21.8h, v2.8h, v0.h[1] + fmla v22.8h, v2.8h, v0.h[2] + fmla v23.8h, v2.8h, v0.h[3] + fmla v24.8h, v2.8h, v0.h[4] + fmla v25.8h, v2.8h, v0.h[5] + fmla v26.8h, v2.8h, v0.h[6] + fmla v27.8h, v2.8h, v0.h[7] + + ld1 {v0.4h}, [x15], #8 + fmla v16.8h, v1.8h, v0.h[0] + fmla v17.8h, v1.8h, v0.h[1] + fmla v18.8h, v1.8h, v0.h[2] + fmla v19.8h, v1.8h, v0.h[3] + fmla v28.8h, v2.8h, v0.h[0] + fmla v29.8h, v2.8h, v0.h[1] + fmla v30.8h, v2.8h, v0.h[2] + fmla v31.8h, v2.8h, v0.h[3] + + beq LoopLEnd + + LoopL2: + // ld1 {v3.8h, v4.8h}, [x2], #32 + subs x12, x12, #2 + ld1 {v2.16b, v3.16b}, [x2], #32 + sxtl v0.8h, v2.8b + sxtl2 v1.8h, v2.16b + scvtf v0.8h, v0.8h + scvtf v1.8h, v1.8h + mov v2.8h, v7.8h + fmla v2.8h, v1.8h, v5.8h + mov v1.8h, v6.8h + fmla v1.8h, v0.8h, v4.8h + + ld1 {v0.8h}, [x15], #16 + fmla v8.8h, v1.8h, v0.h[0] + fmla v9.8h, v1.8h, v0.h[1] + fmla v10.8h, v1.8h, v0.h[2] + fmla v11.8h, v1.8h, v0.h[3] + fmla v12.8h, v1.8h, v0.h[4] + fmla v13.8h, v1.8h, v0.h[5] + fmla v14.8h, v1.8h, v0.h[6] + fmla v15.8h, v1.8h, v0.h[7] + + fmla v20.8h, v2.8h, v0.h[0] + fmla v21.8h, v2.8h, v0.h[1] + fmla v22.8h, v2.8h, v0.h[2] + fmla v23.8h, v2.8h, v0.h[3] + fmla v24.8h, v2.8h, v0.h[4] + fmla v25.8h, v2.8h, v0.h[5] + fmla v26.8h, v2.8h, v0.h[6] + fmla v27.8h, v2.8h, v0.h[7] + + ld1 {v0.4h}, [x15], #8 + fmla v16.8h, v1.8h, v0.h[0] + fmla v17.8h, v1.8h, v0.h[1] + fmla v18.8h, v1.8h, v0.h[2] + fmla v19.8h, v1.8h, v0.h[3] + fmla v28.8h, v2.8h, v0.h[0] + fmla v29.8h, v2.8h, v0.h[1] + fmla v30.8h, v2.8h, v0.h[2] + fmla v31.8h, v2.8h, v0.h[3] + + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + scvtf v0.8h, v0.8h + scvtf v1.8h, v1.8h + mov v2.8h, v7.8h + fmla v2.8h, v1.8h, v5.8h + mov v1.8h, v6.8h + fmla v1.8h, v0.8h, v4.8h + + ld1 {v0.8h}, [x15], #16 + fmla v8.8h, v1.8h, v0.h[0] + fmla v9.8h, v1.8h, v0.h[1] + fmla v10.8h, v1.8h, v0.h[2] + fmla v11.8h, v1.8h, v0.h[3] + fmla v12.8h, v1.8h, v0.h[4] + fmla v13.8h, v1.8h, v0.h[5] + fmla v14.8h, v1.8h, v0.h[6] + fmla v15.8h, v1.8h, v0.h[7] + + fmla v20.8h, v2.8h, v0.h[0] + fmla v21.8h, v2.8h, v0.h[1] + fmla v22.8h, v2.8h, v0.h[2] + fmla v23.8h, v2.8h, v0.h[3] + fmla v24.8h, v2.8h, v0.h[4] + fmla v25.8h, v2.8h, v0.h[5] + fmla v26.8h, v2.8h, v0.h[6] + fmla v27.8h, v2.8h, v0.h[7] + + ld1 {v0.4h}, [x15], #8 + fmla v16.8h, v1.8h, v0.h[0] + fmla v17.8h, v1.8h, v0.h[1] + fmla v18.8h, v1.8h, v0.h[2] + fmla v19.8h, v1.8h, v0.h[3] + fmla v28.8h, v2.8h, v0.h[0] + fmla v29.8h, v2.8h, v0.h[1] + fmla v30.8h, v2.8h, v0.h[2] + fmla v31.8h, v2.8h, v0.h[3] + bne LoopL2 + + LoopLEnd: + + add x2, x2, x11 + sub x10, x10, #2 + cmp x10, #2 + + cbz x4, StoreLH8 + + AddBiasLH8: + ld1 {v5.8h}, [x4] + fcvtn v5.4h, v5.4s + dup v6.8h, v5.h[2] // Min Value + dup v7.8h, v5.h[3] // Max Value + ld1 {v0.8h, v1.8h}, [x5], #32 + + fmla v8.8h, v0.8h, v5.h[1] + fmla v9.8h, v0.8h, v5.h[1] + fmla v10.8h, v0.8h, v5.h[1] + fmla v11.8h, v0.8h, v5.h[1] + + fmla v12.8h, v0.8h, v5.h[1] + fmla v13.8h, v0.8h, v5.h[1] + fmla v14.8h, v0.8h, v5.h[1] + fmla v15.8h, v0.8h, v5.h[1] + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + fmla v20.8h, v1.8h, v5.h[1] + fmla v21.8h, v1.8h, v5.h[1] + fmla v22.8h, v1.8h, v5.h[1] + fmla v23.8h, v1.8h, v5.h[1] + + fmla v24.8h, v1.8h, v5.h[1] + fmla v25.8h, v1.8h, v5.h[1] + fmla v26.8h, v1.8h, v5.h[1] + fmla v27.8h, v1.8h, v5.h[1] + + fmla v28.8h, v1.8h, v5.h[1] + fmla v29.8h, v1.8h, v5.h[1] + fmla v30.8h, v1.8h, v5.h[1] + fmla v31.8h, v1.8h, v5.h[1] + + PostTreatLH8: + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v25.8h, v25.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v27.8h, v27.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v29.8h, v29.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + fmax v31.8h, v31.8h, v6.8h + + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v25.8h, v25.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v27.8h, v27.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v29.8h, v29.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + fmin v31.8h, v31.8h, v7.8h + + StoreLH8: + + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], x14 + + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x0], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x0], x14 + + bge LoopH + +LH4: +cbz x10, End +mov x19, x6 +mov x20, x7 +LoopHRemain: + mov x15, x1 + subs x12, x9, #2 + ld1 {v20.8h}, [x19], #16 // alpha + ld1 {v21.8h}, [x20], #16 // bias + // ld1 {v3.8h}, [x2] + ld1 {v3.16b, v4.16b}, [x2], #32 + uzp1 v0.4s, v3.4s, v4.4s + sxtl v0.8h, v0.8b + sxtl2 v1.8h, v0.16b + scvtf v6.8h, v0.8h + scvtf v7.8h, v1.8h + mov v3.8h, v21.8h + mov v4.8h, v21.8h + fmla v3.8h, v6.8h, v20.8h + fmla v4.8h, v7.8h, v20.8h + + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 + fmul v8.8h, v3.8h, v0.h[0] + fmul v9.8h, v3.8h, v0.h[1] + fmul v10.8h, v3.8h, v0.h[2] + fmul v11.8h, v3.8h, v0.h[3] + fmul v12.8h, v3.8h, v1.h[0] + fmul v13.8h, v3.8h, v1.h[1] + fmul v14.8h, v3.8h, v1.h[2] + fmul v15.8h, v3.8h, v1.h[3] + fmul v16.8h, v3.8h, v2.h[0] + fmul v17.8h, v3.8h, v2.h[1] + fmul v18.8h, v3.8h, v2.h[2] + fmul v19.8h, v3.8h, v2.h[3] + + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v0.h[3] + fmla v12.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v1.h[1] + fmla v14.8h, v4.8h, v1.h[2] + fmla v15.8h, v4.8h, v1.h[3] + fmla v16.8h, v4.8h, v2.h[0] + fmla v17.8h, v4.8h, v2.h[1] + fmla v18.8h, v4.8h, v2.h[2] + fmla v19.8h, v4.8h, v2.h[3] + + beq LoopLREnd + + LoopLR: + subs x12, x12, #2 + // ld1 {v3.8h}, [x2] + ld1 {v3.16b, v4.16b}, [x2], #32 + uzp1 v0.4s, v3.4s, v4.4s + sxtl v0.8h, v0.8b + sxtl2 v1.8h, v0.16b + scvtf v6.8h, v0.8h + scvtf v7.8h, v1.8h + mov v3.8h, v21.8h + mov v4.8h, v21.8h + fmla v3.8h, v6.8h, v20.8h + fmla v4.8h, v7.8h, v20.8h + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 + fmla v8.8h, v3.8h, v0.h[0] + fmla v9.8h, v3.8h, v0.h[1] + fmla v10.8h, v3.8h, v0.h[2] + fmla v11.8h, v3.8h, v0.h[3] + fmla v12.8h, v3.8h, v1.h[0] + fmla v13.8h, v3.8h, v1.h[1] + fmla v14.8h, v3.8h, v1.h[2] + fmla v15.8h, v3.8h, v1.h[3] + fmla v16.8h, v3.8h, v2.h[0] + fmla v17.8h, v3.8h, v2.h[1] + fmla v18.8h, v3.8h, v2.h[2] + fmla v19.8h, v3.8h, v2.h[3] + + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 + fmla v8.8h, v4.8h, v0.h[0] + fmla v9.8h, v4.8h, v0.h[1] + fmla v10.8h, v4.8h, v0.h[2] + fmla v11.8h, v4.8h, v0.h[3] + fmla v12.8h, v4.8h, v1.h[0] + fmla v13.8h, v4.8h, v1.h[1] + fmla v14.8h, v4.8h, v1.h[2] + fmla v15.8h, v4.8h, v1.h[3] + fmla v16.8h, v4.8h, v2.h[0] + fmla v17.8h, v4.8h, v2.h[1] + fmla v18.8h, v4.8h, v2.h[2] + fmla v19.8h, v4.8h, v2.h[3] + + bne LoopLR + LoopLREnd: + + cbz x4, StoreLH4 + AddBiasLH4: + ld1 {v5.8h}, [x4] + fcvtn v5.4h, v5.4s + dup v6.8h, v5.h[2] // Min Value + dup v7.8h, v5.h[3] // Max Value + ld1 {v0.8h}, [x5], #16 + + fmla v8.8h, v0.8h, v5.h[1] + fmla v9.8h, v0.8h, v5.h[1] + fmla v10.8h, v0.8h, v5.h[1] + fmla v11.8h, v0.8h, v5.h[1] + + fmla v12.8h, v0.8h, v5.h[1] + fmla v13.8h, v0.8h, v5.h[1] + fmla v14.8h, v0.8h, v5.h[1] + fmla v15.8h, v0.8h, v5.h[1] + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + PostTreatLH4: + fmax v8.8h, v8.8h, v6.8h + fmax v9.8h, v9.8h, v6.8h + fmax v10.8h, v10.8h, v6.8h + fmax v11.8h, v11.8h, v6.8h + fmax v12.8h, v12.8h, v6.8h + fmax v13.8h, v13.8h, v6.8h + fmax v14.8h, v14.8h, v6.8h + fmax v15.8h, v15.8h, v6.8h + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + + fmin v8.8h, v8.8h, v7.8h + fmin v9.8h, v9.8h, v7.8h + fmin v10.8h, v10.8h, v7.8h + fmin v11.8h, v11.8h, v7.8h + fmin v12.8h, v12.8h, v7.8h + fmin v13.8h, v13.8h, v7.8h + fmin v14.8h, v14.8h, v7.8h + fmin v15.8h, v15.8h, v7.8h + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + + StoreLH4: + + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 + st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0] + sub x10, x10, #1 + + +End: +ldp x19, x20, [sp, #64] +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #80 + +ret + +#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int4.S b/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int4.S new file mode 100644 index 000000000..9045381af --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int4.S @@ -0,0 +1,843 @@ +// +// MNNPackedMatMulRemainFP16_int4.S +// MNN +// +// Created by MNN on 2023/05/22. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 8 * 24 MatMul, C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, e), hP = 24 +// Remain meaning is eSize is any value +asm_function MNNPackedMatMulRemainFP16_int4 +//void MNNPackedMatMulRemainFP16_int4(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, size_t eSize, const size_t* parameter, const FLOAT16* postParameters, const FLOAT16* bias); +//Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias, x7: k, x8: b +// parameter: {aStride, l, h, cStride, bExtraStride} +ldr x8, [sp] // bias +stp d14, d15, [sp, #-112]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] +stp x19, x20, [sp, #64] +stp x21, x22, [sp, #80] +stp x23, x24, [sp, #96] + +mov x22, x7 // alpha +mov x23, x8 // bias +ldr x11, [x4, #0] // aStride +ldr x9, [x4, #8] // l +ldr x10, [x4, #16] // h + +ldr x7, [x4, #24] // cStride +ldr x19, [x4, #40] // bExtraStride + +add x10, x10, #7 +lsr x10, x10, #3 + +cbz x5, Start +ld1 {v5.4s}, [x5] +fcvtn v5.4h, v5.4s +dup v6.8h, v5.h[2] // Min Value +dup v7.8h, v5.h[3] // Max Value + +Start: + +E8: +cmp x3, #8 +blt E4 + +// 8x16 +LoopE8: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + LH8: + cmp x8, #2 + blt LH4 + sub x24, x7, #64 + LoopH8x8: + mov x15, x1 + ld1 {v12.8h, v13.8h}, [x14], #32 // alpha + mov w17, #0x0f + dup v3.16b, w17 + mov w17, #7 + dup v4.16b, w17 + ld1 {v14.8h, v15.8h}, [x16], #32 // bias + subs x12, x9, #2 + ld1 {v0.8h}, [x13], #16 + ushr v1.16b, v0.16b, #4 + and v2.16b, v0.16b, v3.16b + sub v1.16b, v1.16b, v4.16b + sub v2.16b, v2.16b, v4.16b + zip1 v10.16b, v1.16b, v2.16b + zip2 v11.16b, v1.16b, v2.16b + sxtl v1.8h, v10.8b + sxtl2 v2.8h, v10.16b + scvtf v1.8h, v1.8h + scvtf v2.8h, v2.8h + mov v8.8h, v14.8h + mov v9.8h, v15.8h + fmla v8.8h, v1.8h, v12.8h + fmla v9.8h, v2.8h, v13.8h + + ld1 {v0.8h}, [x15], x11 + fmul v16.8h, v8.8h, v0.h[0] + fmul v17.8h, v8.8h, v0.h[1] + fmul v18.8h, v8.8h, v0.h[2] + fmul v19.8h, v8.8h, v0.h[3] + + fmul v20.8h, v9.8h, v0.h[0] + fmul v21.8h, v9.8h, v0.h[1] + fmul v22.8h, v9.8h, v0.h[2] + fmul v23.8h, v9.8h, v0.h[3] + + fmul v24.8h, v8.8h, v0.h[4] + fmul v25.8h, v8.8h, v0.h[5] + fmul v26.8h, v8.8h, v0.h[6] + fmul v27.8h, v8.8h, v0.h[7] + + fmul v28.8h, v9.8h, v0.h[4] + fmul v29.8h, v9.8h, v0.h[5] + fmul v30.8h, v9.8h, v0.h[6] + fmul v31.8h, v9.8h, v0.h[7] + + sxtl v1.8h, v11.8b + sxtl2 v2.8h, v11.16b + scvtf v1.8h, v1.8h + scvtf v2.8h, v2.8h + mov v8.8h, v14.8h + mov v9.8h, v15.8h + fmla v8.8h, v1.8h, v12.8h + fmla v9.8h, v2.8h, v13.8h + + ld1 {v0.8h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] + + fmla v24.8h, v8.8h, v0.h[4] + fmla v25.8h, v8.8h, v0.h[5] + fmla v26.8h, v8.8h, v0.h[6] + fmla v27.8h, v8.8h, v0.h[7] + + fmla v28.8h, v9.8h, v0.h[4] + fmla v29.8h, v9.8h, v0.h[5] + fmla v30.8h, v9.8h, v0.h[6] + fmla v31.8h, v9.8h, v0.h[7] + beq LoopLEnd + + LoopL: + subs x12, x12, #2 + ld1 {v0.8h}, [x13], #16 + ushr v1.16b, v0.16b, #4 + and v2.16b, v0.16b, v3.16b + sub v1.16b, v1.16b, v4.16b + sub v2.16b, v2.16b, v4.16b + zip1 v10.16b, v1.16b, v2.16b + zip2 v11.16b, v1.16b, v2.16b + sxtl v1.8h, v10.8b + sxtl2 v2.8h, v10.16b + scvtf v1.8h, v1.8h + scvtf v2.8h, v2.8h + mov v8.8h, v14.8h + mov v9.8h, v15.8h + fmla v8.8h, v1.8h, v12.8h + fmla v9.8h, v2.8h, v13.8h + + ld1 {v0.8h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] + + fmla v24.8h, v8.8h, v0.h[4] + fmla v25.8h, v8.8h, v0.h[5] + fmla v26.8h, v8.8h, v0.h[6] + fmla v27.8h, v8.8h, v0.h[7] + + fmla v28.8h, v9.8h, v0.h[4] + fmla v29.8h, v9.8h, v0.h[5] + fmla v30.8h, v9.8h, v0.h[6] + fmla v31.8h, v9.8h, v0.h[7] + + sxtl v1.8h, v11.8b + sxtl2 v2.8h, v11.16b + scvtf v1.8h, v1.8h + scvtf v2.8h, v2.8h + mov v8.8h, v14.8h + mov v9.8h, v15.8h + fmla v8.8h, v1.8h, v12.8h + fmla v9.8h, v2.8h, v13.8h + ld1 {v0.8h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] + + fmla v24.8h, v8.8h, v0.h[4] + fmla v25.8h, v8.8h, v0.h[5] + fmla v26.8h, v8.8h, v0.h[6] + fmla v27.8h, v8.8h, v0.h[7] + + fmla v28.8h, v9.8h, v0.h[4] + fmla v29.8h, v9.8h, v0.h[5] + fmla v30.8h, v9.8h, v0.h[6] + fmla v31.8h, v9.8h, v0.h[7] + bne LoopL + + LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + + cbz x5, StoreLH8 + AddBiasLH8: + ld1 {v0.8h, v1.8h}, [x20], #32 + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + fmla v20.8h, v1.8h, v5.h[1] + fmla v21.8h, v1.8h, v5.h[1] + fmla v22.8h, v1.8h, v5.h[1] + fmla v23.8h, v1.8h, v5.h[1] + + fmla v24.8h, v0.8h, v5.h[1] + fmla v25.8h, v0.8h, v5.h[1] + fmla v26.8h, v0.8h, v5.h[1] + fmla v27.8h, v0.8h, v5.h[1] + + fmla v28.8h, v1.8h, v5.h[1] + fmla v29.8h, v1.8h, v5.h[1] + fmla v30.8h, v1.8h, v5.h[1] + fmla v31.8h, v1.8h, v5.h[1] + + PostTreatLH8: + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v25.8h, v25.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v27.8h, v27.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v29.8h, v29.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + fmax v31.8h, v31.8h, v6.8h + + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v25.8h, v25.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v27.8h, v27.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v29.8h, v29.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + fmin v31.8h, v31.8h, v7.8h + + StoreLH8: + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x0], x24 + + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x0], x24 + cmp x8, #2 + bge LoopH8x8 + + LH4: + cbz x8, E8End + LoopHRemain: + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.8h}, [x13] + ld1 {v0.8h}, [x15], x11 + fmul v16.8h, v3.8h, v0.h[0] + fmul v17.8h, v3.8h, v0.h[1] + add x13, x13, #32 + fmul v18.8h, v3.8h, v0.h[2] + fmul v19.8h, v3.8h, v0.h[3] + fmul v20.8h, v3.8h, v0.h[4] + fmul v21.8h, v3.8h, v0.h[5] + fmul v22.8h, v3.8h, v0.h[6] + fmul v23.8h, v3.8h, v0.h[7] + beq LoopLREnd + + LoopLR: + ld1 {v3.8h}, [x13] + ld1 {v0.8h}, [x15], x11 + fmla v16.8h, v3.8h, v0.h[0] + fmla v17.8h, v3.8h, v0.h[1] + fmla v18.8h, v3.8h, v0.h[2] + fmla v19.8h, v3.8h, v0.h[3] + add x13, x13, #32 + + fmla v20.8h, v3.8h, v0.h[4] + fmla v21.8h, v3.8h, v0.h[5] + fmla v22.8h, v3.8h, v0.h[6] + fmla v23.8h, v3.8h, v0.h[7] + + subs x12, x12, #1 + bne LoopLR + LoopLREnd: + + cbz x5, StoreLH8x4 + AddBiasLH8x4: + ld1 {v0.8h}, [x20] + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + fmla v20.8h, v0.8h, v5.h[1] + fmla v21.8h, v0.8h, v5.h[1] + fmla v22.8h, v0.8h, v5.h[1] + fmla v23.8h, v0.8h, v5.h[1] + + PostTreatLH8x4: + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + + StoreLH8x4: + + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + + E8End: + + sub x3, x3, #8 + add x0, x21, #128 + add x1, x1, #16 + +E4: +cmp x3, #4 +mov x20, x6 +blt E1 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + cmp x8, #2 + blt E4LH4 + + E4LH8: + E4LoopH8: + mov x15, x1 + ld1 {v24.8h, v25.8h}, [x14], #32 // alpha + mov w17, #0x0f + dup v30.16b, w17 + mov w17, #7 + dup v31.16b, w17 + ld1 {v26.8h, v27.8h}, [x16], #32 // bias + subs x12, x9, #2 + ld1 {v0.8h}, [x13], #16 + ushr v1.16b, v0.16b, #4 + and v2.16b, v0.16b, v30.16b + sub v1.16b, v1.16b, v31.16b + sub v2.16b, v2.16b, v31.16b + zip1 v3.16b, v1.16b, v2.16b + zip2 v4.16b, v1.16b, v2.16b + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v2.8h, v4.8b + sxtl2 v3.8h, v4.16b + scvtf v12.8h, v0.8h + scvtf v13.8h, v1.8h + scvtf v14.8h, v2.8h + scvtf v15.8h, v3.8h + mov v8.8h, v26.8h + mov v9.8h, v27.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + mov v10.8h, v26.8h + mov v11.8h, v27.8h + fmla v10.8h, v14.8h, v24.8h + fmla v11.8h, v15.8h, v25.8h + + ld1 {v0.4h}, [x15], x11 + ld1 {v1.4h}, [x15], x11 + fmul v16.8h, v8.8h, v0.h[0] + fmul v17.8h, v8.8h, v0.h[1] + fmul v18.8h, v8.8h, v0.h[2] + fmul v19.8h, v8.8h, v0.h[3] + + fmul v20.8h, v9.8h, v0.h[0] + fmul v21.8h, v9.8h, v0.h[1] + fmul v22.8h, v9.8h, v0.h[2] + fmul v23.8h, v9.8h, v0.h[3] + + fmla v16.8h, v10.8h, v1.h[0] + fmla v17.8h, v10.8h, v1.h[1] + fmla v18.8h, v10.8h, v1.h[2] + fmla v19.8h, v10.8h, v1.h[3] + + fmla v20.8h, v11.8h, v1.h[0] + fmla v21.8h, v11.8h, v1.h[1] + fmla v22.8h, v11.8h, v1.h[2] + fmla v23.8h, v11.8h, v1.h[3] + + beq E4LoopLEnd + + E4LoopL: +#if 0 + subs x12, x12, #1 + ld1 {v0.4h}, [x13], #8 + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v8.8b, v1.8b, v2.8b + zip2 v9.8b, v1.8b, v2.8b + sxtl v10.8h, v8.8b + sxtl v11.8h, v9.8b + scvtf v12.8h, v10.8h + scvtf v13.8h, v11.8h + mov v8.8h, v26.8h + mov v9.8h, v27.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + ld1 {v0.4h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] +#else + subs x12, x12, #2 + ld1 {v0.8h}, [x13], #16 + ushr v1.16b, v0.16b, #4 + and v2.16b, v0.16b, v30.16b + sub v1.16b, v1.16b, v31.16b + sub v2.16b, v2.16b, v31.16b + zip1 v3.16b, v1.16b, v2.16b + zip2 v4.16b, v1.16b, v2.16b + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v2.8h, v4.8b + sxtl2 v3.8h, v4.16b + scvtf v12.8h, v0.8h + scvtf v13.8h, v1.8h + scvtf v14.8h, v2.8h + scvtf v15.8h, v3.8h + mov v8.8h, v26.8h + mov v9.8h, v27.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + mov v10.8h, v26.8h + mov v11.8h, v27.8h + fmla v10.8h, v14.8h, v24.8h + fmla v11.8h, v15.8h, v25.8h + + ld1 {v0.4h}, [x15], x11 + ld1 {v1.4h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] + + fmla v16.8h, v10.8h, v1.h[0] + fmla v17.8h, v10.8h, v1.h[1] + fmla v18.8h, v10.8h, v1.h[2] + fmla v19.8h, v10.8h, v1.h[3] + + fmla v20.8h, v11.8h, v1.h[0] + fmla v21.8h, v11.8h, v1.h[1] + fmla v22.8h, v11.8h, v1.h[2] + fmla v23.8h, v11.8h, v1.h[3] +#endif + bne E4LoopL + + E4LoopLEnd: + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH4x8 + + AddBiasLH4x8: + ld1 {v0.8h, v1.8h}, [x20], #32 + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + fmla v20.8h, v1.8h, v5.h[1] + fmla v21.8h, v1.8h, v5.h[1] + fmla v22.8h, v1.8h, v5.h[1] + fmla v23.8h, v1.8h, v5.h[1] + + PostTreatLH4x8: + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + + StoreLH4x8: + + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], x7 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], x7 + + bge E4LoopH8 + + E4LH4: + cbz x8, E4End + mov x15, x1 + ld1 {v4.8h}, [x14], #16 // alpha + // mov v4.d[1], v4.d[0] + mov w17, #0x0f + dup v30.8b, w17 + mov w17, #7 + dup v31.8b, w17 + ld1 {v14.8h}, [x16], #16 // bias + // mov v14.d[1], v14.d[0] + subs x12, x9, #2 + // load 16xint4 to 16xfloat + ld1 {v3.8h}, [x13] + // 01234567xxxxxxx89... => 0123456789... + uzp1 v0.4s, v3.4s, v3.4s + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v8.8b, v1.8b, v2.8b + zip2 v9.8b, v1.8b, v2.8b + sxtl v10.8h, v8.8b + sxtl v11.8h, v9.8b + scvtf v12.8h, v10.8h + scvtf v13.8h, v11.8h + mov v8.8h, v14.8h + mov v9.8h, v14.8h + fmla v8.8h, v12.8h, v4.8h + fmla v9.8h, v13.8h, v4.8h + // st1 {v8.8h, v9.8h}, [x0] + // b End + + ld1 {v0.4h}, [x15], x11 + fmul v16.8h, v8.8h, v0.h[0] + fmul v17.8h, v8.8h, v0.h[1] + fmul v18.8h, v8.8h, v0.h[2] + fmul v19.8h, v8.8h, v0.h[3] + ld1 {v1.4h}, [x15], x11 + fmla v16.8h, v9.8h, v1.h[0] + fmla v17.8h, v9.8h, v1.h[1] + fmla v18.8h, v9.8h, v1.h[2] + fmla v19.8h, v9.8h, v1.h[3] + add x13, x13, #16 + + beq E4LoopLREnd + + E4LoopLR: + ld1 {v3.8h}, [x13] + uzp1 v0.4s, v3.4s, v3.4s + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v8.8b, v1.8b, v2.8b + zip2 v9.8b, v1.8b, v2.8b + sxtl v10.8h, v8.8b + sxtl v11.8h, v9.8b + scvtf v12.8h, v10.8h + scvtf v13.8h, v11.8h + mov v8.8h, v14.8h + mov v9.8h, v14.8h + fmla v8.8h, v12.8h, v4.8h + fmla v9.8h, v13.8h, v4.8h + + ld1 {v0.4h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + ld1 {v1.4h}, [x15], x11 + fmla v16.8h, v9.8h, v1.h[0] + fmla v17.8h, v9.8h, v1.h[1] + fmla v18.8h, v9.8h, v1.h[2] + fmla v19.8h, v9.8h, v1.h[3] + add x13, x13, #16 + + subs x12, x12, #2 + bne E4LoopLR + E4LoopLREnd: + + cbz x5, StoreLH4x4 + AddBiasLH4x4: + ld1 {v0.8h}, [x20] + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + + PostTreatLH4x4: + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + + StoreLH4x4: + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0] + + E4End: + + sub x3, x3, #4 + add x0, x21, #64 + add x1, x1, #8 + +E1: +cmp x3, #0 +beq End + +LoopE1: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + cmp x8, #2 + blt E1LH4 + + E1LH8: + E1LoopH8: + mov x15, x1 + ld1 {v24.8h, v25.8h}, [x14], #32 // alpha + mov w17, #0x0f + dup v30.16b, w17 + mov w17, #7 + dup v31.16b, w17 + ld1 {v26.8h, v27.8h}, [x16], #32 // bias + subs x12, x9, #2 + ld1 {v0.8h}, [x13], #16 + ushr v1.16b, v0.16b, #4 + and v2.16b, v0.16b, v30.16b + sub v1.16b, v1.16b, v31.16b + sub v2.16b, v2.16b, v31.16b + zip1 v3.16b, v1.16b, v2.16b + zip2 v4.16b, v1.16b, v2.16b + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v2.8h, v4.8b + sxtl2 v3.8h, v4.16b + scvtf v12.8h, v0.8h + scvtf v13.8h, v1.8h + scvtf v14.8h, v2.8h + scvtf v15.8h, v3.8h + mov v8.8h, v26.8h + mov v9.8h, v27.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + mov v10.8h, v26.8h + mov v11.8h, v27.8h + fmla v10.8h, v14.8h, v24.8h + fmla v11.8h, v15.8h, v25.8h + + ld1 {v0.h}[0], [x15], x11 + ld1 {v0.h}[1], [x15], x11 + fmul v16.8h, v8.8h, v0.h[0] + fmul v20.8h, v9.8h, v0.h[0] + + fmla v16.8h, v10.8h, v0.h[1] + fmla v20.8h, v11.8h, v0.h[1] + + beq E1LoopLEnd + + E1LoopL: + subs x12, x12, #2 + ld1 {v0.16b}, [x13], #16 + ushr v1.16b, v0.16b, #4 + and v2.16b, v0.16b, v30.16b + sub v1.16b, v1.16b, v31.16b + sub v2.16b, v2.16b, v31.16b + zip1 v3.16b, v1.16b, v2.16b + zip2 v4.16b, v1.16b, v2.16b + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v2.8h, v4.8b + sxtl2 v3.8h, v4.16b + mov v8.8h, v26.8h + mov v9.8h, v27.8h + scvtf v12.8h, v0.8h + scvtf v13.8h, v1.8h + mov v10.8h, v26.8h + mov v11.8h, v27.8h + scvtf v14.8h, v2.8h + ld1 {v0.h}[0], [x15], x11 + scvtf v15.8h, v3.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + ld1 {v0.h}[1], [x15], x11 + fmla v10.8h, v14.8h, v24.8h + fmla v11.8h, v15.8h, v25.8h + + fmla v16.8h, v8.8h, v0.h[0] + fmla v20.8h, v9.8h, v0.h[0] + fmla v16.8h, v10.8h, v0.h[1] + fmla v20.8h, v11.8h, v0.h[1] + bne E1LoopL + + E1LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH1x8 + AddBiasLH1x8: + ld1 {v0.8h, v1.8h}, [x20], #32 + + fmla v16.8h, v0.8h, v5.h[1] + fmla v20.8h, v1.8h, v5.h[1] + + PostTreatLH1x8: + fmax v16.8h, v16.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmin v16.8h, v16.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + + StoreLH1x8: + + st1 {v16.8h}, [x0], x7 + st1 {v20.8h}, [x0], x7 + + bge E1LoopH8 + + E1LH4: + cbz x8, E1End + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.8h}, [x13] + ld1 {v0.h}[0], [x15], x11 + fmul v16.8h, v3.8h, v0.h[0] + add x13, x13, #32 + + beq E1LoopLREnd + + E1LoopLR: + ld1 {v3.8h}, [x13] + ld1 {v0.h}[0], [x15], x11 + fmla v16.8h, v3.8h, v0.h[0] + add x13, x13, #32 + + subs x12, x12, #1 + bne E1LoopLR + E1LoopLREnd: + + cbz x5, StoreLH1x4 + AddBiasLH1x4: + ld1 {v0.8h}, [x20] + fmla v16.8h, v0.8h, v5.h[1] + + PostTreatLH1x4: + fmax v16.8h, v16.8h, v6.8h + fmin v16.8h, v16.8h, v7.8h + + StoreLH1x4: + st1 {v16.8h}, [x0] + + E1End: + + subs x3, x3, #1 + add x0, x21, #16 + add x1, x1, #2 + bne LoopE1 + + +End: +ldp x23, x24, [sp, #96] +ldp x21, x22, [sp, #80] +ldp x19, x20, [sp, #64] +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #112 +ret + + +#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int8.S b/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int8.S new file mode 100644 index 000000000..d1868f96c --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNPackedMatMulRemainFP16_int8.S @@ -0,0 +1,750 @@ +// +// MNNPackedMatMulRemainFP16_int8.S +// MNN +// +// Created by MNN on 2023/06/06. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 8 * 24 MatMul, C(UP_DIV(h,8), e, h8) = B(UP_DIV(h,hP), l, hP) * A(l, e), hP = 24 +// Remain meaning is eSize is any value +asm_function MNNPackedMatMulRemainFP16_int8 +//void MNNPackedMatMulRemainFP16_int8(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, size_t eSize, const size_t* parameter, const FLOAT16* postParameters, const FLOAT16* bias); +//Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias, x7: k, x8: b +// parameter: {aStride, l, h, cStride, bExtraStride} +ldr x8, [sp] // bias +stp d14, d15, [sp, #-112]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] +stp x19, x20, [sp, #64] +stp x21, x22, [sp, #80] +stp x23, x24, [sp, #96] + +mov x22, x7 // alpha +mov x23, x8 // bias +ldr x11, [x4, #0] // aStride +ldr x9, [x4, #8] // l +ldr x10, [x4, #16] // h + +ldr x7, [x4, #24] // cStride +ldr x19, [x4, #40] // bExtraStride + +add x10, x10, #7 +lsr x10, x10, #3 + +cbz x5, Start +ld1 {v5.4s}, [x5] +fcvtn v5.4h, v5.4s +dup v6.8h, v5.h[2] // Min Value +dup v7.8h, v5.h[3] // Max Value + +Start: + +E8: +cmp x3, #8 +blt E4 + +// 8x16 +LoopE8: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + LH8: + cmp x8, #2 + blt LH4 + sub x24, x7, #64 + LoopH8x8: + mov x15, x1 + ld1 {v12.8h, v13.8h}, [x14], #32 // alpha + ld1 {v14.8h, v15.8h}, [x16], #32 // bias + subs x12, x9, #2 + ld1 {v10.16b, v11.16b}, [x13], #32 + sxtl v1.8h, v10.8b + sxtl2 v2.8h, v10.16b + scvtf v1.8h, v1.8h + scvtf v2.8h, v2.8h + mov v8.8h, v14.8h + mov v9.8h, v15.8h + fmla v8.8h, v1.8h, v12.8h + fmla v9.8h, v2.8h, v13.8h + + ld1 {v0.8h}, [x15], x11 + fmul v16.8h, v8.8h, v0.h[0] + fmul v17.8h, v8.8h, v0.h[1] + fmul v18.8h, v8.8h, v0.h[2] + fmul v19.8h, v8.8h, v0.h[3] + + fmul v20.8h, v9.8h, v0.h[0] + fmul v21.8h, v9.8h, v0.h[1] + fmul v22.8h, v9.8h, v0.h[2] + fmul v23.8h, v9.8h, v0.h[3] + + fmul v24.8h, v8.8h, v0.h[4] + fmul v25.8h, v8.8h, v0.h[5] + fmul v26.8h, v8.8h, v0.h[6] + fmul v27.8h, v8.8h, v0.h[7] + + fmul v28.8h, v9.8h, v0.h[4] + fmul v29.8h, v9.8h, v0.h[5] + fmul v30.8h, v9.8h, v0.h[6] + fmul v31.8h, v9.8h, v0.h[7] + + sxtl v1.8h, v11.8b + sxtl2 v2.8h, v11.16b + scvtf v1.8h, v1.8h + scvtf v2.8h, v2.8h + mov v8.8h, v14.8h + mov v9.8h, v15.8h + fmla v8.8h, v1.8h, v12.8h + fmla v9.8h, v2.8h, v13.8h + + ld1 {v0.8h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] + + fmla v24.8h, v8.8h, v0.h[4] + fmla v25.8h, v8.8h, v0.h[5] + fmla v26.8h, v8.8h, v0.h[6] + fmla v27.8h, v8.8h, v0.h[7] + + fmla v28.8h, v9.8h, v0.h[4] + fmla v29.8h, v9.8h, v0.h[5] + fmla v30.8h, v9.8h, v0.h[6] + fmla v31.8h, v9.8h, v0.h[7] + beq LoopLEnd + + LoopL: + subs x12, x12, #2 + ld1 {v10.16b, v11.16b}, [x13], #32 + sxtl v1.8h, v10.8b + sxtl2 v2.8h, v10.16b + scvtf v1.8h, v1.8h + scvtf v2.8h, v2.8h + mov v8.8h, v14.8h + mov v9.8h, v15.8h + fmla v8.8h, v1.8h, v12.8h + fmla v9.8h, v2.8h, v13.8h + + ld1 {v0.8h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] + + fmla v24.8h, v8.8h, v0.h[4] + fmla v25.8h, v8.8h, v0.h[5] + fmla v26.8h, v8.8h, v0.h[6] + fmla v27.8h, v8.8h, v0.h[7] + + fmla v28.8h, v9.8h, v0.h[4] + fmla v29.8h, v9.8h, v0.h[5] + fmla v30.8h, v9.8h, v0.h[6] + fmla v31.8h, v9.8h, v0.h[7] + + sxtl v1.8h, v11.8b + sxtl2 v2.8h, v11.16b + scvtf v1.8h, v1.8h + scvtf v2.8h, v2.8h + mov v8.8h, v14.8h + mov v9.8h, v15.8h + fmla v8.8h, v1.8h, v12.8h + fmla v9.8h, v2.8h, v13.8h + ld1 {v0.8h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] + + fmla v24.8h, v8.8h, v0.h[4] + fmla v25.8h, v8.8h, v0.h[5] + fmla v26.8h, v8.8h, v0.h[6] + fmla v27.8h, v8.8h, v0.h[7] + + fmla v28.8h, v9.8h, v0.h[4] + fmla v29.8h, v9.8h, v0.h[5] + fmla v30.8h, v9.8h, v0.h[6] + fmla v31.8h, v9.8h, v0.h[7] + bne LoopL + + LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + + cbz x5, StoreLH8 + AddBiasLH8: + ld1 {v0.8h, v1.8h}, [x20], #32 + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + fmla v20.8h, v1.8h, v5.h[1] + fmla v21.8h, v1.8h, v5.h[1] + fmla v22.8h, v1.8h, v5.h[1] + fmla v23.8h, v1.8h, v5.h[1] + + fmla v24.8h, v0.8h, v5.h[1] + fmla v25.8h, v0.8h, v5.h[1] + fmla v26.8h, v0.8h, v5.h[1] + fmla v27.8h, v0.8h, v5.h[1] + + fmla v28.8h, v1.8h, v5.h[1] + fmla v29.8h, v1.8h, v5.h[1] + fmla v30.8h, v1.8h, v5.h[1] + fmla v31.8h, v1.8h, v5.h[1] + + PostTreatLH8: + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + fmax v24.8h, v24.8h, v6.8h + fmax v25.8h, v25.8h, v6.8h + fmax v26.8h, v26.8h, v6.8h + fmax v27.8h, v27.8h, v6.8h + fmax v28.8h, v28.8h, v6.8h + fmax v29.8h, v29.8h, v6.8h + fmax v30.8h, v30.8h, v6.8h + fmax v31.8h, v31.8h, v6.8h + + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + fmin v24.8h, v24.8h, v7.8h + fmin v25.8h, v25.8h, v7.8h + fmin v26.8h, v26.8h, v7.8h + fmin v27.8h, v27.8h, v7.8h + fmin v28.8h, v28.8h, v7.8h + fmin v29.8h, v29.8h, v7.8h + fmin v30.8h, v30.8h, v7.8h + fmin v31.8h, v31.8h, v7.8h + + StoreLH8: + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x0], x24 + + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + st1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x0], x24 + cmp x8, #2 + bge LoopH8x8 + + LH4: + cbz x8, E8End + LoopHRemain: + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.8h}, [x13] + ld1 {v0.8h}, [x15], x11 + fmul v16.8h, v3.8h, v0.h[0] + fmul v17.8h, v3.8h, v0.h[1] + add x13, x13, #32 + fmul v18.8h, v3.8h, v0.h[2] + fmul v19.8h, v3.8h, v0.h[3] + fmul v20.8h, v3.8h, v0.h[4] + fmul v21.8h, v3.8h, v0.h[5] + fmul v22.8h, v3.8h, v0.h[6] + fmul v23.8h, v3.8h, v0.h[7] + beq LoopLREnd + + LoopLR: + ld1 {v3.8h}, [x13] + ld1 {v0.8h}, [x15], x11 + fmla v16.8h, v3.8h, v0.h[0] + fmla v17.8h, v3.8h, v0.h[1] + fmla v18.8h, v3.8h, v0.h[2] + fmla v19.8h, v3.8h, v0.h[3] + add x13, x13, #32 + + fmla v20.8h, v3.8h, v0.h[4] + fmla v21.8h, v3.8h, v0.h[5] + fmla v22.8h, v3.8h, v0.h[6] + fmla v23.8h, v3.8h, v0.h[7] + + subs x12, x12, #1 + bne LoopLR + LoopLREnd: + + cbz x5, StoreLH8x4 + AddBiasLH8x4: + ld1 {v0.8h}, [x20] + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + fmla v20.8h, v0.8h, v5.h[1] + fmla v21.8h, v0.8h, v5.h[1] + fmla v22.8h, v0.8h, v5.h[1] + fmla v23.8h, v0.8h, v5.h[1] + + PostTreatLH8x4: + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + + StoreLH8x4: + + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + + E8End: + + sub x3, x3, #8 + add x0, x21, #128 + add x1, x1, #16 + +E4: +cmp x3, #4 +mov x20, x6 +blt E1 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + cmp x8, #2 + blt E4LH4 + + E4LH8: + E4LoopH8: + mov x15, x1 + ld1 {v24.8h, v25.8h}, [x14], #32 // alpha + ld1 {v26.8h, v27.8h}, [x16], #32 // bias + subs x12, x9, #2 + ld1 {v3.16b, v4.16b}, [x13], #32 + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v2.8h, v4.8b + sxtl2 v3.8h, v4.16b + scvtf v12.8h, v0.8h + scvtf v13.8h, v1.8h + scvtf v14.8h, v2.8h + scvtf v15.8h, v3.8h + mov v8.8h, v26.8h + mov v9.8h, v27.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + mov v10.8h, v26.8h + mov v11.8h, v27.8h + fmla v10.8h, v14.8h, v24.8h + fmla v11.8h, v15.8h, v25.8h + + ld1 {v0.4h}, [x15], x11 + ld1 {v1.4h}, [x15], x11 + fmul v16.8h, v8.8h, v0.h[0] + fmul v17.8h, v8.8h, v0.h[1] + fmul v18.8h, v8.8h, v0.h[2] + fmul v19.8h, v8.8h, v0.h[3] + + fmul v20.8h, v9.8h, v0.h[0] + fmul v21.8h, v9.8h, v0.h[1] + fmul v22.8h, v9.8h, v0.h[2] + fmul v23.8h, v9.8h, v0.h[3] + + fmla v16.8h, v10.8h, v1.h[0] + fmla v17.8h, v10.8h, v1.h[1] + fmla v18.8h, v10.8h, v1.h[2] + fmla v19.8h, v10.8h, v1.h[3] + + fmla v20.8h, v11.8h, v1.h[0] + fmla v21.8h, v11.8h, v1.h[1] + fmla v22.8h, v11.8h, v1.h[2] + fmla v23.8h, v11.8h, v1.h[3] + + beq E4LoopLEnd + + E4LoopL: + subs x12, x12, #2 + ld1 {v3.16b, v4.16b}, [x13], #32 + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v2.8h, v4.8b + sxtl2 v3.8h, v4.16b + scvtf v12.8h, v0.8h + scvtf v13.8h, v1.8h + scvtf v14.8h, v2.8h + scvtf v15.8h, v3.8h + mov v8.8h, v26.8h + mov v9.8h, v27.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + mov v10.8h, v26.8h + mov v11.8h, v27.8h + fmla v10.8h, v14.8h, v24.8h + fmla v11.8h, v15.8h, v25.8h + + ld1 {v0.4h}, [x15], x11 + ld1 {v1.4h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + + fmla v20.8h, v9.8h, v0.h[0] + fmla v21.8h, v9.8h, v0.h[1] + fmla v22.8h, v9.8h, v0.h[2] + fmla v23.8h, v9.8h, v0.h[3] + + fmla v16.8h, v10.8h, v1.h[0] + fmla v17.8h, v10.8h, v1.h[1] + fmla v18.8h, v10.8h, v1.h[2] + fmla v19.8h, v10.8h, v1.h[3] + + fmla v20.8h, v11.8h, v1.h[0] + fmla v21.8h, v11.8h, v1.h[1] + fmla v22.8h, v11.8h, v1.h[2] + fmla v23.8h, v11.8h, v1.h[3] + bne E4LoopL + + E4LoopLEnd: + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH4x8 + + AddBiasLH4x8: + ld1 {v0.8h, v1.8h}, [x20], #32 + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + fmla v20.8h, v1.8h, v5.h[1] + fmla v21.8h, v1.8h, v5.h[1] + fmla v22.8h, v1.8h, v5.h[1] + fmla v23.8h, v1.8h, v5.h[1] + + PostTreatLH4x8: + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmax v21.8h, v21.8h, v6.8h + fmax v22.8h, v22.8h, v6.8h + fmax v23.8h, v23.8h, v6.8h + + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + fmin v21.8h, v21.8h, v7.8h + fmin v22.8h, v22.8h, v7.8h + fmin v23.8h, v23.8h, v7.8h + + StoreLH4x8: + + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], x7 + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], x7 + + bge E4LoopH8 + + E4LH4: + cbz x8, E4End + mov x15, x1 + ld1 {v4.8h}, [x14], #16 // alpha + // mov v4.d[1], v4.d[0] + ld1 {v14.8h}, [x16], #16 // bias + // mov v14.d[1], v14.d[0] + subs x12, x9, #2 + // load 16xint4 to 16xfloat + ld1 {v3.4s, v4.4s}, [x13] + uzp1 v0.4s, v3.4s, v4.4s + sxtl v10.8h, v0.8b + sxtl2 v11.8h, v0.16b + scvtf v12.8h, v10.8h + scvtf v13.8h, v11.8h + mov v8.8h, v14.8h + mov v9.8h, v14.8h + fmla v8.8h, v12.8h, v4.8h + fmla v9.8h, v13.8h, v4.8h + // st1 {v8.8h, v9.8h}, [x0] + // b End + + ld1 {v0.4h}, [x15], x11 + fmul v16.8h, v8.8h, v0.h[0] + fmul v17.8h, v8.8h, v0.h[1] + fmul v18.8h, v8.8h, v0.h[2] + fmul v19.8h, v8.8h, v0.h[3] + ld1 {v1.4h}, [x15], x11 + fmla v16.8h, v9.8h, v1.h[0] + fmla v17.8h, v9.8h, v1.h[1] + fmla v18.8h, v9.8h, v1.h[2] + fmla v19.8h, v9.8h, v1.h[3] + add x13, x13, #32 + + beq E4LoopLREnd + + E4LoopLR: + ld1 {v3.4s, v4.4s}, [x13] + uzp1 v0.4s, v3.4s, v4.4s + sxtl v10.8h, v0.8b + sxtl2 v11.8h, v0.16b + scvtf v12.8h, v10.8h + scvtf v13.8h, v11.8h + mov v8.8h, v14.8h + mov v9.8h, v14.8h + fmla v8.8h, v12.8h, v4.8h + fmla v9.8h, v13.8h, v4.8h + + ld1 {v0.4h}, [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v17.8h, v8.8h, v0.h[1] + fmla v18.8h, v8.8h, v0.h[2] + fmla v19.8h, v8.8h, v0.h[3] + ld1 {v1.4h}, [x15], x11 + fmla v16.8h, v9.8h, v1.h[0] + fmla v17.8h, v9.8h, v1.h[1] + fmla v18.8h, v9.8h, v1.h[2] + fmla v19.8h, v9.8h, v1.h[3] + add x13, x13, #16 + + subs x12, x12, #2 + bne E4LoopLR + E4LoopLREnd: + + cbz x5, StoreLH4x4 + AddBiasLH4x4: + ld1 {v0.8h}, [x20] + + fmla v16.8h, v0.8h, v5.h[1] + fmla v17.8h, v0.8h, v5.h[1] + fmla v18.8h, v0.8h, v5.h[1] + fmla v19.8h, v0.8h, v5.h[1] + + + PostTreatLH4x4: + fmax v16.8h, v16.8h, v6.8h + fmax v17.8h, v17.8h, v6.8h + fmax v18.8h, v18.8h, v6.8h + fmax v19.8h, v19.8h, v6.8h + + fmin v16.8h, v16.8h, v7.8h + fmin v17.8h, v17.8h, v7.8h + fmin v18.8h, v18.8h, v7.8h + fmin v19.8h, v19.8h, v7.8h + + StoreLH4x4: + st1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0] + + E4End: + + sub x3, x3, #4 + add x0, x21, #64 + add x1, x1, #8 + +E1: +cmp x3, #0 +beq End + +LoopE1: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + cmp x8, #2 + blt E1LH4 + + E1LH8: + E1LoopH8: + mov x15, x1 + ld1 {v24.8h, v25.8h}, [x14], #32 // alpha + ld1 {v26.8h, v27.8h}, [x16], #32 // bias + subs x12, x9, #2 + ld1 {v3.16b, v4.16b}, [x13], #32 + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v2.8h, v4.8b + sxtl2 v3.8h, v4.16b + scvtf v12.8h, v0.8h + scvtf v13.8h, v1.8h + scvtf v14.8h, v2.8h + scvtf v15.8h, v3.8h + mov v8.8h, v26.8h + mov v9.8h, v27.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + mov v10.8h, v26.8h + mov v11.8h, v27.8h + fmla v10.8h, v14.8h, v24.8h + fmla v11.8h, v15.8h, v25.8h + + ld1 {v0.h}[0], [x15], x11 + ld1 {v0.h}[1], [x15], x11 + fmul v16.8h, v8.8h, v0.h[0] + fmul v20.8h, v9.8h, v0.h[0] + + fmla v16.8h, v10.8h, v0.h[1] + fmla v20.8h, v11.8h, v0.h[1] + + beq E1LoopLEnd + + E1LoopL: + subs x12, x12, #2 + ld1 {v3.16b, v4.16b}, [x13], #32 + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v2.8h, v4.8b + sxtl2 v3.8h, v4.16b + scvtf v12.8h, v0.8h + scvtf v13.8h, v1.8h + scvtf v14.8h, v2.8h + scvtf v15.8h, v3.8h + mov v8.8h, v26.8h + mov v9.8h, v27.8h + fmla v8.8h, v12.8h, v24.8h + fmla v9.8h, v13.8h, v25.8h + mov v10.8h, v26.8h + mov v11.8h, v27.8h + fmla v10.8h, v14.8h, v24.8h + fmla v11.8h, v15.8h, v25.8h + + ld1 {v0.h}[0], [x15], x11 + ld1 {v0.h}[1], [x15], x11 + fmla v16.8h, v8.8h, v0.h[0] + fmla v20.8h, v9.8h, v0.h[0] + + fmla v16.8h, v10.8h, v0.h[1] + fmla v20.8h, v11.8h, v0.h[1] + bne E1LoopL + + E1LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH1x8 + AddBiasLH1x8: + ld1 {v0.8h, v1.8h}, [x20], #32 + + fmla v16.8h, v0.8h, v5.h[1] + fmla v20.8h, v1.8h, v5.h[1] + + PostTreatLH1x8: + fmax v16.8h, v16.8h, v6.8h + fmax v20.8h, v20.8h, v6.8h + fmin v16.8h, v16.8h, v7.8h + fmin v20.8h, v20.8h, v7.8h + + StoreLH1x8: + + st1 {v16.8h}, [x0], x7 + st1 {v20.8h}, [x0], x7 + + bge E1LoopH8 + + E1LH4: + cbz x8, E1End + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.8h}, [x13] + ld1 {v0.h}[0], [x15], x11 + fmul v16.8h, v3.8h, v0.h[0] + add x13, x13, #32 + + beq E1LoopLREnd + + E1LoopLR: + ld1 {v3.8h}, [x13] + ld1 {v0.h}[0], [x15], x11 + fmla v16.8h, v3.8h, v0.h[0] + add x13, x13, #32 + + subs x12, x12, #1 + bne E1LoopLR + E1LoopLREnd: + + cbz x5, StoreLH1x4 + AddBiasLH1x4: + ld1 {v0.8h}, [x20] + fmla v16.8h, v0.8h, v5.h[1] + + PostTreatLH1x4: + fmax v16.8h, v16.8h, v6.8h + fmin v16.8h, v16.8h, v7.8h + + StoreLH1x4: + st1 {v16.8h}, [x0] + + E1End: + + subs x3, x3, #1 + add x0, x21, #16 + add x1, x1, #2 + bne LoopE1 + + +End: +ldp x23, x24, [sp, #96] +ldp x21, x22, [sp, #80] +ldp x19, x20, [sp, #64] +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #112 +ret + + +#endif diff --git a/source/backend/cpu/CMakeLists.txt b/source/backend/cpu/CMakeLists.txt index 953ff853d..8d5a60847 100644 --- a/source/backend/cpu/CMakeLists.txt +++ b/source/backend/cpu/CMakeLists.txt @@ -1,5 +1,7 @@ # CPU option(MNN_SUPPORT_BF16 "Enable MNN's bf16 op" OFF) +option(MNN_LOW_MEMORY "Build MNN support low memory for weight quant model." OFF) + FILE(GLOB MNN_CPU_SRC ${CMAKE_CURRENT_LIST_DIR}/* ${CMAKE_CURRENT_LIST_DIR}/compute/*) add_library(MNNCPU OBJECT ${MNN_CPU_SRC}) if (MNN_SUPPORT_BF16) @@ -17,6 +19,10 @@ if(MNN_USE_SPARSE_COMPUTE) target_compile_options(MNNCPU PRIVATE -DMNN_USE_SPARSE_COMPUTE) endif() +if(MNN_LOW_MEMORY) + target_compile_options(MNNCPU PRIVATE -DMNN_LOW_MEMORY) +endif() + # X86_64 AVX/SSE if (MNN_USE_SSE) include(${CMAKE_CURRENT_LIST_DIR}/x86_x64/CMakeLists.txt) diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 474e81985..84330dc78 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -131,7 +131,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config) const { #ifdef MNN_USE_ARMV82 auto core = MNNGetCoreFunctions(); if (core->supportFp16arith && precision == BackendConfig::Precision_Low) { - return new Arm82Backend(this); + return new Arm82Backend(this, memory); } #endif #ifdef MNN_SUPPORT_BF16 diff --git a/source/backend/cpu/CPUSoftMaxInt8.cpp b/source/backend/cpu/CPUSoftMaxInt8.cpp index f89ac20a9..b21bea795 100644 --- a/source/backend/cpu/CPUSoftMaxInt8.cpp +++ b/source/backend/cpu/CPUSoftMaxInt8.cpp @@ -209,6 +209,10 @@ void CPUSoftmaxInt8::QuantizedSoftmax(const uint8_t* inputData, int outerSize, i #endif int numBitsOverUnit = kAccumulationIntegerBits - headroomPlusOne; + + if (numBitsOverUnit + 31 - 8 > 31) { + numBitsOverUnit = 8; + } int32_t shiftedSumMinusOne = static_cast((static_cast(fixedSumOfExps) << headroomPlusOne) - (static_cast(1) << 31)); FixedPoint0 shiftedScale = one_over_one_plus_x_for_x_in_0_1(FixedPoint0::FromRaw(shiftedSumMinusOne)); diff --git a/source/backend/cpu/arm/CMakeLists.txt b/source/backend/cpu/arm/CMakeLists.txt index c61c79b7a..d23e5adb4 100644 --- a/source/backend/cpu/arm/CMakeLists.txt +++ b/source/backend/cpu/arm/CMakeLists.txt @@ -11,6 +11,10 @@ if (MNN_SUPPORT_BF16) FILE(GLOB MNN_AArch64_SRC ${MNN_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/bf16/*.[sS]) endif() +if (MNN_LOW_MEMORY) + FILE(GLOB MNN_AArch64_SRC ${MNN_AArch64_SRC} ${CMAKE_CURRENT_LIST_DIR}/arm64/low_memory/*.[sS]) +endif() + if(CMAKE_SYSTEM_PROCESSOR MATCHES "^armv7" OR ARCHS MATCHES "^armv7(;armv7s)?") message(STATUS "Enabling AArch32 Assemblies") add_library(MNNARM32 OBJECT ${MNN_AArch32_SRC} ${MNN_NEON_SRC}) diff --git a/source/backend/cpu/arm/FunctionSummary.hpp b/source/backend/cpu/arm/FunctionSummary.hpp index e656a8705..4c9a3ad19 100644 --- a/source/backend/cpu/arm/FunctionSummary.hpp +++ b/source/backend/cpu/arm/FunctionSummary.hpp @@ -31,9 +31,9 @@ void NEON_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup void NEON_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose); void NEON_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); void NEON_MNNPackedMatMulRemain_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, - const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); void NEON_MNNConvRunForUnitDepthWise_BF16(float* dst, const float* src, const float* weight, size_t fw, size_t fh, size_t weight_y_step, size_t dilateX_step, size_t dilateY_step); @@ -50,9 +50,9 @@ void ARMV86_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP); void ARMV86_MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose); void ARMV86_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void ARMV86_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); void ARMV86_MNNPackedMatMulRemain_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, - const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); #endif #endif diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int4.S b/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int4.S new file mode 100644 index 000000000..53b805a97 --- /dev/null +++ b/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int4.S @@ -0,0 +1,1161 @@ +// +// MNNPackedMatMulRemain_int4.S +// MNN +// +// Created by MNN on 2023/05/18. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 12 * 8 MatMul +asm_function MNNPackedMatMulRemain_int4 +//void MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +//Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias, x7: k, x8: b +ldr x8, [sp] +sub sp, sp, #64 +str x19, [sp, #0] +str x20, [sp, #8] +str x21, [sp, #16] +str x22, [sp, #24] +str x23, [sp, #32] + +mov x22, x7 // alpha +mov x23, x8 // bias +ldr x11, [x4, #0] // aStride +ldr x9, [x4, #8] // l +ldr x10, [x4, #16] // h + +ldr x7, [x4, #24] // cStride +ldr x19, [x4, #40] // bExtraStride + +add x10, x10, #3 +lsr x10, x10, #2 + +cbz x5, Start +ld1 {v5.4s}, [x5] +dup v6.4s, v5.s[2] // Min Value +dup v7.4s, v5.s[3] // Max Value + +Start: + +E8: +cmp x3, #8 +blt E4 + +LoopE8: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + LH8: + cmp x8, #2 + blt LH4 + + // sub x14, x7, #64 + LoopH8x8: + mov x15, x1 + ld1 {v12.4s, v13.4s}, [x14], #32 // alpha + mov w17, #0x0f + dup v3.8b, w17 + mov w17, #7 + dup v4.8b, w17 + ld1 {v14.4s, v15.4s}, [x16], #32 // bias + subs x12, x9, #2 + ld1 {v0.4h}, [x13], #8 + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v3.8b + sub v1.8b, v1.8b, v4.8b + sub v2.8b, v2.8b, v4.8b + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v0.8h, v9.8b + sxtl v1.8h, v10.8b + sxtl v8.4s, v0.4h + sxtl2 v9.4s, v0.8h + sxtl v10.4s, v1.4h + sxtl2 v11.4s, v1.8h + scvtf v0.4s, v8.4s + scvtf v1.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v15.4s + fmla v8.4s, v0.4s, v12.4s + fmla v9.4s, v1.4s, v13.4s + scvtf v0.4s, v10.4s + scvtf v1.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v15.4s + fmla v10.4s, v0.4s, v12.4s + fmla v11.4s, v1.4s, v13.4s + ld1 {v0.4s, v1.4s}, [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v0.s[1] + fmul v18.4s, v8.4s, v0.s[2] + fmul v19.4s, v8.4s, v0.s[3] + + fmul v20.4s, v9.4s, v0.s[0] + fmul v21.4s, v9.4s, v0.s[1] + fmul v22.4s, v9.4s, v0.s[2] + fmul v23.4s, v9.4s, v0.s[3] + + fmul v24.4s, v8.4s, v1.s[0] + fmul v25.4s, v8.4s, v1.s[1] + fmul v26.4s, v8.4s, v1.s[2] + fmul v27.4s, v8.4s, v1.s[3] + + fmul v28.4s, v9.4s, v1.s[0] + fmul v29.4s, v9.4s, v1.s[1] + fmul v30.4s, v9.4s, v1.s[2] + fmul v31.4s, v9.4s, v1.s[3] + ld1 {v0.4s, v1.4s}, [x15], x11 + fmla v16.4s, v10.4s, v0.s[0] + fmla v17.4s, v10.4s, v0.s[1] + fmla v18.4s, v10.4s, v0.s[2] + fmla v19.4s, v10.4s, v0.s[3] + + fmla v20.4s, v11.4s, v0.s[0] + fmla v21.4s, v11.4s, v0.s[1] + fmla v22.4s, v11.4s, v0.s[2] + fmla v23.4s, v11.4s, v0.s[3] + + fmla v24.4s, v10.4s, v1.s[0] + fmla v25.4s, v10.4s, v1.s[1] + fmla v26.4s, v10.4s, v1.s[2] + fmla v27.4s, v10.4s, v1.s[3] + + fmla v28.4s, v11.4s, v1.s[0] + fmla v29.4s, v11.4s, v1.s[1] + fmla v30.4s, v11.4s, v1.s[2] + fmla v31.4s, v11.4s, v1.s[3] + beq LoopLEnd + + LoopL: + subs x12, x12, #2 + ld1 {v0.4h}, [x13], #8 + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v3.8b + sub v1.8b, v1.8b, v4.8b + sub v2.8b, v2.8b, v4.8b + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v0.8h, v9.8b + sxtl v1.8h, v10.8b + sxtl v8.4s, v0.4h + sxtl2 v9.4s, v0.8h + sxtl v10.4s, v1.4h + sxtl2 v11.4s, v1.8h + scvtf v0.4s, v8.4s + scvtf v1.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v15.4s + fmla v8.4s, v0.4s, v12.4s + fmla v9.4s, v1.4s, v13.4s + scvtf v0.4s, v10.4s + scvtf v1.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v15.4s + fmla v10.4s, v0.4s, v12.4s + fmla v11.4s, v1.4s, v13.4s + ld1 {v0.4s, v1.4s}, [x15], x11 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v0.s[1] + fmla v18.4s, v8.4s, v0.s[2] + fmla v19.4s, v8.4s, v0.s[3] + + fmla v20.4s, v9.4s, v0.s[0] + fmla v21.4s, v9.4s, v0.s[1] + fmla v22.4s, v9.4s, v0.s[2] + fmla v23.4s, v9.4s, v0.s[3] + + fmla v24.4s, v8.4s, v1.s[0] + fmla v25.4s, v8.4s, v1.s[1] + fmla v26.4s, v8.4s, v1.s[2] + fmla v27.4s, v8.4s, v1.s[3] + + fmla v28.4s, v9.4s, v1.s[0] + fmla v29.4s, v9.4s, v1.s[1] + fmla v30.4s, v9.4s, v1.s[2] + fmla v31.4s, v9.4s, v1.s[3] + + ld1 {v0.4s, v1.4s}, [x15], x11 + fmla v16.4s, v10.4s, v0.s[0] + fmla v17.4s, v10.4s, v0.s[1] + fmla v18.4s, v10.4s, v0.s[2] + fmla v19.4s, v10.4s, v0.s[3] + + fmla v20.4s, v11.4s, v0.s[0] + fmla v21.4s, v11.4s, v0.s[1] + fmla v22.4s, v11.4s, v0.s[2] + fmla v23.4s, v11.4s, v0.s[3] + + fmla v24.4s, v10.4s, v1.s[0] + fmla v25.4s, v10.4s, v1.s[1] + fmla v26.4s, v10.4s, v1.s[2] + fmla v27.4s, v10.4s, v1.s[3] + + fmla v28.4s, v11.4s, v1.s[0] + fmla v29.4s, v11.4s, v1.s[1] + fmla v30.4s, v11.4s, v1.s[2] + fmla v31.4s, v11.4s, v1.s[3] + + bne LoopL + + LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH8 + AddBiasLH8: + ld1 {v0.4s, v1.4s}, [x20], #32 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + fmla v24.4s, v0.4s, v5.s[1] + fmla v25.4s, v0.4s, v5.s[1] + fmla v26.4s, v0.4s, v5.s[1] + fmla v27.4s, v0.4s, v5.s[1] + + fmla v28.4s, v1.4s, v5.s[1] + fmla v29.4s, v1.4s, v5.s[1] + fmla v30.4s, v1.4s, v5.s[1] + fmla v31.4s, v1.4s, v5.s[1] + + PostTreatLH8: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + fmax v24.4s, v24.4s, v6.4s + fmax v25.4s, v25.4s, v6.4s + fmax v26.4s, v26.4s, v6.4s + fmax v27.4s, v27.4s, v6.4s + fmax v28.4s, v28.4s, v6.4s + fmax v29.4s, v29.4s, v6.4s + fmax v30.4s, v30.4s, v6.4s + fmax v31.4s, v31.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + fmin v24.4s, v24.4s, v7.4s + fmin v25.4s, v25.4s, v7.4s + fmin v26.4s, v26.4s, v7.4s + fmin v27.4s, v27.4s, v7.4s + fmin v28.4s, v28.4s, v7.4s + fmin v29.4s, v29.4s, v7.4s + fmin v30.4s, v30.4s, v7.4s + fmin v31.4s, v31.4s, v7.4s + + StoreLH8: + stp q16, q17, [x0] + stp q18, q19, [x0, #(32 * 1)] + stp q24, q25, [x0, #(32 * 2)] + stp q26, q27, [x0, #(32 * 3)] + add x0, x0, x7 // stp donot support post-index offset in register + + stp q20, q21, [x0] + stp q22, q23, [x0, #(32 * 1)] + stp q28, q29, [x0, #(32 * 2)] + stp q30, q31, [x0, #(32 * 3)] + add x0, x0, x7 + + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + // st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], x14 + // st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + // st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x14 + + bge LoopH8x8 + + LH4: + cbz x8, E8End + LoopHRemain: + mov x15, x1 + ld1 {v4.4s}, [x14], #16 // alpha + mov w17, #0x0f + dup v30.8b, w17 + mov w17, #7 + dup v31.8b, w17 + ld1 {v14.4s}, [x16], #16 // bias + subs x12, x9, #4 + // load 16xint4 to 4xfloat + ld1 {v3.8h}, [x13] + uzp1 v0.8h, v3.8h, v3.8h + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmul v18.4s, v8.4s, v0.s[2] + sub x15, x15, #16 + fmul v19.4s, v8.4s, v0.s[3] + add x15, x15, x11 + fmul v20.4s, v8.4s, v1.s[0] + fmul v21.4s, v8.4s, v1.s[1] + fmul v22.4s, v8.4s, v1.s[2] + fmul v23.4s, v8.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v9.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v9.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v9.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v9.4s, v1.s[0] + fmla v21.4s, v9.4s, v1.s[1] + fmla v22.4s, v9.4s, v1.s[2] + fmla v23.4s, v9.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v10.4s, v0.s[0] + fmla v17.4s, v10.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v10.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v10.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v10.4s, v1.s[0] + fmla v21.4s, v10.4s, v1.s[1] + fmla v22.4s, v10.4s, v1.s[2] + fmla v23.4s, v10.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v11.4s, v0.s[0] + fmla v17.4s, v11.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v11.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v11.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v11.4s, v1.s[0] + fmla v21.4s, v11.4s, v1.s[1] + fmla v22.4s, v11.4s, v1.s[2] + fmla v23.4s, v11.4s, v1.s[3] + + add x13, x13, #16 + beq LoopLREnd + + LoopLR: + ld1 {v3.8h}, [x13] + uzp1 v0.8h, v3.8h, v3.8h + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v8.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v8.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v8.4s, v1.s[0] + fmla v21.4s, v8.4s, v1.s[1] + fmla v22.4s, v8.4s, v1.s[2] + fmla v23.4s, v8.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v9.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v9.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v9.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v9.4s, v1.s[0] + fmla v21.4s, v9.4s, v1.s[1] + fmla v22.4s, v9.4s, v1.s[2] + fmla v23.4s, v9.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v10.4s, v0.s[0] + fmla v17.4s, v10.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v10.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v10.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v10.4s, v1.s[0] + fmla v21.4s, v10.4s, v1.s[1] + fmla v22.4s, v10.4s, v1.s[2] + fmla v23.4s, v10.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v11.4s, v0.s[0] + fmla v17.4s, v11.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v11.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v11.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v11.4s, v1.s[0] + fmla v21.4s, v11.4s, v1.s[1] + fmla v22.4s, v11.4s, v1.s[2] + fmla v23.4s, v11.4s, v1.s[3] + + add x13, x13, #16 + subs x12, x12, #4 + bne LoopLR + LoopLREnd: + + cbz x5, StoreLH8x4 + AddBiasLH8x4: + ld1 {v0.4s}, [x20] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v0.4s, v5.s[1] + fmla v21.4s, v0.4s, v5.s[1] + fmla v22.4s, v0.4s, v5.s[1] + fmla v23.4s, v0.4s, v5.s[1] + + PostTreatLH8x4: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + + StoreLH8x4: + + stp q16, q17, [x0] + stp q18, q19, [x0, #(32 * 1)] + stp q20, q21, [x0, #(32 * 2)] + stp q22, q23, [x0, #(32 * 3)] + add x0, x0, #(32 * 4) + + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + // st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + + E8End: + + sub x3, x3, #8 + cmp x3, #8 + add x0, x21, #128 + add x1, x1, #32 + bge LoopE8 + +E4: +cmp x3, #4 +mov x20, x6 +blt E1 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + cmp x8, #2 + blt E4LH4 + + E4LH8: + E4LoopH8: + mov x15, x1 + ld1 {v24.4s, v25.4s}, [x14], #32 // alpha + mov w17, #0x0f + dup v30.8b, w17 + mov w17, #7 + dup v31.8b, w17 + ld1 {v26.4s, v27.4s}, [x16], #32 // bias + subs x12, x9, #2 + // ld1 {v3.4s, v4.4s}, [x13], #32 + ld1 {v0.4h}, [x13], #8 + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v26.4s + mov v9.4s, v27.4s + fmla v8.4s, v12.4s, v24.4s + fmla v9.4s, v13.4s, v25.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v26.4s + mov v11.4s, v27.4s + fmla v10.4s, v12.4s, v24.4s + fmla v11.4s, v13.4s, v25.4s + + ld1 {v0.4s}, [x15], x11 + ld1 {v1.4s}, [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v0.s[1] + fmul v18.4s, v8.4s, v0.s[2] + fmul v19.4s, v8.4s, v0.s[3] + + fmul v20.4s, v9.4s, v0.s[0] + fmul v21.4s, v9.4s, v0.s[1] + fmul v22.4s, v9.4s, v0.s[2] + fmul v23.4s, v9.4s, v0.s[3] + + fmla v16.4s, v10.4s, v1.s[0] + fmla v17.4s, v10.4s, v1.s[1] + fmla v18.4s, v10.4s, v1.s[2] + fmla v19.4s, v10.4s, v1.s[3] + + fmla v20.4s, v11.4s, v1.s[0] + fmla v21.4s, v11.4s, v1.s[1] + fmla v22.4s, v11.4s, v1.s[2] + fmla v23.4s, v11.4s, v1.s[3] + beq E4LoopLEnd + + E4LoopL: + subs x12, x12, #2 + // ld1 {v3.4s, v4.4s}, [x13], #32 + ld1 {v0.4h}, [x13], #8 + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v26.4s + mov v9.4s, v27.4s + fmla v8.4s, v12.4s, v24.4s + fmla v9.4s, v13.4s, v25.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v26.4s + mov v11.4s, v27.4s + fmla v10.4s, v12.4s, v24.4s + fmla v11.4s, v13.4s, v25.4s + ld1 {v0.4s}, [x15], x11 + ld1 {v1.4s}, [x15], x11 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v0.s[1] + fmla v18.4s, v8.4s, v0.s[2] + fmla v19.4s, v8.4s, v0.s[3] + + fmla v20.4s, v9.4s, v0.s[0] + fmla v21.4s, v9.4s, v0.s[1] + fmla v22.4s, v9.4s, v0.s[2] + fmla v23.4s, v9.4s, v0.s[3] + + fmla v16.4s, v10.4s, v1.s[0] + fmla v17.4s, v10.4s, v1.s[1] + fmla v18.4s, v10.4s, v1.s[2] + fmla v19.4s, v10.4s, v1.s[3] + + fmla v20.4s, v11.4s, v1.s[0] + fmla v21.4s, v11.4s, v1.s[1] + fmla v22.4s, v11.4s, v1.s[2] + fmla v23.4s, v11.4s, v1.s[3] + bne E4LoopL + E4LoopLEnd: + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH4x8 + + AddBiasLH4x8: + ld1 {v0.4s, v1.4s}, [x20], #32 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + PostTreatLH4x8: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + + StoreLH4x8: + stp q16, q17, [x0] + stp q18, q19, [x0, #32] + add x0, x0, x7 // stp donot support post-index offset in register + stp q20, q21, [x0] + stp q22, q23, [x0, #32] + add x0, x0, x7 + + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x7 + // st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], x7 + + bge E4LoopH8 + + E4LH4: + cbz x8, E4End + mov x15, x1 +#if 0 + // l_tile = 1 loop + ld1 {v1.4s}, [x14], #16 // alpha + ld1 {v2.4s}, [x16], #16 // bias + subs x12, x9, #1 + ldrb w8, [x13] + ldrb w9, [x13, #1] + movi v8.4s, #15 + fmov s3, w8 + mov v3.s[1], w9 + zip1 v3.4s, v3.4s, v3.4s + ushr v4.4s, v3.4s, #4 + and v3.16b, v3.16b, v8.16b + rev64 v4.4s, v4.4s + trn2 v3.4s, v4.4s, v3.4s + mvni v9.4s, #6 + add v3.4s, v3.4s, v9.4s + scvtf v3.4s, v3.4s + mov v4.4s, v2.4s + fmla v4.4s, v3.4s, v1.4s + + ld1 {v0.4s}, [x15], x11 + fmul v16.4s, v4.4s, v0.s[0] + fmul v17.4s, v4.4s, v0.s[1] + fmul v18.4s, v4.4s, v0.s[2] + fmul v19.4s, v4.4s, v0.s[3] + add x13, x13, #4 +#else + ld1 {v4.4s}, [x14], #16 // alpha + mov w17, #0x0f + dup v30.8b, w17 + mov w17, #7 + dup v31.8b, w17 + ld1 {v14.4s}, [x16], #16 // bias + subs x12, x9, #4 + // load 16xint4 to 16xfloat + ld1 {v3.8h}, [x13] + uzp1 v0.8h, v3.8h, v3.8h + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + ld1 {v0.4s}, [x15], x11 + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + ld1 {v1.4s}, [x15], x11 + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + ld1 {v2.4s}, [x15], x11 + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + ld1 {v3.4s}, [x15], x11 + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + // ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v0.s[1] + fmul v18.4s, v8.4s, v0.s[2] + fmul v19.4s, v8.4s, v0.s[3] + + fmla v16.4s, v9.4s, v1.s[0] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v1.s[2] + fmla v19.4s, v9.4s, v1.s[3] + + fmla v16.4s, v10.4s, v2.s[0] + fmla v17.4s, v10.4s, v2.s[1] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v2.s[3] + + fmla v16.4s, v11.4s, v3.s[0] + fmla v17.4s, v11.4s, v3.s[1] + fmla v18.4s, v11.4s, v3.s[2] + fmla v19.4s, v11.4s, v3.s[3] + add x13, x13, #16 +#endif + beq E4LoopLREnd + + E4LoopLR: +#if 0 + // ld1 {v3.4s}, [x13] + ldrb w8, [x13] + ldrb w9, [x13, #1] + movi v8.4s, #15 + fmov s3, w8 + mov v3.s[1], w9 + zip1 v3.4s, v3.4s, v3.4s + ushr v4.4s, v3.4s, #4 + and v3.16b, v3.16b, v8.16b + rev64 v4.4s, v4.4s + trn2 v3.4s, v4.4s, v3.4s + mvni v9.4s, #6 + add v3.4s, v3.4s, v9.4s + scvtf v3.4s, v3.4s + mov v4.4s, v2.4s + fmla v4.4s, v3.4s, v1.4s + + ld1 {v0.4s}, [x15], x11 + fmla v16.4s, v4.4s, v0.s[0] + fmla v17.4s, v4.4s, v0.s[1] + fmla v18.4s, v4.4s, v0.s[2] + fmla v19.4s, v4.4s, v0.s[3] + add x13, x13, #4 + + subs x12, x12, #1 +#else + ld1 {v3.8h}, [x13] + // 0123xxxx4567xxxx => 01234567... + uzp1 v0.8h, v3.8h, v3.8h + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + // ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], x11 + ld1 {v0.4s}, [x15], x11 + ld1 {v1.4s}, [x15], x11 + ld1 {v2.4s}, [x15], x11 + ld1 {v3.4s}, [x15], x11 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v0.s[1] + fmla v18.4s, v8.4s, v0.s[2] + fmla v19.4s, v8.4s, v0.s[3] + + fmla v16.4s, v9.4s, v1.s[0] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v1.s[2] + fmla v19.4s, v9.4s, v1.s[3] + + fmla v16.4s, v10.4s, v2.s[0] + fmla v17.4s, v10.4s, v2.s[1] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v2.s[3] + + fmla v16.4s, v11.4s, v3.s[0] + fmla v17.4s, v11.4s, v3.s[1] + fmla v18.4s, v11.4s, v3.s[2] + fmla v19.4s, v11.4s, v3.s[3] + add x13, x13, #16 + subs x12, x12, #4 +#endif + bne E4LoopLR + E4LoopLREnd: + + cbz x5, StoreLH4x4 + AddBiasLH4x4: + ld1 {v0.4s}, [x20] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + + PostTreatLH4x4: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + + StoreLH4x4: + stp q16, q17, [x0] + stp q18, q19, [x0, #32] + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0] + + E4End: + + sub x3, x3, #4 + add x0, x21, #64 + add x1, x1, #16 + +E1: +cmp x3, #0 +beq End + +LoopE1: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + cmp x8, #2 + blt E1LH4 + + E1LH8: + E1LoopH8: + mov x15, x1 + ld1 {v24.4s, v25.4s}, [x14], #32 // alpha + mov w17, #0x0f + dup v30.8b, w17 + mov w17, #7 + dup v31.8b, w17 + ld1 {v26.4s, v27.4s}, [x16], #32 // bias + subs x12, x9, #2 + // ld1 {v3.4s, v4.4s}, [x13], #32 + ld1 {v0.4h}, [x13], #8 + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v26.4s + mov v9.4s, v27.4s + fmla v8.4s, v12.4s, v24.4s + fmla v9.4s, v13.4s, v25.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v26.4s + mov v11.4s, v27.4s + fmla v10.4s, v12.4s, v24.4s + fmla v11.4s, v13.4s, v25.4s + ld1 {v0.s}[0], [x15], x11 + ld1 {v0.s}[1], [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmul v20.4s, v9.4s, v0.s[0] + fmla v16.4s, v10.4s, v0.s[1] + fmla v20.4s, v11.4s, v0.s[1] + beq E1LoopLEnd + + E1LoopL: + subs x12, x12, #2 + // ld1 {v3.4s, v4.4s}, [x13], #32 + ld1 {v0.4h}, [x13], #8 + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v26.4s + mov v9.4s, v27.4s + fmla v8.4s, v12.4s, v24.4s + fmla v9.4s, v13.4s, v25.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v26.4s + mov v11.4s, v27.4s + fmla v10.4s, v12.4s, v24.4s + fmla v11.4s, v13.4s, v25.4s + ld1 {v0.s}[0], [x15], x11 + ld1 {v0.s}[1], [x15], x11 + fmla v16.4s, v8.4s, v0.s[0] + fmla v20.4s, v9.4s, v0.s[0] + fmla v16.4s, v10.4s, v0.s[1] + fmla v20.4s, v11.4s, v0.s[1] + bne E1LoopL + + E1LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH1x8 + AddBiasLH1x8: + ld1 {v0.4s, v1.4s}, [x20], #32 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v20.4s, v1.4s, v5.s[1] + + PostTreatLH1x8: + fmax v16.4s, v16.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmin v16.4s, v16.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + + StoreLH1x8: + + st1 {v16.4s}, [x0], x7 + st1 {v20.4s}, [x0], x7 + + bge E1LoopH8 + + E1LH4: + cbz x8, E1End + mov x15, x1 + ld1 {v4.4s}, [x14], #16 // alpha + mov w17, #0x0f + dup v30.8b, w17 + mov w17, #7 + dup v31.8b, w17 + ld1 {v14.4s}, [x16], #16 // bias + subs x12, x9, #4 + ld1 {v3.8h}, [x13] + uzp1 v0.8h, v3.8h, v3.8h + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + ld1 {v0.s}[0], [x15], x11 + ld1 {v0.s}[1], [x15], x11 + ld1 {v0.s}[2], [x15], x11 + ld1 {v0.s}[3], [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmla v16.4s, v9.4s, v0.s[1] + fmla v16.4s, v10.4s, v0.s[2] + fmla v16.4s, v11.4s, v0.s[3] + add x13, x13, #16 + + beq E1LoopLREnd + + E1LoopLR: + // weight: load 16 x int4 to 16 x float + ld1 {v3.8h}, [x13] + uzp1 v0.8h, v3.8h, v3.8h + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v30.8b + sub v1.8b, v1.8b, v31.8b + sub v2.8b, v2.8b, v31.8b + + zip1 v9.8b, v1.8b, v2.8b + zip2 v10.8b, v1.8b, v2.8b + sxtl v11.8h, v9.8b + sxtl v12.8h, v10.8b + + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v15.4s + mov v9.4s, v15.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v15.4s + mov v11.4s, v15.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + // input: load 4 x float + ld1 {v0.s}[0], [x15], x11 + ld1 {v0.s}[1], [x15], x11 + ld1 {v0.s}[2], [x15], x11 + ld1 {v0.s}[3], [x15], x11 + // compute + fmla v16.4s, v8.4s, v0.s[0] + fmla v16.4s, v9.4s, v0.s[1] + fmla v16.4s, v10.4s, v0.s[2] + fmla v16.4s, v11.4s, v0.s[3] + add x13, x13, #16 + subs x12, x12, #4 + bne E1LoopLR + E1LoopLREnd: + + cbz x5, StoreLH1x4 + AddBiasLH1x4: + ld1 {v0.4s}, [x20] + fmla v16.4s, v0.4s, v5.s[1] + + PostTreatLH1x4: + fmax v16.4s, v16.4s, v6.4s + fmin v16.4s, v16.4s, v7.4s + + StoreLH1x4: + st1 {v16.4s}, [x0] + + E1End: + + subs x3, x3, #1 + add x0, x21, #16 + add x1, x1, #4 + bne LoopE1 + + +End: +ldr x19, [sp, #0] +ldr x20, [sp, #8] +ldr x21, [sp, #16] +ldr x22, [sp, #24] +ldr x23, [sp, #32] +add sp, sp, #64 + +ret + +#endif diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int8.S b/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int8.S new file mode 100644 index 000000000..56574e2ed --- /dev/null +++ b/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMulRemain_int8.S @@ -0,0 +1,1003 @@ +// +// MNNPackedMatMulRemain_int8.S +// MNN +// +// Created by MNN on 2023/06/06. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 12 * 8 MatMul +asm_function MNNPackedMatMulRemain_int8 +//void MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +//Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias, x7: k, x8: b +ldr x8, [sp] +sub sp, sp, #64 +str x19, [sp, #0] +str x20, [sp, #8] +str x21, [sp, #16] +str x22, [sp, #24] +str x23, [sp, #32] + +mov x22, x7 // alpha +mov x23, x8 // bias +ldr x11, [x4, #0] // aStride +ldr x9, [x4, #8] // l +ldr x10, [x4, #16] // h + +ldr x7, [x4, #24] // cStride +ldr x19, [x4, #40] // bExtraStride + +add x10, x10, #3 +lsr x10, x10, #2 + +cbz x5, Start +ld1 {v5.4s}, [x5] +dup v6.4s, v5.s[2] // Min Value +dup v7.4s, v5.s[3] // Max Value + +Start: + +E8: +cmp x3, #8 +blt E4 + +LoopE8: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + LH8: + cmp x8, #2 + blt LH4 + + // sub x14, x7, #64 + LoopH8x8: + mov x15, x1 + ld1 {v12.4s, v13.4s}, [x14], #32 // alpha + ld1 {v14.4s, v15.4s}, [x16], #32 // bias + subs x12, x9, #2 + ld1 {v3.16b}, [x13], #16 + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v8.4s, v0.4h + sxtl2 v9.4s, v0.8h + sxtl v10.4s, v1.4h + sxtl2 v11.4s, v1.8h + scvtf v0.4s, v8.4s + scvtf v1.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v15.4s + fmla v8.4s, v0.4s, v12.4s + fmla v9.4s, v1.4s, v13.4s + scvtf v0.4s, v10.4s + scvtf v1.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v15.4s + fmla v10.4s, v0.4s, v12.4s + fmla v11.4s, v1.4s, v13.4s + ld1 {v0.4s, v1.4s}, [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v0.s[1] + fmul v18.4s, v8.4s, v0.s[2] + fmul v19.4s, v8.4s, v0.s[3] + + fmul v20.4s, v9.4s, v0.s[0] + fmul v21.4s, v9.4s, v0.s[1] + fmul v22.4s, v9.4s, v0.s[2] + fmul v23.4s, v9.4s, v0.s[3] + + fmul v24.4s, v8.4s, v1.s[0] + fmul v25.4s, v8.4s, v1.s[1] + fmul v26.4s, v8.4s, v1.s[2] + fmul v27.4s, v8.4s, v1.s[3] + + fmul v28.4s, v9.4s, v1.s[0] + fmul v29.4s, v9.4s, v1.s[1] + fmul v30.4s, v9.4s, v1.s[2] + fmul v31.4s, v9.4s, v1.s[3] + ld1 {v0.4s, v1.4s}, [x15], x11 + fmla v16.4s, v10.4s, v0.s[0] + fmla v17.4s, v10.4s, v0.s[1] + fmla v18.4s, v10.4s, v0.s[2] + fmla v19.4s, v10.4s, v0.s[3] + + fmla v20.4s, v11.4s, v0.s[0] + fmla v21.4s, v11.4s, v0.s[1] + fmla v22.4s, v11.4s, v0.s[2] + fmla v23.4s, v11.4s, v0.s[3] + + fmla v24.4s, v10.4s, v1.s[0] + fmla v25.4s, v10.4s, v1.s[1] + fmla v26.4s, v10.4s, v1.s[2] + fmla v27.4s, v10.4s, v1.s[3] + + fmla v28.4s, v11.4s, v1.s[0] + fmla v29.4s, v11.4s, v1.s[1] + fmla v30.4s, v11.4s, v1.s[2] + fmla v31.4s, v11.4s, v1.s[3] + beq LoopLEnd + + LoopL: + subs x12, x12, #2 + ld1 {v3.16b}, [x13], #16 + sxtl v0.8h, v3.8b + sxtl2 v1.8h, v3.16b + sxtl v8.4s, v0.4h + sxtl2 v9.4s, v0.8h + sxtl v10.4s, v1.4h + sxtl2 v11.4s, v1.8h + scvtf v0.4s, v8.4s + scvtf v1.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v15.4s + fmla v8.4s, v0.4s, v12.4s + fmla v9.4s, v1.4s, v13.4s + scvtf v0.4s, v10.4s + scvtf v1.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v15.4s + fmla v10.4s, v0.4s, v12.4s + fmla v11.4s, v1.4s, v13.4s + ld1 {v0.4s, v1.4s}, [x15], x11 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v0.s[1] + fmla v18.4s, v8.4s, v0.s[2] + fmla v19.4s, v8.4s, v0.s[3] + + fmla v20.4s, v9.4s, v0.s[0] + fmla v21.4s, v9.4s, v0.s[1] + fmla v22.4s, v9.4s, v0.s[2] + fmla v23.4s, v9.4s, v0.s[3] + + fmla v24.4s, v8.4s, v1.s[0] + fmla v25.4s, v8.4s, v1.s[1] + fmla v26.4s, v8.4s, v1.s[2] + fmla v27.4s, v8.4s, v1.s[3] + + fmla v28.4s, v9.4s, v1.s[0] + fmla v29.4s, v9.4s, v1.s[1] + fmla v30.4s, v9.4s, v1.s[2] + fmla v31.4s, v9.4s, v1.s[3] + + ld1 {v0.4s, v1.4s}, [x15], x11 + fmla v16.4s, v10.4s, v0.s[0] + fmla v17.4s, v10.4s, v0.s[1] + fmla v18.4s, v10.4s, v0.s[2] + fmla v19.4s, v10.4s, v0.s[3] + + fmla v20.4s, v11.4s, v0.s[0] + fmla v21.4s, v11.4s, v0.s[1] + fmla v22.4s, v11.4s, v0.s[2] + fmla v23.4s, v11.4s, v0.s[3] + + fmla v24.4s, v10.4s, v1.s[0] + fmla v25.4s, v10.4s, v1.s[1] + fmla v26.4s, v10.4s, v1.s[2] + fmla v27.4s, v10.4s, v1.s[3] + + fmla v28.4s, v11.4s, v1.s[0] + fmla v29.4s, v11.4s, v1.s[1] + fmla v30.4s, v11.4s, v1.s[2] + fmla v31.4s, v11.4s, v1.s[3] + + bne LoopL + + LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH8 + AddBiasLH8: + ld1 {v0.4s, v1.4s}, [x20], #32 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + fmla v24.4s, v0.4s, v5.s[1] + fmla v25.4s, v0.4s, v5.s[1] + fmla v26.4s, v0.4s, v5.s[1] + fmla v27.4s, v0.4s, v5.s[1] + + fmla v28.4s, v1.4s, v5.s[1] + fmla v29.4s, v1.4s, v5.s[1] + fmla v30.4s, v1.4s, v5.s[1] + fmla v31.4s, v1.4s, v5.s[1] + + PostTreatLH8: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + fmax v24.4s, v24.4s, v6.4s + fmax v25.4s, v25.4s, v6.4s + fmax v26.4s, v26.4s, v6.4s + fmax v27.4s, v27.4s, v6.4s + fmax v28.4s, v28.4s, v6.4s + fmax v29.4s, v29.4s, v6.4s + fmax v30.4s, v30.4s, v6.4s + fmax v31.4s, v31.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + fmin v24.4s, v24.4s, v7.4s + fmin v25.4s, v25.4s, v7.4s + fmin v26.4s, v26.4s, v7.4s + fmin v27.4s, v27.4s, v7.4s + fmin v28.4s, v28.4s, v7.4s + fmin v29.4s, v29.4s, v7.4s + fmin v30.4s, v30.4s, v7.4s + fmin v31.4s, v31.4s, v7.4s + + StoreLH8: + stp q16, q17, [x0] + stp q18, q19, [x0, #(32 * 1)] + stp q24, q25, [x0, #(32 * 2)] + stp q26, q27, [x0, #(32 * 3)] + add x0, x0, x7 // stp donot support post-index offset in register + + stp q20, q21, [x0] + stp q22, q23, [x0, #(32 * 1)] + stp q28, q29, [x0, #(32 * 2)] + stp q30, q31, [x0, #(32 * 3)] + add x0, x0, x7 + + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + // st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], x14 + // st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + // st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x14 + + bge LoopH8x8 + + LH4: + cbz x8, E8End + LoopHRemain: + mov x15, x1 + ld1 {v4.4s}, [x14], #16 // alpha + ld1 {v14.4s}, [x16], #16 // bias + subs x12, x9, #4 + // load 16xint4 to 4xfloat + ld1 {v3.4s}, [x13] + uzp1 v0.4s, v3.4s, v3.4s + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmul v18.4s, v8.4s, v0.s[2] + sub x15, x15, #16 + fmul v19.4s, v8.4s, v0.s[3] + add x15, x15, x11 + fmul v20.4s, v8.4s, v1.s[0] + fmul v21.4s, v8.4s, v1.s[1] + fmul v22.4s, v8.4s, v1.s[2] + fmul v23.4s, v8.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v9.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v9.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v9.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v9.4s, v1.s[0] + fmla v21.4s, v9.4s, v1.s[1] + fmla v22.4s, v9.4s, v1.s[2] + fmla v23.4s, v9.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v10.4s, v0.s[0] + fmla v17.4s, v10.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v10.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v10.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v10.4s, v1.s[0] + fmla v21.4s, v10.4s, v1.s[1] + fmla v22.4s, v10.4s, v1.s[2] + fmla v23.4s, v10.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v11.4s, v0.s[0] + fmla v17.4s, v11.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v11.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v11.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v11.4s, v1.s[0] + fmla v21.4s, v11.4s, v1.s[1] + fmla v22.4s, v11.4s, v1.s[2] + fmla v23.4s, v11.4s, v1.s[3] + + add x13, x13, #32 + beq LoopLREnd + + LoopLR: + ld1 {v3.4s}, [x13] + uzp1 v0.4s, v3.4s, v3.4s + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v8.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v8.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v8.4s, v1.s[0] + fmla v21.4s, v8.4s, v1.s[1] + fmla v22.4s, v8.4s, v1.s[2] + fmla v23.4s, v8.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v9.4s, v0.s[0] + fmla v17.4s, v9.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v9.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v9.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v9.4s, v1.s[0] + fmla v21.4s, v9.4s, v1.s[1] + fmla v22.4s, v9.4s, v1.s[2] + fmla v23.4s, v9.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v10.4s, v0.s[0] + fmla v17.4s, v10.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v10.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v10.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v10.4s, v1.s[0] + fmla v21.4s, v10.4s, v1.s[1] + fmla v22.4s, v10.4s, v1.s[2] + fmla v23.4s, v10.4s, v1.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v11.4s, v0.s[0] + fmla v17.4s, v11.4s, v0.s[1] + ld1 {v1.4s}, [x15] + fmla v18.4s, v11.4s, v0.s[2] + sub x15, x15, #16 + fmla v19.4s, v11.4s, v0.s[3] + add x15, x15, x11 + fmla v20.4s, v11.4s, v1.s[0] + fmla v21.4s, v11.4s, v1.s[1] + fmla v22.4s, v11.4s, v1.s[2] + fmla v23.4s, v11.4s, v1.s[3] + + add x13, x13, #32 + subs x12, x12, #4 + bne LoopLR + LoopLREnd: + + cbz x5, StoreLH8x4 + AddBiasLH8x4: + ld1 {v0.4s}, [x20] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v0.4s, v5.s[1] + fmla v21.4s, v0.4s, v5.s[1] + fmla v22.4s, v0.4s, v5.s[1] + fmla v23.4s, v0.4s, v5.s[1] + + PostTreatLH8x4: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + + StoreLH8x4: + + stp q16, q17, [x0] + stp q18, q19, [x0, #(32 * 1)] + stp q20, q21, [x0, #(32 * 2)] + stp q22, q23, [x0, #(32 * 3)] + add x0, x0, #(32 * 4) + + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], #64 + // st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + + E8End: + + sub x3, x3, #8 + cmp x3, #8 + add x0, x21, #128 + add x1, x1, #32 + bge LoopE8 + +E4: +cmp x3, #4 +mov x20, x6 +blt E1 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + cmp x8, #2 + blt E4LH4 + + E4LH8: + E4LoopH8: + mov x15, x1 + ld1 {v24.4s, v25.4s}, [x14], #32 // alpha + ld1 {v26.4s, v27.4s}, [x16], #32 // bias + subs x12, x9, #2 + // ld1 {v3.4s, v4.4s}, [x13], #32 + ld1 {v0.16b}, [x13], #16 + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v26.4s + mov v9.4s, v27.4s + fmla v8.4s, v12.4s, v24.4s + fmla v9.4s, v13.4s, v25.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v26.4s + mov v11.4s, v27.4s + fmla v10.4s, v12.4s, v24.4s + fmla v11.4s, v13.4s, v25.4s + // st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0] + // b End + + ld1 {v0.4s}, [x15], x11 + ld1 {v1.4s}, [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v0.s[1] + fmul v18.4s, v8.4s, v0.s[2] + fmul v19.4s, v8.4s, v0.s[3] + + fmul v20.4s, v9.4s, v0.s[0] + fmul v21.4s, v9.4s, v0.s[1] + fmul v22.4s, v9.4s, v0.s[2] + fmul v23.4s, v9.4s, v0.s[3] + + fmla v16.4s, v10.4s, v1.s[0] + fmla v17.4s, v10.4s, v1.s[1] + fmla v18.4s, v10.4s, v1.s[2] + fmla v19.4s, v10.4s, v1.s[3] + + fmla v20.4s, v11.4s, v1.s[0] + fmla v21.4s, v11.4s, v1.s[1] + fmla v22.4s, v11.4s, v1.s[2] + fmla v23.4s, v11.4s, v1.s[3] + beq E4LoopLEnd + + E4LoopL: + subs x12, x12, #2 + ld1 {v0.16b}, [x13], #16 + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v26.4s + mov v9.4s, v27.4s + fmla v8.4s, v12.4s, v24.4s + fmla v9.4s, v13.4s, v25.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v26.4s + mov v11.4s, v27.4s + fmla v10.4s, v12.4s, v24.4s + fmla v11.4s, v13.4s, v25.4s + ld1 {v0.4s}, [x15], x11 + ld1 {v1.4s}, [x15], x11 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v0.s[1] + fmla v18.4s, v8.4s, v0.s[2] + fmla v19.4s, v8.4s, v0.s[3] + + fmla v20.4s, v9.4s, v0.s[0] + fmla v21.4s, v9.4s, v0.s[1] + fmla v22.4s, v9.4s, v0.s[2] + fmla v23.4s, v9.4s, v0.s[3] + + fmla v16.4s, v10.4s, v1.s[0] + fmla v17.4s, v10.4s, v1.s[1] + fmla v18.4s, v10.4s, v1.s[2] + fmla v19.4s, v10.4s, v1.s[3] + + fmla v20.4s, v11.4s, v1.s[0] + fmla v21.4s, v11.4s, v1.s[1] + fmla v22.4s, v11.4s, v1.s[2] + fmla v23.4s, v11.4s, v1.s[3] + bne E4LoopL + E4LoopLEnd: + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH4x8 + + AddBiasLH4x8: + ld1 {v0.4s, v1.4s}, [x20], #32 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + PostTreatLH4x8: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + + StoreLH4x8: + stp q16, q17, [x0] + stp q18, q19, [x0, #32] + add x0, x0, x7 // stp donot support post-index offset in register + stp q20, q21, [x0] + stp q22, q23, [x0, #32] + add x0, x0, x7 + + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x7 + // st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], x7 + + bge E4LoopH8 + + E4LH4: + cbz x8, E4End + mov x15, x1 + ld1 {v4.4s}, [x14], #16 // alpha + ld1 {v14.4s}, [x16], #16 // bias + subs x12, x9, #4 + // load 16xint4 to 16xfloat + ld1 {v3.4s}, [x13] + uzp1 v0.4s, v3.4s, v3.4s + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + ld1 {v0.4s}, [x15], x11 + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + ld1 {v1.4s}, [x15], x11 + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + ld1 {v2.4s}, [x15], x11 + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + ld1 {v3.4s}, [x15], x11 + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + // ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmul v17.4s, v8.4s, v0.s[1] + fmul v18.4s, v8.4s, v0.s[2] + fmul v19.4s, v8.4s, v0.s[3] + + fmla v16.4s, v9.4s, v1.s[0] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v1.s[2] + fmla v19.4s, v9.4s, v1.s[3] + + fmla v16.4s, v10.4s, v2.s[0] + fmla v17.4s, v10.4s, v2.s[1] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v2.s[3] + + fmla v16.4s, v11.4s, v3.s[0] + fmla v17.4s, v11.4s, v3.s[1] + fmla v18.4s, v11.4s, v3.s[2] + fmla v19.4s, v11.4s, v3.s[3] + add x13, x13, #32 + beq E4LoopLREnd + + E4LoopLR: + ld1 {v3.4s}, [x13] + uzp1 v0.4s, v3.4s, v3.4s + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + // ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x15], x11 + ld1 {v0.4s}, [x15], x11 + ld1 {v1.4s}, [x15], x11 + ld1 {v2.4s}, [x15], x11 + ld1 {v3.4s}, [x15], x11 + fmla v16.4s, v8.4s, v0.s[0] + fmla v17.4s, v8.4s, v0.s[1] + fmla v18.4s, v8.4s, v0.s[2] + fmla v19.4s, v8.4s, v0.s[3] + + fmla v16.4s, v9.4s, v1.s[0] + fmla v17.4s, v9.4s, v1.s[1] + fmla v18.4s, v9.4s, v1.s[2] + fmla v19.4s, v9.4s, v1.s[3] + + fmla v16.4s, v10.4s, v2.s[0] + fmla v17.4s, v10.4s, v2.s[1] + fmla v18.4s, v10.4s, v2.s[2] + fmla v19.4s, v10.4s, v2.s[3] + + fmla v16.4s, v11.4s, v3.s[0] + fmla v17.4s, v11.4s, v3.s[1] + fmla v18.4s, v11.4s, v3.s[2] + fmla v19.4s, v11.4s, v3.s[3] + add x13, x13, #32 + subs x12, x12, #4 + bne E4LoopLR + E4LoopLREnd: + + cbz x5, StoreLH4x4 + AddBiasLH4x4: + ld1 {v0.4s}, [x20] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + + PostTreatLH4x4: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + + StoreLH4x4: + stp q16, q17, [x0] + stp q18, q19, [x0, #32] + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0] + + E4End: + + sub x3, x3, #4 + add x0, x21, #64 + add x1, x1, #16 + +E1: +cmp x3, #0 +beq End + +LoopE1: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + mov x14, x22 + mov x16, x23 + + cmp x8, #2 + blt E1LH4 + + E1LH8: + E1LoopH8: + mov x15, x1 + ld1 {v24.4s, v25.4s}, [x14], #32 // alpha + ld1 {v26.4s, v27.4s}, [x16], #32 // bias + subs x12, x9, #2 + // ld1 {v3.4s, v4.4s}, [x13], #32 + ld1 {v0.16b}, [x13], #16 + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v26.4s + mov v9.4s, v27.4s + fmla v8.4s, v12.4s, v24.4s + fmla v9.4s, v13.4s, v25.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v26.4s + mov v11.4s, v27.4s + fmla v10.4s, v12.4s, v24.4s + fmla v11.4s, v13.4s, v25.4s + ld1 {v0.s}[0], [x15], x11 + ld1 {v0.s}[1], [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmul v20.4s, v9.4s, v0.s[0] + fmla v16.4s, v10.4s, v0.s[1] + fmla v20.4s, v11.4s, v0.s[1] + beq E1LoopLEnd + + E1LoopL: + subs x12, x12, #2 + // ld1 {v3.4s, v4.4s}, [x13], #32 + ld1 {v0.16b}, [x13], #16 + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v26.4s + mov v9.4s, v27.4s + fmla v8.4s, v12.4s, v24.4s + fmla v9.4s, v13.4s, v25.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v26.4s + mov v11.4s, v27.4s + fmla v10.4s, v12.4s, v24.4s + fmla v11.4s, v13.4s, v25.4s + ld1 {v0.s}[0], [x15], x11 + ld1 {v0.s}[1], [x15], x11 + fmla v16.4s, v8.4s, v0.s[0] + fmla v20.4s, v9.4s, v0.s[0] + fmla v16.4s, v10.4s, v0.s[1] + fmla v20.4s, v11.4s, v0.s[1] + bne E1LoopL + + E1LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH1x8 + AddBiasLH1x8: + ld1 {v0.4s, v1.4s}, [x20], #32 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v20.4s, v1.4s, v5.s[1] + + PostTreatLH1x8: + fmax v16.4s, v16.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmin v16.4s, v16.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + + StoreLH1x8: + + st1 {v16.4s}, [x0], x7 + st1 {v20.4s}, [x0], x7 + + bge E1LoopH8 + + E1LH4: + cbz x8, E1End + mov x15, x1 + ld1 {v4.4s}, [x14], #16 // alpha + ld1 {v14.4s}, [x16], #16 // bias + subs x12, x9, #4 + ld1 {v3.4s}, [x13] + uzp1 v0.4s, v3.4s, v3.4s + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v14.4s + mov v9.4s, v14.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v14.4s + mov v11.4s, v14.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + ld1 {v0.s}[0], [x15], x11 + ld1 {v0.s}[1], [x15], x11 + ld1 {v0.s}[2], [x15], x11 + ld1 {v0.s}[3], [x15], x11 + fmul v16.4s, v8.4s, v0.s[0] + fmla v16.4s, v9.4s, v0.s[1] + fmla v16.4s, v10.4s, v0.s[2] + fmla v16.4s, v11.4s, v0.s[3] + add x13, x13, #32 + + beq E1LoopLREnd + + E1LoopLR: + // weight: load 16 x int4 to 16 x float + ld1 {v3.4s}, [x13] + uzp1 v0.4s, v3.4s, v3.4s + sxtl v11.8h, v0.8b + sxtl2 v12.8h, v0.16b + sxtl v8.4s, v11.4h + sxtl2 v9.4s, v11.8h + sxtl v10.4s, v12.4h + sxtl2 v11.4s, v12.8h + scvtf v12.4s, v8.4s + scvtf v13.4s, v9.4s + mov v8.4s, v15.4s + mov v9.4s, v15.4s + fmla v8.4s, v12.4s, v4.4s + fmla v9.4s, v13.4s, v4.4s + scvtf v12.4s, v10.4s + scvtf v13.4s, v11.4s + mov v10.4s, v15.4s + mov v11.4s, v15.4s + fmla v10.4s, v12.4s, v4.4s + fmla v11.4s, v13.4s, v4.4s + + // input: load 4 x float + ld1 {v0.s}[0], [x15], x11 + ld1 {v0.s}[1], [x15], x11 + ld1 {v0.s}[2], [x15], x11 + ld1 {v0.s}[3], [x15], x11 + // compute + fmla v16.4s, v8.4s, v0.s[0] + fmla v16.4s, v9.4s, v0.s[1] + fmla v16.4s, v10.4s, v0.s[2] + fmla v16.4s, v11.4s, v0.s[3] + add x13, x13, #32 + subs x12, x12, #4 + bne E1LoopLR + E1LoopLREnd: + + cbz x5, StoreLH1x4 + AddBiasLH1x4: + ld1 {v0.4s}, [x20] + fmla v16.4s, v0.4s, v5.s[1] + + PostTreatLH1x4: + fmax v16.4s, v16.4s, v6.4s + fmin v16.4s, v16.4s, v7.4s + + StoreLH1x4: + st1 {v16.4s}, [x0] + + E1End: + + subs x3, x3, #1 + add x0, x21, #16 + add x1, x1, #4 + bne LoopE1 + + +End: +ldr x19, [sp, #0] +ldr x20, [sp, #8] +ldr x21, [sp, #16] +ldr x22, [sp, #24] +ldr x23, [sp, #32] +add sp, sp, #64 + +ret + +#endif diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int4.S b/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int4.S new file mode 100644 index 000000000..4f33c8bb7 --- /dev/null +++ b/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int4.S @@ -0,0 +1,630 @@ +// +// MNNPackedMatMul_int4.S +// MNN +// +// Created by MNN on 2023/05/29. +// Copyright © 2018, Alibaba Group Holding Limited +// +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 12 * 8 MatMul +asm_function MNNPackedMatMul_int4 +//void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); +// x0: C, x1:A, x2:B, x3:parameter, x4: postParameters, x5:bias, x6: k, x7: b +stp d14, d15, [sp, #-80]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] +stp x19, x20, [sp, #64] + +//ldr x8, [x3, #0] // deprecated +ldr x9, [x3, #8] // l +ldr x10, [x3, #16] // h + +ldr x13, [x3, #24] // cStride +ldr x11, [x3, #40] // bExtraStride + +// v0, v1, v2: A +// v3, v4: B +// v8 - v31: C +add x10, x10, #3 +lsr x10, x10, #2 + +cbz x4, Start + + +Start: + +cmp x10, #2 +blt LH4 + +LH8: +// sub x14, x13, #160 +mov x19, x6 +mov x20, x7 +LoopH: + mov x15, x1 + ld1 {v4.4s, v5.4s}, [x19], #32 // alpha + ld1 {v6.4s, v7.4s}, [x20], #32 // bias + subs x12, x9, #2 + // ld1 {v3.4s, v4.4s}, [x2], #32 + ld1 {v0.4h}, [x2], #8 + ushr v1.8b, v0.8b, #4 + mov w17, #0x0f + dup v3.8b, w17 + and v2.8b, v0.8b, v3.8b + mov w17, #7 + dup v3.8b, w17 + sub v0.8b, v1.8b, v3.8b + sub v1.8b, v2.8b, v3.8b + zip1 v2.8b, v0.8b, v1.8b + zip2 v3.8b, v0.8b, v1.8b + sxtl v0.8h, v2.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v0.4s, v1.4s + scvtf v1.4s, v2.4s + mov v2.4s, v7.4s + fmla v2.4s, v1.4s, v5.4s + mov v1.4s, v6.4s + fmla v1.4s, v0.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmul v8.4s, v1.4s, v0.s[0] + fmul v9.4s, v1.4s, v0.s[1] + fmul v10.4s, v1.4s, v0.s[2] + fmul v11.4s, v1.4s, v0.s[3] + fmul v20.4s, v2.4s, v0.s[0] + fmul v21.4s, v2.4s, v0.s[1] + fmul v22.4s, v2.4s, v0.s[2] + fmul v23.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmul v12.4s, v1.4s, v0.s[0] + fmul v13.4s, v1.4s, v0.s[1] + fmul v14.4s, v1.4s, v0.s[2] + fmul v15.4s, v1.4s, v0.s[3] + fmul v24.4s, v2.4s, v0.s[0] + fmul v25.4s, v2.4s, v0.s[1] + fmul v26.4s, v2.4s, v0.s[2] + fmul v27.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmul v16.4s, v1.4s, v0.s[0] + fmul v17.4s, v1.4s, v0.s[1] + fmul v18.4s, v1.4s, v0.s[2] + fmul v19.4s, v1.4s, v0.s[3] + fmul v28.4s, v2.4s, v0.s[0] + fmul v29.4s, v2.4s, v0.s[1] + fmul v30.4s, v2.4s, v0.s[2] + fmul v31.4s, v2.4s, v0.s[3] + + sxtl v0.8h, v3.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v0.4s, v1.4s + scvtf v1.4s, v2.4s + mov v2.4s, v7.4s + fmla v2.4s, v1.4s, v5.4s + mov v1.4s, v6.4s + fmla v1.4s, v0.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v1.4s, v0.s[0] + fmla v9.4s, v1.4s, v0.s[1] + fmla v10.4s, v1.4s, v0.s[2] + fmla v11.4s, v1.4s, v0.s[3] + fmla v20.4s, v2.4s, v0.s[0] + fmla v21.4s, v2.4s, v0.s[1] + fmla v22.4s, v2.4s, v0.s[2] + fmla v23.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v12.4s, v1.4s, v0.s[0] + fmla v13.4s, v1.4s, v0.s[1] + fmla v14.4s, v1.4s, v0.s[2] + fmla v15.4s, v1.4s, v0.s[3] + fmla v24.4s, v2.4s, v0.s[0] + fmla v25.4s, v2.4s, v0.s[1] + fmla v26.4s, v2.4s, v0.s[2] + fmla v27.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v1.4s, v0.s[0] + fmla v17.4s, v1.4s, v0.s[1] + fmla v18.4s, v1.4s, v0.s[2] + fmla v19.4s, v1.4s, v0.s[3] + fmla v28.4s, v2.4s, v0.s[0] + fmla v29.4s, v2.4s, v0.s[1] + fmla v30.4s, v2.4s, v0.s[2] + fmla v31.4s, v2.4s, v0.s[3] + + beq LoopLEnd + + LoopL2: + subs x12, x12, #2 + // ld1 {v3.4s, v4.4s}, [x2], #32 + ld1 {v0.4h}, [x2], #8 + ushr v1.8b, v0.8b, #4 + mov w17, #0x0f + dup v3.8b, w17 + and v2.8b, v0.8b, v3.8b + mov w17, #7 + dup v3.8b, w17 + sub v0.8b, v1.8b, v3.8b + sub v1.8b, v2.8b, v3.8b + zip1 v2.8b, v0.8b, v1.8b + zip2 v3.8b, v0.8b, v1.8b + sxtl v0.8h, v2.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v0.4s, v1.4s + scvtf v1.4s, v2.4s + mov v2.4s, v7.4s + fmla v2.4s, v1.4s, v5.4s + mov v1.4s, v6.4s + fmla v1.4s, v0.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v1.4s, v0.s[0] + fmla v9.4s, v1.4s, v0.s[1] + fmla v10.4s, v1.4s, v0.s[2] + fmla v11.4s, v1.4s, v0.s[3] + fmla v20.4s, v2.4s, v0.s[0] + fmla v21.4s, v2.4s, v0.s[1] + fmla v22.4s, v2.4s, v0.s[2] + fmla v23.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v12.4s, v1.4s, v0.s[0] + fmla v13.4s, v1.4s, v0.s[1] + fmla v14.4s, v1.4s, v0.s[2] + fmla v15.4s, v1.4s, v0.s[3] + fmla v24.4s, v2.4s, v0.s[0] + fmla v25.4s, v2.4s, v0.s[1] + fmla v26.4s, v2.4s, v0.s[2] + fmla v27.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v1.4s, v0.s[0] + fmla v17.4s, v1.4s, v0.s[1] + fmla v18.4s, v1.4s, v0.s[2] + fmla v19.4s, v1.4s, v0.s[3] + fmla v28.4s, v2.4s, v0.s[0] + fmla v29.4s, v2.4s, v0.s[1] + fmla v30.4s, v2.4s, v0.s[2] + fmla v31.4s, v2.4s, v0.s[3] + + sxtl v0.8h, v3.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v0.4s, v1.4s + scvtf v1.4s, v2.4s + mov v2.4s, v7.4s + fmla v2.4s, v1.4s, v5.4s + mov v1.4s, v6.4s + fmla v1.4s, v0.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v1.4s, v0.s[0] + fmla v9.4s, v1.4s, v0.s[1] + fmla v10.4s, v1.4s, v0.s[2] + fmla v11.4s, v1.4s, v0.s[3] + fmla v20.4s, v2.4s, v0.s[0] + fmla v21.4s, v2.4s, v0.s[1] + fmla v22.4s, v2.4s, v0.s[2] + fmla v23.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v12.4s, v1.4s, v0.s[0] + fmla v13.4s, v1.4s, v0.s[1] + fmla v14.4s, v1.4s, v0.s[2] + fmla v15.4s, v1.4s, v0.s[3] + fmla v24.4s, v2.4s, v0.s[0] + fmla v25.4s, v2.4s, v0.s[1] + fmla v26.4s, v2.4s, v0.s[2] + fmla v27.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v1.4s, v0.s[0] + fmla v17.4s, v1.4s, v0.s[1] + fmla v18.4s, v1.4s, v0.s[2] + fmla v19.4s, v1.4s, v0.s[3] + fmla v28.4s, v2.4s, v0.s[0] + fmla v29.4s, v2.4s, v0.s[1] + fmla v30.4s, v2.4s, v0.s[2] + fmla v31.4s, v2.4s, v0.s[3] + bne LoopL2 + + LoopLEnd: + + add x2, x2, x11 + sub x10, x10, #2 + cmp x10, #2 + + cbz x4, StoreLH8 + + AddBiasLH8: + ld1 {v5.4s}, [x4] + dup v6.4s, v5.s[2] // Min Value + dup v7.4s, v5.s[3] // Max Value + ld1 {v0.4s, v1.4s}, [x5], #32 + + fmla v8.4s, v0.4s, v5.s[1] + fmla v9.4s, v0.4s, v5.s[1] + fmla v10.4s, v0.4s, v5.s[1] + fmla v11.4s, v0.4s, v5.s[1] + + fmla v12.4s, v0.4s, v5.s[1] + fmla v13.4s, v0.4s, v5.s[1] + fmla v14.4s, v0.4s, v5.s[1] + fmla v15.4s, v0.4s, v5.s[1] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + fmla v24.4s, v1.4s, v5.s[1] + fmla v25.4s, v1.4s, v5.s[1] + fmla v26.4s, v1.4s, v5.s[1] + fmla v27.4s, v1.4s, v5.s[1] + + fmla v28.4s, v1.4s, v5.s[1] + fmla v29.4s, v1.4s, v5.s[1] + fmla v30.4s, v1.4s, v5.s[1] + fmla v31.4s, v1.4s, v5.s[1] + + PostTreatLH8: + fmax v8.4s, v8.4s, v6.4s + fmax v9.4s, v9.4s, v6.4s + fmax v10.4s, v10.4s, v6.4s + fmax v11.4s, v11.4s, v6.4s + fmax v12.4s, v12.4s, v6.4s + fmax v13.4s, v13.4s, v6.4s + fmax v14.4s, v14.4s, v6.4s + fmax v15.4s, v15.4s, v6.4s + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + fmax v24.4s, v24.4s, v6.4s + fmax v25.4s, v25.4s, v6.4s + fmax v26.4s, v26.4s, v6.4s + fmax v27.4s, v27.4s, v6.4s + fmax v28.4s, v28.4s, v6.4s + fmax v29.4s, v29.4s, v6.4s + fmax v30.4s, v30.4s, v6.4s + fmax v31.4s, v31.4s, v6.4s + + fmin v8.4s, v8.4s, v7.4s + fmin v9.4s, v9.4s, v7.4s + fmin v10.4s, v10.4s, v7.4s + fmin v11.4s, v11.4s, v7.4s + fmin v12.4s, v12.4s, v7.4s + fmin v13.4s, v13.4s, v7.4s + fmin v14.4s, v14.4s, v7.4s + fmin v15.4s, v15.4s, v7.4s + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + fmin v24.4s, v24.4s, v7.4s + fmin v25.4s, v25.4s, v7.4s + fmin v26.4s, v26.4s, v7.4s + fmin v27.4s, v27.4s, v7.4s + fmin v28.4s, v28.4s, v7.4s + fmin v29.4s, v29.4s, v7.4s + fmin v30.4s, v30.4s, v7.4s + fmin v31.4s, v31.4s, v7.4s + + StoreLH8: + stp q8, q9, [x0] + stp q10, q11, [x0, #(32 * 1)] // 2 * 4 * sizeof(int16_t) + stp q12, q13, [x0, #(32 * 2)] + stp q14, q15, [x0, #(32 * 3)] + stp q16, q17, [x0, #(32 * 4)] + stp q18, q19, [x0, #(32 * 5)] + add x0, x0, x13 // stp donot support post-index offset in register + stp q20, q21, [x0] + stp q22, q23, [x0, #(32 * 1)] + stp q24, q25, [x0, #(32 * 2)] + stp q26, q27, [x0, #(32 * 3)] + stp q28, q29, [x0, #(32 * 4)] + stp q30, q31, [x0, #(32 * 5)] + add x0, x0, x13 + + // st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 + // st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x14 +// + // st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + // st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + // st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x14 + + bge LoopH + +LH4: +cbz x10, End +LoopHRemain: + mov x15, x1 + subs x12, x9, #4 + ld1 {v20.4s}, [x19], #16 // alpha + ld1 {v21.4s}, [x20], #16 // bias + mov w17, #0x0f + dup v22.16b, w17 + mov w17, #7 + dup v23.16b, w17 + // ld1 {v3.4s}, [x2] + ld1 {v3.8h}, [x2], #16 + uzp1 v0.4s, v3.4s, v3.4s + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v22.8b + sub v1.8b, v1.8b, v23.8b + sub v2.8b, v2.8b, v23.8b + zip1 v6.8b, v1.8b, v2.8b + zip2 v7.8b, v1.8b, v2.8b + sxtl v0.8h, v6.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v1.4s, v1.4s + scvtf v2.4s, v2.4s + mov v3.4s, v21.4s + mov v4.4s, v21.4s + fmla v3.4s, v1.4s, v20.4s + fmla v4.4s, v2.4s, v20.4s + + ld1 {v0.4s}, [x15], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v9.4s, v3.4s, v0.s[1] + ld1 {v1.4s}, [x15], #16 + fmul v10.4s, v3.4s, v0.s[2] + fmul v11.4s, v3.4s, v0.s[3] + fmul v12.4s, v3.4s, v1.s[0] + ld1 {v2.4s}, [x15], #16 + fmul v13.4s, v3.4s, v1.s[1] + fmul v14.4s, v3.4s, v1.s[2] + fmul v15.4s, v3.4s, v1.s[3] + fmul v16.4s, v3.4s, v2.s[0] + fmul v17.4s, v3.4s, v2.s[1] + fmul v18.4s, v3.4s, v2.s[2] + fmul v19.4s, v3.4s, v2.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v4.4s, v0.s[0] + fmla v9.4s, v4.4s, v0.s[1] + ld1 {v1.4s}, [x15], #16 + fmla v10.4s, v4.4s, v0.s[2] + fmla v11.4s, v4.4s, v0.s[3] + fmla v12.4s, v4.4s, v1.s[0] + ld1 {v2.4s}, [x15], #16 + fmla v13.4s, v4.4s, v1.s[1] + fmla v14.4s, v4.4s, v1.s[2] + fmla v15.4s, v4.4s, v1.s[3] + fmla v16.4s, v4.4s, v2.s[0] + fmla v17.4s, v4.4s, v2.s[1] + fmla v18.4s, v4.4s, v2.s[2] + fmla v19.4s, v4.4s, v2.s[3] + + sxtl v0.8h, v7.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v1.4s, v1.4s + scvtf v2.4s, v2.4s + mov v3.4s, v21.4s + mov v4.4s, v21.4s + fmla v3.4s, v1.4s, v20.4s + fmla v4.4s, v2.4s, v20.4s + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + ld1 {v1.4s}, [x15], #16 + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + ld1 {v2.4s}, [x15], #16 + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v4.4s, v0.s[0] + fmla v9.4s, v4.4s, v0.s[1] + ld1 {v1.4s}, [x15], #16 + fmla v10.4s, v4.4s, v0.s[2] + fmla v11.4s, v4.4s, v0.s[3] + fmla v12.4s, v4.4s, v1.s[0] + ld1 {v2.4s}, [x15], #16 + fmla v13.4s, v4.4s, v1.s[1] + fmla v14.4s, v4.4s, v1.s[2] + fmla v15.4s, v4.4s, v1.s[3] + fmla v16.4s, v4.4s, v2.s[0] + fmla v17.4s, v4.4s, v2.s[1] + fmla v18.4s, v4.4s, v2.s[2] + fmla v19.4s, v4.4s, v2.s[3] + + beq LoopLREnd + + LoopLR: + subs x12, x12, #4 + // ld1 {v3.4s}, [x2] + ld1 {v3.8h}, [x2], #16 + uzp1 v0.4s, v3.4s, v3.4s + ushr v1.8b, v0.8b, #4 + and v2.8b, v0.8b, v22.8b + sub v1.8b, v1.8b, v23.8b + sub v2.8b, v2.8b, v23.8b + zip1 v6.8b, v1.8b, v2.8b + zip2 v7.8b, v1.8b, v2.8b + sxtl v0.8h, v6.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v1.4s, v1.4s + scvtf v2.4s, v2.4s + mov v3.4s, v21.4s + mov v4.4s, v21.4s + fmla v3.4s, v1.4s, v20.4s + fmla v4.4s, v2.4s, v20.4s + ld1 {v0.4s, v1.4s, v2.4s}, [x15], #48 + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + ld1 {v0.4s, v1.4s, v2.4s}, [x15], #48 + fmla v8.4s, v4.4s, v0.s[0] + fmla v9.4s, v4.4s, v0.s[1] + fmla v10.4s, v4.4s, v0.s[2] + fmla v11.4s, v4.4s, v0.s[3] + fmla v12.4s, v4.4s, v1.s[0] + fmla v13.4s, v4.4s, v1.s[1] + fmla v14.4s, v4.4s, v1.s[2] + fmla v15.4s, v4.4s, v1.s[3] + fmla v16.4s, v4.4s, v2.s[0] + fmla v17.4s, v4.4s, v2.s[1] + fmla v18.4s, v4.4s, v2.s[2] + fmla v19.4s, v4.4s, v2.s[3] + + sxtl v0.8h, v7.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v1.4s, v1.4s + scvtf v2.4s, v2.4s + mov v3.4s, v21.4s + mov v4.4s, v21.4s + fmla v3.4s, v1.4s, v20.4s + fmla v4.4s, v2.4s, v20.4s + + ld1 {v0.4s, v1.4s, v2.4s}, [x15], #48 + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + ld1 {v0.4s, v1.4s, v2.4s}, [x15], #48 + fmla v8.4s, v4.4s, v0.s[0] + fmla v9.4s, v4.4s, v0.s[1] + fmla v10.4s, v4.4s, v0.s[2] + fmla v11.4s, v4.4s, v0.s[3] + fmla v12.4s, v4.4s, v1.s[0] + fmla v13.4s, v4.4s, v1.s[1] + fmla v14.4s, v4.4s, v1.s[2] + fmla v15.4s, v4.4s, v1.s[3] + fmla v16.4s, v4.4s, v2.s[0] + fmla v17.4s, v4.4s, v2.s[1] + fmla v18.4s, v4.4s, v2.s[2] + fmla v19.4s, v4.4s, v2.s[3] + bne LoopLR + LoopLREnd: + + cbz x4, StoreLH4 + AddBiasLH4: + ld1 {v5.4s}, [x4] + dup v6.4s, v5.s[2] // Min Value + dup v7.4s, v5.s[3] // Max Value + ld1 {v0.4s}, [x5], #16 + + fmla v8.4s, v0.4s, v5.s[1] + fmla v9.4s, v0.4s, v5.s[1] + fmla v10.4s, v0.4s, v5.s[1] + fmla v11.4s, v0.4s, v5.s[1] + + fmla v12.4s, v0.4s, v5.s[1] + fmla v13.4s, v0.4s, v5.s[1] + fmla v14.4s, v0.4s, v5.s[1] + fmla v15.4s, v0.4s, v5.s[1] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + PostTreatLH4: + fmax v8.4s, v8.4s, v6.4s + fmax v9.4s, v9.4s, v6.4s + fmax v10.4s, v10.4s, v6.4s + fmax v11.4s, v11.4s, v6.4s + fmax v12.4s, v12.4s, v6.4s + fmax v13.4s, v13.4s, v6.4s + fmax v14.4s, v14.4s, v6.4s + fmax v15.4s, v15.4s, v6.4s + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + + fmin v8.4s, v8.4s, v7.4s + fmin v9.4s, v9.4s, v7.4s + fmin v10.4s, v10.4s, v7.4s + fmin v11.4s, v11.4s, v7.4s + fmin v12.4s, v12.4s, v7.4s + fmin v13.4s, v13.4s, v7.4s + fmin v14.4s, v14.4s, v7.4s + fmin v15.4s, v15.4s, v7.4s + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + + StoreLH4: + stp q8, q9, [x0] + stp q10, q11, [x0, #(32 * 1)] // 2 * 4 * sizeof(float) + stp q12, q13, [x0, #(32 * 2)] + stp q14, q15, [x0, #(32 * 3)] + stp q16, q17, [x0, #(32 * 4)] + stp q18, q19, [x0, #(32 * 5)] + + // st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 + // st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0] + sub x10, x10, #1 + + +End: +ldp x19, x20, [sp, #64] +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #80 + +ret + +#endif diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int8.S b/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int8.S new file mode 100644 index 000000000..65b98faff --- /dev/null +++ b/source/backend/cpu/arm/arm64/low_memory/MNNPackedMatMul_int8.S @@ -0,0 +1,595 @@ +// +// MNNPackedMatMul_int8.S +// MNN +// +// Created by MNN on 2023/06/06. +// Copyright © 2018, Alibaba Group Holding Limited +// +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 12 * 8 MatMul +asm_function MNNPackedMatMul_int8 +//void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); +// x0: C, x1:A, x2:B, x3:parameter, x4: postParameters, x5:bias, x6: k, x7: b +stp d14, d15, [sp, #-80]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] +stp x19, x20, [sp, #64] + +//ldr x8, [x3, #0] // deprecated +ldr x9, [x3, #8] // l +ldr x10, [x3, #16] // h + +ldr x13, [x3, #24] // cStride +ldr x11, [x3, #40] // bExtraStride + +// v0, v1, v2: A +// v3, v4: B +// v8 - v31: C +add x10, x10, #3 +lsr x10, x10, #2 + +cbz x4, Start + + +Start: + +cmp x10, #2 +blt LH4 + +LH8: +// sub x14, x13, #160 +mov x19, x6 +mov x20, x7 +LoopH: + + mov x15, x1 + ld1 {v4.4s, v5.4s}, [x19], #32 // alpha + ld1 {v6.4s, v7.4s}, [x20], #32 // bias + subs x12, x9, #2 + // ld1 {v3.4s, v4.4s}, [x2], #32 + ld1 {v3.16b}, [x2], #16 + sxtl v0.8h, v3.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v0.4s, v1.4s + scvtf v1.4s, v2.4s + mov v2.4s, v7.4s + fmla v2.4s, v1.4s, v5.4s + mov v1.4s, v6.4s + fmla v1.4s, v0.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmul v8.4s, v1.4s, v0.s[0] + fmul v9.4s, v1.4s, v0.s[1] + fmul v10.4s, v1.4s, v0.s[2] + fmul v11.4s, v1.4s, v0.s[3] + fmul v20.4s, v2.4s, v0.s[0] + fmul v21.4s, v2.4s, v0.s[1] + fmul v22.4s, v2.4s, v0.s[2] + fmul v23.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmul v12.4s, v1.4s, v0.s[0] + fmul v13.4s, v1.4s, v0.s[1] + fmul v14.4s, v1.4s, v0.s[2] + fmul v15.4s, v1.4s, v0.s[3] + fmul v24.4s, v2.4s, v0.s[0] + fmul v25.4s, v2.4s, v0.s[1] + fmul v26.4s, v2.4s, v0.s[2] + fmul v27.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmul v16.4s, v1.4s, v0.s[0] + fmul v17.4s, v1.4s, v0.s[1] + fmul v18.4s, v1.4s, v0.s[2] + fmul v19.4s, v1.4s, v0.s[3] + fmul v28.4s, v2.4s, v0.s[0] + fmul v29.4s, v2.4s, v0.s[1] + fmul v30.4s, v2.4s, v0.s[2] + fmul v31.4s, v2.4s, v0.s[3] + + sxtl2 v0.8h, v3.16b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v0.4s, v1.4s + scvtf v1.4s, v2.4s + mov v2.4s, v7.4s + fmla v2.4s, v1.4s, v5.4s + mov v1.4s, v6.4s + fmla v1.4s, v0.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v1.4s, v0.s[0] + fmla v9.4s, v1.4s, v0.s[1] + fmla v10.4s, v1.4s, v0.s[2] + fmla v11.4s, v1.4s, v0.s[3] + fmla v20.4s, v2.4s, v0.s[0] + fmla v21.4s, v2.4s, v0.s[1] + fmla v22.4s, v2.4s, v0.s[2] + fmla v23.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v12.4s, v1.4s, v0.s[0] + fmla v13.4s, v1.4s, v0.s[1] + fmla v14.4s, v1.4s, v0.s[2] + fmla v15.4s, v1.4s, v0.s[3] + fmla v24.4s, v2.4s, v0.s[0] + fmla v25.4s, v2.4s, v0.s[1] + fmla v26.4s, v2.4s, v0.s[2] + fmla v27.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v1.4s, v0.s[0] + fmla v17.4s, v1.4s, v0.s[1] + fmla v18.4s, v1.4s, v0.s[2] + fmla v19.4s, v1.4s, v0.s[3] + fmla v28.4s, v2.4s, v0.s[0] + fmla v29.4s, v2.4s, v0.s[1] + fmla v30.4s, v2.4s, v0.s[2] + fmla v31.4s, v2.4s, v0.s[3] + + beq LoopLEnd + + LoopL2: + subs x12, x12, #2 + // ld1 {v3.4s, v4.4s}, [x2], #32 + ld1 {v3.16b}, [x2], #16 + sxtl v0.8h, v3.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v0.4s, v1.4s + scvtf v1.4s, v2.4s + mov v2.4s, v7.4s + fmla v2.4s, v1.4s, v5.4s + mov v1.4s, v6.4s + fmla v1.4s, v0.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v1.4s, v0.s[0] + fmla v9.4s, v1.4s, v0.s[1] + fmla v10.4s, v1.4s, v0.s[2] + fmla v11.4s, v1.4s, v0.s[3] + fmla v20.4s, v2.4s, v0.s[0] + fmla v21.4s, v2.4s, v0.s[1] + fmla v22.4s, v2.4s, v0.s[2] + fmla v23.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v12.4s, v1.4s, v0.s[0] + fmla v13.4s, v1.4s, v0.s[1] + fmla v14.4s, v1.4s, v0.s[2] + fmla v15.4s, v1.4s, v0.s[3] + fmla v24.4s, v2.4s, v0.s[0] + fmla v25.4s, v2.4s, v0.s[1] + fmla v26.4s, v2.4s, v0.s[2] + fmla v27.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v1.4s, v0.s[0] + fmla v17.4s, v1.4s, v0.s[1] + fmla v18.4s, v1.4s, v0.s[2] + fmla v19.4s, v1.4s, v0.s[3] + fmla v28.4s, v2.4s, v0.s[0] + fmla v29.4s, v2.4s, v0.s[1] + fmla v30.4s, v2.4s, v0.s[2] + fmla v31.4s, v2.4s, v0.s[3] + + sxtl2 v0.8h, v3.16b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v0.4s, v1.4s + scvtf v1.4s, v2.4s + mov v2.4s, v7.4s + fmla v2.4s, v1.4s, v5.4s + mov v1.4s, v6.4s + fmla v1.4s, v0.4s, v4.4s + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v1.4s, v0.s[0] + fmla v9.4s, v1.4s, v0.s[1] + fmla v10.4s, v1.4s, v0.s[2] + fmla v11.4s, v1.4s, v0.s[3] + fmla v20.4s, v2.4s, v0.s[0] + fmla v21.4s, v2.4s, v0.s[1] + fmla v22.4s, v2.4s, v0.s[2] + fmla v23.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v12.4s, v1.4s, v0.s[0] + fmla v13.4s, v1.4s, v0.s[1] + fmla v14.4s, v1.4s, v0.s[2] + fmla v15.4s, v1.4s, v0.s[3] + fmla v24.4s, v2.4s, v0.s[0] + fmla v25.4s, v2.4s, v0.s[1] + fmla v26.4s, v2.4s, v0.s[2] + fmla v27.4s, v2.4s, v0.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v16.4s, v1.4s, v0.s[0] + fmla v17.4s, v1.4s, v0.s[1] + fmla v18.4s, v1.4s, v0.s[2] + fmla v19.4s, v1.4s, v0.s[3] + fmla v28.4s, v2.4s, v0.s[0] + fmla v29.4s, v2.4s, v0.s[1] + fmla v30.4s, v2.4s, v0.s[2] + fmla v31.4s, v2.4s, v0.s[3] + bne LoopL2 + + LoopLEnd: + + add x2, x2, x11 + sub x10, x10, #2 + cmp x10, #2 + + cbz x4, StoreLH8 + + AddBiasLH8: + ld1 {v5.4s}, [x4] + dup v6.4s, v5.s[2] // Min Value + dup v7.4s, v5.s[3] // Max Value + ld1 {v0.4s, v1.4s}, [x5], #32 + + fmla v8.4s, v0.4s, v5.s[1] + fmla v9.4s, v0.4s, v5.s[1] + fmla v10.4s, v0.4s, v5.s[1] + fmla v11.4s, v0.4s, v5.s[1] + + fmla v12.4s, v0.4s, v5.s[1] + fmla v13.4s, v0.4s, v5.s[1] + fmla v14.4s, v0.4s, v5.s[1] + fmla v15.4s, v0.4s, v5.s[1] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + fmla v24.4s, v1.4s, v5.s[1] + fmla v25.4s, v1.4s, v5.s[1] + fmla v26.4s, v1.4s, v5.s[1] + fmla v27.4s, v1.4s, v5.s[1] + + fmla v28.4s, v1.4s, v5.s[1] + fmla v29.4s, v1.4s, v5.s[1] + fmla v30.4s, v1.4s, v5.s[1] + fmla v31.4s, v1.4s, v5.s[1] + + PostTreatLH8: + fmax v8.4s, v8.4s, v6.4s + fmax v9.4s, v9.4s, v6.4s + fmax v10.4s, v10.4s, v6.4s + fmax v11.4s, v11.4s, v6.4s + fmax v12.4s, v12.4s, v6.4s + fmax v13.4s, v13.4s, v6.4s + fmax v14.4s, v14.4s, v6.4s + fmax v15.4s, v15.4s, v6.4s + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + fmax v24.4s, v24.4s, v6.4s + fmax v25.4s, v25.4s, v6.4s + fmax v26.4s, v26.4s, v6.4s + fmax v27.4s, v27.4s, v6.4s + fmax v28.4s, v28.4s, v6.4s + fmax v29.4s, v29.4s, v6.4s + fmax v30.4s, v30.4s, v6.4s + fmax v31.4s, v31.4s, v6.4s + + fmin v8.4s, v8.4s, v7.4s + fmin v9.4s, v9.4s, v7.4s + fmin v10.4s, v10.4s, v7.4s + fmin v11.4s, v11.4s, v7.4s + fmin v12.4s, v12.4s, v7.4s + fmin v13.4s, v13.4s, v7.4s + fmin v14.4s, v14.4s, v7.4s + fmin v15.4s, v15.4s, v7.4s + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + fmin v24.4s, v24.4s, v7.4s + fmin v25.4s, v25.4s, v7.4s + fmin v26.4s, v26.4s, v7.4s + fmin v27.4s, v27.4s, v7.4s + fmin v28.4s, v28.4s, v7.4s + fmin v29.4s, v29.4s, v7.4s + fmin v30.4s, v30.4s, v7.4s + fmin v31.4s, v31.4s, v7.4s + + StoreLH8: + stp q8, q9, [x0] + stp q10, q11, [x0, #(32 * 1)] // 2 * 4 * sizeof(int16_t) + stp q12, q13, [x0, #(32 * 2)] + stp q14, q15, [x0, #(32 * 3)] + stp q16, q17, [x0, #(32 * 4)] + stp q18, q19, [x0, #(32 * 5)] + add x0, x0, x13 // stp donot support post-index offset in register + stp q20, q21, [x0] + stp q22, q23, [x0, #(32 * 1)] + stp q24, q25, [x0, #(32 * 2)] + stp q26, q27, [x0, #(32 * 3)] + stp q28, q29, [x0, #(32 * 4)] + stp q30, q31, [x0, #(32 * 5)] + add x0, x0, x13 + + // st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 + // st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x14 +// + // st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + // st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + // st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x14 + + bge LoopH + +LH4: +cbz x10, End +LoopHRemain: + mov x15, x1 + subs x12, x9, #4 + ld1 {v20.4s}, [x19], #16 // alpha + ld1 {v21.4s}, [x20], #16 // bias + // ld1 {v3.4s}, [x2] + ld1 {v0.4s}, [x2], #16 + uzp1 v7.4s, v0.4s, v0.4s + sxtl v0.8h, v7.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v1.4s, v1.4s + scvtf v2.4s, v2.4s + mov v3.4s, v21.4s + mov v4.4s, v21.4s + fmla v3.4s, v1.4s, v20.4s + fmla v4.4s, v2.4s, v20.4s + + ld1 {v0.4s}, [x15], #16 + fmul v8.4s, v3.4s, v0.s[0] + fmul v9.4s, v3.4s, v0.s[1] + ld1 {v1.4s}, [x15], #16 + fmul v10.4s, v3.4s, v0.s[2] + fmul v11.4s, v3.4s, v0.s[3] + fmul v12.4s, v3.4s, v1.s[0] + ld1 {v2.4s}, [x15], #16 + fmul v13.4s, v3.4s, v1.s[1] + fmul v14.4s, v3.4s, v1.s[2] + fmul v15.4s, v3.4s, v1.s[3] + fmul v16.4s, v3.4s, v2.s[0] + fmul v17.4s, v3.4s, v2.s[1] + fmul v18.4s, v3.4s, v2.s[2] + fmul v19.4s, v3.4s, v2.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v4.4s, v0.s[0] + fmla v9.4s, v4.4s, v0.s[1] + ld1 {v1.4s}, [x15], #16 + fmla v10.4s, v4.4s, v0.s[2] + fmla v11.4s, v4.4s, v0.s[3] + fmla v12.4s, v4.4s, v1.s[0] + ld1 {v2.4s}, [x15], #16 + fmla v13.4s, v4.4s, v1.s[1] + fmla v14.4s, v4.4s, v1.s[2] + fmla v15.4s, v4.4s, v1.s[3] + fmla v16.4s, v4.4s, v2.s[0] + fmla v17.4s, v4.4s, v2.s[1] + fmla v18.4s, v4.4s, v2.s[2] + fmla v19.4s, v4.4s, v2.s[3] + + sxtl2 v0.8h, v7.16b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v1.4s, v1.4s + scvtf v2.4s, v2.4s + mov v3.4s, v21.4s + mov v4.4s, v21.4s + fmla v3.4s, v1.4s, v20.4s + fmla v4.4s, v2.4s, v20.4s + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + ld1 {v1.4s}, [x15], #16 + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + ld1 {v2.4s}, [x15], #16 + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + ld1 {v0.4s}, [x15], #16 + fmla v8.4s, v4.4s, v0.s[0] + fmla v9.4s, v4.4s, v0.s[1] + ld1 {v1.4s}, [x15], #16 + fmla v10.4s, v4.4s, v0.s[2] + fmla v11.4s, v4.4s, v0.s[3] + fmla v12.4s, v4.4s, v1.s[0] + ld1 {v2.4s}, [x15], #16 + fmla v13.4s, v4.4s, v1.s[1] + fmla v14.4s, v4.4s, v1.s[2] + fmla v15.4s, v4.4s, v1.s[3] + fmla v16.4s, v4.4s, v2.s[0] + fmla v17.4s, v4.4s, v2.s[1] + fmla v18.4s, v4.4s, v2.s[2] + fmla v19.4s, v4.4s, v2.s[3] + + beq LoopLREnd + + LoopLR: + subs x12, x12, #4 + // ld1 {v3.4s}, [x2] + ld1 {v3.4s}, [x2], #16 + uzp1 v7.4s, v3.4s, v3.4s + sxtl v0.8h, v7.8b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v1.4s, v1.4s + scvtf v2.4s, v2.4s + mov v3.4s, v21.4s + mov v4.4s, v21.4s + fmla v3.4s, v1.4s, v20.4s + fmla v4.4s, v2.4s, v20.4s + ld1 {v0.4s, v1.4s, v2.4s}, [x15], #48 + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + ld1 {v0.4s, v1.4s, v2.4s}, [x15], #48 + fmla v8.4s, v4.4s, v0.s[0] + fmla v9.4s, v4.4s, v0.s[1] + fmla v10.4s, v4.4s, v0.s[2] + fmla v11.4s, v4.4s, v0.s[3] + fmla v12.4s, v4.4s, v1.s[0] + fmla v13.4s, v4.4s, v1.s[1] + fmla v14.4s, v4.4s, v1.s[2] + fmla v15.4s, v4.4s, v1.s[3] + fmla v16.4s, v4.4s, v2.s[0] + fmla v17.4s, v4.4s, v2.s[1] + fmla v18.4s, v4.4s, v2.s[2] + fmla v19.4s, v4.4s, v2.s[3] + + sxtl2 v0.8h, v7.16b + sxtl v1.4s, v0.4h + sxtl2 v2.4s, v0.8h + scvtf v1.4s, v1.4s + scvtf v2.4s, v2.4s + mov v3.4s, v21.4s + mov v4.4s, v21.4s + fmla v3.4s, v1.4s, v20.4s + fmla v4.4s, v2.4s, v20.4s + + ld1 {v0.4s, v1.4s, v2.4s}, [x15], #48 + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + ld1 {v0.4s, v1.4s, v2.4s}, [x15], #48 + fmla v8.4s, v4.4s, v0.s[0] + fmla v9.4s, v4.4s, v0.s[1] + fmla v10.4s, v4.4s, v0.s[2] + fmla v11.4s, v4.4s, v0.s[3] + fmla v12.4s, v4.4s, v1.s[0] + fmla v13.4s, v4.4s, v1.s[1] + fmla v14.4s, v4.4s, v1.s[2] + fmla v15.4s, v4.4s, v1.s[3] + fmla v16.4s, v4.4s, v2.s[0] + fmla v17.4s, v4.4s, v2.s[1] + fmla v18.4s, v4.4s, v2.s[2] + fmla v19.4s, v4.4s, v2.s[3] + bne LoopLR + LoopLREnd: + + cbz x4, StoreLH4 + AddBiasLH4: + ld1 {v5.4s}, [x4] + dup v6.4s, v5.s[2] // Min Value + dup v7.4s, v5.s[3] // Max Value + ld1 {v0.4s}, [x5], #16 + + fmla v8.4s, v0.4s, v5.s[1] + fmla v9.4s, v0.4s, v5.s[1] + fmla v10.4s, v0.4s, v5.s[1] + fmla v11.4s, v0.4s, v5.s[1] + + fmla v12.4s, v0.4s, v5.s[1] + fmla v13.4s, v0.4s, v5.s[1] + fmla v14.4s, v0.4s, v5.s[1] + fmla v15.4s, v0.4s, v5.s[1] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + PostTreatLH4: + fmax v8.4s, v8.4s, v6.4s + fmax v9.4s, v9.4s, v6.4s + fmax v10.4s, v10.4s, v6.4s + fmax v11.4s, v11.4s, v6.4s + fmax v12.4s, v12.4s, v6.4s + fmax v13.4s, v13.4s, v6.4s + fmax v14.4s, v14.4s, v6.4s + fmax v15.4s, v15.4s, v6.4s + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + + fmin v8.4s, v8.4s, v7.4s + fmin v9.4s, v9.4s, v7.4s + fmin v10.4s, v10.4s, v7.4s + fmin v11.4s, v11.4s, v7.4s + fmin v12.4s, v12.4s, v7.4s + fmin v13.4s, v13.4s, v7.4s + fmin v14.4s, v14.4s, v7.4s + fmin v15.4s, v15.4s, v7.4s + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + + StoreLH4: + stp q8, q9, [x0] + stp q10, q11, [x0, #(32 * 1)] // 2 * 4 * sizeof(float) + stp q12, q13, [x0, #(32 * 2)] + stp q14, q15, [x0, #(32 * 3)] + stp q16, q17, [x0, #(32 * 4)] + stp q18, q19, [x0, #(32 * 5)] + + // st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 + // st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 + // st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0] + sub x10, x10, #1 + + +End: +ldp x19, x20, [sp, #64] +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #80 + +ret + +#endif diff --git a/source/backend/cpu/bf16/BF16Functions.cpp b/source/backend/cpu/bf16/BF16Functions.cpp index ea5a75af9..e62494bef 100644 --- a/source/backend/cpu/bf16/BF16Functions.cpp +++ b/source/backend/cpu/bf16/BF16Functions.cpp @@ -211,7 +211,7 @@ void MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t void MNNPackedMatMulRemain_BF16(float* CFloat, const float* AFloat, const float* BFloat, size_t eSize, const size_t* parameter, float* cacheFloat, const float* postParameters, - const float* biasFloat) { + const float* biasFloat, const float* k, const float* b) { int16_t* C = (int16_t*)CFloat; int16_t* A = (int16_t*)AFloat; int16_t* B = (int16_t*)BFloat; @@ -271,8 +271,8 @@ void MNNPackedMatMulRemain_BF16(float* CFloat, const float* AFloat, const float* } void MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, float* cache, - const float* postParameters, const float* bias) { - return MNNPackedMatMulRemain_BF16(C, A, B, 16, parameter, cache, postParameters, bias); + const float* postParameters, const float* bias, const float* k, const float* b) { + return MNNPackedMatMulRemain_BF16(C, A, B, 16, parameter, cache, postParameters, bias, nullptr, nullptr); // return _AVX_MNNPackedMatMulFMA(C, A, B, parameter, cache); } diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index 71c67159a..44cb5b8e9 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -275,15 +275,170 @@ static void _MNNPackedMatMulRemain(float* C, const float* A, const float* B, siz } } } -void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias) { + +void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { return _MNNPackedMatMulRemain(C, A, B, 16, parameter, postParameters, bias, 16); } -void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias) { +void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { auto aStride = parameter[0] / sizeof(float); _MNNPackedMatMulRemain(C, A, B, eSize, parameter, postParameters, bias, aStride); } +#ifdef MNN_LOW_MEMORY +static void _MNNPackedMatMulRemain_int4(float* C, const float* A, const float* fB, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, int aStride, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto hRemain = parameter[4]; + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y=0; y().max(); + float maxValue = std::numeric_limits().max(); + if (nullptr != postParameters) { + minValue = postParameters[2]; + maxValue = postParameters[3]; + alpha = postParameters[0]; + beta = postParameters[1]; + } + + for (int x=0; x(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto hRemain = parameter[4]; + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y=0; y().max(); + float maxValue = std::numeric_limits().max(); + if (nullptr != postParameters) { + minValue = postParameters[2]; + maxValue = postParameters[3]; + alpha = postParameters[0]; + beta = postParameters[1]; + } + + for (int x=0; xMNNPackForMatMul_B = MNNPackForMatMul_B; gCoreFunction->MNNPackedMatMul = MNNPackedMatMul; gCoreFunction->MNNPackedMatMulRemain = MNNPackedMatMulRemain; +#ifdef MNN_LOW_MEMORY + gCoreFunction->MNNPackedMatMul_int4 = MNNPackedMatMul_int4; + gCoreFunction->MNNPackedMatMulRemain_int4 = MNNPackedMatMulRemain_int4; + gCoreFunction->MNNPackedMatMul_int8 = MNNPackedMatMul_int8; + gCoreFunction->MNNPackedMatMulRemain_int8 = MNNPackedMatMulRemain_int8; +#endif gCoreFunction->MNNGetSparseMatMulPackMode = MNNGetSparseMatMulPackMode; gCoreFunction->MNNAdjustOptimalSparseKernel = _MNNAdjustOptimalSparseKernel; diff --git a/source/backend/cpu/compute/CommonOptFunction.h b/source/backend/cpu/compute/CommonOptFunction.h index 6c181822e..10ef40ec3 100644 --- a/source/backend/cpu/compute/CommonOptFunction.h +++ b/source/backend/cpu/compute/CommonOptFunction.h @@ -114,9 +114,15 @@ void MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const in void MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose); // parameters: e, l, h, CStride, AStride, BStride -void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); +void MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNFunctionInit(); -void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); +void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#ifdef MNN_LOW_MEMORY +void MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#endif void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose); struct SparseMatMulParas @@ -188,8 +194,14 @@ struct CoreFunctions { void(*MNNPackC4ForMatMul_A)(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void(*MNNPackForMatMul_B)(float* dest, const float* source, size_t h, size_t l, bool transpose); // parameters: e, l, h, CStride, AStride, BStride - void(*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); - void(*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); + void(*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); + void(*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#ifdef MNN_LOW_MEMORY + void(*MNNPackedMatMul_int4)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); + void(*MNNPackedMatMulRemain_int4)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); + void(*MNNPackedMatMul_int8)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); + void(*MNNPackedMatMulRemain_int8)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +#endif void(*MNNComputeMatMulForH_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); void(*MNNComputeMatMulForE_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); diff --git a/source/backend/cpu/compute/ConvolutionFloatFactory.cpp b/source/backend/cpu/compute/ConvolutionFloatFactory.cpp index e372b4a74..41416b77c 100644 --- a/source/backend/cpu/compute/ConvolutionFloatFactory.cpp +++ b/source/backend/cpu/compute/ConvolutionFloatFactory.cpp @@ -28,7 +28,11 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend const Convolution2D* conv2d, const float* originWeight, size_t originWeightSize, const float* bias, size_t biasSize, std::shared_ptr weightQuantInfo, bool supportSparse) { auto cpuBackend = (CPUBackend*)backend; +#ifdef MNN_LOW_MEMORY bool lowMemory = cpuBackend->memoryMode() == BackendConfig::Memory_Low; +#else + bool lowMemory = false; +#endif auto common = conv2d->common(); #ifdef MNN_USE_ONEDNN return OneDNN::createConvolution(common, backend, originWeight, originWeightSize, bias, biasSize); @@ -70,7 +74,11 @@ Execution* ConvolutionFloatFactory::create(const std::vector& inputs, c // Multi Input return new ConvolutionTiledExecutorMultiInput(conv2d->common(), backend); } - bool lowMemory = static_cast(backend)->memoryMode() == BackendConfig::Memory_Low && static_cast(backend)->functions()->bytes == 4; +#ifdef MNN_LOW_MEMORY + bool lowMemory = static_cast(backend)->memoryMode() == BackendConfig::Memory_Low; +#else + bool lowMemory = false; +#endif const float* originWeight = nullptr; const float* originBias = nullptr; int originWeightSize = 0; diff --git a/source/backend/cpu/compute/ConvolutionPackFreeWinograd.cpp b/source/backend/cpu/compute/ConvolutionPackFreeWinograd.cpp index a1139ef7a..bb97db324 100644 --- a/source/backend/cpu/compute/ConvolutionPackFreeWinograd.cpp +++ b/source/backend/cpu/compute/ConvolutionPackFreeWinograd.cpp @@ -328,7 +328,7 @@ ErrorCode ConvolutionPackFreeWinograd::onExecute(const std::vector &in auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes); auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0)); auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes); - core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBufferPtr, _weightFloatPtr, parameters.data(), nullptr, nullptr); + core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBufferPtr, _weightFloatPtr, parameters.data(), nullptr, nullptr, nullptr, nullptr); } } else { for (int i = tId; i < srcUnit2; i+=threadNumber) { @@ -340,7 +340,7 @@ ErrorCode ConvolutionPackFreeWinograd::onExecute(const std::vector &in auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes); auto _weightFloatPtr = (const float*)(weight + i * mResource->mWeight->stride(0)); auto gemmBufferPtr = (const float*)(gemmBuffer + i * ePack * ic_roundup * bytes); - core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBufferPtr, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr); + core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBufferPtr, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr, nullptr, nullptr); } } }; diff --git a/source/backend/cpu/compute/ConvolutionPackWinograd.cpp b/source/backend/cpu/compute/ConvolutionPackWinograd.cpp index 05d7b011c..0883b1ac1 100644 --- a/source/backend/cpu/compute/ConvolutionPackWinograd.cpp +++ b/source/backend/cpu/compute/ConvolutionPackWinograd.cpp @@ -416,7 +416,7 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector &inputs, auto unitsGemmbuffer = gemmBuffer + iNh * ic_4 * pack * ePack * bytes; auto _dstFloatPtr = (float*)(_dstOrigin + (iNh * srcUnit + iNw) * dc_4 * pack * ePack * bytes); auto _weightFloatPtr = (const float*)(weight + (iNh * srcUnit + iNw) * weightStride); - core->MNNPackedMatMul(_dstFloatPtr, (float*)unitsGemmbuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr); + core->MNNPackedMatMul(_dstFloatPtr, (float*)unitsGemmbuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr, nullptr, nullptr); } } } else { @@ -441,7 +441,7 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector &inputs, auto _weightFloatPtr = (const float*)(weight + i * weightStride); core->MNNPackC4ForMatMul_A((float*)gemmBuffer, &srcTemp, info, el); - core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr); + core->MNNPackedMatMul(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, parameters.data(), nullptr, nullptr, nullptr, nullptr); } } else { for (int i = 0; i < srcUnit2; ++i) { @@ -449,7 +449,7 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector &inputs, auto _dstFloatPtr = (float*)(_dstOrigin + i * dc_4 * pack * xC * bytes); auto _weightFloatPtr = (const float*)(weight + i * weightStride); core->MNNPackC4ForMatMul_A((float*)gemmBuffer, &srcTemp, info, el); - core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr); + core->MNNPackedMatMulRemain(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, xC, parametersRemain.data(), nullptr, nullptr, nullptr, nullptr); } } } diff --git a/source/backend/cpu/compute/ConvolutionTiledExecutor.hpp b/source/backend/cpu/compute/ConvolutionTiledExecutor.hpp index 071784be8..1d83fa3b3 100644 --- a/source/backend/cpu/compute/ConvolutionTiledExecutor.hpp +++ b/source/backend/cpu/compute/ConvolutionTiledExecutor.hpp @@ -28,6 +28,7 @@ class ConvolutionTiledImpl : public CPUConvolution { Tensor mTempBufferTranspose; ConvolutionCommon::Im2ColParameter mIm2ColParameters; std::pair> mFunction; + const CPUConvolution::Resource* mResource = nullptr; }; class ConvolutionTiledExecutor : public Execution { diff --git a/source/backend/cpu/compute/DeconvolutionWithStride.cpp b/source/backend/cpu/compute/DeconvolutionWithStride.cpp index e76af5bc6..5dc7718ad 100644 --- a/source/backend/cpu/compute/DeconvolutionWithStride.cpp +++ b/source/backend/cpu/compute/DeconvolutionWithStride.cpp @@ -82,7 +82,7 @@ static void _winograd(const DeconvolutionWithStride::ComputeUnit& unit, int thre auto tempColAddr = destAddr + i * unit.dstBuffer->stride(1); auto weightAddr = unit.weight->host() + unit.weight->stride(0) * i; MNNPackC4ForMatMul_A(cachePackBuffer, &tempSourceAddr, info, el); - MNNPackedMatMul(tempColAddr, cachePackBuffer,weightAddr, parameters, nullptr, nullptr); + MNNPackedMatMul(tempColAddr, cachePackBuffer,weightAddr, parameters, nullptr, nullptr, nullptr, nullptr); } auto B = unit.winogradInfo.B.get(); auto midAddr = unit.winogradInfo.dstTransformedBuffer->host() + @@ -144,7 +144,7 @@ static void _gemmAndIm2col(const DeconvolutionWithStride::ComputeUnit& unit, int for (int fx = 0; fx < unit.xUnit; ++fx) { auto ucolAddr = tempColAddr + dc_4 * eP * 4 * (fx + fy * unit.xUnit); auto uwAddr = weightAddr + unit.weight->stride(0) * (fx + fy * unit.xUnit); - MNNPackedMatMul(ucolAddr, cachePackBuffer, uwAddr, parameters, nullptr, nullptr); + MNNPackedMatMul(ucolAddr, cachePackBuffer, uwAddr, parameters, nullptr, nullptr, nullptr, nullptr); } } // FUNC_PRINT_ALL(tempColAddr[0], f); diff --git a/source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp b/source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp index 2644ef101..a6310e52d 100644 --- a/source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp +++ b/source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp @@ -19,7 +19,6 @@ #include "common/MemoryFormater.h" #define PARAMETERSIZE 6 -#define MNN_ALLOC_MEMORY_INDIRECTLY using Vec4 = MNN::Math::Vec; namespace MNN { @@ -28,7 +27,7 @@ void DenseConvolutionTiledExecutor::initWeight(float *dest, const float *source, function->MNNPackForMatMul_B(dest, cache, outputCount, kernelSize * depth, true); } -static bool _initQuantizeResource(std::shared_ptr int8Info, std::shared_ptr resource, int hU, int hP, int lU, int lP, int outputCount, int srcChannel, int kernelSize) { +static bool _initQuantizeResource(std::shared_ptr int8Info, std::shared_ptr resource, int hU, int hP, int lU, int lP, int outputCount, int srcChannel, int kernelSize, int bytes) { int weightLength = hU * lU * hP * lP; resource->mWeight.reset(Tensor::createDevice( {weightLength})); @@ -66,17 +65,32 @@ static bool _initQuantizeResource(std::shared_ptr } auto alphaPtr = resource->mDequantize.mScaleBias->host(); auto biasPtr = resource->mDequantize.mScaleBias->host() + hU * hP; - ::memset(alphaPtr, 0, 2 * hU * hP * sizeof(float)); + ::memset(alphaPtr, 0, 2 * hU * hP * bytes); int h = int8Info->alpha.size(); - if (int8Info->asymmetric) { - h = h / 2; - for (int i=0; ialpha.get()[2 * i + 1]; - biasPtr[i] = int8Info->alpha.get()[2 * i]; + if (bytes == 2) { + auto core = static_cast(resource->backend)->functions(); + if (int8Info->asymmetric) { + std::unique_ptr tmp(new int16_t[h]); + core->MNNFp32ToLowp(int8Info->alpha.get(), tmp.get(), h); + for (int i=0; i< h/2; ++i) { + reinterpret_cast(alphaPtr)[i] = tmp[2 * i + 1]; + reinterpret_cast(biasPtr)[i] = tmp[2 * i]; + } + } else { + core->MNNFp32ToLowp(int8Info->alpha.get(), reinterpret_cast(alphaPtr), h); } } else { - for (int i=0; ialpha.get()[i]; + if (int8Info->asymmetric) { + h = h / 2; + for (int i=0; ialpha.get()[2 * i + 1]; + biasPtr[i] = int8Info->alpha.get()[2 * i]; + } + } else { + for (int i=0; ialpha.get()[i]; + biasPtr[i] = 0.f; + } } } if (int8Info->canUseInt4) { @@ -95,9 +109,7 @@ static bool _initQuantizeResource(std::shared_ptr for (int i=0; iweightReverseMap[(int)s0 + 128]; - s1 = int8Info->weightReverseMap[(int)s1 + 128]; - int d = s0 * 16 + s1; + int d = (s0 + 7) * 16 + (s1 + 7); dstPtr[i] = d; } resource->mWeight = weightLow; @@ -127,7 +139,7 @@ DenseConvolutionTiledExecutor::DenseConvolutionTiledExecutor(const Convolution2D auto lU = UP_DIV(lSize, lP); if (useInt8Weight) { // Quantize weight to int8 - auto allocSuccess = _initQuantizeResource(int8Info, mResource, hU, hP, lU, lP, outputCount, srcCount, common->kernelX() * common->kernelY()); + auto allocSuccess = _initQuantizeResource(int8Info, mResource, hU, hP, lU, lP, outputCount, srcCount, common->kernelX() * common->kernelY(), bytes); if (!allocSuccess) { mValid = false; return; @@ -149,11 +161,11 @@ DenseConvolutionTiledExecutor::DenseConvolutionTiledExecutor(const Convolution2D // formatMatrix(mResource->mWeight->host(), {UP_DIV(outputCount, hP), lSize, hP}); backend()->onReleaseBuffer(cache.get(), Backend::STATIC); } - mProxy.reset(new DenseConvolutionTiledImpl(common, b)); + mProxy.reset(new DenseConvolutionTiledImpl(common, b, mResource.get())); } DenseConvolutionTiledExecutor::DenseConvolutionTiledExecutor(std::shared_ptr res, const Convolution2DCommon* common, Backend* b) : ConvolutionTiledExecutor(res, b) { - mProxy.reset(new DenseConvolutionTiledImpl(common, b)); + mProxy.reset(new DenseConvolutionTiledImpl(common, b, mResource.get())); } DenseConvolutionTiledExecutor::~DenseConvolutionTiledExecutor() { @@ -173,117 +185,15 @@ bool DenseConvolutionTiledExecutor::onClone(Backend* bn, const Op* op, Execution } ErrorCode DenseConvolutionTiledExecutor::onExecute(const std::vector &inputs, const std::vector &outputs) { - bool needDequantize = mResource->mDequantize.bits <= 8; - if (needDequantize) { -#ifndef MNN_ALLOC_MEMORY_INDIRECTLY - auto res = backend()->onAcquireBuffer(mWeightCache.weight.get(), Backend::STATIC); - if (!res) { - return OUT_OF_MEMORY; - } - if (nullptr != mWeightCache.weightInt8) { - res = backend()->onAcquireBuffer(mWeightCache.weightInt8.get(), Backend::STATIC); - if (!res) { - return OUT_OF_MEMORY; - } - } -#endif - auto hU = mResource->hU; - auto hP = mResource->hP; - auto mid = mResource->lU * mResource->lP; - auto srcInt8 = mResource->mWeight->host(); - if (mResource->mDequantize.bits == 4) { - int weightLength = hU * hP * mid; - weightLength = UP_DIV(weightLength, 2); - auto srcPtr = mResource->mWeight->host(); - auto dstPtr = mWeightCache.weightInt8->host(); - for (int i=0; imDequantize.mLowBitWeightMap[s0]; - s1 = mResource->mDequantize.mLowBitWeightMap[s1]; - dstPtr[2 * i + 0] = s0; - dstPtr[2 * i + 1] = s1; - } - srcInt8 = mWeightCache.weightInt8->host(); - } - auto alpha = mResource->mDequantize.mScaleBias->host(); - auto bias = mResource->mDequantize.mScaleBias->host() + hU * hP; - auto dstFloat = mWeightCache.weight->host(); - for (int yo=0; yoonReleaseBuffer(mWeightCache.weightInt8.get(), Backend::STATIC); - } -#endif - } auto code = mProxy->onExecute(mInputs, outputs); -#ifndef MNN_ALLOC_MEMORY_INDIRECTLY - if (needDequantize) { - backend()->onReleaseBuffer(mWeightCache.weight.get(), Backend::STATIC); - } - ((Runtime*)(static_cast(backend())->getRuntime()))->onGabageCollect(0); -#endif return code; } ErrorCode DenseConvolutionTiledExecutor::onResize(const std::vector &inputs, const std::vector &outputs) { mInputs = {inputs[0], mResource->mWeight.get(), mResource->mBias.get()}; - bool needDequantize = mResource->mDequantize.bits <= 8; - if (needDequantize) { - if (mWeightCache.weight == nullptr) { - int weightLength = mResource->hU * mResource->lU * mResource->hP * mResource->lP; - mWeightCache.weight.reset(new Tensor); - mWeightCache.weight->buffer().type = halide_type_of(); - TensorUtils::getDescribe(mWeightCache.weight.get())->dimensionFormat = MNN_DATA_FORMAT_NCHW; - mWeightCache.weight->buffer().dimensions = 1; - mWeightCache.weight->setLength(0, weightLength); - if (mWeightCache.weightInt8 == nullptr && mResource->mDequantize.bits == 4) { - mWeightCache.weightInt8.reset(new Tensor); - mWeightCache.weightInt8->buffer().type = halide_type_of(); - mWeightCache.weightInt8->buffer().dimensions = 1; - mWeightCache.weightInt8->setLength(0, weightLength); - TensorUtils::getDescribe(mWeightCache.weightInt8.get())->dimensionFormat = MNN_DATA_FORMAT_NCHW; - } - } - mInputs[1] = mWeightCache.weight.get(); -#ifdef MNN_ALLOC_MEMORY_INDIRECTLY - bool res = false; - if (nullptr != mWeightCache.weightInt8) { - res = backend()->onAcquireBuffer(mWeightCache.weightInt8.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - } - res = backend()->onAcquireBuffer(mWeightCache.weight.get(), Backend::DYNAMIC); - if (!res) { - return OUT_OF_MEMORY; - } - if (nullptr != mWeightCache.weightInt8) { - backend()->onReleaseBuffer(mWeightCache.weightInt8.get(), Backend::DYNAMIC); - } -#endif - } auto code = mProxy->onResize(mInputs, outputs); if (NO_ERROR != code) { return code; } - if (needDequantize) { -#ifdef MNN_ALLOC_MEMORY_INDIRECTLY - backend()->onReleaseBuffer(mWeightCache.weight.get(), Backend::DYNAMIC); -#endif - } return NO_ERROR; } @@ -375,7 +285,7 @@ void DenseConvolutionTiledImpl::getPackParameter(int* eP, int* lP, int* hP, cons // #define PROFILE_DETAIL PerfConfig DenseConvolutionTiledImpl::bestTileConvolutionConfig(const Convolution2DCommon *common, const Tensor *inputTensor, - const Tensor *outputTensor, int threadNumber, Backend* b) { + const Tensor *outputTensor, int threadNumber, Backend* b) { auto input = inputTensor; Tensor *bias = nullptr; auto core = static_cast(b)->functions(); @@ -492,12 +402,34 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs Tensor *bias = nullptr; auto core = static_cast(backend())->functions(); int bytes = core->bytes; + float weightBytes = bytes; int unit = core->pack; auto packA = core->MNNPackC4ForMatMul_A; int eP, lP, hP; getPackParameter(&eP, &lP, &hP, core); auto matmulUnit = core->MNNPackedMatMul; auto matmulRemain = core->MNNPackedMatMulRemain; + auto weightType = weight->getType(); + const uint8_t* dequantAlpha = nullptr; + const uint8_t* dequantBias = nullptr; +#ifdef MNN_LOW_MEMORY + if (mResource && mResource->mDequantize.bits <= 8) { + if (mResource->mDequantize.bits == 8) { + matmulUnit = core->MNNPackedMatMul_int8; + matmulRemain = core->MNNPackedMatMulRemain_int8; + weightBytes = 1; + } + if (mResource->mDequantize.bits == 4) { + matmulUnit = core->MNNPackedMatMul_int4; + matmulRemain = core->MNNPackedMatMulRemain_int4; + weightBytes = 0.5; + } + dequantAlpha = mResource->mDequantize.mScaleBias->host(); + dequantBias = dequantAlpha + mResource->hU * mResource->hP * bytes; + } +#endif + auto kernel_width = mCommon->kernelX(); + auto kernel_height = mCommon->kernelY(); auto output = outputs[0]; auto batch = output->batch(); int threadNumber = ((CPUBackend *)backend())->threadNumber(); @@ -522,7 +454,8 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs TensorUtils::setLinearLayout(&mTempBufferTranspose); auto plane = mIm2ColParameters.ow * mIm2ColParameters.oh * batch; int tileCount = UP_DIV(plane, eP); - auto oC4 = UP_DIV(outputChannel, unit); + auto tileC = std::max(unit, hP); + auto oC4 = UP_DIV(outputChannel, tileC); mConvPerfconfig = bestTileConvolutionConfig(mCommon, input, output, threadNumber, backend()); auto threadNumberFirst = mConvPerfconfig.isParallelInner ? threadNumber : std::min(threadNumber, tileCount); @@ -573,11 +506,11 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs parameters[5] = 0; #ifdef PROFILE_DETAIL - uint64_t durationMul[threadNumberFirst] = {0}; - uint64_t packATime[threadNumberFirst] = {0}; - uint64_t indexTime[threadNumberFirst] = {0}; + std::vector durationMul(threadNumberFirst, 0); + std::vector packATime(threadNumberFirst, 0); + std::vector indexTime(threadNumberFirst, 0); Timer timer[threadNumberFirst]; - double macs[threadNumberFirst] = {0}; + std::vector macs(threadNumberFirst, 0); #endif auto dstOrigin = output->host(); @@ -621,9 +554,6 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs timer[0].reset(); #endif - auto tileC = std::max(unit, hP); - auto oC4 = UP_DIV(outputChannel, tileC); - auto weightBytes = core->bytes; if (xC == eP) { MNN_CONCURRENCY_BEGIN(tId, threadNumberFirst) { size_t paraParameters[PARAMETERSIZE]; @@ -634,7 +564,9 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs auto _weightFloatPtr = reinterpret_cast(weightPtr + int((ocIndex / hP * LRoundup * hP) * weightBytes)); auto _biasFloatPtr = reinterpret_cast(reinterpret_cast(biasPtr) + ocIndex * bytes); paraParameters[2] = std::min(outputChannel - ocIndex, tileC); - matmulUnit(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, paraParameters, postParameters.data(), _biasFloatPtr); + auto k = reinterpret_cast(dequantAlpha + ocIndex * bytes); + auto b = reinterpret_cast(dequantBias + ocIndex * bytes); + matmulUnit(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, paraParameters, postParameters.data(), _biasFloatPtr, k, b); } } MNN_CONCURRENCY_END(); @@ -648,7 +580,9 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs auto _weightFloatPtr = reinterpret_cast(weightPtr + int((ocIndex / hP * LRoundup * hP) * weightBytes)); auto _biasFloatPtr = reinterpret_cast(reinterpret_cast(biasPtr) + ocIndex * bytes); paraParameters[2] = std::min(outputChannel - ocIndex, tileC); - matmulRemain(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, xC, paraParameters, postParameters.data(), _biasFloatPtr); + auto k = reinterpret_cast(dequantAlpha + ocIndex * bytes); + auto b = reinterpret_cast(dequantBias + ocIndex * bytes); + matmulRemain(_dstFloatPtr, (float*)gemmBuffer, _weightFloatPtr, xC, paraParameters, postParameters.data(), _biasFloatPtr, k, b); } } MNN_CONCURRENCY_END(); @@ -699,11 +633,11 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs parameters[5] = 0; #ifdef PROFILE_DETAIL - uint64_t durationMul[threadNumberFirst] = {0}; - uint64_t packATime[threadNumberFirst] = {0}; - uint64_t indexTime[threadNumberFirst] = {0}; + std::vector durationMul(threadNumberFirst, 0); + std::vector packATime(threadNumberFirst, 0); + std::vector indexTime(threadNumberFirst, 0); Timer timer[threadNumberFirst]; - double macs[threadNumberFirst] = {0}; + std::vector macs(threadNumberFirst, 0); #endif auto dstOrigin = output->host(); @@ -732,14 +666,17 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs packATime[tId] += timer[tId].durationInUs(); timer[tId].reset(); #endif + auto k = reinterpret_cast(dequantAlpha); + auto b = reinterpret_cast(dequantBias); + auto _dstFloatPtr = reinterpret_cast(dstOrigin + start * unit * bytes); if (xC == eP) { - matmulUnit((float*)(dstOrigin + start * unit * bytes), (float*)gemmBuffer, (float*)weightPtr, parameters,postParameters.data(), biasPtr); + matmulUnit(_dstFloatPtr, (float*)gemmBuffer, (float*)weightPtr, parameters, postParameters.data(), biasPtr, k, b); } else { - matmulRemain((float*)(dstOrigin + start * unit * bytes), (float*)gemmBuffer, (float*)weightPtr, xC, parameters,postParameters.data(), biasPtr); + matmulRemain(_dstFloatPtr, (float*)gemmBuffer, (float*)weightPtr, xC, parameters, postParameters.data(), biasPtr, k, b); } #ifdef PROFILE_DETAIL - macs[tId] += 2.0 * xC * L * oC4 * unit; // bias + macs[tId] += 2.0 * xC * L * oC4 * unit; // bias durationMul[tId] += timer[tId].durationInUs(); timer[tId].reset(); #endif diff --git a/source/backend/cpu/compute/DenseConvolutionTiledExecutor.hpp b/source/backend/cpu/compute/DenseConvolutionTiledExecutor.hpp index 21adcd7cf..1141473a3 100644 --- a/source/backend/cpu/compute/DenseConvolutionTiledExecutor.hpp +++ b/source/backend/cpu/compute/DenseConvolutionTiledExecutor.hpp @@ -17,8 +17,8 @@ namespace MNN { class DenseConvolutionTiledImpl : public ConvolutionTiledImpl { public: - DenseConvolutionTiledImpl(const Convolution2DCommon *common, Backend *b) : ConvolutionTiledImpl(common, b) { - // Do nothing + DenseConvolutionTiledImpl(const Convolution2DCommon *common, Backend *b, CPUConvolution::Resource* resource = nullptr) : ConvolutionTiledImpl(common, b) { + mResource = resource; } ErrorCode onResize(const std::vector& inputs, const std::vector& outputs) override; diff --git a/source/backend/cpu/compute/StrassenMatmulComputor.cpp b/source/backend/cpu/compute/StrassenMatmulComputor.cpp index d9b2c4900..98cf924fa 100644 --- a/source/backend/cpu/compute/StrassenMatmulComputor.cpp +++ b/source/backend/cpu/compute/StrassenMatmulComputor.cpp @@ -106,7 +106,7 @@ ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(int e, int l, int h, con int xStart = i * eP; auto aStart = aHost + xStart * packUnit; core->MNNPackC4ForMatMul_A((float*)(tileHost), (const float**)(&aStart), info, stride); - core->MNNPackedMatMul((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, parameters, postParametersPtr, (const float*)biasPtr); + core->MNNPackedMatMul((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, parameters, postParametersPtr, (const float*)biasPtr, nullptr, nullptr); } if (tId != numberThread -1) { return; @@ -120,7 +120,7 @@ ErrorCode StrassenMatrixComputor::_generateTrivalMatMul(int e, int l, int h, con auto aStart = aHost + xStart * packUnit; // Copy core->MNNPackC4ForMatMul_A((float*)(tileHost), (const float**)(&aStart), info, stride); - core->MNNPackedMatMulRemain((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, xCount, parameters, postParametersPtr, (const float*)biasPtr); + core->MNNPackedMatMulRemain((float*)(cHost + xStart * packUnit), (float*)tileHost, (float*)bHost, xCount, parameters, postParametersPtr, (const float*)biasPtr, nullptr, nullptr); } }, numberThread)); static_cast(backend())->getBufferAllocator()->free(tileBufferBasic); diff --git a/source/backend/cpu/x86_x64/AVX2Functions.cpp b/source/backend/cpu/x86_x64/AVX2Functions.cpp index 7bb523a96..592c0a8bc 100644 --- a/source/backend/cpu/x86_x64/AVX2Functions.cpp +++ b/source/backend/cpu/x86_x64/AVX2Functions.cpp @@ -39,6 +39,12 @@ bool AVX2Functions::init(int cpuFlags) { coreFunction->MNNPackedMatMul = _AVX_MNNPackedMatMul; coreFunction->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemain; +#ifdef MNN_LOW_MEMORY + coreFunction->MNNPackedMatMul_int4 = _AVX_MNNPackedMatMul_int4; + coreFunction->MNNPackedMatMulRemain_int4 = _AVX_MNNPackedMatMulRemain_int4; + coreFunction->MNNPackedMatMul_int8 = _AVX_MNNPackedMatMul_int8; + coreFunction->MNNPackedMatMulRemain_int8 = _AVX_MNNPackedMatMulRemain_int8; +#endif coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A; coreFunction->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B; coreFunction->MNNComputeMatMulForE_1 = _AVX_MNNComputeMatMulForE_1; diff --git a/source/backend/cpu/x86_x64/CMakeLists.txt b/source/backend/cpu/x86_x64/CMakeLists.txt index 2ded7bdda..631f12069 100644 --- a/source/backend/cpu/x86_x64/CMakeLists.txt +++ b/source/backend/cpu/x86_x64/CMakeLists.txt @@ -89,6 +89,12 @@ if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(X86_64)|(x64)|(X64)|(amd64)|(AMD64) target_compile_options(MNNAVXFMA PRIVATE -DMNN_SSE_USE_FP16_INSTEAD -mf16c) endif() endif() + if (MNN_LOW_MEMORY) + target_compile_options(MNNX8664 PRIVATE -DMNN_LOW_MEMORY) + target_compile_options(MNNSSE PRIVATE -DMNN_LOW_MEMORY) + target_compile_options(MNNAVX PRIVATE -DMNN_LOW_MEMORY) + target_compile_options(MNNAVXFMA PRIVATE -DMNN_LOW_MEMORY) + endif() list(APPEND MNN_OBJECTS_TO_LINK $ $ $ $) if (MSVC AND WIN_USE_ASM) target_compile_options(MNNAVX PRIVATE -DMNN_X86_USE_ASM) diff --git a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp index 2d1cfe2db..7cab6d1a4 100644 --- a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp +++ b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp @@ -50,6 +50,12 @@ void MNNFunctionInit() { coreFunction->MNNGetMatMulPackMode = _SSEMNNGetMatMulPackMode; coreFunction->MNNPackedMatMul = _SSE_MNNPackedMatMul; coreFunction->MNNPackedMatMulRemain = _SSE_MNNPackedMatMulRemain; +#ifdef MNN_LOW_MEMORY + coreFunction->MNNPackedMatMul_int4 = _SSE_MNNPackedMatMul_int4; + coreFunction->MNNPackedMatMulRemain_int4 = _SSE_MNNPackedMatMulRemain_int4; + coreFunction->MNNPackedMatMul_int8 = _SSE_MNNPackedMatMul_int8; + coreFunction->MNNPackedMatMulRemain_int8 = _SSE_MNNPackedMatMulRemain_int8; +#endif coreFunction->MNNPackC4ForMatMul_A = _SSE_MNNPackC4ForMatMul_A; coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B; } diff --git a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp index ba42f9229..1b355d80c 100644 --- a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp @@ -34,9 +34,19 @@ // ========= CommonOptFunction.cpp =========== extern "C" { void _AVX_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); void _AVX_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, - const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); +#ifdef MNN_LOW_MEMORY +void _AVX_MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); +void _AVX_MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); +void _AVX_MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); +void _AVX_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); +#endif void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void _AVX_MNNExpC8(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8); diff --git a/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp b/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp index a4f32f1ab..3ff5dc955 100644 --- a/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp @@ -20,17 +20,40 @@ #include "GemmFunction.hpp" void _AVX_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias) { + const float* postParameters, const float* bias, const float* k, const float* b) { _AVX_MNNPackedMatMul_Main(C, A, B, parameter); AVX2GemmPostTreat(C, MNN_UNIT_E, parameter, postParameters, bias); } void _AVX_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, - const float* postParameters, const float* bias) { + const float* postParameters, const float* bias, const float* k, const float* b) { _AVX_MNNPackednMatMulRemainCommon(C, A, B, eSize, parameter); AVX2GemmPostTreat(C, eSize, parameter, postParameters, bias); } +#ifdef MNN_LOW_MEMORY +void _AVX_MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + _AVX_MNNPackedMatMul_Main_int4(C, A, B, parameter, k, b); + AVX2GemmPostTreat(C, MNN_UNIT_E, parameter, postParameters, bias); +} +void _AVX_MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + _AVX_MNNPackednMatMulRemainCommon_int4(C, A, B, eSize, parameter, k, b); + AVX2GemmPostTreat(C, eSize, parameter, postParameters, bias); +} +void _AVX_MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + _AVX_MNNPackedMatMul_Main_int8(C, A, B, parameter, k, b); + AVX2GemmPostTreat(C, MNN_UNIT_E, parameter, postParameters, bias); +} +void _AVX_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + _AVX_MNNPackednMatMulRemainCommon_int8(C, A, B, eSize, parameter, k, b); + AVX2GemmPostTreat(C, eSize, parameter, postParameters, bias); +} +#endif + void _AVX_MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) { auto l = param->l; auto h = param->h; diff --git a/source/backend/cpu/x86_x64/avx/GemmFunction.hpp b/source/backend/cpu/x86_x64/avx/GemmFunction.hpp index 3a73fcb35..e699ee615 100644 --- a/source/backend/cpu/x86_x64/avx/GemmFunction.hpp +++ b/source/backend/cpu/x86_x64/avx/GemmFunction.hpp @@ -623,6 +623,7 @@ static void _AVX_MNNPackedMatMul_4(TYPE* C, const TYPE* A, const TYPE* B, const STORE_4(dst + 8 * 3, z9); } } + template static void _AVX_MNNPackednMatMulRemainCommon(TYPE* C, const TYPE* A, const TYPE* B, size_t eSize, const size_t* parameter) { @@ -793,3 +794,1637 @@ static void _AVX_MNNPackednMatMulRemainCommon(TYPE* C, const TYPE* A, const TYPE STORE_4(dst, sum); } } + +#ifdef MNN_LOW_MEMORY +//----------------------- MatMul(float, int4) Functions ---------------------------// + +#define LOAD_WEIGHT_ALPHA_BIAS_int4x4 \ + auto weight0 = B + (hC4Unit * y + 0) * bStride / 2;\ + auto weight1 = B + (hC4Unit * y + 1) * bStride / 2;\ + auto weight2 = B + (hC4Unit * y + 2) * bStride / 2;\ + auto weight3 = B + (hC4Unit * y + 3) * bStride / 2;\ + auto alpha0 = _mm_loadu_ps(k + y * 16 + 0);\ + auto alpha1 = _mm_loadu_ps(k + y * 16 + 4);\ + auto alpha2 = _mm_loadu_ps(k + y * 16 + 8);\ + auto alpha3 = _mm_loadu_ps(k + y * 16 + 12);\ + auto bias0 = _mm_loadu_ps(b + y * 16 + 0);\ + auto bias1 = _mm_loadu_ps(b + y * 16 + 4);\ + auto bias2 = _mm_loadu_ps(b + y * 16 + 8);\ + auto bias3 = _mm_loadu_ps(b + y * 16 + 12); + +#define LOAD_ALPHA_BIAS_DOUBLE \ + auto alpha0_2 = _mm256_set_m128(alpha0, alpha0);\ + auto alpha1_2 = _mm256_set_m128(alpha1, alpha1);\ + auto alpha2_2 = _mm256_set_m128(alpha2, alpha2);\ + auto alpha3_2 = _mm256_set_m128(alpha3, alpha3);\ + auto bias0_2 = _mm256_set_m128(bias0, bias0);\ + auto bias1_2 = _mm256_set_m128(bias1, bias1);\ + auto bias2_2 = _mm256_set_m128(bias2, bias2);\ + auto bias3_2 = _mm256_set_m128(bias3, bias3); + +static inline __m128 _load_int4x4(const uint8_t* src, __m128 alpha, __m128 bias) { + auto w01 = src[0]; + auto w23 = src[1]; + int iw01 = w01; + int iw23 = w23; + int iw0 = iw01 / 16; + int iw1 = iw01 % 16; + int iw2 = iw23 / 16; + int iw3 = iw23 % 16; + auto ws = _mm_set_ps(iw3, iw2, iw1, iw0); + ws = _mm_sub_ps(ws, _mm_set1_ps(7)); + ws = _mm_add_ps(_mm_mul_ps(ws, alpha), bias); + return ws; +} + +static inline __m256 _load_int4x8(const uint8_t* src, __m256 alpha, __m256 bias) { + float w[8]; + for (int i = 0; i < 4; i++) { + int x = src[i]; + int a = x / 16; + int b = x % 16; + w[i * 2] = a - 7; + w[i * 2 + 1] = b - 7; + } + auto w8 = LOAD8(w); + return _mm256_add_ps(_mm256_mul_ps(w8, alpha), bias); +} + + +template +static void _AVX_MNNPackedMatMul_Main_int4(TYPE* C, const TYPE* A, const TYPE* fB, const size_t* parameter, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = LOAD8(A + 0 * 24); + auto s1 = LOAD8(A + 0 * 24 + 8); + auto s2 = LOAD8(A + 0 * 24 + 16); + auto ws = _load_int4x4(weight, alpha, bias); + auto w0 = _mm256_set1_ps(ws[0]); + auto w1 = _mm256_set1_ps(ws[1]); + auto w2 = _mm256_set1_ps(ws[2]); + auto w3 = _mm256_set1_ps(ws[3]); + auto z0 = _mm256_mul_ps(s0, w0); + auto z1 = _mm256_mul_ps(s1, w0); + auto z2 = _mm256_mul_ps(s2, w0); + auto z3 = _mm256_mul_ps(s0, w1); + auto z4 = _mm256_mul_ps(s1, w1); + auto z5 = _mm256_mul_ps(s2, w1); + auto z6 = _mm256_mul_ps(s0, w2); + auto z7 = _mm256_mul_ps(s1, w2); + auto z8 = _mm256_mul_ps(s2, w2); + auto z9 = _mm256_mul_ps(s0, w3); + auto z10 = _mm256_mul_ps(s1, w3); + auto z11 = _mm256_mul_ps(s2, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = LOAD8(A + sy * 24); + s1 = LOAD8(A + sy * 24 + 8); + s2 = LOAD8(A + sy * 24 + 16); + ws = _load_int4x4(weight + sy * 2, alpha, bias); + w0 = _mm256_set1_ps(ws[0]); + w1 = _mm256_set1_ps(ws[1]); + w2 = _mm256_set1_ps(ws[2]); + w3 = _mm256_set1_ps(ws[3]); + z0 = MNNAVXFMA(s0, w0, z0); + z1 = MNNAVXFMA(s1, w0, z1); + z2 = MNNAVXFMA(s2, w0, z2); + z3 = MNNAVXFMA(s0, w1, z3); + z4 = MNNAVXFMA(s1, w1, z4); + z5 = MNNAVXFMA(s2, w1, z5); + z6 = MNNAVXFMA(s0, w2, z6); + z7 = MNNAVXFMA(s1, w2, z7); + z8 = MNNAVXFMA(s2, w2, z8); + z9 = MNNAVXFMA(s0, w3, z9); + z10 = MNNAVXFMA(s1, w3, z10); + z11 = MNNAVXFMA(s2, w3, z11); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); + TRANPOSE_SAVE(1, 2, z2, z5, z8, z11); + } +} + + +template +static void _AVX_MNNPackedMatMul_int4_20(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = LOAD8(A + 0 * aStride); + auto s1 = LOAD8(A + 0 * aStride + 8); + auto s2 = EXPAND_128(LOAD4(A + 0 * aStride + 16)); + auto ws = _load_int4x4(weight, alpha, bias); + auto w0 = _mm256_set1_ps(ws[0]); + auto w1 = _mm256_set1_ps(ws[1]); + auto w2 = _mm256_set1_ps(ws[2]); + auto w3 = _mm256_set1_ps(ws[3]); + auto z0 = _mm256_mul_ps(s0, w0); + auto z1 = _mm256_mul_ps(s1, w0); + auto z2 = _mm256_mul_ps(s2, w0); + auto z3 = _mm256_mul_ps(s0, w1); + auto z4 = _mm256_mul_ps(s1, w1); + auto z5 = _mm256_mul_ps(s2, w1); + auto z6 = _mm256_mul_ps(s0, w2); + auto z7 = _mm256_mul_ps(s1, w2); + auto z8 = _mm256_mul_ps(s2, w2); + auto z9 = _mm256_mul_ps(s0, w3); + auto z10 = _mm256_mul_ps(s1, w3); + auto z11 = _mm256_mul_ps(s2, w3); + for (int sy = 1; sy < l; ++sy) { + s0 = LOAD8(A + sy * aStride); + s1 = LOAD8(A + sy * aStride + 8); + s2 = EXPAND_128(LOAD4(A + sy * aStride + 16)); + ws = _load_int4x4(weight + sy * 2, alpha, bias); + w0 = _mm256_set1_ps(ws[0]); + w1 = _mm256_set1_ps(ws[1]); + w2 = _mm256_set1_ps(ws[2]); + w3 = _mm256_set1_ps(ws[3]); + z0 = MNNAVXFMA(s0, w0, z0); + z1 = MNNAVXFMA(s1, w0, z1); + z2 = MNNAVXFMA(s2, w0, z2); + z3 = MNNAVXFMA(s0, w1, z3); + z4 = MNNAVXFMA(s1, w1, z4); + z5 = MNNAVXFMA(s2, w1, z5); + z6 = MNNAVXFMA(s0, w2, z6); + z7 = MNNAVXFMA(s1, w2, z7); + z8 = MNNAVXFMA(s2, w2, z8); + z9 = MNNAVXFMA(s0, w3, z9); + z10 = MNNAVXFMA(s1, w3, z10); + z11 = MNNAVXFMA(s2, w3, z11); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); + } +} + +template +static void _AVX_MNNPackedMatMul_int4_16(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = LOAD8(A + 0 * aStride); + auto s1 = LOAD8(A + 0 * aStride + 8); + auto ws = _load_int4x4(weight, alpha, bias); + auto w0 = _mm256_set1_ps(ws[0]); + auto w1 = _mm256_set1_ps(ws[1]); + auto w2 = _mm256_set1_ps(ws[2]); + auto w3 = _mm256_set1_ps(ws[3]); + auto z0 = _mm256_mul_ps(s0, w0); + auto z1 = _mm256_mul_ps(s1, w0); + auto z3 = _mm256_mul_ps(s0, w1); + auto z4 = _mm256_mul_ps(s1, w1); + auto z6 = _mm256_mul_ps(s0, w2); + auto z7 = _mm256_mul_ps(s1, w2); + auto z9 = _mm256_mul_ps(s0, w3); + auto z10 = _mm256_mul_ps(s1, w3); + for (int sy = 1; sy < l; ++sy) { + s0 = LOAD8(A + sy * aStride); + s1 = LOAD8(A + sy * aStride + 8); + ws = _load_int4x4(weight + sy * 2, alpha, bias); + w0 = _mm256_set1_ps(ws[0]); + w1 = _mm256_set1_ps(ws[1]); + w2 = _mm256_set1_ps(ws[2]); + w3 = _mm256_set1_ps(ws[3]); + z0 = MNNAVXFMA(s0, w0, z0); + z1 = MNNAVXFMA(s1, w0, z1); + z3 = MNNAVXFMA(s0, w1, z3); + z4 = MNNAVXFMA(s1, w1, z4); + z6 = MNNAVXFMA(s0, w2, z6); + z7 = MNNAVXFMA(s1, w2, z7); + z9 = MNNAVXFMA(s0, w3, z9); + z10 = MNNAVXFMA(s1, w3, z10); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); + } +} + +template +static void _AVX_MNNPackedMatMul_int4_5(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + for (int y = 0; y < hC16; ++y) { + LOAD_WEIGHT_ALPHA_BIAS_int4x4 + DST_ADDR_UNPACK4(0); + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto sumAvx20 = _mm256_setzero_ps(); + auto sumAvx21 = _mm256_setzero_ps(); + + auto sumAvx30 = _mm256_setzero_ps(); + auto sumAvx31 = _mm256_setzero_ps(); + + auto sumAvx40 = _mm256_setzero_ps(); + auto sumAvx41 = _mm256_setzero_ps(); + + auto srcUse = src; + for (int sy = 0; sy < l; ++sy) { + auto S0 = BROAD_LOAD(srcUse + 0); + auto S1 = BROAD_LOAD(srcUse + 1); + auto S2 = BROAD_LOAD(srcUse + 2); + auto S3 = BROAD_LOAD(srcUse + 3); + auto S4 = BROAD_LOAD(srcUse + 4); + auto w0 = _load_int4x4(weight0, alpha0, bias0); + auto w1 = _load_int4x4(weight1, alpha1, bias1); + auto w2 = _load_int4x4(weight2, alpha2, bias2); + auto w3 = _load_int4x4(weight3, alpha3, bias3); + auto W0 = _mm256_set_m128(w1, w0); + auto W1 = _mm256_set_m128(w3, w2); + + sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); + sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); + + sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); + + sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); + sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); + + sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); + sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); + + sumAvx40 = MNNAVXFMA(S4, W0, sumAvx40); + sumAvx41 = MNNAVXFMA(S4, W1, sumAvx41); + + srcUse += aStride; + weight0 += 2; + weight1 += 2; + weight2 += 2; + weight3 += 2; + } + STORE_8(dst0, sumAvx00); + STORE_8(dst0 + 8, sumAvx10); + STORE_8(dst0 + 16, sumAvx20); + STORE_8(dst0 + 24, sumAvx30); + STORE_8(dst0 + 32, sumAvx40); + + STORE_8(dst2, sumAvx01); + STORE_8(dst2 + 8, sumAvx11); + STORE_8(dst2 + 16, sumAvx21); + STORE_8(dst2 + 24, sumAvx31); + STORE_8(dst2 + 32, sumAvx41); + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); + auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); + auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); + auto s3 = BROAD_LOAD_4(A + 0 * aStride + 3); + auto s4 = BROAD_LOAD_4(A + 0 * aStride + 4); + auto w0 = _load_int4x4(weight, alpha, bias); + auto z0 = _mm_mul_ps(s0, w0); + auto z1 = _mm_mul_ps(s1, w0); + auto z2 = _mm_mul_ps(s2, w0); + auto z3 = _mm_mul_ps(s3, w0); + auto z4 = _mm_mul_ps(s4, w0); + + for (int sy = 1; sy < l; ++sy) { + s0 = BROAD_LOAD_4(A + sy * aStride + 0); + s1 = BROAD_LOAD_4(A + sy * aStride + 1); + s2 = BROAD_LOAD_4(A + sy * aStride + 2); + s3 = BROAD_LOAD_4(A + sy * aStride + 3); + s4 = BROAD_LOAD_4(A + sy * aStride + 4); + w0 = _load_int4x4(weight + sy * 2, alpha, bias); + z0 = MNNSSEFMA(s0, w0, z0); + z1 = MNNSSEFMA(s1, w0, z1); + z2 = MNNSSEFMA(s2, w0, z2); + z3 = MNNSSEFMA(s3, w0, z3); + z4 = MNNSSEFMA(s4, w0, z4); + } + STORE_4(dst + 8 * 0, z0); + STORE_4(dst + 8 * 1, z1); + STORE_4(dst + 8 * 2, z2); + STORE_4(dst + 8 * 3, z3); + STORE_4(dst + 8 * 4, z4); + } +} + + +template +static void _AVX_MNNPackedMatMul_int4_4(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + for (int y = 0; y < hC16; ++y) { + LOAD_WEIGHT_ALPHA_BIAS_int4x4 + DST_ADDR_UNPACK4(0); + + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto sumAvx20 = _mm256_setzero_ps(); + auto sumAvx21 = _mm256_setzero_ps(); + + auto sumAvx30 = _mm256_setzero_ps(); + auto sumAvx31 = _mm256_setzero_ps(); + + auto srcUse = src; + for (int sy = 0; sy < l; ++sy) { + auto S0 = BROAD_LOAD(srcUse + 0); + auto S1 = BROAD_LOAD(srcUse + 1); + auto S2 = BROAD_LOAD(srcUse + 2); + auto S3 = BROAD_LOAD(srcUse + 3); + auto w0 = _load_int4x4(weight0, alpha0, bias0); + auto w1 = _load_int4x4(weight1, alpha1, bias1); + auto w2 = _load_int4x4(weight2, alpha2, bias2); + auto w3 = _load_int4x4(weight3, alpha3, bias3); + auto W0 = _mm256_set_m128(w1, w0); + auto W1 = _mm256_set_m128(w3, w2); + + sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); + sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); + + sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); + + sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); + sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); + + sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); + sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); + + srcUse += aStride; + weight0 += 2; + weight1 += 2; + weight2 += 2; + weight3 += 2; + } + STORE_8(dst0, sumAvx00); + STORE_8(dst0 + 8, sumAvx10); + STORE_8(dst0 + 16, sumAvx20); + STORE_8(dst0 + 24, sumAvx30); + + STORE_8(dst2, sumAvx01); + STORE_8(dst2 + 8, sumAvx11); + STORE_8(dst2 + 16, sumAvx21); + STORE_8(dst2 + 24, sumAvx31); + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = LOAD4(A + 0 * aStride); + auto ws = _load_int4x4(weight, alpha, bias); + auto w0 = _mm_set1_ps(ws[0]); + auto w1 = _mm_set1_ps(ws[1]); + auto w2 = _mm_set1_ps(ws[2]); + auto w3 = _mm_set1_ps(ws[3]); + auto z0 = _mm_mul_ps(s0, w0); + auto z3 = _mm_mul_ps(s0, w1); + auto z6 = _mm_mul_ps(s0, w2); + auto z9 = _mm_mul_ps(s0, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = LOAD4(A + sy * aStride); + ws = _load_int4x4(weight + sy * 2, alpha, bias); + w0 = _mm_set1_ps(ws[0]); + w1 = _mm_set1_ps(ws[1]); + w2 = _mm_set1_ps(ws[2]); + w3 = _mm_set1_ps(ws[3]); + z0 = MNNSSEFMA(s0, w0, z0); + z3 = MNNSSEFMA(s0, w1, z3); + z6 = MNNSSEFMA(s0, w2, z6); + z9 = MNNSSEFMA(s0, w3, z9); + } + _MM_TRANSPOSE4_PS(z0, z3, z6, z9); + STORE_4(dst + 8 * 0, z0); + STORE_4(dst + 8 * 1, z3); + STORE_4(dst + 8 * 2, z6); + STORE_4(dst + 8 * 3, z9); + } +} +template +static void _AVX_MNNPackedMatMul_int4_3(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + for (int y = 0; y < hC16; ++y) { + LOAD_WEIGHT_ALPHA_BIAS_int4x4 + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto sumAvx20 = _mm256_setzero_ps(); + auto sumAvx21 = _mm256_setzero_ps(); + + DST_ADDR_UNPACK4(0); + + auto srcUse = src; + for (int sy = 0; sy < l; ++sy) { + auto S0 = BROAD_LOAD(srcUse + 0); + auto S1 = BROAD_LOAD(srcUse + 1); + auto S2 = BROAD_LOAD(srcUse + 2); + auto w0 = _load_int4x4(weight0, alpha0, bias0); + auto w1 = _load_int4x4(weight1, alpha1, bias1); + auto w2 = _load_int4x4(weight2, alpha2, bias2); + auto w3 = _load_int4x4(weight3, alpha3, bias3); + auto W0 = _mm256_set_m128(w1, w0); + auto W1 = _mm256_set_m128(w3, w2); + + sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); + sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); + + sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); + + sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); + sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); + + srcUse += aStride; + weight0 += 2; + weight1 += 2; + weight2 += 2; + weight3 += 2; + } + STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); + STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); + STORE_4(dst0 + 16, _mm256_extractf128_ps(sumAvx20, 0)); + + STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); + STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); + STORE_4(dst1 + 16, _mm256_extractf128_ps(sumAvx20, 1)); + + STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); + STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); + STORE_4(dst2 + 16, _mm256_extractf128_ps(sumAvx21, 0)); + + STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); + STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); + STORE_4(dst3 + 16, _mm256_extractf128_ps(sumAvx21, 1)); + + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); + auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); + auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); + auto w0 = _load_int4x4(weight, alpha, bias); + auto z0 = _mm_mul_ps(s0, w0); + auto z1 = _mm_mul_ps(s1, w0); + auto z2 = _mm_mul_ps(s2, w0); + + for (int sy = 1; sy < l; ++sy) { + s0 = BROAD_LOAD_4(A + sy * aStride + 0); + s1 = BROAD_LOAD_4(A + sy * aStride + 1); + s2 = BROAD_LOAD_4(A + sy * aStride + 2); + w0 = _load_int4x4(weight + sy * 2, alpha, bias); + z0 = MNNSSEFMA(s0, w0, z0); + z1 = MNNSSEFMA(s1, w0, z1); + z2 = MNNSSEFMA(s2, w0, z2); + } + STORE_4(dst + 8 * 0, z0); + STORE_4(dst + 8 * 1, z1); + STORE_4(dst + 8 * 2, z2); + } +} + +template +static void _AVX_MNNPackedMatMul_int4_2(TYPE* C, const TYPE* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + for (int y = 0; y < hC16; ++y) { + LOAD_WEIGHT_ALPHA_BIAS_int4x4 + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + DST_ADDR_UNPACK4(0); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto srcUse = src; + for (int sy = 0; sy < l; ++sy) { + auto S0 = BROAD_LOAD(srcUse + 0); + auto S1 = BROAD_LOAD(srcUse + 1); + auto w0 = _load_int4x4(weight0, alpha0, bias0); + auto w1 = _load_int4x4(weight1, alpha1, bias1); + auto w2 = _load_int4x4(weight2, alpha2, bias2); + auto w3 = _load_int4x4(weight3, alpha3, bias3); + auto W0 = _mm256_set_m128(w1, w0); + auto W1 = _mm256_set_m128(w3, w2); + + sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); + sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); + + sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); + + srcUse += aStride; + weight0 += 2; + weight1 += 2; + weight2 += 2; + weight3 += 2; + } + STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); + STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); + + STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); + STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); + + STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); + STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); + + STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); + STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); + + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); + auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); + auto w0 = _load_int4x4(weight, alpha, bias); + auto z0 = _mm_mul_ps(s0, w0); + auto z1 = _mm_mul_ps(s1, w0); + + for (int sy = 1; sy < l; ++sy) { + s0 = BROAD_LOAD_4(A + sy * aStride + 0); + s1 = BROAD_LOAD_4(A + sy * aStride + 1); + w0 = _load_int4x4(weight + sy * 2, alpha, bias); + z0 = MNNSSEFMA(s0, w0, z0); + z1 = MNNSSEFMA(s1, w0, z1); + } + STORE_4(dst + 8 * 0, z0); + STORE_4(dst + 8 * 1, z1); + } +} + +template +static void _AVX_MNNPackednMatMulRemainCommon_int4(TYPE* C, const TYPE* A, const TYPE* fB, size_t eSize, + const size_t* parameter, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + auto es = eSize; + auto oC = C; + auto aStride = parameter[0] / sizeof(TYPE); + if (eSize >= 20) { + _AVX_MNNPackedMatMul_int4_20(C, A, B, parameter, k, b); + eSize -= 20; + C += 20 * 8; + A += 20; + } + if (eSize >= 16) { + _AVX_MNNPackedMatMul_int4_16(C, A, B, parameter, k, b); + eSize -= 16; + C += 16 * 8; + A += 16; + } + while (eSize >= 5) { + _AVX_MNNPackedMatMul_int4_5(C, A, B, parameter, k, b); + eSize -= 5; + C += 5 * 8; + A += 5; + } + if (eSize == 4) { + _AVX_MNNPackedMatMul_int4_4(C, A, B, parameter, k, b); + return; + } + if (eSize == 3) { + _AVX_MNNPackedMatMul_int4_3(C, A, B, parameter, k, b); + return; + } + if (eSize == 2) { + _AVX_MNNPackedMatMul_int4_2(C, A, B, parameter, k, b); + return; + } + if (eSize == 0) { + return; + } + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + int x = 0; + for (int y = 0; y < hC16; ++y) { + auto dst0 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8; + auto dst1 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4; + auto dst2 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8; + auto dst3 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4; + LOAD_WEIGHT_ALPHA_BIAS_int4x4 + LOAD_ALPHA_BIAS_DOUBLE + + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto sumAvx20 = _mm256_setzero_ps(); + auto sumAvx21 = _mm256_setzero_ps(); + + auto sumAvx30 = _mm256_setzero_ps(); + auto sumAvx31 = _mm256_setzero_ps(); + + auto srcUse = src; + for (int sy = 0; sy < lC4; ++sy) { + auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); + auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); + auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); + auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); + auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); + auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); + + auto W00 = _load_int4x8(weight0 + 8 * sy + 0, alpha0_2, bias0_2); + auto W01 = _load_int4x8(weight0 + 8 * sy + 4, alpha0_2, bias0_2); + auto W10 = _load_int4x8(weight1 + 8 * sy + 0, alpha1_2, bias1_2); + auto W11 = _load_int4x8(weight1 + 8 * sy + 4, alpha1_2, bias1_2); + + auto W20 = _load_int4x8(weight2 + 8 * sy + 0, alpha2_2, bias2_2); + auto W21 = _load_int4x8(weight2 + 8 * sy + 4, alpha2_2, bias2_2); + auto W30 = _load_int4x8(weight3 + 8 * sy + 0, alpha3_2, bias3_2); + auto W31 = _load_int4x8(weight3 + 8 * sy + 4, alpha3_2, bias3_2); + + sumAvx00 = MNNAVXFMA(S0, W00, sumAvx00); + sumAvx01 = MNNAVXFMA(S1, W01, sumAvx01); + + sumAvx10 = MNNAVXFMA(S0, W10, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W11, sumAvx11); + + sumAvx20 = MNNAVXFMA(S0, W20, sumAvx20); + sumAvx21 = MNNAVXFMA(S1, W21, sumAvx21); + + sumAvx30 = MNNAVXFMA(S0, W30, sumAvx30); + sumAvx31 = MNNAVXFMA(S1, W31, sumAvx31); + srcUse += 4 * aStride; + } + sumAvx00 = _mm256_add_ps(sumAvx00, sumAvx01); + sumAvx10 = _mm256_add_ps(sumAvx10, sumAvx11); + sumAvx20 = _mm256_add_ps(sumAvx20, sumAvx21); + sumAvx30 = _mm256_add_ps(sumAvx30, sumAvx31); + auto sum00 = _mm256_extractf128_ps(sumAvx00, 0); + auto sum01 = _mm256_extractf128_ps(sumAvx00, 1); + auto sum0 = _mm_add_ps(sum00, sum01); + auto sum10 = _mm256_extractf128_ps(sumAvx10, 0); + auto sum11 = _mm256_extractf128_ps(sumAvx10, 1); + auto sum1 = _mm_add_ps(sum10, sum11); + + auto sum20 = _mm256_extractf128_ps(sumAvx20, 0); + auto sum21 = _mm256_extractf128_ps(sumAvx20, 1); + auto sum2 = _mm_add_ps(sum20, sum21); + auto sum30 = _mm256_extractf128_ps(sumAvx30, 0); + auto sum31 = _mm256_extractf128_ps(sumAvx30, 1); + auto sum3 = _mm_add_ps(sum30, sum31); + for (int sy = lR; sy < l; ++sy) { + auto s = BROAD_LOAD_4(srcUse); + auto w0 = _load_int4x4(weight0 + 2 * sy, alpha0, bias0); + auto w1 = _load_int4x4(weight1 + 2 * sy, alpha1, bias1); + auto w2 = _load_int4x4(weight2 + 2 * sy, alpha2, bias2); + auto w3 = _load_int4x4(weight3 + 2 * sy, alpha3, bias3); + sum0 = MNNSSEFMA(s, w0, sum0); + sum1 = MNNSSEFMA(s, w1, sum1); + sum2 = MNNSSEFMA(s, w2, sum2); + sum3 = MNNSSEFMA(s, w3, sum3); + srcUse += aStride; + } + STORE_4(dst0, sum0); + STORE_4(dst1, sum1); + STORE_4(dst2, sum2); + STORE_4(dst3, sum3); + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + (y / 2) * cStride + x * 8 + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto alpha_2 = _mm256_set_m128(alpha, alpha); + auto bias_2 = _mm256_set_m128(bias, bias); + + auto sumAvx0 = _mm256_setzero_ps(); + auto sumAvx1 = _mm256_setzero_ps(); + auto srcUse = src; + for (int sy = 0; sy < lC4; ++sy) { + auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); + auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); + auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); + auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); + auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); + auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); + auto W0 = _load_int4x8(weight + 8 * sy + 0, alpha_2, bias_2); + auto W1 = _load_int4x8(weight + 8 * sy + 4, alpha_2, bias_2); + sumAvx0 = MNNAVXFMA(S0, W0, sumAvx0); + sumAvx1 = MNNAVXFMA(S1, W1, sumAvx1); + srcUse += 4 * aStride; + } + sumAvx0 = _mm256_add_ps(sumAvx0, sumAvx1); + auto sum0 = _mm256_extractf128_ps(sumAvx0, 0); + auto sum1 = _mm256_extractf128_ps(sumAvx0, 1); + auto sum = _mm_add_ps(sum0, sum1); + for (int sy = lR; sy < l; ++sy) { + auto s = BROAD_LOAD_4(srcUse); + auto w = _load_int4x4(weight + sy * 2, alpha, bias); + sum = MNNSSEFMA(s, w, sum); + srcUse += aStride; + } + STORE_4(dst, sum); + } +} + +//----------------------- MatMul(float, int8) Functions ---------------------------// + +#define LOAD_WEIGHT_ALPHA_BIAS_int8x4 \ + auto weight0 = B + (hC4Unit * y + 0) * bStride;\ + auto weight1 = B + (hC4Unit * y + 1) * bStride;\ + auto weight2 = B + (hC4Unit * y + 2) * bStride;\ + auto weight3 = B + (hC4Unit * y + 3) * bStride;\ + auto alpha0 = _mm_loadu_ps(k + y * 16 + 0);\ + auto alpha1 = _mm_loadu_ps(k + y * 16 + 4);\ + auto alpha2 = _mm_loadu_ps(k + y * 16 + 8);\ + auto alpha3 = _mm_loadu_ps(k + y * 16 + 12);\ + auto bias0 = _mm_loadu_ps(b + y * 16 + 0);\ + auto bias1 = _mm_loadu_ps(b + y * 16 + 4);\ + auto bias2 = _mm_loadu_ps(b + y * 16 + 8);\ + auto bias3 = _mm_loadu_ps(b + y * 16 + 12); + +static inline __m128 _load_int8x4(const int8_t* src, __m128 alpha, __m128 bias) { + int iw0 = src[0]; + int iw1 = src[1]; + int iw2 = src[2]; + int iw3 = src[3]; + auto ws = _mm_set_ps(iw3, iw2, iw1, iw0); + ws = _mm_add_ps(_mm_mul_ps(ws, alpha), bias); + return ws; +} + +static inline __m256 _load_int8x8(const int8_t* src, __m256 alpha, __m256 bias) { + float w[8]; + for (int i = 0; i < 8; i++) { + w[i] = int(src[i]); + } + auto w8 = LOAD8(w); + return _mm256_add_ps(_mm256_mul_ps(w8, alpha), bias); +} + + +template +static void _AVX_MNNPackedMatMul_Main_int8(TYPE* C, const TYPE* A, const TYPE* fB, const size_t* parameter, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = LOAD8(A + 0 * 24); + auto s1 = LOAD8(A + 0 * 24 + 8); + auto s2 = LOAD8(A + 0 * 24 + 16); + auto ws = _load_int8x4(weight, alpha, bias); + auto w0 = _mm256_set1_ps(ws[0]); + auto w1 = _mm256_set1_ps(ws[1]); + auto w2 = _mm256_set1_ps(ws[2]); + auto w3 = _mm256_set1_ps(ws[3]); + auto z0 = _mm256_mul_ps(s0, w0); + auto z1 = _mm256_mul_ps(s1, w0); + auto z2 = _mm256_mul_ps(s2, w0); + auto z3 = _mm256_mul_ps(s0, w1); + auto z4 = _mm256_mul_ps(s1, w1); + auto z5 = _mm256_mul_ps(s2, w1); + auto z6 = _mm256_mul_ps(s0, w2); + auto z7 = _mm256_mul_ps(s1, w2); + auto z8 = _mm256_mul_ps(s2, w2); + auto z9 = _mm256_mul_ps(s0, w3); + auto z10 = _mm256_mul_ps(s1, w3); + auto z11 = _mm256_mul_ps(s2, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = LOAD8(A + sy * 24); + s1 = LOAD8(A + sy * 24 + 8); + s2 = LOAD8(A + sy * 24 + 16); + ws = _load_int8x4(weight + sy * 4, alpha, bias); + w0 = _mm256_set1_ps(ws[0]); + w1 = _mm256_set1_ps(ws[1]); + w2 = _mm256_set1_ps(ws[2]); + w3 = _mm256_set1_ps(ws[3]); + z0 = MNNAVXFMA(s0, w0, z0); + z1 = MNNAVXFMA(s1, w0, z1); + z2 = MNNAVXFMA(s2, w0, z2); + z3 = MNNAVXFMA(s0, w1, z3); + z4 = MNNAVXFMA(s1, w1, z4); + z5 = MNNAVXFMA(s2, w1, z5); + z6 = MNNAVXFMA(s0, w2, z6); + z7 = MNNAVXFMA(s1, w2, z7); + z8 = MNNAVXFMA(s2, w2, z8); + z9 = MNNAVXFMA(s0, w3, z9); + z10 = MNNAVXFMA(s1, w3, z10); + z11 = MNNAVXFMA(s2, w3, z11); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); + TRANPOSE_SAVE(1, 2, z2, z5, z8, z11); + } +} + + +template +static void _AVX_MNNPackedMatMul_int8_20(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = LOAD8(A + 0 * aStride); + auto s1 = LOAD8(A + 0 * aStride + 8); + auto s2 = EXPAND_128(LOAD4(A + 0 * aStride + 16)); + auto ws = _load_int8x4(weight, alpha, bias); + auto w0 = _mm256_set1_ps(ws[0]); + auto w1 = _mm256_set1_ps(ws[1]); + auto w2 = _mm256_set1_ps(ws[2]); + auto w3 = _mm256_set1_ps(ws[3]); + auto z0 = _mm256_mul_ps(s0, w0); + auto z1 = _mm256_mul_ps(s1, w0); + auto z2 = _mm256_mul_ps(s2, w0); + auto z3 = _mm256_mul_ps(s0, w1); + auto z4 = _mm256_mul_ps(s1, w1); + auto z5 = _mm256_mul_ps(s2, w1); + auto z6 = _mm256_mul_ps(s0, w2); + auto z7 = _mm256_mul_ps(s1, w2); + auto z8 = _mm256_mul_ps(s2, w2); + auto z9 = _mm256_mul_ps(s0, w3); + auto z10 = _mm256_mul_ps(s1, w3); + auto z11 = _mm256_mul_ps(s2, w3); + for (int sy = 1; sy < l; ++sy) { + s0 = LOAD8(A + sy * aStride); + s1 = LOAD8(A + sy * aStride + 8); + s2 = EXPAND_128(LOAD4(A + sy * aStride + 16)); + ws = _load_int8x4(weight + sy * 4, alpha, bias); + w0 = _mm256_set1_ps(ws[0]); + w1 = _mm256_set1_ps(ws[1]); + w2 = _mm256_set1_ps(ws[2]); + w3 = _mm256_set1_ps(ws[3]); + z0 = MNNAVXFMA(s0, w0, z0); + z1 = MNNAVXFMA(s1, w0, z1); + z2 = MNNAVXFMA(s2, w0, z2); + z3 = MNNAVXFMA(s0, w1, z3); + z4 = MNNAVXFMA(s1, w1, z4); + z5 = MNNAVXFMA(s2, w1, z5); + z6 = MNNAVXFMA(s0, w2, z6); + z7 = MNNAVXFMA(s1, w2, z7); + z8 = MNNAVXFMA(s2, w2, z8); + z9 = MNNAVXFMA(s0, w3, z9); + z10 = MNNAVXFMA(s1, w3, z10); + z11 = MNNAVXFMA(s2, w3, z11); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); + } +} + +template +static void _AVX_MNNPackedMatMul_int8_16(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = LOAD8(A + 0 * aStride); + auto s1 = LOAD8(A + 0 * aStride + 8); + auto ws = _load_int8x4(weight, alpha, bias); + auto w0 = _mm256_set1_ps(ws[0]); + auto w1 = _mm256_set1_ps(ws[1]); + auto w2 = _mm256_set1_ps(ws[2]); + auto w3 = _mm256_set1_ps(ws[3]); + auto z0 = _mm256_mul_ps(s0, w0); + auto z1 = _mm256_mul_ps(s1, w0); + auto z3 = _mm256_mul_ps(s0, w1); + auto z4 = _mm256_mul_ps(s1, w1); + auto z6 = _mm256_mul_ps(s0, w2); + auto z7 = _mm256_mul_ps(s1, w2); + auto z9 = _mm256_mul_ps(s0, w3); + auto z10 = _mm256_mul_ps(s1, w3); + for (int sy = 1; sy < l; ++sy) { + s0 = LOAD8(A + sy * aStride); + s1 = LOAD8(A + sy * aStride + 8); + ws = _load_int8x4(weight + sy * 4, alpha, bias); + w0 = _mm256_set1_ps(ws[0]); + w1 = _mm256_set1_ps(ws[1]); + w2 = _mm256_set1_ps(ws[2]); + w3 = _mm256_set1_ps(ws[3]); + z0 = MNNAVXFMA(s0, w0, z0); + z1 = MNNAVXFMA(s1, w0, z1); + z3 = MNNAVXFMA(s0, w1, z3); + z4 = MNNAVXFMA(s1, w1, z4); + z6 = MNNAVXFMA(s0, w2, z6); + z7 = MNNAVXFMA(s1, w2, z7); + z9 = MNNAVXFMA(s0, w3, z9); + z10 = MNNAVXFMA(s1, w3, z10); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(1, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(1, 1, z1, z4, z7, z10); + } +} + +template +static void _AVX_MNNPackedMatMul_int8_5(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + for (int y = 0; y < hC16; ++y) { + LOAD_WEIGHT_ALPHA_BIAS_int8x4 + DST_ADDR_UNPACK4(0); + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto sumAvx20 = _mm256_setzero_ps(); + auto sumAvx21 = _mm256_setzero_ps(); + + auto sumAvx30 = _mm256_setzero_ps(); + auto sumAvx31 = _mm256_setzero_ps(); + + auto sumAvx40 = _mm256_setzero_ps(); + auto sumAvx41 = _mm256_setzero_ps(); + + auto srcUse = src; + for (int sy = 0; sy < l; ++sy) { + auto S0 = BROAD_LOAD(srcUse + 0); + auto S1 = BROAD_LOAD(srcUse + 1); + auto S2 = BROAD_LOAD(srcUse + 2); + auto S3 = BROAD_LOAD(srcUse + 3); + auto S4 = BROAD_LOAD(srcUse + 4); + auto w0 = _load_int8x4(weight0, alpha0, bias0); + auto w1 = _load_int8x4(weight1, alpha1, bias1); + auto w2 = _load_int8x4(weight2, alpha2, bias2); + auto w3 = _load_int8x4(weight3, alpha3, bias3); + auto W0 = _mm256_set_m128(w1, w0); + auto W1 = _mm256_set_m128(w3, w2); + + sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); + sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); + + sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); + + sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); + sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); + + sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); + sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); + + sumAvx40 = MNNAVXFMA(S4, W0, sumAvx40); + sumAvx41 = MNNAVXFMA(S4, W1, sumAvx41); + + srcUse += aStride; + weight0 += 4; + weight1 += 4; + weight2 += 4; + weight3 += 4; + } + STORE_8(dst0, sumAvx00); + STORE_8(dst0 + 8, sumAvx10); + STORE_8(dst0 + 16, sumAvx20); + STORE_8(dst0 + 24, sumAvx30); + STORE_8(dst0 + 32, sumAvx40); + + STORE_8(dst2, sumAvx01); + STORE_8(dst2 + 8, sumAvx11); + STORE_8(dst2 + 16, sumAvx21); + STORE_8(dst2 + 24, sumAvx31); + STORE_8(dst2 + 32, sumAvx41); + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); + auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); + auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); + auto s3 = BROAD_LOAD_4(A + 0 * aStride + 3); + auto s4 = BROAD_LOAD_4(A + 0 * aStride + 4); + auto w0 = _load_int8x4(weight, alpha, bias); + auto z0 = _mm_mul_ps(s0, w0); + auto z1 = _mm_mul_ps(s1, w0); + auto z2 = _mm_mul_ps(s2, w0); + auto z3 = _mm_mul_ps(s3, w0); + auto z4 = _mm_mul_ps(s4, w0); + + for (int sy = 1; sy < l; ++sy) { + s0 = BROAD_LOAD_4(A + sy * aStride + 0); + s1 = BROAD_LOAD_4(A + sy * aStride + 1); + s2 = BROAD_LOAD_4(A + sy * aStride + 2); + s3 = BROAD_LOAD_4(A + sy * aStride + 3); + s4 = BROAD_LOAD_4(A + sy * aStride + 4); + w0 = _load_int8x4(weight + sy * 4, alpha, bias); + z0 = MNNSSEFMA(s0, w0, z0); + z1 = MNNSSEFMA(s1, w0, z1); + z2 = MNNSSEFMA(s2, w0, z2); + z3 = MNNSSEFMA(s3, w0, z3); + z4 = MNNSSEFMA(s4, w0, z4); + } + STORE_4(dst + 8 * 0, z0); + STORE_4(dst + 8 * 1, z1); + STORE_4(dst + 8 * 2, z2); + STORE_4(dst + 8 * 3, z3); + STORE_4(dst + 8 * 4, z4); + } +} + + +template +static void _AVX_MNNPackedMatMul_int8_4(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + for (int y = 0; y < hC16; ++y) { + LOAD_WEIGHT_ALPHA_BIAS_int8x4 + DST_ADDR_UNPACK4(0); + + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto sumAvx20 = _mm256_setzero_ps(); + auto sumAvx21 = _mm256_setzero_ps(); + + auto sumAvx30 = _mm256_setzero_ps(); + auto sumAvx31 = _mm256_setzero_ps(); + + auto srcUse = src; + for (int sy = 0; sy < l; ++sy) { + auto S0 = BROAD_LOAD(srcUse + 0); + auto S1 = BROAD_LOAD(srcUse + 1); + auto S2 = BROAD_LOAD(srcUse + 2); + auto S3 = BROAD_LOAD(srcUse + 3); + auto w0 = _load_int8x4(weight0, alpha0, bias0); + auto w1 = _load_int8x4(weight1, alpha1, bias1); + auto w2 = _load_int8x4(weight2, alpha2, bias2); + auto w3 = _load_int8x4(weight3, alpha3, bias3); + auto W0 = _mm256_set_m128(w1, w0); + auto W1 = _mm256_set_m128(w3, w2); + + sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); + sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); + + sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); + + sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); + sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); + + sumAvx30 = MNNAVXFMA(S3, W0, sumAvx30); + sumAvx31 = MNNAVXFMA(S3, W1, sumAvx31); + + srcUse += aStride; + weight0 += 4; + weight1 += 4; + weight2 += 4; + weight3 += 4; + } + STORE_8(dst0, sumAvx00); + STORE_8(dst0 + 8, sumAvx10); + STORE_8(dst0 + 16, sumAvx20); + STORE_8(dst0 + 24, sumAvx30); + + STORE_8(dst2, sumAvx01); + STORE_8(dst2 + 8, sumAvx11); + STORE_8(dst2 + 16, sumAvx21); + STORE_8(dst2 + 24, sumAvx31); + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = LOAD4(A + 0 * aStride); + auto ws = _load_int8x4(weight, alpha, bias); + auto w0 = _mm_set1_ps(ws[0]); + auto w1 = _mm_set1_ps(ws[1]); + auto w2 = _mm_set1_ps(ws[2]); + auto w3 = _mm_set1_ps(ws[3]); + auto z0 = _mm_mul_ps(s0, w0); + auto z3 = _mm_mul_ps(s0, w1); + auto z6 = _mm_mul_ps(s0, w2); + auto z9 = _mm_mul_ps(s0, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = LOAD4(A + sy * aStride); + ws = _load_int8x4(weight + sy * 4, alpha, bias); + w0 = _mm_set1_ps(ws[0]); + w1 = _mm_set1_ps(ws[1]); + w2 = _mm_set1_ps(ws[2]); + w3 = _mm_set1_ps(ws[3]); + z0 = MNNSSEFMA(s0, w0, z0); + z3 = MNNSSEFMA(s0, w1, z3); + z6 = MNNSSEFMA(s0, w2, z6); + z9 = MNNSSEFMA(s0, w3, z9); + } + _MM_TRANSPOSE4_PS(z0, z3, z6, z9); + STORE_4(dst + 8 * 0, z0); + STORE_4(dst + 8 * 1, z3); + STORE_4(dst + 8 * 2, z6); + STORE_4(dst + 8 * 3, z9); + } +} +template +static void _AVX_MNNPackedMatMul_int8_3(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + for (int y = 0; y < hC16; ++y) { + LOAD_WEIGHT_ALPHA_BIAS_int8x4 + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto sumAvx20 = _mm256_setzero_ps(); + auto sumAvx21 = _mm256_setzero_ps(); + + DST_ADDR_UNPACK4(0); + + auto srcUse = src; + for (int sy = 0; sy < l; ++sy) { + auto S0 = BROAD_LOAD(srcUse + 0); + auto S1 = BROAD_LOAD(srcUse + 1); + auto S2 = BROAD_LOAD(srcUse + 2); + auto w0 = _load_int8x4(weight0, alpha0, bias0); + auto w1 = _load_int8x4(weight1, alpha1, bias1); + auto w2 = _load_int8x4(weight2, alpha2, bias2); + auto w3 = _load_int8x4(weight3, alpha3, bias3); + auto W0 = _mm256_set_m128(w1, w0); + auto W1 = _mm256_set_m128(w3, w2); + + sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); + sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); + + sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); + + sumAvx20 = MNNAVXFMA(S2, W0, sumAvx20); + sumAvx21 = MNNAVXFMA(S2, W1, sumAvx21); + + srcUse += aStride; + weight0 += 4; + weight1 += 4; + weight2 += 4; + weight3 += 4; + } + STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); + STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); + STORE_4(dst0 + 16, _mm256_extractf128_ps(sumAvx20, 0)); + + STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); + STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); + STORE_4(dst1 + 16, _mm256_extractf128_ps(sumAvx20, 1)); + + STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); + STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); + STORE_4(dst2 + 16, _mm256_extractf128_ps(sumAvx21, 0)); + + STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); + STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); + STORE_4(dst3 + 16, _mm256_extractf128_ps(sumAvx21, 1)); + + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); + auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); + auto s2 = BROAD_LOAD_4(A + 0 * aStride + 2); + auto w0 = _load_int8x4(weight, alpha, bias); + auto z0 = _mm_mul_ps(s0, w0); + auto z1 = _mm_mul_ps(s1, w0); + auto z2 = _mm_mul_ps(s2, w0); + + for (int sy = 1; sy < l; ++sy) { + s0 = BROAD_LOAD_4(A + sy * aStride + 0); + s1 = BROAD_LOAD_4(A + sy * aStride + 1); + s2 = BROAD_LOAD_4(A + sy * aStride + 2); + w0 = _load_int8x4(weight + sy * 4, alpha, bias); + z0 = MNNSSEFMA(s0, w0, z0); + z1 = MNNSSEFMA(s1, w0, z1); + z2 = MNNSSEFMA(s2, w0, z2); + } + STORE_4(dst + 8 * 0, z0); + STORE_4(dst + 8 * 1, z1); + STORE_4(dst + 8 * 2, z2); + } +} + +template +static void _AVX_MNNPackedMatMul_int8_2(TYPE* C, const TYPE* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(TYPE); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + for (int y = 0; y < hC16; ++y) { + LOAD_WEIGHT_ALPHA_BIAS_int8x4 + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + DST_ADDR_UNPACK4(0); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto srcUse = src; + for (int sy = 0; sy < l; ++sy) { + auto S0 = BROAD_LOAD(srcUse + 0); + auto S1 = BROAD_LOAD(srcUse + 1); + auto w0 = _load_int8x4(weight0, alpha0, bias0); + auto w1 = _load_int8x4(weight1, alpha1, bias1); + auto w2 = _load_int8x4(weight2, alpha2, bias2); + auto w3 = _load_int8x4(weight3, alpha3, bias3); + auto W0 = _mm256_set_m128(w1, w0); + auto W1 = _mm256_set_m128(w3, w2); + + sumAvx00 = MNNAVXFMA(S0, W0, sumAvx00); + sumAvx01 = MNNAVXFMA(S0, W1, sumAvx01); + + sumAvx10 = MNNAVXFMA(S1, W0, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W1, sumAvx11); + + srcUse += aStride; + weight0 += 4; + weight1 += 4; + weight2 += 4; + weight3 += 4; + } + STORE_4(dst0 + 0, _mm256_extractf128_ps(sumAvx00, 0)); + STORE_4(dst0 + 8, _mm256_extractf128_ps(sumAvx10, 0)); + + STORE_4(dst1 + 0, _mm256_extractf128_ps(sumAvx00, 1)); + STORE_4(dst1 + 8, _mm256_extractf128_ps(sumAvx10, 1)); + + STORE_4(dst2 + 0, _mm256_extractf128_ps(sumAvx01, 0)); + STORE_4(dst2 + 8, _mm256_extractf128_ps(sumAvx11, 0)); + + STORE_4(dst3 + 0, _mm256_extractf128_ps(sumAvx01, 1)); + STORE_4(dst3 + 8, _mm256_extractf128_ps(sumAvx11, 1)); + + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + (y / 2) * cStride + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = BROAD_LOAD_4(A + 0 * aStride + 0); + auto s1 = BROAD_LOAD_4(A + 0 * aStride + 1); + auto w0 = _load_int8x4(weight, alpha, bias); + auto z0 = _mm_mul_ps(s0, w0); + auto z1 = _mm_mul_ps(s1, w0); + + for (int sy = 1; sy < l; ++sy) { + s0 = BROAD_LOAD_4(A + sy * aStride + 0); + s1 = BROAD_LOAD_4(A + sy * aStride + 1); + w0 = _load_int8x4(weight + sy * 4, alpha, bias); + z0 = MNNSSEFMA(s0, w0, z0); + z1 = MNNSSEFMA(s1, w0, z1); + } + STORE_4(dst + 8 * 0, z0); + STORE_4(dst + 8 * 1, z1); + } +} + +template +static void _AVX_MNNPackednMatMulRemainCommon_int8(TYPE* C, const TYPE* A, const TYPE* fB, size_t eSize, + const size_t* parameter, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(TYPE); + auto bExtraStride = parameter[5] / sizeof(TYPE); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + auto es = eSize; + auto oC = C; + auto aStride = parameter[0] / sizeof(TYPE); + if (eSize >= 20) { + _AVX_MNNPackedMatMul_int8_20(C, A, B, parameter, k, b); + eSize -= 20; + C += 20 * 8; + A += 20; + } + if (eSize >= 16) { + _AVX_MNNPackedMatMul_int8_16(C, A, B, parameter, k, b); + eSize -= 16; + C += 16 * 8; + A += 16; + } + while (eSize >= 5) { + _AVX_MNNPackedMatMul_int8_5(C, A, B, parameter, k, b); + eSize -= 5; + C += 5 * 8; + A += 5; + } + if (eSize == 4) { + _AVX_MNNPackedMatMul_int8_4(C, A, B, parameter, k, b); + return; + } + if (eSize == 3) { + _AVX_MNNPackedMatMul_int8_3(C, A, B, parameter, k, b); + return; + } + if (eSize == 2) { + _AVX_MNNPackedMatMul_int8_2(C, A, B, parameter, k, b); + return; + } + if (eSize == 0) { + return; + } + int lC4 = l / 4; + int lR = lC4 * 4; + const int hC4Unit = 4; + int hC16 = hC4 / hC4Unit; + int hR = hC16 * hC4Unit; + auto src = A; + int x = 0; + for (int y = 0; y < hC16; ++y) { + auto dst0 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8; + auto dst1 = C + (hC4Unit * y / 2 + 0) * cStride + x * 8 + 4; + auto dst2 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8; + auto dst3 = C + (hC4Unit * y / 2 + 1) * cStride + x * 8 + 4; + LOAD_WEIGHT_ALPHA_BIAS_int8x4 + LOAD_ALPHA_BIAS_DOUBLE + + auto sumAvx00 = _mm256_setzero_ps(); + auto sumAvx01 = _mm256_setzero_ps(); + + auto sumAvx10 = _mm256_setzero_ps(); + auto sumAvx11 = _mm256_setzero_ps(); + + auto sumAvx20 = _mm256_setzero_ps(); + auto sumAvx21 = _mm256_setzero_ps(); + + auto sumAvx30 = _mm256_setzero_ps(); + auto sumAvx31 = _mm256_setzero_ps(); + + auto srcUse = src; + for (int sy = 0; sy < lC4; ++sy) { + auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); + auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); + auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); + auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); + auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); + auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); + + auto W00 = _load_int8x8(weight0 + 16 * sy + 0, alpha0_2, bias0_2); + auto W01 = _load_int8x8(weight0 + 16 * sy + 8, alpha0_2, bias0_2); + auto W10 = _load_int8x8(weight1 + 16 * sy + 0, alpha1_2, bias1_2); + auto W11 = _load_int8x8(weight1 + 16 * sy + 8, alpha1_2, bias1_2); + + auto W20 = _load_int8x8(weight2 + 16 * sy + 0, alpha2_2, bias2_2); + auto W21 = _load_int8x8(weight2 + 16 * sy + 8, alpha2_2, bias2_2); + auto W30 = _load_int8x8(weight3 + 16 * sy + 0, alpha3_2, bias3_2); + auto W31 = _load_int8x8(weight3 + 16 * sy + 8, alpha3_2, bias3_2); + + sumAvx00 = MNNAVXFMA(S0, W00, sumAvx00); + sumAvx01 = MNNAVXFMA(S1, W01, sumAvx01); + + sumAvx10 = MNNAVXFMA(S0, W10, sumAvx10); + sumAvx11 = MNNAVXFMA(S1, W11, sumAvx11); + + sumAvx20 = MNNAVXFMA(S0, W20, sumAvx20); + sumAvx21 = MNNAVXFMA(S1, W21, sumAvx21); + + sumAvx30 = MNNAVXFMA(S0, W30, sumAvx30); + sumAvx31 = MNNAVXFMA(S1, W31, sumAvx31); + srcUse += 4 * aStride; + } + sumAvx00 = _mm256_add_ps(sumAvx00, sumAvx01); + sumAvx10 = _mm256_add_ps(sumAvx10, sumAvx11); + sumAvx20 = _mm256_add_ps(sumAvx20, sumAvx21); + sumAvx30 = _mm256_add_ps(sumAvx30, sumAvx31); + auto sum00 = _mm256_extractf128_ps(sumAvx00, 0); + auto sum01 = _mm256_extractf128_ps(sumAvx00, 1); + auto sum0 = _mm_add_ps(sum00, sum01); + auto sum10 = _mm256_extractf128_ps(sumAvx10, 0); + auto sum11 = _mm256_extractf128_ps(sumAvx10, 1); + auto sum1 = _mm_add_ps(sum10, sum11); + + auto sum20 = _mm256_extractf128_ps(sumAvx20, 0); + auto sum21 = _mm256_extractf128_ps(sumAvx20, 1); + auto sum2 = _mm_add_ps(sum20, sum21); + auto sum30 = _mm256_extractf128_ps(sumAvx30, 0); + auto sum31 = _mm256_extractf128_ps(sumAvx30, 1); + auto sum3 = _mm_add_ps(sum30, sum31); + for (int sy = lR; sy < l; ++sy) { + auto s = BROAD_LOAD_4(srcUse); + auto w0 = _load_int8x4(weight0 + 4 * sy, alpha0, bias0); + auto w1 = _load_int8x4(weight1 + 4 * sy, alpha1, bias1); + auto w2 = _load_int8x4(weight2 + 4 * sy, alpha2, bias2); + auto w3 = _load_int8x4(weight3 + 4 * sy, alpha3, bias3); + sum0 = MNNSSEFMA(s, w0, sum0); + sum1 = MNNSSEFMA(s, w1, sum1); + sum2 = MNNSSEFMA(s, w2, sum2); + sum3 = MNNSSEFMA(s, w3, sum3); + srcUse += aStride; + } + STORE_4(dst0, sum0); + STORE_4(dst1, sum1); + STORE_4(dst2, sum2); + STORE_4(dst3, sum3); + } + for (int y = hR; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + (y / 2) * cStride + x * 8 + 4 * (y % 2); + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto alpha_2 = _mm256_set_m128(alpha, alpha); + auto bias_2 = _mm256_set_m128(bias, bias); + + auto sumAvx0 = _mm256_setzero_ps(); + auto sumAvx1 = _mm256_setzero_ps(); + auto srcUse = src; + for (int sy = 0; sy < lC4; ++sy) { + auto s0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (0) * aStride)); + auto s1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (1) * aStride)); + auto S0 = _mm256_castsi256_ps(_mm256_insertf128_si256(s0, s1, 1)); + auto d0 = _mm256_castps_si256(BROAD_LOAD(srcUse + (2) * aStride)); + auto d1 = _mm_castps_si128(BROAD_LOAD_4(srcUse + (3) * aStride)); + auto S1 = _mm256_castsi256_ps(_mm256_insertf128_si256(d0, d1, 1)); + auto W0 = _load_int8x8(weight + 16 * sy + 0, alpha_2, bias_2); + auto W1 = _load_int8x8(weight + 16 * sy + 8, alpha_2, bias_2); + sumAvx0 = MNNAVXFMA(S0, W0, sumAvx0); + sumAvx1 = MNNAVXFMA(S1, W1, sumAvx1); + srcUse += 4 * aStride; + } + sumAvx0 = _mm256_add_ps(sumAvx0, sumAvx1); + auto sum0 = _mm256_extractf128_ps(sumAvx0, 0); + auto sum1 = _mm256_extractf128_ps(sumAvx0, 1); + auto sum = _mm_add_ps(sum0, sum1); + for (int sy = lR; sy < l; ++sy) { + auto s = BROAD_LOAD_4(srcUse); + auto w = _load_int8x4(weight + sy * 4, alpha, bias); + sum = MNNSSEFMA(s, w, sum); + srcUse += aStride; + } + STORE_4(dst, sum); + } +} + +#endif diff --git a/source/backend/cpu/x86_x64/avx/MathFunctions.cpp b/source/backend/cpu/x86_x64/avx/MathFunctions.cpp index 4e971fb94..9a6afa305 100644 --- a/source/backend/cpu/x86_x64/avx/MathFunctions.cpp +++ b/source/backend/cpu/x86_x64/avx/MathFunctions.cpp @@ -263,4 +263,4 @@ void _AVX_MNNNorm(float *dst, const float *src, const float *gamma, const float dst[i] = (src[i] - mean) * variable; } } -} +} \ No newline at end of file diff --git a/source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp b/source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp index d277e1789..b7e776da3 100644 --- a/source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp @@ -36,8 +36,8 @@ extern "C" { void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose); void _AVX512_MNNPackC8ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); -void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); -void _AVX512_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); +void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void _AVX512_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void _AVX512_MNNGetSparseMatMulPackMode(int* eP, int *lP, int* hP); void _AVX512_MNNPackedSparseMatMulEpx8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, diff --git a/source/backend/cpu/x86_x64/avx512/GemmCommon.cpp b/source/backend/cpu/x86_x64/avx512/GemmCommon.cpp index cf9120904..0911129fc 100644 --- a/source/backend/cpu/x86_x64/avx512/GemmCommon.cpp +++ b/source/backend/cpu/x86_x64/avx512/GemmCommon.cpp @@ -1131,7 +1131,7 @@ static void _AVX512_MNNPackednMatMulRemainCommon(float* C, const float* A, const } } -void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias) { +void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { #ifdef MNN_X86_USE_ASM if (nullptr == postParameters) { _AVX512_MNNGemmFloatUnit48x8(C, A, B, parameter); @@ -1147,7 +1147,7 @@ void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const siz } //#define MNN_X86_DEBUG -void _AVX512_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias) { +void _AVX512_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { #ifdef MNN_X86_DEBUG static std::set gSize; if (gSize.find(eSize) == gSize.end()) { diff --git a/source/backend/cpu/x86_x64/avxfma/FunctionSummary.hpp b/source/backend/cpu/x86_x64/avxfma/FunctionSummary.hpp index 2ac3d3147..577449db1 100644 --- a/source/backend/cpu/x86_x64/avxfma/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/avxfma/FunctionSummary.hpp @@ -35,12 +35,12 @@ extern "C" { void _AVX_MNNPackedMatMulFMA(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias); -void _AVX_MNNPackedMatMulRemainFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); +void _AVX_MNNPackedMatMulRemainFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void _AVX_MNNComputeMatMulForE_1FMA(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); void _AVX_MNNPackedMatMulFMA_BF16(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias); -void _AVX_MNNPackedMatMulRemainFMA_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); +void _AVX_MNNPackedMatMulRemainFMA_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void _AVX_MNNComputeMatMulForH_1FMA(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); void _AVX_MNNGeluFMA(float *dst, const float *src, size_t size, float* parameters); void _AVX_MNNExpC8FMA(float* dest, const float* source, const float* offset, const float* parameters, size_t countC8); diff --git a/source/backend/cpu/x86_x64/avxfma/GemmAVX2FMA.cpp b/source/backend/cpu/x86_x64/avxfma/GemmAVX2FMA.cpp index 63bda0d05..404d4755e 100644 --- a/source/backend/cpu/x86_x64/avxfma/GemmAVX2FMA.cpp +++ b/source/backend/cpu/x86_x64/avxfma/GemmAVX2FMA.cpp @@ -28,7 +28,7 @@ void _AVX_MNNGemmFloatUnitMainFMA_Fused(float* C, const float* A, const float* B #endif void _AVX_MNNPackedMatMulFMA(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias) { + const float* postParameters, const float* bias, const float* k, const float* b) { auto h = parameter[2]; auto cStride = parameter[3] / sizeof(float); #ifdef MNN_X86_USE_ASM @@ -54,7 +54,7 @@ void _AVX_MNNPackedMatMulFMA(float* C, const float* A, const float* B, const siz #endif } -void _AVX_MNNPackedMatMulRemainFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias) { +void _AVX_MNNPackedMatMulRemainFMA(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { _AVX_MNNPackednMatMulRemainCommon(C, A, B, eSize, parameter); AVX2GemmPostTreat(C, eSize, parameter, postParameters, bias); } diff --git a/source/backend/cpu/x86_x64/avxfma/GemmAVX2FMABF16.cpp b/source/backend/cpu/x86_x64/avxfma/GemmAVX2FMABF16.cpp index 982722932..827ab542b 100644 --- a/source/backend/cpu/x86_x64/avxfma/GemmAVX2FMABF16.cpp +++ b/source/backend/cpu/x86_x64/avxfma/GemmAVX2FMABF16.cpp @@ -120,11 +120,11 @@ void AVX2GemmPostTreatBF16(float* CO, size_t eSize, const size_t* parameter, con } void _AVX_MNNPackedMatMulFMA_BF16(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias) { + const float* postParameters, const float* bias, const float* k, const float* b) { _AVX_MNNPackedMatMul_3((int16_t*)C, (const int16_t*)A, (const int16_t*)B, parameter); AVX2GemmPostTreatBF16(C, 3, parameter, postParameters, bias); } -void _AVX_MNNPackedMatMulRemainFMA_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias) { +void _AVX_MNNPackedMatMulRemainFMA_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { _AVX_MNNPackednMatMulRemainCommon((int16_t*)C, (const int16_t*)A, (const int16_t*)B, eSize, parameter); AVX2GemmPostTreatBF16(C, eSize, parameter, postParameters, bias); } diff --git a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp index c8b0d9a64..3d0c6b595 100644 --- a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp @@ -46,9 +46,19 @@ void _SSE_MNNStrassenMergeCFunction(float* c11, float* c12, float* c21, float* c size_t length, size_t hSub); void _SSE_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); void _SSE_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, - const float* postParameters, const float* bias); + const float* postParameters, const float* bias, const float* k, const float* b); +#ifdef MNN_LOW_MEMORY +void _SSE_MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); +void _SSE_MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); +void _SSE_MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); +void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b); +#endif void _SSE_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void _SSE_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, diff --git a/source/backend/cpu/x86_x64/sse/GemmFunction.hpp b/source/backend/cpu/x86_x64/sse/GemmFunction.hpp index aa988cb5b..b7507abaa 100644 --- a/source/backend/cpu/x86_x64/sse/GemmFunction.hpp +++ b/source/backend/cpu/x86_x64/sse/GemmFunction.hpp @@ -200,3 +200,429 @@ static void _SSE_MNNPackednMatMulRemainCommon(float* C, const float* A, const fl } } } + +#ifdef MNN_LOW_MEMORY +//----------------------- MatMul(float, int4) Functions ---------------------------// +static inline __m128 _load_int4x4(const uint8_t* src, __m128 alpha, __m128 bias) { + auto w01 = src[0]; + auto w23 = src[1]; + int iw01 = w01; + int iw23 = w23; + int iw0 = iw01 / 16; + int iw1 = iw01 % 16; + int iw2 = iw23 / 16; + int iw3 = iw23 % 16; + auto ws = _mm_set_ps(iw3, iw2, iw1, iw0); + ws = _mm_sub_ps(ws, _mm_set1_ps(7)); + ws = _mm_add_ps(_mm_mul_ps(ws, alpha), bias); + return ws; +} + +static void _SSE_MNNPackedMatMul_12_int4(float* C, const float* A, const float* fB, const size_t* parameter, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + y * cStride; + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = _mm_loadu_ps(A + 0 * 12); + auto s1 = _mm_loadu_ps(A + 0 * 12 + 4); + auto s2 = _mm_loadu_ps(A + 0 * 12 + 8); + auto ws = _load_int4x4(weight, alpha, bias); + auto w0 = _mm_set1_ps(ws[0]); + auto w1 = _mm_set1_ps(ws[1]); + auto w2 = _mm_set1_ps(ws[2]); + auto w3 = _mm_set1_ps(ws[3]); + auto z0 = _mm_mul_ps(s0, w0); + auto z1 = _mm_mul_ps(s1, w0); + auto z2 = _mm_mul_ps(s2, w0); + auto z3 = _mm_mul_ps(s0, w1); + auto z4 = _mm_mul_ps(s1, w1); + auto z5 = _mm_mul_ps(s2, w1); + auto z6 = _mm_mul_ps(s0, w2); + auto z7 = _mm_mul_ps(s1, w2); + auto z8 = _mm_mul_ps(s2, w2); + auto z9 = _mm_mul_ps(s0, w3); + auto z10 = _mm_mul_ps(s1, w3); + auto z11 = _mm_mul_ps(s2, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = _mm_loadu_ps(A + sy * 12); + s1 = _mm_loadu_ps(A + sy * 12 + 4); + s2 = _mm_loadu_ps(A + sy * 12 + 8); + ws = _load_int4x4(weight + sy * 2, alpha, bias); + w0 = _mm_set1_ps(ws[0]); + w1 = _mm_set1_ps(ws[1]); + w2 = _mm_set1_ps(ws[2]); + w3 = _mm_set1_ps(ws[3]); + z0 = MNNSSEFMA(s0, w0, z0); + z1 = MNNSSEFMA(s1, w0, z1); + z2 = MNNSSEFMA(s2, w0, z2); + z3 = MNNSSEFMA(s0, w1, z3); + z4 = MNNSSEFMA(s1, w1, z4); + z5 = MNNSSEFMA(s2, w1, z5); + z6 = MNNSSEFMA(s0, w2, z6); + z7 = MNNSSEFMA(s1, w2, z7); + z8 = MNNSSEFMA(s2, w2, z8); + z9 = MNNSSEFMA(s0, w3, z9); + z10 = MNNSSEFMA(s1, w3, z10); + z11 = MNNSSEFMA(s2, w3, z11); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); + } +} + +static void _SSE_MNNPackedMatMul_8_int4(float* C, const float* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(float); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + y * cStride; + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = _mm_loadu_ps(A + 0 * aStride); + auto s1 = _mm_loadu_ps(A + 0 * aStride + 4); + auto ws = _load_int4x4(weight, alpha, bias); + auto w0 = _mm_set1_ps(ws[0]); + auto w1 = _mm_set1_ps(ws[1]); + auto w2 = _mm_set1_ps(ws[2]); + auto w3 = _mm_set1_ps(ws[3]); + auto z0 = _mm_mul_ps(s0, w0); + auto z3 = _mm_mul_ps(s0, w1); + auto z6 = _mm_mul_ps(s0, w2); + auto z9 = _mm_mul_ps(s0, w3); + auto z1 = _mm_mul_ps(s1, w0); + auto z4 = _mm_mul_ps(s1, w1); + auto z7 = _mm_mul_ps(s1, w2); + auto z10 = _mm_mul_ps(s1, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = _mm_loadu_ps(A + sy * aStride); + s1 = _mm_loadu_ps(A + sy * aStride + 4); + ws = _load_int4x4(weight + sy * 2, alpha, bias); + w0 = _mm_set1_ps(ws[0]); + w1 = _mm_set1_ps(ws[1]); + w2 = _mm_set1_ps(ws[2]); + w3 = _mm_set1_ps(ws[3]); + z0 = MNNSSEFMA(s0, w0, z0); + z3 = MNNSSEFMA(s0, w1, z3); + z6 = MNNSSEFMA(s0, w2, z6); + z9 = MNNSSEFMA(s0, w3, z9); + z1 = MNNSSEFMA(s1, w0, z1); + z4 = MNNSSEFMA(s1, w1, z4); + z7 = MNNSSEFMA(s1, w2, z7); + z10 = MNNSSEFMA(s1, w3, z10); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + } +} + +static void _SSE_MNNPackedMatMul_4_int4(float* C, const float* A, const uint8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(float); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + y * cStride; + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = _mm_loadu_ps(A + 0 * aStride); + auto ws = _load_int4x4(weight, alpha, bias); + auto w0 = _mm_set1_ps(ws[0]); + auto w1 = _mm_set1_ps(ws[1]); + auto w2 = _mm_set1_ps(ws[2]); + auto w3 = _mm_set1_ps(ws[3]); + auto z0 = _mm_mul_ps(s0, w0); + auto z3 = _mm_mul_ps(s0, w1); + auto z6 = _mm_mul_ps(s0, w2); + auto z9 = _mm_mul_ps(s0, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = _mm_loadu_ps(A + sy * aStride); + ws = _load_int4x4(weight + sy * 2, alpha, bias); + w0 = _mm_set1_ps(ws[0]); + w1 = _mm_set1_ps(ws[1]); + w2 = _mm_set1_ps(ws[2]); + w3 = _mm_set1_ps(ws[3]); + z0 = MNNSSEFMA(s0, w0, z0); + z3 = MNNSSEFMA(s0, w1, z3); + z6 = MNNSSEFMA(s0, w2, z6); + z9 = MNNSSEFMA(s0, w3, z9); + } + _MM_TRANSPOSE4_PS(z0, z3, z6, z9); + _mm_storeu_ps(dst + 4 * 0, z0); + _mm_storeu_ps(dst + 4 * 1, z3); + _mm_storeu_ps(dst + 4 * 2, z6); + _mm_storeu_ps(dst + 4 * 3, z9); + } +} + +static void _SSE_MNNPackednMatMulRemainCommon_int4(float* C, const float* A, const float* fB, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + auto es = eSize; + auto oC = C; + auto aStride = parameter[0] / sizeof(float); + if (eSize >= 8) { + _SSE_MNNPackedMatMul_8_int4(C, A, B, parameter, k, b); + eSize -= 8; + C += 8 * 4; + A += 8; + } + if (eSize >= 4) { + _SSE_MNNPackedMatMul_4_int4(C, A, B, parameter, k, b); + eSize -= 4; + C += 4 * 4; + A += 4; + } + for (int x = 0; x < eSize; ++x) { + auto src = A + x; + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride / 2; + auto dst = C + y * cStride + x * 4; + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto sum = _mm_set1_ps(0.0f); + for (int sy = 0; sy < l; ++sy) { + auto s = _mm_set1_ps(src[sy * aStride]); + auto w = _load_int4x4(weight + sy * 2, alpha, bias); + sum = MNNSSEFMA(s, w, sum); + } + _mm_storeu_ps(dst, sum); + } + } +} +//----------------------- MatMul(float, int8) Functions ---------------------------// +static inline __m128 _load_int8x4(const int8_t* src, __m128 alpha, __m128 bias) { + int iw0 = src[0]; + int iw1 = src[1]; + int iw2 = src[2]; + int iw3 = src[3]; + auto ws = _mm_set_ps(iw3, iw2, iw1, iw0); + ws = _mm_add_ps(_mm_mul_ps(ws, alpha), bias); + return ws; +} + +static void _SSE_MNNPackedMatMul_12_int8(float* C, const float* A, const float* fB, const size_t* parameter, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + y * cStride; + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = _mm_loadu_ps(A + 0 * 12); + auto s1 = _mm_loadu_ps(A + 0 * 12 + 4); + auto s2 = _mm_loadu_ps(A + 0 * 12 + 8); + auto ws = _load_int8x4(weight, alpha, bias); + auto w0 = _mm_set1_ps(ws[0]); + auto w1 = _mm_set1_ps(ws[1]); + auto w2 = _mm_set1_ps(ws[2]); + auto w3 = _mm_set1_ps(ws[3]); + auto z0 = _mm_mul_ps(s0, w0); + auto z1 = _mm_mul_ps(s1, w0); + auto z2 = _mm_mul_ps(s2, w0); + auto z3 = _mm_mul_ps(s0, w1); + auto z4 = _mm_mul_ps(s1, w1); + auto z5 = _mm_mul_ps(s2, w1); + auto z6 = _mm_mul_ps(s0, w2); + auto z7 = _mm_mul_ps(s1, w2); + auto z8 = _mm_mul_ps(s2, w2); + auto z9 = _mm_mul_ps(s0, w3); + auto z10 = _mm_mul_ps(s1, w3); + auto z11 = _mm_mul_ps(s2, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = _mm_loadu_ps(A + sy * 12); + s1 = _mm_loadu_ps(A + sy * 12 + 4); + s2 = _mm_loadu_ps(A + sy * 12 + 8); + ws = _load_int8x4(weight + sy * 4, alpha, bias); + w0 = _mm_set1_ps(ws[0]); + w1 = _mm_set1_ps(ws[1]); + w2 = _mm_set1_ps(ws[2]); + w3 = _mm_set1_ps(ws[3]); + z0 = MNNSSEFMA(s0, w0, z0); + z1 = MNNSSEFMA(s1, w0, z1); + z2 = MNNSSEFMA(s2, w0, z2); + z3 = MNNSSEFMA(s0, w1, z3); + z4 = MNNSSEFMA(s1, w1, z4); + z5 = MNNSSEFMA(s2, w1, z5); + z6 = MNNSSEFMA(s0, w2, z6); + z7 = MNNSSEFMA(s1, w2, z7); + z8 = MNNSSEFMA(s2, w2, z8); + z9 = MNNSSEFMA(s0, w3, z9); + z10 = MNNSSEFMA(s1, w3, z10); + z11 = MNNSSEFMA(s2, w3, z11); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + TRANPOSE_SAVE(0, 2, z2, z5, z8, z11); + } +} + +static void _SSE_MNNPackedMatMul_8_int8(float* C, const float* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(float); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + y * cStride; + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = _mm_loadu_ps(A + 0 * aStride); + auto s1 = _mm_loadu_ps(A + 0 * aStride + 4); + auto ws = _load_int8x4(weight, alpha, bias); + auto w0 = _mm_set1_ps(ws[0]); + auto w1 = _mm_set1_ps(ws[1]); + auto w2 = _mm_set1_ps(ws[2]); + auto w3 = _mm_set1_ps(ws[3]); + auto z0 = _mm_mul_ps(s0, w0); + auto z3 = _mm_mul_ps(s0, w1); + auto z6 = _mm_mul_ps(s0, w2); + auto z9 = _mm_mul_ps(s0, w3); + auto z1 = _mm_mul_ps(s1, w0); + auto z4 = _mm_mul_ps(s1, w1); + auto z7 = _mm_mul_ps(s1, w2); + auto z10 = _mm_mul_ps(s1, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = _mm_loadu_ps(A + sy * aStride); + s1 = _mm_loadu_ps(A + sy * aStride + 4); + ws = _load_int8x4(weight + sy * 4, alpha, bias); + w0 = _mm_set1_ps(ws[0]); + w1 = _mm_set1_ps(ws[1]); + w2 = _mm_set1_ps(ws[2]); + w3 = _mm_set1_ps(ws[3]); + z0 = MNNSSEFMA(s0, w0, z0); + z3 = MNNSSEFMA(s0, w1, z3); + z6 = MNNSSEFMA(s0, w2, z6); + z9 = MNNSSEFMA(s0, w3, z9); + z1 = MNNSSEFMA(s1, w0, z1); + z4 = MNNSSEFMA(s1, w1, z4); + z7 = MNNSSEFMA(s1, w2, z7); + z10 = MNNSSEFMA(s1, w3, z10); + } + TRANPOSE_SAVE(0, 0, z0, z3, z6, z9); + TRANPOSE_SAVE(0, 1, z1, z4, z7, z10); + } +} + +static void _SSE_MNNPackedMatMul_4_int8(float* C, const float* A, const int8_t* B, const size_t* parameter, const float* k, const float* b) { + auto aStride = parameter[0] / sizeof(float); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + y * cStride; + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto s0 = _mm_loadu_ps(A + 0 * aStride); + auto ws = _load_int8x4(weight, alpha, bias); + auto w0 = _mm_set1_ps(ws[0]); + auto w1 = _mm_set1_ps(ws[1]); + auto w2 = _mm_set1_ps(ws[2]); + auto w3 = _mm_set1_ps(ws[3]); + auto z0 = _mm_mul_ps(s0, w0); + auto z3 = _mm_mul_ps(s0, w1); + auto z6 = _mm_mul_ps(s0, w2); + auto z9 = _mm_mul_ps(s0, w3); + + for (int sy = 1; sy < l; ++sy) { + s0 = _mm_loadu_ps(A + sy * aStride); + ws = _load_int8x4(weight + sy * 4, alpha, bias); + w0 = _mm_set1_ps(ws[0]); + w1 = _mm_set1_ps(ws[1]); + w2 = _mm_set1_ps(ws[2]); + w3 = _mm_set1_ps(ws[3]); + z0 = MNNSSEFMA(s0, w0, z0); + z3 = MNNSSEFMA(s0, w1, z3); + z6 = MNNSSEFMA(s0, w2, z6); + z9 = MNNSSEFMA(s0, w3, z9); + } + _MM_TRANSPOSE4_PS(z0, z3, z6, z9); + _mm_storeu_ps(dst + 4 * 0, z0); + _mm_storeu_ps(dst + 4 * 1, z3); + _mm_storeu_ps(dst + 4 * 2, z6); + _mm_storeu_ps(dst + 4 * 3, z9); + } +} + +static void _SSE_MNNPackednMatMulRemainCommon_int8(float* C, const float* A, const float* fB, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto bExtraStride = parameter[5] / sizeof(float); + auto bStride = bExtraStride + l * 4; + auto hC4 = UP_DIV(h, 4); + auto es = eSize; + auto oC = C; + auto aStride = parameter[0] / sizeof(float); + if (eSize >= 8) { + _SSE_MNNPackedMatMul_8_int8(C, A, B, parameter, k, b); + eSize -= 8; + C += 8 * 4; + A += 8; + } + if (eSize >= 4) { + _SSE_MNNPackedMatMul_4_int8(C, A, B, parameter, k, b); + eSize -= 4; + C += 4 * 4; + A += 4; + } + for (int x = 0; x < eSize; ++x) { + auto src = A + x; + for (int y = 0; y < hC4; ++y) { + auto weight = B + y * bStride; + auto dst = C + y * cStride + x * 4; + auto alpha = _mm_loadu_ps(k + y * 4); + auto bias = _mm_loadu_ps(b + y * 4); + auto sum = _mm_set1_ps(0.0f); + for (int sy = 0; sy < l; ++sy) { + auto s = _mm_set1_ps(src[sy * aStride]); + auto w = _load_int8x4(weight + sy * 4, alpha, bias); + sum = MNNSSEFMA(s, w, sum); + } + _mm_storeu_ps(dst, sum); + } + } +} +#endif diff --git a/source/backend/cpu/x86_x64/sse/GemmSSE.cpp b/source/backend/cpu/x86_x64/sse/GemmSSE.cpp index 0ca7580ed..499591b0e 100644 --- a/source/backend/cpu/x86_x64/sse/GemmSSE.cpp +++ b/source/backend/cpu/x86_x64/sse/GemmSSE.cpp @@ -13,7 +13,7 @@ #include "GemmFunction.hpp" void _SSE_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, - const float* postParameters, const float* bias) { + const float* postParameters, const float* bias, const float* k, const float* b) { auto h = parameter[2]; auto hC4 = UP_DIV(h, 4); auto cStride = parameter[3] / sizeof(float); @@ -22,7 +22,40 @@ void _SSE_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t } void _SSE_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, - const float* postParameters, const float* bias) { + const float* postParameters, const float* bias, const float* k, const float* b) { _SSE_MNNPackednMatMulRemainCommon(C, A, B, eSize, parameter, postParameters, bias); _SSE_GemmPostTreat(C, eSize, parameter, postParameters, bias); } + +#ifdef MNN_LOW_MEMORY +//----------------------- MatMul(float, int4) Functions ---------------------------// +void _SSE_MNNPackedMatMul_int4(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + auto h = parameter[2]; + auto hC4 = UP_DIV(h, 4); + auto cStride = parameter[3] / sizeof(float); + _SSE_MNNPackedMatMul_12_int4(C, A, B, parameter, k, b); + _SSE_GemmPostTreat(C, 12, parameter, postParameters, bias); +} + +void _SSE_MNNPackedMatMulRemain_int4(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + _SSE_MNNPackednMatMulRemainCommon_int4(C, A, B, eSize, parameter, postParameters, bias, k, b); + _SSE_GemmPostTreat(C, eSize, parameter, postParameters, bias); +} + +void _SSE_MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + auto h = parameter[2]; + auto hC4 = UP_DIV(h, 4); + auto cStride = parameter[3] / sizeof(float); + _SSE_MNNPackedMatMul_12_int8(C, A, B, parameter, k, b); + _SSE_GemmPostTreat(C, 12, parameter, postParameters, bias); +} + +void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, + const float* postParameters, const float* bias, const float* k, const float* b) { + _SSE_MNNPackednMatMulRemainCommon_int8(C, A, B, eSize, parameter, postParameters, bias, k, b); + _SSE_GemmPostTreat(C, eSize, parameter, postParameters, bias); +} +#endif diff --git a/source/backend/cuda/execution/ArgMaxExecution.cu b/source/backend/cuda/execution/ArgMaxExecution.cu index a32c326ee..b8fa13182 100644 --- a/source/backend/cuda/execution/ArgMaxExecution.cu +++ b/source/backend/cuda/execution/ArgMaxExecution.cu @@ -48,12 +48,13 @@ __global__ void ARGMAX_SECOND_STEP(const int count, const int outside, const int int idx_output = idx_out * inside + idx_in; const T* inpPtr = inputData + idx_out * dims * inside + idx_in; - int maxIndex = inputIndex[0]; + const int* baseInputIndex = inputIndex + idx_out * dims * inside + idx_in; + int maxIndex = baseInputIndex[0]; T maxValue = inpPtr[0 * inside]; for(int j=1; j(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001))", op, backend); + return new BinaryBufExecution(inputs, "in0-floor(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))*in1", op, backend); default: break; } diff --git a/source/backend/opencl/execution/image/EltwiseExecution.cpp b/source/backend/opencl/execution/image/EltwiseExecution.cpp index 22f001516..a1b262b3c 100644 --- a/source/backend/opencl/execution/image/EltwiseExecution.cpp +++ b/source/backend/opencl/execution/image/EltwiseExecution.cpp @@ -207,7 +207,7 @@ class EltwiseCreator : public OpenCLBackend::Creator { case BinaryOpOperation_NOTEQUAL: return new EltwiseExecution(inputs, "convert_float4(-isnotequal(in0,in1))", op, backend); case BinaryOpOperation_MOD: - return new EltwiseExecution(inputs, "in0-sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001))", op, backend); + return new EltwiseExecution(inputs, "in0-floor(sign(in1)*in0/(fabs(in1)>(FLOAT4)((FLOAT)0.0000001)?fabs(in1):(FLOAT4)((FLOAT)0.0000001)))*in1", op, backend); default: break; } diff --git a/source/core/ConvolutionCommon.cpp b/source/core/ConvolutionCommon.cpp index a85c6d4f8..245631335 100644 --- a/source/core/ConvolutionCommon.cpp +++ b/source/core/ConvolutionCommon.cpp @@ -10,6 +10,7 @@ #include #include "backend/cpu/compute/CommonOptFunction.h" #include "half.hpp" + namespace MNN { static inline void *MNNMemoryAllocAlignZeroAlign(size_t size) { return MNNMemoryCallocAlign(size, MNN_MEMORY_ALIGN_DEFAULT); @@ -185,6 +186,21 @@ static void StreamSizeRead(void *dst, int unit, size_t count, unsigned char *&fi file += (unit * count); } +static bool isFastSample(const std::vector& sample, int bit) { + if (bit != 4) { + return false; + } + if (sample.size() != (1 << bit) - 1) { + return false; + } + for (int i = 0; i < sample.size(); i++) { + if (static_cast(sample[i]) != i - 7) { + return false; + } + } + return true; +} + static int8_t *ReadQuanData_c(unsigned char *&s, size_t* len, ConvolutionCommon::Int8Common* result, bool shapeInt32) { int8_t *blob = nullptr; uint8_t *idxBuf = nullptr; @@ -222,28 +238,40 @@ static int8_t *ReadQuanData_c(unsigned char *&s, size_t* len, ConvolutionCommon: break; } StreamSizeRead(idxBuf, 1, idxBufSize, s); - // split index value into bytes - idxBytes = (uint8_t *)MNNMemoryAllocAlignZeroAlign(dataCnt * sizeof(uint8_t)); - if (idxBitsCnt == 0 || nullptr == idxBytes) { - break; - } - SplitBufToArray(idxBuf, (uint32_t)idxBufSize, idxBytes, (uint32_t)dataCnt, (uint32_t)idxBitsCnt); - int i = 0; blob = (int8_t *)MNNMemoryAllocAlignZeroAlign((size_t)dataCnt); if (nullptr == blob) { break; } - for (i = 0; i < dataCnt; i++) { - if (idxBytes[i] >= sampleCnt) { - MNN_PRINT("iNeedBits is %u\nRead quan weights error with idx:%d\n", idxBitsCnt, (int)idxBytes[i]); + if (isFastSample(result->weightMap, idxBitsCnt)) { + for (int i = 0; i < idxBufSize; i++) { + int val = idxBuf[i]; + int x1 = val / 16; + int x2 = val % 16; + blob[2 * i] = x1 - 7; + blob[2 * i + 1] = x2 - 7; + } + } else { + // split index value into bytes + idxBytes = (uint8_t *)MNNMemoryAllocAlignZeroAlign(dataCnt * sizeof(uint8_t)); + if (idxBitsCnt == 0 || nullptr == idxBytes) { break; } - blob[i] = samples[idxBytes[i]]; - } - if (i < dataCnt) { - MNNMemoryFreeAlign(blob); - blob = nullptr; - break; + SplitBufToArray(idxBuf, (uint32_t)idxBufSize, idxBytes, (uint32_t)dataCnt, (uint32_t)idxBitsCnt); + int i = 0; + for (; i < dataCnt; i++) { + if (idxBytes[i] >= sampleCnt) { + MNN_PRINT("iNeedBits is %u\nRead quan weights error with idx:%d\n", idxBitsCnt, (int)idxBytes[i]); + break; + } + blob[i] = samples[idxBytes[i]]; + } + if (i < dataCnt) { + MNNMemoryFreeAlign(blob); + blob = nullptr; + break; + } + MNNMemoryFreeAlign(idxBytes); + idxBytes = nullptr; } } while (0); diff --git a/test/core/IDSTTest.cpp b/test/core/IDSTTest.cpp index 6939e972b..7dd973159 100644 --- a/test/core/IDSTTest.cpp +++ b/test/core/IDSTTest.cpp @@ -22,7 +22,7 @@ class IDSTTest : public MNNTestCase { std::vector scale(kernelNum, 0.f); std::vector quantWeight(kernelNum * kernelSize, 0); // IDST encode - std::unique_ptr idstQuantT = IDSTEncoder::encode(weight, scale, kernelSize, kernelNum, false, quantWeight.data(), -127); + std::unique_ptr idstQuantT = IDSTEncoder::encode(weight.data(), scale, kernelSize, kernelNum, false, quantWeight.data(), -127); flatbuffers::FlatBufferBuilder builder; auto lastOffset = IDSTQuan::Pack(builder, idstQuantT.get()); builder.Finish(lastOffset); diff --git a/tools/converter/source/common/FullQuantAndCoding.cpp b/tools/converter/source/common/FullQuantAndCoding.cpp index 23d4b8f3c..2f31db6e7 100644 --- a/tools/converter/source/common/FullQuantAndCoding.cpp +++ b/tools/converter/source/common/FullQuantAndCoding.cpp @@ -144,7 +144,7 @@ void FullQuantAndCoding(std::unique_ptr& netT, std::unique_ptr fakeScales(kernelNum, 1.0f); - convParams->quanParameter = IDSTEncoder::encode(quantWeightFloat, fakeScales, kernelSize, kernelNum, asymmetricQuantFlag, quantWeights.data(), wClampMin); + convParams->quanParameter = IDSTEncoder::encode(quantWeightFloat.data(), fakeScales, kernelSize, kernelNum, asymmetricQuantFlag, quantWeights.data(), wClampMin); convParams->weight.clear(); convParams->quanParameter->alpha = std::move(scale); convParams->quanParameter->scaleIn = inputParams.scales(0); diff --git a/tools/converter/source/common/WeightQuantAndCoding.cpp b/tools/converter/source/common/WeightQuantAndCoding.cpp index ed1459d22..f92782246 100644 --- a/tools/converter/source/common/WeightQuantAndCoding.cpp +++ b/tools/converter/source/common/WeightQuantAndCoding.cpp @@ -151,11 +151,11 @@ void WeightQuantAndCoding(std::unique_ptr& op, const modelConfig& conf } if (opType == MNN::OpType_ConvInt8 || opType == MNN::OpType_DepthwiseConvInt8) { - param->quanParameter = IDSTEncoder::encode(weightData, scales, kernelSize, kernelNum, false, param->symmetricQuan->weight.data(), int(clampMin)); + param->quanParameter = IDSTEncoder::encode(weightData.data(), scales, kernelSize, kernelNum, false, param->symmetricQuan->weight.data(), int(clampMin)); param->symmetricQuan->weight.clear(); param->quanParameter->alpha = {1.0f}; // fake scales } else { - param->quanParameter = IDSTEncoder::encode(weightData, scales, kernelSize, kernelNum, asymmetricQuantFlag, quantWeights.data(), int(clampMin)); + param->quanParameter = IDSTEncoder::encode(weightData.data(), scales, kernelSize, kernelNum, asymmetricQuantFlag, quantWeights.data(), int(clampMin)); param->weight.clear(); } }; diff --git a/tools/cpp/ConvertToFullQuant.hpp b/tools/cpp/ConvertToFullQuant.hpp index 206c02251..a7bd45ce3 100644 --- a/tools/cpp/ConvertToFullQuant.hpp +++ b/tools/cpp/ConvertToFullQuant.hpp @@ -133,7 +133,7 @@ void ConvertOp(std::unique_ptr& op, int opIndex, NetT* net, SubGraphProtoT* weightFloat.emplace_back(weight[i] * weightScale[i / ks]); } - conv2D->quanParameter = IDSTEncoder::encode(weightFloat, weightScale, ks, kn, false, weight.data(), aMin); + conv2D->quanParameter = IDSTEncoder::encode(weightFloat.data(), weightScale, ks, kn, false, weight.data(), aMin); conv2D->quanParameter->scaleIn = scaleIn; conv2D->quanParameter->scaleOut = scaleOut; conv2D->symmetricQuan->weight.clear(); diff --git a/tools/cpp/IDSTEncoder.hpp b/tools/cpp/IDSTEncoder.hpp index 908469fe1..afef60404 100644 --- a/tools/cpp/IDSTEncoder.hpp +++ b/tools/cpp/IDSTEncoder.hpp @@ -395,12 +395,12 @@ static void WriteSparseQuanBlobs(std::ostream &out, const float* weightData, con delete[] data_buf; } -static std::unique_ptr encode(const std::vector& weight, const std::vector& scale, int kernelSize, int kernelNum, +static std::unique_ptr encode(const float* weight, const std::vector& scale, int kernelSize, int kernelNum, bool asymmetricQuantFlag, const int8_t* quantWeightPtr, const int clampMin) { std::ostringstream outputStringStreamCQ, outputStringStreamSQ; bool shapeUseInt32 = false; - WriteCQBlobs(outputStringStreamCQ, weight.data(), scale.data(), kernelSize, kernelNum, asymmetricQuantFlag, shapeUseInt32); - WriteSparseQuanBlobs(outputStringStreamSQ, weight.data(), scale.data(), kernelSize, kernelNum, asymmetricQuantFlag, shapeUseInt32); + WriteCQBlobs(outputStringStreamCQ, weight, scale.data(), kernelSize, kernelNum, asymmetricQuantFlag, shapeUseInt32); + WriteSparseQuanBlobs(outputStringStreamSQ, weight, scale.data(), kernelSize, kernelNum, asymmetricQuantFlag, shapeUseInt32); std::unique_ptr idst(new IDSTQuanT); auto cqStr = outputStringStreamCQ.str(); auto sqStr = outputStringStreamSQ.str(); diff --git a/tools/cpp/revertMNNModel.cpp b/tools/cpp/revertMNNModel.cpp index c86816af9..cfe32352b 100644 --- a/tools/cpp/revertMNNModel.cpp +++ b/tools/cpp/revertMNNModel.cpp @@ -77,8 +77,8 @@ void Revert::writeExtraDescribeTensor(float* scale, float* offset) { const int weightSize = param->weight.size(); param->common->inputCount = weightSize / (channels * param->common->kernelX * param->common->kernelY); std::vector quantizedWeight(weightSize, 1); - std::vector quantizedWeightScale(outputChannel, 0.008); - param->quanParameter = IDSTEncoder::encode(param->weight, quantizedWeightScale, weightSize/channels, channels, false, quantizedWeight.data(), -127.0f); + std::vector quantizedWeightScale(channels, 0.008); + param->quanParameter = IDSTEncoder::encode(param->weight.data(), quantizedWeightScale, weightSize/channels, channels, false, quantizedWeight.data(), -127.0f); param->quanParameter->scaleIn = *scale; param->quanParameter->scaleOut = *scale; if (param->common->relu6) { @@ -99,7 +99,7 @@ void Revert::packMNNNet() { mMNNNet.reset(); } -void Revert::initialize(float spasity, int sparseBlockOC, bool rewrite) { +void Revert::initialize(float spasity, int sparseBlockOC, bool rewrite, bool quantizedModel) { if (mMNNNet->bizCode == "benchmark" || rewrite) { randStart(); bool useSparse = spasity > 0.5f; @@ -177,7 +177,10 @@ void Revert::initialize(float spasity, int sparseBlockOC, bool rewrite) { } } } - + if (quantizedModel) { + float scale = 0.008, offset = 0; + writeExtraDescribeTensor(&scale, &offset); + } packMNNNet(); } diff --git a/tools/cpp/revertMNNModel.hpp b/tools/cpp/revertMNNModel.hpp index 7371ed658..62f4a61c0 100644 --- a/tools/cpp/revertMNNModel.hpp +++ b/tools/cpp/revertMNNModel.hpp @@ -17,7 +17,7 @@ class Revert { ~Revert(); void* getBuffer() const; const size_t getBufferSize() const; - void initialize(float sparsity = 0.0f, int sparseBlockOC = 1, bool rewrite = false); + void initialize(float sparsity = 0.0f, int sparseBlockOC = 1, bool rewrite = false, bool quantizedModel = false); static void fillRandValue(float * data, size_t size); void writeExtraDescribeTensor(float* scales, float* offsets); private: diff --git a/tools/quantization/calibration.cpp b/tools/quantization/calibration.cpp index 5169ac56f..9d3aa1a9e 100644 --- a/tools/quantization/calibration.cpp +++ b/tools/quantization/calibration.cpp @@ -36,6 +36,8 @@ #include "train/source/optimizer/SGD.hpp" #include "train/source/transformer/Transformer.hpp" #include "cpp/ConvertToFullQuant.hpp" +#include "core/ConvolutionCommon.hpp" + using namespace MNN::CV; using namespace MNN::Train; @@ -704,15 +706,30 @@ void Calibration::_insertScale() { const int channles = param->common->outputCount; param->symmetricQuan.reset(new MNN::QuantizedFloatParamT); param->symmetricQuan->nbits = _quant_bits; - const int weightSize = param->weight.size(); + const float* originWeight = param->weight.data(); + int originWeightSize = param->weight.size(); + auto conv2d = param; + std::shared_ptr quanCommon; + std::unique_ptr externalWeightTensor, externalBiasTensor; + if (nullptr != conv2d->quanParameter.get()) { + flatbuffers::FlatBufferBuilder tempBuilder; + tempBuilder.Finish(IDSTQuan::Pack(tempBuilder, conv2d->quanParameter.get())); + auto quanP = flatbuffers::GetRoot( tempBuilder.GetBufferPointer()); + bool forceFloat = true; + quanCommon = ConvolutionCommon::load(quanP, true, true); + // Back to float + originWeight = quanCommon->weightFloat.get(); + originWeightSize = quanCommon->weightFloat.size(); + } + const int weightSize = originWeightSize; std::vector quantizedWeight(weightSize); std::vector quantizedWeightScale(outputChannel); if (_weightQuantizeMethod == "MAX_ABS"){ - SymmetricQuantizeWeight(param->weight.data(), weightSize, quantizedWeight.data(), quantizedWeightScale.data(), outputChannel, _weightClampValue); + SymmetricQuantizeWeight(originWeight, weightSize, quantizedWeight.data(), quantizedWeightScale.data(), outputChannel, _weightClampValue); } else if (_weightQuantizeMethod == "ADMM") { - QuantizeWeightADMM(param->weight.data(), weightSize, quantizedWeight.data(), quantizedWeightScale.data(), outputChannel, _weightClampValue); + QuantizeWeightADMM(originWeight, weightSize, quantizedWeight.data(), quantizedWeightScale.data(), outputChannel, _weightClampValue); } - param->quanParameter = IDSTEncoder::encode(param->weight, quantizedWeightScale, weightSize/channles, channles, false, quantizedWeight.data(), -_weightClampValue); + param->quanParameter = IDSTEncoder::encode(originWeight, quantizedWeightScale, weightSize/channles, channles, false, quantizedWeight.data(), -_weightClampValue); param->quanParameter->scaleIn = inputScale; param->quanParameter->scaleOut = outputScale; if (param->common->relu6) {