diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..59119ab --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*build +*install +.nfs* +toolchain* diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 0000000..396239c --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,102 @@ +cmake_minimum_required(VERSION 3.7) + +if (NOT CMAKE_BUILD_TYPE OR CMAKE_BUILD_TYPE STREQUAL "") + set(CMAKE_BUILD_TYPE "Release" CACHE STRING "" FORCE) +endif() + +set(CMAKE_C_STANDARD 99) +if(CMAKE_BUILD_TYPE STREQUAL "Release" AND NOT CMAKE_C_FLAGS) + set(CMAKE_C_FLAGS "-O2" CACHE STRING "" FORCE) +endif() + +option(EML_ARMV7A "build for armv7a architecture instead of armv8a" OFF) + +if(ANDROID) +#Android build +#ANDROID_NDK must be provided +#ANDROID_PLATFORM is optional + if(NOT DEFINED ANDROID_PLATFORM) + set(ANDROID_PLATFORM 27) + endif() + if(EML_ARMV7A) + set(ANDROID_ABI "armeabi-v7a") + set(ANDROID_ARM_MODE arm) #not to use thumb + set(ANDROID_ARM_NEON ON) #enable NEON on armv7a + else() #armv8a + set(ANDROID_ABI "arm64-v8a") + endif() + include(${ANDROID_NDK}/build/cmake/android.toolchain.cmake) + if(CMAKE_BUILD_TYPE STREQUAL "Release") + add_compile_options(-g0) #disable NDK debug info generation + endif() + set(RUNTIME_LIB dl log) +else() +#Linux build +#CMAKE_C_COMPILER must be provided +#CMAKE_SYSROOT is optional + set(CMAKE_SYSTEM_NAME Linux) + if(EML_ARMV7A) + set(CMAKE_SYSTEM_PROCESSOR arm) + add_compile_options(-marm -march=armv7ve) + add_compile_options(-mfpu=neon-vfpv4 -mfp16-format=ieee) + else() + set(CMAKE_SYSTEM_PROCESSOR aarch64) + endif() + set(RUNTIME_LIB pthread) +endif() + +project(emll + VERSION 1.0 + LANGUAGES C) + +try_compile(EMLL_COMPILER_OPENMP_SUPPORTED ${CMAKE_BINARY_DIR} + "${PROJECT_SOURCE_DIR}/test/TestCompilerOpenMP.c" + COMPILE_DEFINITIONS -fopenmp + LINK_LIBRARIES -fopenmp) + +if (EMLL_COMPILER_OPENMP_SUPPORTED) + add_compile_options(-fopenmp) + list(APPEND RUNTIME_LIB -fopenmp) +else() + message(STATUS "The compiler doesn't support OpenMP. Build serial version only.") + add_definitions(-DEMLL_SERIAL_ONLY) +endif() + +include_directories("${PROJECT_SOURCE_DIR}/include") + +file(GLOB interface_header "${PROJECT_SOURCE_DIR}/include/*.h") +file(GLOB arm_src "${PROJECT_SOURCE_DIR}/src/arm_neon/*.c") +if(EML_ARMV7A) + file(GLOB_RECURSE neon_src "${PROJECT_SOURCE_DIR}/src/neon_armv7a/*.c") + add_library(eml-armneon ${arm_src} ${neon_src}) +else() + file(GLOB neon_src "${PROJECT_SOURCE_DIR}/src/neon_armv8a/*.c") + file(GLOB skinny_dot_src + "${PROJECT_SOURCE_DIR}/src/neon_armv8a/sgemm_skinny_dot_kernel/*.c") + file(GLOB ext_src "${PROJECT_SOURCE_DIR}/src/neon_armv8a/extension/*.c") + set_source_files_properties(${arm_src} ${ext_src} + PROPERTIES COMPILE_FLAGS "-march=armv8.2-a+dotprod+fp16") + add_library(eml-armneon ${arm_src} ${ext_src} ${skinny_dot_src} ${neon_src}) +endif() + +option(EML_TEST "build test programs for the library" ON) + +if(EML_TEST) + message(STATUS "Build testing executables for EML") + set(EML_TEST_EXECUTABLES test_emll_gemm test_emll_bias test_emll_quant) + add_executable(test_emll_gemm "${PROJECT_SOURCE_DIR}/test/TestGemm.c") + add_executable(test_emll_bias "${PROJECT_SOURCE_DIR}/test/TestBias.c") + add_executable(test_emll_quant "${PROJECT_SOURCE_DIR}/test/TestQuant.c") + target_link_libraries(test_emll_gemm eml-armneon ${RUNTIME_LIB}) + target_link_libraries(test_emll_bias eml-armneon ${RUNTIME_LIB}) + target_link_libraries(test_emll_quant eml-armneon ${RUNTIME_LIB}) +endif() + +set_target_properties(eml-armneon PROPERTIES PUBLIC_HEADER "${interface_header}") +install(TARGETS eml-armneon ${EML_TEST_EXECUTABLES} + EXPORT EMLLTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + PUBLIC_HEADER DESTINATION include) + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..169249f --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright YouDao, Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/ReadMe.md b/ReadMe.md new file mode 100644 index 0000000..2eb5614 --- /dev/null +++ b/ReadMe.md @@ -0,0 +1,70 @@ +![logo](doc/EMLL.png) + +[中文介绍](ReadMe_ZH.md) + +# Edge ML Library - High-performance Compute Library for On-device Machine Learning Inference + +Edge ML Library (EMLL) offers optimized basic routines like general matrix multiplications (GEMM) and quantizations, to speed up machine learning (ML) inference on ARM-based devices. EMLL supports fp32, fp16 and int8 data types. EMLL accelerates on-device NMT, ASR and OCR engines of Youdao, Inc. + +## Features + +### Performance-Oriented Design + +The matrix-multiplication routines are heavily-optimized for matrix shapes common in on-device ML tasks, including "skinny" ones. The matrix-multiplication kernels are tuned for specific CPUs with a large portion of inline assembly codes. + +Here are benchmarks of SGEMM on 2 machines[1]: + +| armv8a cortex-A35 4-thread | armv8a cortex-A53 4-thread | +| -------------------------- | -------------------------- | +| ![test1](bench/test_sgemm_en1.png) | ![test2](bench/test_sgemm_en2.png) | + +[1].The fomular of GEMM: C[MxN] = A[MxK] B[KxN]; For each test case, the better performance in all-row-major and all-column-major situations is selected. + +### Facile Interface + +The data and parameters are passed straightforward without wrappings. Matrices and arrays are passed with base address + dimensions. GEMM parameters seldom used in on-device inference like LDA-LDC are excluded from the interface. There is no dependency on any third-party compute libraries. + + +### Extensibility + +EMLL abstracts the core structures of CPU-based high-performance matrix multiplication algorithms and also bias/quant functions to general macros (see files under include/common), which can be applied to a variety of processors. When developing for a new architecture, a lot of coding works can be saved with these macros. + +## EMLL APIs + +EMLL provides a series of C functions. See [Usage_EN.md](doc/Usage_EN.md) for details. + +| Type | Name | Parameters | +| ---- | ---- | ---------- | +| Matrix Multiplication | data_type + "gemm" | matrix_orders, addresses of matrices, M, N, K, beta, number of threads | +| Fully-connect Layer (fp32) | "fc" | addresses of src/weight/bias/output, dimensions M/K/N, orders of source matrices, (number of threads) | +| Quantization | "quantize_" + "symmetric"/"asymmetric" + input_type + output_type | input array, output array, (zero point), scale, size of array, input range | +| Requantization | "requantize_" + "symmetric/asymmetric" + "_XtoY" | input array, output array, (zero point), output scale, size of array, input range | +| Bias | "bias" + data_type | the matrix to be biased, scalar bias to all elements, vector bias along major direction, vector bias along minor direction, dimensions of the matrix | + +## Supported Architectures and Data Types + +| Target CPU | Matrix Multiplication | Bias | Quantization | Requantization | +| -------------- | ------------------------------------------------ | ----------------- | ------------------------ | ------------------------------------------- | +| ARMv7a 32-bit | fp32 -> fp32, (u)int8 -> (u)int32 | fp32, int32 | fp32 -> (u)int8/(u)int16 | int32 -> (u)int8/(u)int16, int16 -> (u)int8 | +| ARMv8a 64-bit | fp32 -> fp32, (u)int8 -> (u)int32, fp16 -> fp16 | fp32, fp16, int32 | fp32 -> (u)int8/(u)int16 | int32 -> (u)int8/(u)int16, int16 -> (u)int8 | + +Supported OS: Linux & Android + +Supported Compilers: GCC & Clang + +## Future Plan + +EMLL may support on-device GPUs and NPUs in the future, with the expansion of available functions, according to business requirements. + +## License + +Apache 2.0 + +## Reference + +Eigen: [https://eigen.tuxfamily.org] + +OpenBLAS: [https://github.com/xianyi/OpenBLAS] + + + diff --git a/ReadMe_ZH.md b/ReadMe_ZH.md new file mode 100644 index 0000000..b37621c --- /dev/null +++ b/ReadMe_ZH.md @@ -0,0 +1,65 @@ +# EMLL - 高性能端侧机器学习计算库 + +EMLL(Edge ML Library)为加速终端侧设备上机器学习的推理而设计,提供基于端侧处理器的高性能机器学习计算函数库。EMLL支持fp32、fp16、int8等数据类型,已在有道词典笔、翻译王和超级词典等硬件产品的机器翻译和语音识别引擎中应用,大幅降低了推理延迟。 + +## 特点 + +### 高性能 + +EMLL实现的矩阵乘法函数,为端侧人工智能中常见的扁平矩阵作了专门的优化,为各常见ARM处理器作了特定的优化。对于cortex-A35/A53/A55处理器,本库针对它们的流水线特点,使用了汇编级别的优化。 + +下面是单精度矩阵乘法的测试结果[1]: + +| ARMv8A Cortex-A35 四线程 | ARMv8A Cortex-A53 四线程 | +| ------------------------ | ------------------------ | +| ![结果1](bench/test_sgemm_zh1.png) | ![结果2](bench/test_sgemm_zh2.png) | + +[1]矩阵乘法的通式为 C[MxN] = A[MxK] B[KxN];所列数据为全行主序和全列主序的最好性能。 + +### 易用性 + +EMLL使用的函数接口在参数设计上力求简洁直接,矩阵乘法去掉了不常用的LD*参数,矩阵和向量的传递通过指针和整数维度分别传递。本库的构建和运行不依赖第三方计算库。 + +### 扩展性 + +对于矩阵乘法和量化函数,EMLL 库提取了它们和架构无关的代码作为通用的宏,这些宏可以在支持新的CPU架构时大大节省所需的代码量。 + +## EMLL 应用接口 + +EMLL提供基于 C 的接口,详情请见 [Usage_ZH.md](doc/Usage_ZH.md)。 + +| 函数类型 | 函数名称 | 函数参数 | +| -------- | -------- | -------- | +| 矩阵乘法 | data_type + "gemm" | 源矩阵排列顺序,各矩阵地址,M,N,K,beta,并行线程数 | +| 全连接层(单精度) | "fc" | src/weight/bias/output的地址,M,K,N,源矩阵排列顺序,(并行线程数) | +| 量化 | "quantize_" + "symmetric"/"asymmetric" + input_type + output_type | 输入数组,输出数组,(输出零点值),缩放值,数组大小,输入范围 | +| 重量化 | "requantize_" + "symmetric/asymmetric" + "_XtoY" | 输入数组,输出数组,(输出零点值),输出缩放值,数组大小,输入范围 | +| 偏置 | "bias" + data_type | 被偏置的矩阵,标量偏置,平行于主方向的向量偏置,平行于次方向的向量偏置,矩阵大小 | + +## 各函数支持的数据类型 + +| 处理器 | 矩阵乘法 | 偏置 | 量化 | 重量化 | +| -------------- | ------------------------ | ---------------- | --------------- | ------------- | +| ARMv7a 32-bit | fp32,(u)int8 | fp32,int32 | fp32 -> (u)int16/(u)int8 | int32 -> (u)int16/(u)int8,int16 -> (u)int8 | +| ARMv8a 64-bit | fp32,fp16,(u)int8 | fp32,int32 | fp32 -> (u)int16/(u)int8 | int32 -> (u)int16/(u)int8,int16 -> (u)int8 | + +EMLL 支持在 Linux 和安卓系统上运行。 + + +EMLL 支持用 GCC 和 Clang 编译。 + +## 展望 + +EMLL 将来会根据需求,增加对端侧 GPU 和 NPU 的支持,并拓展支持的算子范围(卷积、激活函数等)。 + +## 许可证 + +Apache 2.0 + +## 参考 + +Eigen: https://eigen.tuxfamily.org/ + +OpenBLAS: https://github.com/xianyi/OpenBLAS/ + + diff --git a/bench/test_sgemm_en1.png b/bench/test_sgemm_en1.png new file mode 100644 index 0000000..7373a43 Binary files /dev/null and b/bench/test_sgemm_en1.png differ diff --git a/bench/test_sgemm_en2.png b/bench/test_sgemm_en2.png new file mode 100644 index 0000000..a335514 Binary files /dev/null and b/bench/test_sgemm_en2.png differ diff --git a/bench/test_sgemm_zh1.png b/bench/test_sgemm_zh1.png new file mode 100644 index 0000000..fa83e7b Binary files /dev/null and b/bench/test_sgemm_zh1.png differ diff --git a/bench/test_sgemm_zh2.png b/bench/test_sgemm_zh2.png new file mode 100644 index 0000000..66aba94 Binary files /dev/null and b/bench/test_sgemm_zh2.png differ diff --git a/doc/EMLL.png b/doc/EMLL.png new file mode 100644 index 0000000..a7ebffd Binary files /dev/null and b/doc/EMLL.png differ diff --git a/doc/Usage_EN.md b/doc/Usage_EN.md new file mode 100644 index 0000000..9e3e901 --- /dev/null +++ b/doc/Usage_EN.md @@ -0,0 +1,141 @@ +## Building the Library + +### Compilers + + +| Tested Compilers | ARMv7A | ARMv8A | +| ---------------- | ------ | ------ | +| Linux target| Linaro-GCC-gnueabihf 201912 | Linaro-GCC-aarch64 201912 | +| Android target | NDK-r20 clang | NDK-r20 clang | + +### CMake + + +The CMake version should be 3.7 or newer. + + +### Linux + +A cross-compiling gcc toolchain (7.5.0 or later) is required. + + +``` +git clone https://github.com/netease-youdao/EMLL.git +cd EMLL +mkdir install +mkdir build && cd build +cmake .. -DCMAKE_INSTALL_PREFIX=../install -DCMAKE_C_COMPILER=/path/to/gcc [-DCMAKE_SYSROOT=/path/to/toolchain/sysroot] [-DEML_ARMV7A=ON #if built for armv7a 32-bit target] +make install +``` + +### Android + + +NDK r19 or newer is required. + + +``` +git clone https://github.com/netease-youdao/EMLL.git +cd EMLL +mkdir install +mkdir build && cd build +cmake .. -DCMAKE_INSTALL_PREFIX=../install -DANDROID=ON -DANDROID_NDK=/path/to/ndk [-DANDROID_PLATFORM=XX #SDK version of the target device] [-DEML_ARMV7A=ON #if built for armv7a 32-bit target] +make +make install +``` + + +### Linking with your application + + +The static library "libeml-armneon.a" will be generated under EMLL/install/lib on building. There are 3 headers under /include (Gemm.h, Quant.h, Layer.h) which summarize the C interfaces provided by the library. + +## Testing + +When the test option is enabled in cmake, additional executables for testing results and performances will be generated under EMLL/install/bin. They can be executed on the target device with calling from command line (terminal/adb). + + +| Executable | Command-Line Usage | Notes | +| ---------- | ------------------ | ----- | +| test_gemm | test_gemm < M > < N > < K > | matrix_order: 0-3; gemm_type: sgemm, hgemm, u8u32, s8s32 | +| test_bias | test_bias | bias_type: 0-7 for bias, 8-9 for summing of rows/cols | +| test_quant | test_quant | array_size: the number of elements; job_type: qs/qu/d/rs/ru | + +## API + +The library provide C functions for GEMM, bias and quantization. + + +| Functions | Header | +| --------- | ------ | +| General Matrix Multiplication (GEMM) | include/Gemm.h | +| Fully-Connected Layer (FC) with bias | include/Layer.h | +| Quantization, Dequantization, Requantization | include/Quant.h | + +### GEMM + +For simplicity, the GEMM interface does not include LDA-LDC and alpha (assume 1.0). + +The storage order of output matrix C is fixed to column-major. The storage orders of input matrices are specified via function parameters. An element in the matrix can be accessed via column_id "([0, column_numbers))" and row_id "([0, row_numbers))", which can be combined into a 1D index if its storage order is known: + +| Storage Order | Element Index | +| ------------- | ------------- | +| Column-Major | column_id * row_numbers + row_id | +| Row-Major | row_id * column_numbers + column_id | + +The GEMM interface is summarized in [include/Gemm.h](../include/Gemm.h). + +#### Function Name + +| Data Types | Function Name | +| ---------- | ------------- | +| fp32 -> fp32 | sgemm | +| fp16 -> fp16 | hgemm [1] | +| int8 -> int32 | s8s32gemm [2] | +| uint8 -> uint32 | u8u32gemm [2] | + + +[1] Currently not implemented for Aarch32. Return error code 2 when the processor has no support for ARMv8.2a-fp16 ISA + + +[2] Aarch64 version: Use dot instructions automatically on processors supporting ARMv8.2a-dotprod, use mla-long instructions otherwise + + +#### Function Parameters + +The operation of GEMM: C[MxN] = A[MxK] B[KxN] + beta * C[MxN] + +| Parameters | Description | +| ---------- | ----------- | +| a_rowmajor | The storage order of matrix A, row-major if not 0 | +| b_rowmajor | The storage order of matrix B, row-major if not 0 | +| A | The address of the first element in source matrix A | +| B | The address of the first element in source matrix B | +| C | The address of the first element in output matrix C[1] | +| M | The number of rows in source matrix A | +| N | The number of columns in source matrix B | +| K | The number of columns in A, must be equal to the number of rows in B | +| beta | The scaling factor on C prior to the addition of AB product | +| num_threads | The (maximum) number of threads to use in parallel run | + + +[1] The output matrix C is fixed to column-major. + +### Quantization + +Please refer to [include/Quant.h](../include/Quant.h) for details. + +#### Function Name + + +| Name | Description | +| ---- | ----------- | +| bias_int32_t | Perform bias on a 32-bit integer matrix, can be used as a component in asymmetric quantitized GEMM | +| u8u32_sum | Perform row-wise or column-wise sum on the input 8-bit unsigned integer matrix, can be used as a component in asymmetric quantitized GEMM | +| quantize_asymmetric_fX_uY | Asymmetric quantization of X-bit float data to unsigned Y-bit values | +| quantize_symmetric_fX_sY | Symmetric quantization of X-bit float data to signed Y-bit values | +| dequantize_symmetric_fX_sY | Symmetric dequantization of Y-bit integer results to X-bit float ones | +| requantize_asymmetric_XtoY | Asymmetric requantization of X-bit integer values to unsigned Y-bit values | +| requantize_symmetric_XtoY | Symmetric requantization of X-bit integer values to signed Y-bit values | + + diff --git a/doc/Usage_ZH.md b/doc/Usage_ZH.md new file mode 100644 index 0000000..da846fa --- /dev/null +++ b/doc/Usage_ZH.md @@ -0,0 +1,145 @@ +## 如何构建 Edge ML 库 + +### 测试过的编译器 + + +| 端侧设备 | ARMv7A | ARMv8A | +| -------- | ------ | ------ | +| Linux | Linaro-GCC-gnueabihf 201912 | Linaro-GCC-aarch64 201912 | +| Android | NDK-r20 clang | NDK-r20 clang | + +目前支持在Linux系统上交叉编译。 + +### CMake 版本 + + +CMake 需要 3.7 或更新的版本。 + + +### 为运行 Linux 系统的端侧设备构建 + +需要 7.5.0 及以后的 GCC 交叉编译工具链。 + +以下为在 Linux 系统开发机上的构建命令 + +``` +git clone https://github.com/netease-youdao/EMLL.git +mkdir install +mkdir build && cd build +cmake .. -DCMAKE_INSTALL_PREFIX=../install -DCMAKE_C_COMPILER=GCC编译器的目录 [-DCMAKE_SYSROOT=GCC工具链中sysroot的目录] [-DEML_ARMV7A=ON #若端侧为32位请开此选项] +make +make install +``` + +### 为运行 Android 系统的端侧设备构建 + + +需要 r19 或更高版本的 Android NDK。 + +以下为在 Linux 系统开发机上的构建命令 + +``` +git clone https://github.com/netease-youdao/EMLL.git +mkdir install +mkdir build && cd build +cmake .. -DCMAKE_INSTALL_PREFIX=../install -DANDROID=ON -DANDROID_NDK=NDK的安装目录 [-DANDROID_PLATFORM=目标安卓SDK版本] [-DEML_ARMV7A=ON #若端侧为32位请开此选项] +make +make install +``` + + +### 使用构建好的库 + +在 EMLL/install 下会生成 bin,lib 和 include 文件夹,其中 lib 下包含了生成的静态库 libeml-armneon.a,include 下包含了定义 EMLL 对外接口的头文件。应用程序只需在源码中包含对应的头文件,链接时静态链接 libeml-armneon.a 即可。 + +## 如何测试 Edge ML 库 + +构建过程中,默认会在 EMLL/install/bin 下生成三个用于测试的可执行文件:test_gemm,test_bias 和 test_quant。把它们拷贝到端侧设备上,命令行 (adb/ssh) 运行它们即可。 + + +| 测试程序 | 命令行参数 | 说明 | +| --------- | ------------------ | ----- | +| test_gemm | test_gemm < M > < N > < K > <源矩阵排列顺序> <并行线程数> <数据类型> | 源矩阵排列顺序:0-3;数据类型:sgemm、hgemm、u8u32、s8s32 | +| test_bias | test_bias <主维度长> <次维度长> <任务种类> | 任务种类:0-7 偏置,8-9 按行或列求和 | +| test_quant | test_quant <测试数组大小> <任务类型> <其他参数> | 任务类型:qs/qu/d/rs/ru | + +## 应用程序接口 + +Edge ML 库提供基于 C 的矩阵乘法和量化接口 + + +| 函数类别 | 头文件 | +| --------- | ------ | +| 矩阵乘法 | include/Gemm.h | +| 全连接层 | include/Layer.h | +| 量化、反量化、重量化 | include/Quant.h | + +### 矩阵乘法 + +为了简便,矩阵乘法接口去掉了 LDA-LDC 参数,固定 alpha = 1.0。 + +输出矩阵的排列顺序固定为列主序。输入矩阵的排列顺序由函数参数确定。矩阵中的每个元素位置可以通过行号 ([0,行数)) 和列号 ([0,列数)) 确定。当矩阵的排列顺序确定时,其元素地址的偏移量是确定的: + +| 排列顺序 | 元素偏移量(相对于首元素)| +| -------- | ------------------------- | +| 列主序 | 列号 * 行数 + 行号 | +| 行主序 | 行号 * 列数 + 列号 | + +具体接口定义详见[include/Gemm.h](../include/Gemm.h)。 + +#### 函数名称 + +| 数据类型 | 函数名称 | +| ---------- | ------------- | +| fp32 -> fp32 | sgemm | +| fp16 -> fp16 | hgemm [1] | +| int8 -> int32 | s8s32gemm [2] | +| uint8 -> uint32 | u8u32gemm [2] | + + +[1] 目前不支持 Aarch32 设备;当目标处理器不支持 ARMv8.2-a 半精扩展时,返回错误 2 。 + + +[2] Aarch64 版本:在支持 ARMv8.2a 点积扩展的处理器上自动使用点积指令运算,其他处理器上使用变长乘加指令运算。 + + +#### 函数参数 + + + +矩阵乘法通式:C[MxN] = A[MxK] B[KxN] + C[MxN] * beta + +| 参数 | 描述 | +| ---------- | ----------- | +| a_rowmajor | 源矩阵 A 的排列顺序,非零表示行主序 | +| b_rowmajor | 源矩阵 B 的排列顺序,非零表示行主序 | +| A | 源矩阵 A 的地址 | +| B | 源矩阵 B 的地址 | +| C | 输出矩阵 C 的地址 | +| M | 矩阵 A 的行数 | +| N | 矩阵 B 的列数 | +| K | A的列数,必须等于 B 的行数 | +| beta | 作用于矩阵 C 的预乘因子 | +| num_threads | 并行时能够使用的线程数 [2] | + + +[1] 输出矩阵 C 固定为列主序。 + + +[2] 等于 1 时运行串行版本;等于 0 时使用所有 OpenMP 运行时提供的线程。 + +### 量化相关函数 + +详见[include/Quant.h](../include/Quant.h)。 + +| 函数名 | 描述 | +| ---- | ----------- | +| bias_int32_t | 对32位整数的矩阵施加偏置;可用于非对称量化的整数乘法的后处理 | +| u8u32_sum | 对8位整数的矩阵按行或按列求和,结果存于32位向量 | +| quantize_asymmetric_fX_uY | 非对称量化,从X位浮点到Y位整数 | +| quantize_symmetric_fX_sY | 对称量化,从X位浮点到Y位整数 | +| dequantize_symmetric_fX_sY | 对称反量化,从Y位整数到X位浮点 | +| requantize_asymmetric_XtoY | 非对称重量化,从X位整数到Y位整数 | +| requantize_symmetric_XtoY | 对称重量化,从X位整数到Y位整数 | + + diff --git a/example/CMakeLists.txt b/example/CMakeLists.txt new file mode 100644 index 0000000..6fdbd28 --- /dev/null +++ b/example/CMakeLists.txt @@ -0,0 +1,42 @@ +#Command line for Android NDK: +# cmake -DANDROID=ON -DANDROID_NDK=/path/to/ndk \ +# -DEMLL_DIR=emll/installation/path [-DEML_ARMV7A=ON] +# make + +#Command line for GCC: +# cmake [-DCMAKE_SYSROOT=/path/to/gcc/sysroot] \ +# -DCMAKE_C_COMPILER=/path/to/gcc \ +# -DEMLL_DIR=emll/installation/path [-DEML_ARMV7A=ON] +# make + +cmake_minimum_required(VERSION 3.7) +set(CMAKE_BUILD_TYPE Release) + +set(CMAKE_C_STANDARD 99) +set(CMAKE_C_FLAGS_RELEASE "-O2") + +if(ANDROID) #variable ANDROID_NDK must be provided prior to this section + set(ANDROID_PLATFORM 27) + if(EML_ARMV7A) + set(ANDROID_ABI "armeabi-v7a") + else() #armv8a + set(ANDROID_ABI "arm64-v8a") + endif() + include(${ANDROID_NDK}/build/cmake/android.toolchain.cmake) + set(RUNTIME_LIB dl log -fopenmp) +else() #Linux. Variables CMAKE_C_COMPILER must be provided, CMAKE_SYSROOT is optional + set(CMAKE_SYSTEM_NAME Linux) + if(EML_ARMV7A) + set(CMAKE_SYSTEM_PROCESSOR arm) + else() + set(CMAKE_SYSTEM_PROCESSOR aarch64) + endif() + set(RUNTIME_LIB pthread -fopenmp -lm) +endif() + +# variable EMLL_DIR must be provided +project(example_emll C) +include_directories(${EMLL_DIR}/include) +add_executable(example_emll_gemm Gemm.c) +target_link_libraries(example_emll_gemm ${EMLL_DIR}/lib/libeml-armneon.a ${RUNTIME_LIB}) + diff --git a/example/Gemm.c b/example/Gemm.c new file mode 100644 index 0000000..7550b36 --- /dev/null +++ b/example/Gemm.c @@ -0,0 +1,195 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: example/Gemm.c + * Description: This file is an example of using EMLL in your application. + * In this example we do 3 types of fp32 GEMM: + * (1) direct SGEMM + * (2) asymmetrically quantize to uint8, do GEMM to int32, + * finally dequantize to fp32. + * (3) symmetrically quantize to int8, do GEMM to int32, + * finally dequantize to fp32. + * Users should tell compilers to include the "include" + * directory of the library and link to the static + * library of EMLL. + *****************************************************************************/ + +#include "Gemm.h" +#include "Quant.h" + +#include +#include +#include +#include +#include +#include + +int main(int argc, char **argv) { + + if (argc == 1 || !strcmp(argv[1], "-h") || !strcmp(argv[1], "--help")) { + printf("Usage: %s [M] [N] [K]\n", argv[0]); + return 0; + } + + uint16_t M = 300, N = 400, K = 500; + if (argc > 1) M = atoi(argv[1]); + if (argc > 2) N = atoi(argv[2]); + if (argc > 3) K = atoi(argv[3]); + + if (!M || !N || !K) { + fprintf(stderr, "Invalid (zero or negative) M, N or K.\n"); + return -1; + } + + printf("Test matmul C=AB with fp32, symmetric & asymmetric quantizations.\n"); + printf("matrix A (column-major): %u x %u\n", M, K); + printf("matrix B (column-major): %u x %u\n", K, N); + printf("matrix C (column-major): %u x %u\n", M, N); + const uint32_t size_a = (uint32_t)M * (uint32_t)K; + const uint32_t size_b = (uint32_t)N * (uint32_t)K; + const uint32_t size_c = (uint32_t)M * (uint32_t)N; + + /* allocate fp32 matrices */ + float * const A_f = (float *)malloc(size_a * 4); + float * const B_f = (float *)malloc(size_b * 4); + float * const C_f = (float *)malloc(size_c * 4); + + /* allocate quant-u8 matrices and arrays */ + uint8_t * const A_u = (uint8_t *)malloc(size_a); + uint8_t * const B_u = (uint8_t *)malloc(size_b); + int32_t * const C_qu = (int32_t *)malloc(size_c * 4); + float * const C_fqu = (float *)malloc(size_c * 4); + uint32_t * const A_sum = (uint32_t *)malloc(M * 4); + uint32_t * const B_sum = (uint32_t *)malloc(N * 4); + + /* allocate quant-s8 matrices and arrays */ + int8_t * const A_s = (int8_t *)malloc(size_a); + int8_t * const B_s = (int8_t *)malloc(size_b); + int32_t * const C_qs = (int32_t *)malloc(size_c * 4); + float * const C_fqs = (float *)malloc(size_c * 4); + + int ret_status = 0; + do { + if (!A_f || !B_f || !C_f || !A_u || !B_u || !C_qu || !C_fqu || + !A_sum || !B_sum || !A_s || !B_s || !C_qs || !C_fqs) { + fprintf(stderr, "Memory allocation failed.\n"); + ret_status = -1; + break; + } + + /* prepare data */ + srand(time(NULL)); + for (uint32_t i = 0; i < size_a; ++i) { + A_f[i] = (float)rand() / (float)RAND_MAX - 0.3; + } + for (uint32_t i = 0; i < size_b; ++i) { + B_f[i] = (float)rand() / (float)RAND_MAX - 0.3; + } + printf("Matrix preparation done. rand [-0.3, 0.7)\n"); + + /* all matrices are column-major */ + /* example 1: do normal fp32 GEMM */ + /* gemm(a_rowmajor, b_rowmajor, a_addr, b_addr, c_addr, m, n, k, beta, threads) */ + int sgemm_status = sgemm(0, 0, A_f, B_f, C_f, M, N, K, 1, 0); + if (sgemm_status != 0) { + fprintf(stderr, "sgemm returns error code %d\n", sgemm_status); + ret_status = -1; + break; + } + printf("Normal SGEMM done.\n"); + + /* example 2: do asymmetric quant 8-bit GEMM */ + float scale_a, scale_b; + uint8_t zero_point_a, zero_point_b; + /* quantitize the source matrices */ + /* quant_asym(input_addr, output_addr, &zero_point, &scale, array_length, input_min, input_Max) */ + quantize_asymmetric_f32_u8(A_f, A_u, &zero_point_a, &scale_a, size_a, 0, -1); + quantize_asymmetric_f32_u8(B_f, B_u, &zero_point_b, &scale_b, size_b, 0, -1); + /* do unsigned 8->32 bit GEMM */ + /* gemm(a_rowmajor, b_rowmajor, a_addr, b_addr, c_addr, m, n, k, beta, threads) */ + int u8u32_status = u8u32gemm(0, 0, A_u, B_u, (uint32_t *)C_qu, M, N, K, 1, 0); + if (u8u32_status != 0) { + fprintf(stderr, "u8u32gemm returns error code %d\n", u8u32_status); + ret_status = -1; + break; + } + /* sum row/col of source matrices (along K dim) */ + u8u32_sum(A_u, A_sum, M, K, 0); + u8u32_sum(B_u, B_sum, K, N, 1); + /* bias the result of 8->32 bit GEMM */ + bias_int32_t(C_qu, + (int32_t)zero_point_a * (int32_t)zero_point_b * (int32_t)K, + (int32_t *)A_sum, -(int32_t)zero_point_b, + (int32_t *)B_sum, -(int32_t)zero_point_a, M, N); + /* dequantitize the result */ + /* dequant(input_addr, output_addr, scale, array_length) */ + dequantize_symmetric_f32_s32(C_qu, C_fqu, scale_a * scale_b, size_c); + printf("Asym quant GEMM done.\n"); + + /* example 3: do symmetric quant 8-bit GEMM */ + /* quantitize the source matrices */ + /* quant_sym(input_addr, output_addr, &scale, array_length, input_min, input_Max) */ + quantize_symmetric_f32_s8(A_f, A_s, &scale_a, size_a, 0, -1); + quantize_symmetric_f32_s8(B_f, B_s, &scale_b, size_b, 0, -1); + /* do signed 8->32 bit GEMM */ + int s8s32_status = s8s32gemm(0, 0, A_s, B_s, C_qs, M, N, K, 1, 0); + if (s8s32_status != 0) { + fprintf(stderr, "s8s32gemm returns error code %d\n", s8s32_status); + ret_status = -1; + break; + } + /* dequantitize the result */ + /* dequant(input_addr, output_addr, scale, array_length) */ + dequantize_symmetric_f32_s32(C_qs, C_fqs, scale_a * scale_b, size_c); + printf("Sym quant GEMM done.\n"); + + /* evaluate the results */ + float max_diff_qu = 0, max_diff_qs = 0; + double sum_diff_sqr_qu = 0, sum_diff_sqr_qs = 0; + for (uint32_t i = 0; i < size_c; ++i) { + float tmp_diff_qu = fabsf(C_fqu[i] - C_f[i]); + float tmp_diff_qs = fabsf(C_fqs[i] - C_f[i]); + max_diff_qu = fmaxf(max_diff_qu, tmp_diff_qu); + max_diff_qs = fmaxf(max_diff_qs, tmp_diff_qs); + sum_diff_sqr_qu += max_diff_qu * max_diff_qu; + sum_diff_sqr_qs += max_diff_qs * max_diff_qs; + } + double std_dev_qu = size_c == 1 ? 0 : sqrt(sum_diff_sqr_qu / (size_c - 1)); + double std_dev_qs = size_c == 1 ? 0 : sqrt(sum_diff_sqr_qs / (size_c - 1)); + printf("The results of asym quant compared to std fp32: "); + printf("max_diff = %.2e, stdev = %.2e\n", max_diff_qu, std_dev_qu); + printf("The results of sym quant compared to std fp32: "); + printf("max_diff = %.2e, stdev = %.2e\n", max_diff_qs, std_dev_qs); + } while (false); + + /* clean up */ + free(A_f); + free(B_f); + free(C_f); + free(A_u); + free(B_u); + free(C_qu); + free(C_fqu); + free(A_sum); + free(B_sum); + free(A_s); + free(B_s); + free(C_qs); + free(C_fqs); + return ret_status; +} diff --git a/example/Usage_EN.md b/example/Usage_EN.md new file mode 100644 index 0000000..5521372 --- /dev/null +++ b/example/Usage_EN.md @@ -0,0 +1,61 @@ +## How to link and use EMLL in your application with CMake + +### Build EMLL + +Please refer to doc/Usage_EN.md for detailed procedure. + +### Include Headers in Your Source + +``` +#include "Gemm.h" // for GEMM functions +#include "Layer.h" // for FC functions +#include "Quant.h" // for quantization/dequantization/requantization + + +``` + +### Write CMakeLists.txt + +You can use the default CMakeLists.txt or manually rewrite it as follows: + +``` +cmake_minimum_required(VERSION ) +set(CMAKE_BUILD_TYPE ) + +set(CMAKE_C_COMPILER ndk/or/arm-gcc/compiler) +# add your compile options + +project( C) + +add_executable( ) +target_include_directories( /include) +target_link_libraries( /lib/libeml-armneon.a) + +if(ANDROID) + target_link_libraries( dl log -fopenmp) +else() + target_link_libraries( pthread -fopenmp) +endif() +``` + +### Build Your Application + +``` +cd +mkdir build && cd build +cmake .. [-DANDROID=ON # for android] [#other options of your project] +make +``` + +### Example + +The source file "Gemm.c" gives an example of using GEMM and quantization functions of EMLL library. It can be built into an executable by the following commands. + +``` +cd +mkdir build && cd build +cmake .. [-DANDROID=ON -DANDROID_NDK=/path/to/ndk #options for Android] [-DCMAKE_C_COMPILER=/path/to/gcc [-DCMAKE_SYSROOT=/path/to/gnu/sysroot] #options for GNU-Linux] [-DEML_ARMV7A=ON #armv7 device] +make +# The executable "example_emll_gemm" will be generated under the build directory, which can be executed on the target device. +``` + diff --git a/example/Usage_ZH.md b/example/Usage_ZH.md new file mode 100644 index 0000000..0bdb2c3 --- /dev/null +++ b/example/Usage_ZH.md @@ -0,0 +1,61 @@ +## 如何借助 CMake 链接和使用 EMLL + +### 构建 EMLL + +详细步骤请参阅 doc/Usage_ZH.md。 + +### 在源码中包含 EMLL 的头文件 + +``` +#include "Gemm.h" // 矩阵乘法函数 +#include "Layer.h" // 全连接函数 +#include "Quant.h" // 量化、反量化、重量化 + +<其他代码> +``` + +### 编写 CMakeLists.txt + +可以参照 example 文件夹中的 CMakeLists.txt,也可以按如下样式重写: + +``` +cmake_minimum_required(VERSION <用户指定的最低版本>) +set(CMAKE_BUILD_TYPE <用户指定的构建类型>) + +set(CMAKE_C_COMPILER ndk/or/arm-gcc/compiler) +# 添加其他编译选项 + +project(<用户指定的工程名称> C) + +add_executable(<应用程序名> <源文件>) +target_include_directories(<应用程序名> /include) +target_link_libraries(<应用程序名> /lib/libeml-armneon.a) + +if(ANDROID) + target_link_libraries(<应用程序名> dl log -fopenmp) +else() + target_link_libraries(<应用程序名> pthread -fopenmp) +endif() +``` + +### 构建应用程序 + +``` +cd +mkdir build && cd build +cmake .. [-DANDROID=ON #安卓平台] <其他您的工程需要的选项> +make +``` + +### 示例代码 + +本文件夹中的 Gemm.c 提供了 EMLL 函数的使用示例,可以通过以下命令编译它并用 adb 拷贝到端侧设备上运行。 + +``` +cd +mkdir build && cd build +cmake .. [-DANDROID=ON -DANDROID_NDK=/path/to/ndk #安卓平台] [-DCMAKE_C_COMPILER=/path/to/gcc [-DCMAKE_SYSROOT=/path/to/gnu/sysroot] #GNU-Linux平台] [-DEML_ARMV7A=ON #armv7平台] +make +# 在 build 文件夹中生成 example_emll_gemm 程序,可到端侧设备上运行它 +``` + diff --git a/include/Gemm.h b/include/Gemm.h new file mode 100644 index 0000000..8480143 --- /dev/null +++ b/include/Gemm.h @@ -0,0 +1,120 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +#ifndef INCLUDE_ARM_GEMM_INTERFACE +#define INCLUDE_ARM_GEMM_INTERFACE + +#ifdef __cplusplus +extern "C" { +#endif + +/******************************************************************** +Function: sgemm +Description: fp32 general matrix multiplication, do C = AB + beta * C + with OpenMP parallelization. +Input: int a_rowmajor: an integer indicating the storage order + of input matrix A. Non-zero number for + row-major storage, 0 for column-major storage. + int b_rowmajor: an integer indicating the storage order + of input matrix B. Non-zero number for + row-major storage, 0 for column-major storage. + (matrix C is fixed to column-major) + const float *A, *B: the addresses of input matrices + uint32_t M, N, K: the dimensions of matrices + A: M x K; B: K x N; C: M x N + float beta: the scale on matrix C prior to GEMM + uint32_t num_threads: the maximum number of threads + in OpenMP parallelization. + 0 : the function will determine + the number of threads from + the problem size, use as many + threads as possible up to + omp_get_max_threads() when + M, N and K are large. + positive number: limit the maximum + number of threads the function + can use in OpenMP parallelization + 1 : force serial execution +Output: float *C: the address of output matrix +Return: 0 on success, 1 on illegal parameters +********************************************************************/ +int sgemm(int a_rowmajor, int b_rowmajor, + const float *A, const float *B, float *C, + uint32_t M, uint32_t N, uint32_t K, + float beta, uint32_t num_threads); + +/************************************************************************** +Function: s8s32gemm +Description: signed 8bit -> 32bit integer matrix multiplication, + do C = AB + beta * C with OpenMP parallelization, + use *mlal NEON instructions on CPUs without ARMv8.2a feature, + use *dot NEON instructions on CPUs support ARMv8.2a-dotprod. +Input: int a_rowmajor, b_rowmajor: the same as in function sgemm + const int8_t *A, *B: the addresses of int8_t input matrices + M, N, K, beta, num_threads: the same meaning as in function sgemm +Output: int32_t *C: the address of int32_t output matrix C +Return: 0 on success, 1 on illegal parameters +**************************************************************************/ +int s8s32gemm(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, + int32_t beta, uint32_t num_threads); + +/************************************************************************** +Function: u8u32gemm +Description: unsigned 8bit -> 32bit integer matrix multiplication, + do C = AB with OpenMP parallelization, + use *mlal NEON instructions on CPUs without ARMv8.2a feature, + use *dot NEON instructions on CPUs support ARMv8.2a-dotprod. +Input: int a_rowmajor, b_rowmajor: the same as in function sgemm + const uint8_t *A, *B: the addresses of uint8_t input matrices + M, N, K, beta, num_threads: the same meaning as in function sgemm +Output: uint32_t *C: the address of uint32_t output matrix C +Return: 0 on success, 1 on illegal parameters +**************************************************************************/ +int u8u32gemm(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, + uint32_t beta, uint32_t num_threads); + +#if __aarch64__ +/************************************************************************** +Function: hgemm +Description: fp16 (half precision) matrix multiplication, + do C = AB with OpenMP parallelization. +Input: int a_rowmajor, b_rowmajor: the same as in function sgemm + const float16_t *A, *B: the addresses of input matrices + M, N, K, beta, num_threads: the same meaning as in function sgemm +Output: float16_t *C: the address of output matrix C +Return: 0 on success, 1 on illegal parameters, + 2 when the CPU doesn't support ARMv8.2a-fp16 +**************************************************************************/ +int hgemm(int a_rowmajor, int b_rowmajor, + const float16_t *A, const float16_t *B, float16_t *C, + uint32_t M, uint32_t N, uint32_t K, + float16_t beta, uint32_t num_threads); + +#endif //aarch64 + +#ifdef __cplusplus +} +#endif + +#endif //INCLUDE_ARM_GEMM_INTERFACE diff --git a/include/Layer.h b/include/Layer.h new file mode 100644 index 0000000..2dcd804 --- /dev/null +++ b/include/Layer.h @@ -0,0 +1,55 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +#ifndef INCLUDE_ARM_LAYER_INTERFACE +#define INCLUDE_ARM_LAYER_INTERFACE + +#ifdef __cplusplus +extern "C" { +#endif + +/****************************************************************************** +Function: fc +Description: Function to perform transformation in a fully-connected layer, + paralleled with OpenMP. + output = src * weight + bias +Input: float *src: the address of source data matrix. + float *weight: the address of weight matrix. + float *bias: the address of bias vector. +Output: float *output: the address of output matrix. +Parameters: int M: the number of rows in source data matrix. + int K: the number of columns in source data matrix. + int N: the number of columns in output matrix. + int trans_src: 1 for column-major source data matrix, + 0 for row-major source data matrix. + int trans_weight: 1 for column-major weight matrix, + 0 for row-major weight matrix. + int num_threads: number of OpenMP threads to use. +Return: 0 on success, non-zero number on errors. +******************************************************************************/ +int fc(const float *src, const float *weight, const float *bias, + float *output, int M, int K, int N, int trans_src, int trans_weight, + int num_threads); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/include/Quant.h b/include/Quant.h new file mode 100644 index 0000000..7e6c17e --- /dev/null +++ b/include/Quant.h @@ -0,0 +1,254 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +#ifndef INCLUDE_ARM_QUANT_INTERFACE +#define INCLUDE_ARM_QUANT_INTERFACE + +#ifdef __cplusplus +extern "C" { +#endif + +/*********************************************************************** +Function: bias_int32_t +Description: Perform bias operation on a 32-bit signed int matrix. + This function can be used in asymmetric quantitized GEMM. +Parameters: dst: the address of the matrix to apply bias on + bias_dim0: the bias value on every element + bias_dim1: the address of the input bias vector which + will be applied to the matrix along its + major dimension, i.e. when the element + can be indexed by x * dim1 + y, each element + is biased by bias_dim1[y]. No bias will be + performed with NULL pointer as input. + bias_dim1_scale: the scale to be applied on elements + of bias_dim1[] prior to the bias + operation + bias_dim2: the address of the input bias vector which + whill be applied to the matrix along its + minor dimension, i.e. when the element + can be indexed by x * dim1 + y, each element + is biased by bias_dim2[x]. No bias will be + performed with NULL pointer as input. + bias_dim2_scale: the scale to be applied on elements + of bias_dim2[] prior to the bias + operation + dim1: the length of the major dimension of input matrix + dim2: the length of the minor dimension of input matrix +***********************************************************************/ +void bias_int32_t(int32_t *dst, int32_t bias_dim0, + const int32_t *bias_dim1, int32_t bias_dim1_scale, + const int32_t *bias_dim2, int32_t bias_dim2_scale, + uint32_t dim1, uint32_t dim2); + +/*********************************************************************** +Function: u8u32_sum +Description: Perform summing operation of cols/rows of the unsigned + 8-bit int matrix. The sum of each col/row is an unsigned + 32-bit integer. +Parameters: src: the address of input matrix. + dst: the address of output vector. + dim1: the length of major dimension of input matrix. + dim2: the length of minor dimension of input matrix. + (the major dimension is the vertical one for column- + major matrix, or the horizontal one for row-major + matrix) + direction: the direction of summing + 0: sum along the minor dimension, + output_vector_size == dim1; + 1: sum along the major dimension, + output_vector_size == dim2. +***********************************************************************/ +void u8u32_sum(const uint8_t *src, uint32_t *dst, + uint32_t dim1, uint32_t dim2, uint8_t direction); + +/*********************************************************************** +Function: quantize_asymmetric_f32_u8 +Description: Asymmetric quantization from fp32 to unsigned 8-bit int, + producing an 8-bit zero-point integer Z0, a fp32 scale S0 + and quantitized unsigned 8-bit data Q1-Qn on the run. + For each quantitized element Qi, S0 * (Qi - Z0) can + approximate the original input (fp32) Fi. +Parameters: const float32_t *input: the address of the input fp32 array + uint8_t *output: the address of the output integer array + uint8_t *zero_point: the address to output Z0 + float32_t *scale: the address to output S0 + uint32_t size: the number of elements in the input + float32_t input_min, input_max: + the min and max of input float32_t numbers. + when input_min > input_max, the min and max + of input are reevaluated. +***********************************************************************/ +void quantize_asymmetric_f32_u8(const float32_t *input, uint8_t *output, + uint8_t *zero_point, float32_t *scale, uint32_t size, + float32_t input_min, float32_t input_max); + +/*********************************************************************** +Function: quantize_symmetric_f32_s8 +Description: symmetric quantization from fp32 to signed 8-bit int, + producing a fp32 scale S0 and quantitized 8-bit data + Q1-Qn on the run. + For each quantitized element Qi, S0 * Qi can + approximate the original input (fp32) Fi. +Parameters: const float32_t *input: the address of the input fp32 array + int8_t *output: the address of the output integer array + float32_t *scale: the address to output S0 + uint32_t size: the number of elements in the input + float32_t input_min, input_max: + the min and max of input float32_t numbers. + when input_min > input_max, the min and max + of input are reevaluated. +***********************************************************************/ +void quantize_symmetric_f32_s8(const float32_t *input, int8_t *output, + float32_t *scale, uint32_t size, float32_t input_min, float32_t input_max); + +/*********************************************************************** +Function: quantize_asymmetric_f32_u16 +Description: Asymmetric quantization from fp32 to unsigned 16-bit int, + producing an 16-bit zero-point integer Z0, a fp32 scale S0 + and quantitized unsigned 16-bit data Q1-Qn on the run. + This function does the same thing as + quantize_asymmetric_f32_u8 except the zero point and + outputs are 16-bit integers. +***********************************************************************/ +void quantize_asymmetric_f32_u16(const float32_t *input, uint16_t *output, + uint16_t *zero_point, float32_t *scale, uint32_t size, + float32_t input_min, float32_t input_max); + +/*********************************************************************** +Function: quantize_symmetric_f32_s16 +Description: symmetric quantization from fp32 to signed 16-bit int, + producing a fp32 scale S0 and quantitized 16-bit data + Q1-Qn on the run. This function does the same thing + as quantize_symmetric_f32_s8 except the outputs are + 16-bit integers. +***********************************************************************/ +void quantize_symmetric_f32_s16(const float32_t *input, int16_t *output, + float32_t *scale, uint32_t size, float32_t input_min, float32_t input_max); + +/*********************************************************************** +Function: dequantize_symmetric_f32_s32 +Description: Convert 32-bit signed int values to fp32 ones with scaling. +Parameters: const int32_t *src: the address of the input integer array + float32_t *dst: the address of the output fp32 array + float32_t scale: the scaling factor on the input + uint32_t size: the number of elements in the input +***********************************************************************/ +void dequantize_symmetric_f32_s32(const int32_t *src, float32_t *dst, + float32_t scale, uint32_t size); + +/************************************************************************ +Function: requantize_asymmetric_32to8 +Description: asymmetric requantization from signed 32-bit int to + unsigned 8-bit int, which produces an 8-bit zero-point + integer Z0, updates the fp32 scale S0 and outputs + requantitized unsigned 8-bit data Q1-Qn on the run. + For each requantitized element Qi, S0 * (Qi - Z0) can + approximate the original dequantized value (fp32) Fi + of the corresponding 32-bit input. +Parameters: const int32_t *input: the address of the input int array + uint8_t *output: the address of the output integer array + float *scale: the address to update scaling factor S0 + uint8_t *zero_point: the address to output Z0 + uint32_t size: the number of elements in the input + int32_t input_min, input_max: the min and max value + of input int32 numbers. if input_min > input_max, + the min and max of the input integers are recalculated. +Note: The following function is near-equivalent to this sequence: + dequant_cvt_float_int32_t(input, temporal_array, *scale, size); + quant_unsym_float_uint8_t(temporal_array, output, + zero_point, scale, size); +************************************************************************/ +void requantize_asymmetric_32to8(const int32_t *input, uint8_t *output, + float *scale, uint8_t *zero_point, uint32_t size, + int32_t input_min, int32_t input_max); + +/************************************************************************ +Function: requantize_symmetric_32to8 +Description: symmetric requantization from signed 32-bit int to + signed 8-bit int, which updates the fp32 scale S0 + and outputs requantitized signed 8-bit data Q1-Qn + on the run. + For each requantitized element Qi, S0 * Qi can + approximate the original dequantized value (fp32) Fi + of the corresponding 32-bit input. +Parameters: const int32_t *input: the address of the input int array + int8_t *output: the address of the output integer array + float *scale: the address to update scaling factor S0 + uint32_t size: the number of elements in the input + int32_t input_min, input_max: the min and max value + of input int32 numbers. if input_min > input_max, + the min and max of the input integers are recalculated. +Note: The following function is near-equivalent to this sequence: + dequant_cvt_float_int32_t(input, temporal_array, *scale, size); + quant_sym_float_int8_t(temporal_array, output, scale, size); +************************************************************************/ +void requantize_symmetric_32to8(const int32_t *input, int8_t *output, + float *scale, uint32_t size, + int32_t input_min, int32_t input_max); + +/************************************************************************ + * Function: requantize_asymmetric_32to16 + * Description: asymmetric requantization from signed 32-bit int to + * unsigned 16-bit int, which does the same thing as + * requantize_asymmetric_32to8 except that the outputs + * and zero point are 16-bit integers + ***********************************************************************/ +void requantize_asymmetric_32to16(const int32_t *input, uint16_t *output, + float *scale, uint16_t *zero_point, uint32_t size, + int32_t input_min, int32_t input_max); + +/************************************************************************ + * Function: requantize_symmetric_32to16 + * Description: symmetric requantization from signed 32-bit int to + * signed 16-bit int, which does the same thing as + * requantize_symmetric_32to8 except that the outputs + * are 16-bit integers + ***********************************************************************/ +void requantize_symmetric_32to16(const int32_t *input, int16_t *output, + float *scale, uint32_t size, + int32_t input_min, int32_t input_max); + +/************************************************************************ + * Function: requantize_asymmetric_16to8 + * Description: asymmetric requantization from signed 16-bit int to + * unsigned 8-bit int, which does the same thing as + * requantize_asymmetric_32to8 except that the inputs + * are 16-bit integers + ***********************************************************************/ +void requantize_asymmetric_16to8(const int16_t *input, uint8_t *output, + float *scale, uint8_t *zero_point, uint32_t size, + int16_t input_min, int16_t input_max); + +/************************************************************************ + * Function: requantize_symmetric_16to8 + * Description: symmetric requantization from signed 16-bit int to + * signed 8-bit int, which does the same thing as + * requantize_symmetric_32to8 except that the inputs + * are 16-bit integers + ***********************************************************************/ +void requantize_symmetric_16to8(const int16_t *input, int8_t *output, + float *scale, uint32_t size, + int16_t input_min, int16_t input_max); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/include/arm_neon/ARMCompareAndSwap.h b/include/arm_neon/ARMCompareAndSwap.h new file mode 100644 index 0000000..6d11746 --- /dev/null +++ b/include/arm_neon/ARMCompareAndSwap.h @@ -0,0 +1,56 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: ARMCompareAndSwap.h + * Description: Atomic compare and swap functions on ARM processors + *****************************************************************************/ + +#include + +/****************************************************************************** + * Function: atomicCAS_U32 + * Description: Atomic "compare and swap" of 32-bit integer in main memory. + * Parameters: comp: the value to compare + * write: the value to write + * dst: the memory location of 32-bit integer + * Operation: # atomic operation + * { + * uint32_t ret = *dst; + * if (*dst == comp) *dst = write; + * return ret; + * } + * Return: The original value of the 32-bit integer in memory + *****************************************************************************/ +uint32_t atomicCAS_U32(uint32_t comp, uint32_t write, uint32_t *dst); + +/****************************************************************************** + * Function: atomicCAS_U64 + * Description: Atomic "compare and swap" of 64-bit integer in main memory. + * Parameters: comp: the value to compare + * write: the value to write + * dst: the memory location of 64-bit integer + * Operation: # atomic operation + * { + * uint64_t ret = *dst; + * if (*dst == comp) *dst = write; + * return ret; + * } + * Return: The original value of the 64-bit integer in memory + *****************************************************************************/ +uint64_t atomicCAS_U64(uint64_t comp, uint64_t write, uint64_t *dst); + diff --git a/include/arm_neon/ARMCpuType.h b/include/arm_neon/ARMCpuType.h new file mode 100644 index 0000000..94aac99 --- /dev/null +++ b/include/arm_neon/ARMCpuType.h @@ -0,0 +1,85 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: ARMCpuType.h + * Description: Functions support real-time ARM CPU detection: + * CPU pipeline type & ISA support. + * On ARM, a user program is not allowed to access system + * registers holding CPUID information. As a result, the CPU + * recognization relies on a "healthy" linux kernel which + * read those registers and store their information into sysfs + * on initialization process. + *****************************************************************************/ + +#include + +/* currently the function can work only on Linux kernels since 2015 */ + +#ifndef INCLUDE_ARM_CPUTYPE +#define INCLUDE_ARM_CPUTYPE + +/***************************************************************************** + * Function: blas_arm_get_cpu_type + * Description: Detect the NEON pipeline type of the CPU. There're 4 major + * types of NEON pipelines: + * (1) only 1 64-bit NEON pipeline, shared by vector load & arith, + * with in-order execution, + * like that in cortex-A7 and cortex-A35. + * (2) 2 64-bit NEON pipelines, can be combined to execute 128-bit + * wide operations, shared by vector load & arith, with + * in-order execution & dual-issue ability, + * like that in cortex-A53. + * (3) has identical NEON piplines as stated in (2), with an + * additional load unit capable of simple 64-bit NEON loads + * and element insertion, like that in cortex-A55. + * (4) at least 2 64-bit NEON pipelines, out-of-order execution, + * has additional load unit(s) supporting vector loads, like + * that in cortex-A57. + * Parameter: cpuid: the ID of CPU core whose type need to be determined, + * e.g. the return value of sched_getcpu() when the core where + * the calling thread runs needs to be determined. + * Return: A 8-bit integer representing the type, 35 for (1), 53 for (2), + * 55 for (3) and 0 for (4) + ****************************************************************************/ +uint8_t blas_arm_get_cpu_type(uint8_t cpuid); + +/***************************************************************************** + * Function: blas_arm_get_fp16_support() + * Description: Determine the support level for half-precision arithmetic + * operations of the current system. Rely on "healthy" linux + * kernel which detects the CPU correctly. + * Return: 0 for no-support, 1 for support of conversion from/to fp32, + * 2 for support of add/mul/fma operations + ****************************************************************************/ +uint8_t blas_arm_get_fp16_support(); + +/***************************************************************************** + * Function: blas_arm_get_i8i32_support() + * Description: Determine the support level for int8->int32 accumulate + * operations of the current system. Rely on "healthy" linux + * kernel which detects the CPU correctly. + * Return: 0 for no-support, 1 for support with *mlal instructions, + * 2 for support with *dot instructions + ****************************************************************************/ +/* return an integer indicating i8->i32 GEMM support */ +/* return 0 for non-support from SIMD */ +/* return 1 for basic support with SIMD multiply add */ +/* return 2 when armv8.2a-dotprod is available */ +uint8_t blas_arm_get_i8i32_support(); + +#endif diff --git a/include/arm_neon/NeonBias.h b/include/arm_neon/NeonBias.h new file mode 100644 index 0000000..22bd6c0 --- /dev/null +++ b/include/arm_neon/NeonBias.h @@ -0,0 +1,200 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/***************************************************************************** + * File: NeonBias.h + * Description: Bias functions based on ARM NEON instructions. + ****************************************************************************/ + +#include +#include +#include + +#ifndef INCLUDE_NEON_BIAS +#define INCLUDE_NEON_BIAS + +/***************************************************************************** + * Template: NEON_BIAS + * Description: Function template for NEON-based bias + * Template Parameters: type_scalar: the type of scalar data, + * e.g. float for fp32 bias + * type_vector: the type of SIMD vector data, + * e.g. float32x4_t + * type_short: the short of data type in NEON intrinsics, + * e.g. f32 for fp32 bias + * vector_size: the length of SIMD vector, e.g. 4 when + * type_vector == float32x4_t + * fma: the short for multiply-add operation in the name + * of NEON intrinsics. Use "fma" for fused + * multiply-add and "mla" for sequential multiply-add + * Function Parameters: C: the address of the matrix to apply bias on + * bias_dim0: the bias value on every element + * bias_dim1: the address of the input bias vector which + * will be applied to the matrix along its + * major dimension, i.e. when the element + * can be indexed by x * dim1 + y, each element + * is biased by bias_dim1[y]. No bias will be + * performed with NULL pointer as input. + * bias_dim1_scale: the scale to be applied on elements + * of bias_dim1[] prior to the bias + * operation + * bias_dim2: the address of the input bias vector which + * whill be applied to the matrix along its + * minor dimension, i.e. when the element + * can be indexed by x * dim1 + y, each element + * is biased by bias_dim2[x]. No bias will be + * performed with NULL pointer as input. + * bias_dim2_scale: the scale to be applied on elements + * of bias_dim2[] prior to the bias + * operation + * dim1: the length of the major dimension of input matrix + * dim2: the length of the minor dimension of input matrix + ****************************************************************************/ +#define NEON_BIAS(type_scalar, type_vector, type_short, vector_size, fma) \ +void bias_##type_scalar(type_scalar *C,\ + type_scalar bias_dim0,\ + const type_scalar *bias_dim1,\ + type_scalar bias_dim1_scale,\ + const type_scalar *bias_dim2,\ + type_scalar bias_dim2_scale,\ + uint32_t dim1, uint32_t dim2) {\ +\ + bool do_bias_0 = (bias_dim0 != 0);\ + bool do_bias_1 = bias_dim1 && (bias_dim1_scale != 0);\ + bool do_bias_2 = bias_dim2 && (bias_dim2_scale != 0);\ +\ + if (!do_bias_0 && !do_bias_1 && !do_bias_2) return;\ +\ + if (!do_bias_1 && (do_bias_0 || do_bias_2)) {\ + type_scalar *c_ptr = C;\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + const type_scalar bs = bias_dim0 + \ + (bias_dim2 ? bias_dim2[dim2_pos] * bias_dim2_scale : (type_scalar)0);\ + const type_vector bv = vdupq_n_##type_short(bs);\ + uint32_t dim1_left = dim1;\ + for (; dim1_left >= vector_size * 4; dim1_left -= vector_size * 4) {\ + type_vector c1 = vld1q_##type_short(c_ptr);\ + type_vector c2 = vld1q_##type_short(c_ptr + vector_size);\ + type_vector c3 = vld1q_##type_short(c_ptr + vector_size * 2);\ + type_vector c4 = vld1q_##type_short(c_ptr + vector_size * 3);\ + c1 = vaddq_##type_short(c1, bv);\ + c2 = vaddq_##type_short(c2, bv);\ + c3 = vaddq_##type_short(c3, bv);\ + c4 = vaddq_##type_short(c4, bv);\ + vst1q_##type_short(c_ptr, c1);\ + vst1q_##type_short(c_ptr + vector_size, c2);\ + vst1q_##type_short(c_ptr + vector_size * 2, c3);\ + vst1q_##type_short(c_ptr + vector_size * 3, c4);\ + c_ptr += vector_size * 4;\ + }\ + for (; dim1_left >= vector_size; dim1_left -= vector_size) {\ + type_vector c1 = vld1q_##type_short(c_ptr);\ + c1 = vaddq_##type_short(c1, bv);\ + vst1q_##type_short(c_ptr, c1); c_ptr += vector_size;\ + }\ + for (; dim1_left > 0; dim1_left--) {\ + *c_ptr += bs; c_ptr++;\ + }\ + }\ + } else if (do_bias_1 && !do_bias_0 && !do_bias_2) {\ + type_scalar *c_ptr = C;\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + uint32_t dim1_left = dim1;\ + const type_scalar *bias_ptr = bias_dim1;\ + for (; dim1_left >= vector_size * 4; dim1_left -= vector_size * 4) {\ + type_vector c1 = vld1q_##type_short(c_ptr);\ + type_vector c2 = vld1q_##type_short(c_ptr + vector_size);\ + type_vector c3 = vld1q_##type_short(c_ptr + vector_size * 2);\ + type_vector c4 = vld1q_##type_short(c_ptr + vector_size * 3);\ + type_vector b1 = vld1q_##type_short(bias_ptr);\ + type_vector b2 = vld1q_##type_short(bias_ptr + vector_size);\ + type_vector b3 = vld1q_##type_short(bias_ptr + vector_size * 2);\ + type_vector b4 = vld1q_##type_short(bias_ptr + vector_size * 3);\ + bias_ptr += vector_size * 4;\ + c1 = v##fma##q_n_##type_short(c1, b1, bias_dim1_scale);\ + c2 = v##fma##q_n_##type_short(c2, b2, bias_dim1_scale);\ + c3 = v##fma##q_n_##type_short(c3, b3, bias_dim1_scale);\ + c4 = v##fma##q_n_##type_short(c4, b4, bias_dim1_scale);\ + vst1q_##type_short(c_ptr, c1);\ + vst1q_##type_short(c_ptr + vector_size, c2);\ + vst1q_##type_short(c_ptr + vector_size * 2, c3);\ + vst1q_##type_short(c_ptr + vector_size * 3, c4);\ + c_ptr += vector_size * 4;\ + }\ + for (; dim1_left >= vector_size; dim1_left -= vector_size) {\ + type_vector c1 = vld1q_##type_short(c_ptr);\ + type_vector b1 = vld1q_##type_short(bias_ptr);\ + bias_ptr += vector_size;\ + c1 = v##fma##q_n_##type_short(c1, b1, bias_dim1_scale);\ + vst1q_##type_short(c_ptr, c1);\ + c_ptr += vector_size;\ + }\ + for (; dim1_left > 0; dim1_left--) {\ + *c_ptr += (*bias_ptr) * bias_dim1_scale; bias_ptr++; c_ptr++;\ + }\ + }\ + } else {\ + type_scalar *c_ptr = C;\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + const type_scalar bs = bias_dim0 + \ + (bias_dim2 ? bias_dim2[dim2_pos] * bias_dim2_scale : (type_scalar)0);\ + const type_vector bv = vdupq_n_##type_short(bs);\ + const type_scalar *bias_ptr = bias_dim1;\ + uint32_t dim1_left = dim1;\ + for (; dim1_left >= vector_size * 4; dim1_left -= vector_size * 4) {\ + type_vector c1 = vld1q_##type_short(c_ptr);\ + type_vector c2 = vld1q_##type_short(c_ptr + vector_size);\ + type_vector c3 = vld1q_##type_short(c_ptr + vector_size * 2);\ + type_vector c4 = vld1q_##type_short(c_ptr + vector_size * 3);\ + c1 = vaddq_##type_short(c1, bv);\ + c2 = vaddq_##type_short(c2, bv);\ + c3 = vaddq_##type_short(c3, bv);\ + c4 = vaddq_##type_short(c4, bv);\ + type_vector b1 = vld1q_##type_short(bias_ptr);\ + type_vector b2 = vld1q_##type_short(bias_ptr + vector_size);\ + type_vector b3 = vld1q_##type_short(bias_ptr + vector_size * 2);\ + type_vector b4 = vld1q_##type_short(bias_ptr + vector_size * 3);\ + bias_ptr += vector_size * 4;\ + c1 = v##fma##q_n_##type_short(c1, b1, bias_dim1_scale);\ + c2 = v##fma##q_n_##type_short(c2, b2, bias_dim1_scale);\ + c3 = v##fma##q_n_##type_short(c3, b3, bias_dim1_scale);\ + c4 = v##fma##q_n_##type_short(c4, b4, bias_dim1_scale);\ + vst1q_##type_short(c_ptr, c1);\ + vst1q_##type_short(c_ptr + vector_size, c2);\ + vst1q_##type_short(c_ptr + vector_size * 2, c3);\ + vst1q_##type_short(c_ptr + vector_size * 3, c4);\ + c_ptr += vector_size * 4;\ + }\ + for (; dim1_left >= vector_size; dim1_left -= vector_size) {\ + type_vector c1 = vld1q_##type_short(c_ptr);\ + c1 = vaddq_##type_short(c1, bv);\ + type_vector b1 = vld1q_##type_short(bias_ptr);\ + bias_ptr += vector_size;\ + c1 = v##fma##q_n_##type_short(c1, b1, bias_dim1_scale);\ + vst1q_##type_short(c_ptr, c1);\ + c_ptr += vector_size;\ + }\ + for (; dim1_left > 0; dim1_left--) {\ + *c_ptr += (*bias_ptr) * bias_dim1_scale + bs;\ + bias_ptr++; c_ptr++;\ + }\ + }\ + }\ +} + +#endif + diff --git a/include/arm_neon/NeonExtreme.h b/include/arm_neon/NeonExtreme.h new file mode 100644 index 0000000..3255e30 --- /dev/null +++ b/include/arm_neon/NeonExtreme.h @@ -0,0 +1,112 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: NeonExtreme.h + * Description: Source code template for NEON max/min functions. + *****************************************************************************/ + +#include "common/ExpandMacro.h" +#include +#include + +#ifndef INCLUDE_NEON_EXTREME +#define INCLUDE_NEON_EXTREME + +#define NEON_REDUC_S_ITEM(n, type, short) {\ + type tmin = vget_lane_##short(vmin1d, n - 1);\ + type tmax = vget_lane_##short(vmax1d, n - 1);\ + smin = tmin < smin ? tmin : smin;\ + smax = tmax > smax ? tmax : smax;\ +} + +#define NEON_REDUC_S_MIN_MAX(n, type, short) \ + MACRO_EXPANSION_##n(VOID_BASE, NEON_REDUC_S_ITEM, type, short) + +#define NEON_FIND_EXTREME(type, short, dvec, qvec, dlen) \ +static inline void inline_find_extreme_##type(const type *dat, uint32_t size,\ + type *min, type *max) {\ +\ + qvec vmin1, vmin2, vmin3, vmin4;\ + qvec vmax1, vmax2, vmax3, vmax4;\ +\ + if (size == 0) return;\ + vmin1 = vmin2 = vmin3 = vmin4 = \ + vmax1 = vmax2 = vmax3 = vmax4 = vld1q_dup_##short(dat);\ + uint32_t elem_left = size;\ + for (; elem_left >= dlen * 8; elem_left -= dlen * 8) {\ + qvec l1 = vld1q_##short(dat);\ + qvec l2 = vld1q_##short(dat + dlen * 2);\ + qvec l3 = vld1q_##short(dat + dlen * 4);\ + qvec l4 = vld1q_##short(dat + dlen * 6);\ + dat += dlen * 8;\ + vmin1 = vminq_##short(vmin1, l1);\ + vmax1 = vmaxq_##short(vmax1, l1);\ + vmin2 = vminq_##short(vmin2, l2);\ + vmax2 = vmaxq_##short(vmax2, l2);\ + vmin3 = vminq_##short(vmin3, l3);\ + vmax3 = vmaxq_##short(vmax3, l3);\ + vmin4 = vminq_##short(vmin4, l4);\ + vmax4 = vmaxq_##short(vmax4, l4);\ + }\ + vmin1 = vminq_##short(vmin1, vmin3);\ + vmin2 = vminq_##short(vmin2, vmin4);\ + vmax1 = vmaxq_##short(vmax1, vmax3);\ + vmax2 = vmaxq_##short(vmax2, vmax4);\ + if (elem_left >= dlen * 4) {\ + qvec l1 = vld1q_##short(dat);\ + qvec l2 = vld1q_##short(dat + dlen * 2);\ + dat += dlen * 4;\ + vmin1 = vminq_##short(vmin1, l1);\ + vmax1 = vmaxq_##short(vmax1, l1);\ + vmin2 = vminq_##short(vmin2, l2);\ + vmax2 = vmaxq_##short(vmax2, l2);\ + elem_left -= dlen * 4;\ + }\ + vmin1 = vminq_##short(vmin1, vmin2);\ + vmax1 = vmaxq_##short(vmax1, vmax2);\ + if (elem_left >= dlen * 2) {\ + qvec l1 = vld1q_##short(dat);\ + dat += dlen * 2;\ + vmin1 = vminq_##short(vmin1, l1);\ + vmax1 = vmaxq_##short(vmax1, l1);\ + elem_left -= dlen * 2;\ + }\ + dvec vmin1d = vmin_##short(vget_low_##short(vmin1),\ + vget_high_##short(vmin1));\ + dvec vmax1d = vmax_##short(vget_low_##short(vmax1),\ + vget_high_##short(vmax1));\ + if (elem_left >= dlen) {\ + dvec d1 = vld1_##short(dat);\ + dat += dlen;\ + vmin1d = vmin_##short(vmin1d, d1);\ + vmax1d = vmax_##short(vmax1d, d1);\ + elem_left -= dlen;\ + }\ + type smin = vget_lane_##short(vmin1d, 0);\ + type smax = vget_lane_##short(vmax1d, 0);\ + NEON_REDUC_S_MIN_MAX(dlen, type, short)\ + for (; elem_left > 0; elem_left--) {\ + type s1 = *dat++;\ + smin = s1 < smin ? s1 : smin;\ + smax = s1 > smax ? s1 : smax;\ + }\ + *min = smin;\ + *max = smax;\ +} + +#endif diff --git a/include/arm_neon/NeonI8I32DotGemmSkinnyDot.h b/include/arm_neon/NeonI8I32DotGemmSkinnyDot.h new file mode 100644 index 0000000..c3dfa07 --- /dev/null +++ b/include/arm_neon/NeonI8I32DotGemmSkinnyDot.h @@ -0,0 +1,153 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/***************************************************************************** + * File: NeonI8I32DotGemmSkinnyDot.h + * Description: Source code template for NEON 8->32bit GEMM skinny dot kernel + ****************************************************************************/ + +#include "common/CommonSkinnyDot.h" +#include "arm_neon/NeonIntOpSign.h" + +#ifndef INCLUDE_I8I32_DOT_SKINNYDOT +#define INCLUDE_I8I32_DOT_SKINNYDOT + +typedef I8 I8I32DOTGEMM_SKINNYDOT_ASCALAR; +typedef I8 I8I32DOTGEMM_SKINNYDOT_BSCALAR; +typedef I32 I8I32DOTGEMM_SKINNYDOT_CSCALAR; + +typedef I16 I8I32DOTGEMM_SKINNYDOT_AVEC1; +typedef I16 I8I32DOTGEMM_SKINNYDOT_BVEC1; +typedef I32 I8I32DOTGEMM_SKINNYDOT_CVEC1; + +typedef I8X8 I8I32DOTGEMM_SKINNYDOT_AVEC4; +typedef I8X8 I8I32DOTGEMM_SKINNYDOT_BVEC4; +typedef I32X2 I8I32DOTGEMM_SKINNYDOT_CVEC4; + +typedef I8X8 I8I32DOTGEMM_SKINNYDOT_AVEC8; +typedef I8X8 I8I32DOTGEMM_SKINNYDOT_BVEC8; +typedef I32X2 I8I32DOTGEMM_SKINNYDOT_CVEC8; + +typedef I8X16 I8I32DOTGEMM_SKINNYDOT_AVEC16; +typedef I8X16 I8I32DOTGEMM_SKINNYDOT_BVEC16; +typedef I32X4 I8I32DOTGEMM_SKINNYDOT_CVEC16; + +#define GEMM_SKINNY_DOT_UNIT_DEDUCE(TYPE, ...) \ + GEMM_SKINNY_DOT_##TYPE##_UNIT(__VA_ARGS__) + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32DOTGEMM, 16) { + return VDOTQ_I32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32DOTGEMM, 8) { + return VDOT_I32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32DOTGEMM, 4) { + return VDOT_I32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32DOTGEMM, 1) { + return c_vec + a_vec * b_vec; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32DOTGEMM, 16) { +#if __aarch64__ + __asm__("prfm pldl1keep,[%0,#80]"::"r"(a_ptr):); +#else + __asm__("pld [%0,#80]"::"r"(a_ptr):); +#endif + return VLD1Q_I8(a_ptr); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32DOTGEMM, 8) { +#if __aarch64__ + __asm__("prfm pldl1keep,[%0,#72]"::"r"(a_ptr):); +#else + __asm__("pld [%0,#72]"::"r"(a_ptr):); +#endif + return VLD1_I8(a_ptr); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32DOTGEMM, 4) { +#if __aarch64__ + I8X8 ret; /* higher 4 elements not used */ + __asm__("ldr %s0,[%1]; prfm pldl1keep,[%1,#72]":"=w"(ret):"r"(a_ptr):"memory"); +#else + register I8X8 ret __asm("d0"); /* higher 4 elements not used */ + __asm__("vld1.32 {%0[0]},[%1]; pld [%1,#72]":"=w"(ret):"r"(a_ptr):"memory"); +#endif + return ret; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32DOTGEMM, 1) { + return *a_ptr; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32DOTGEMM, 16) { + return VLD1Q_I8(b_ptr); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32DOTGEMM, 8) { + return VLD1_I8(b_ptr); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32DOTGEMM, 4) { +#if __aarch64__ + I8X8 ret; /* higher 4 elements not used */ + __asm__("ldr %s0,[%1]":"=w"(ret):"r"(b_ptr):"memory"); +#else + register I8X8 ret __asm("d0"); /* higher 4 elements not used */ + __asm__("vld1.32 {%0[0]},[%1]":"=w"(ret):"r"(b_ptr):"memory"); +#endif + return ret; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32DOTGEMM, 1) { + return *b_ptr; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(REDUC, I8I32DOTGEMM, 16, 8) { + return VADD_I32(VGET_LOW_I32(c_vec), VGET_HIGH_I32(c_vec)); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(REDUC, I8I32DOTGEMM, 8, 4) { + const static I32X2 z0 = {0, 0}; + return VPADD_I32(c_vec, z0); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(REDUC, I8I32DOTGEMM, 4, 1) { + return VGET_LANE_I32(c_vec, 0); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32DOTGEMM, 16) { + return VDUPQ_N_I32(0); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32DOTGEMM, 8) { + return VDUP_N_I32(0); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32DOTGEMM, 4) { + return VDUP_N_I32(0); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32DOTGEMM, 1) { + return 0; +} + +#endif \ No newline at end of file diff --git a/include/arm_neon/NeonI8I32MlaGemmCopy.h b/include/arm_neon/NeonI8I32MlaGemmCopy.h new file mode 100644 index 0000000..1ac52c3 --- /dev/null +++ b/include/arm_neon/NeonI8I32MlaGemmCopy.h @@ -0,0 +1,181 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/***************************************************************************** + * File: NeonI8I32MlaGemmCopy.h + * Description: Source code template for NEON 8->32bit GEMM packing functions + ****************************************************************************/ + +#include "NeonIntOpSign.h" + +#ifndef INCLUDE_NEON_I8I32_COPY +#define INCLUDE_NEON_I8I32_COPY + +static inline void pref_ab(const I8 *dat) { +#if __aarch64__ + __asm__ ("prfm pldl1keep,[%0,#64]\n\t"::"r"(dat):); +#else + __asm__ ("pld [%0,#64]\n\t"::"r"(dat):); +#endif +} + +#define NCOPY_LOOP_K8_UNROLL4(inc, dst_ptr, src1, src2, src3, src4) \ + for (dim1_count = dim1_cache; dim1_count > 7; dim1_count -= 8) {\ + I8X8 d1 = VLD1_I8(src1); src1 += 8; pref_ab(src1);\ + I8X8 d2 = VLD1_I8(src2); src2 += 8; pref_ab(src2);\ + I8X8 d3 = VLD1_I8(src3); src3 += 8; pref_ab(src3);\ + I8X8 d4 = VLD1_I8(src4); src4 += 8; pref_ab(src4);\ + I16X8X4 tm1;\ + tm1.val[0] = VMOVL_I8(d1); tm1.val[1] = VMOVL_I8(d2);\ + tm1.val[2] = VMOVL_I8(d3); tm1.val[3] = VMOVL_I8(d4);\ + VST4Q_LANE_I16(dst_ptr, tm1, 0);\ + VST4Q_LANE_I16(dst_ptr + inc, tm1, 1);\ + VST4Q_LANE_I16(dst_ptr + inc * 2, tm1, 2);\ + VST4Q_LANE_I16(dst_ptr + inc * 3, tm1, 3);\ + VST4Q_LANE_I16(dst_ptr + inc * 4, tm1, 4);\ + VST4Q_LANE_I16(dst_ptr + inc * 5, tm1, 5);\ + VST4Q_LANE_I16(dst_ptr + inc * 6, tm1, 6);\ + VST4Q_LANE_I16(dst_ptr + inc * 7, tm1, 7);\ + dst_ptr += inc * 8;\ + } + +#define NCOPY_LOOP_K8_UNROLL3(inc, dst_ptr, src1, src2, src3) \ + for (dim1_count = dim1_cache; dim1_count > 7; dim1_count -= 8) {\ + I8X8 d1 = VLD1_I8(src1); src1 += 8; pref_ab(src1);\ + I8X8 d2 = VLD1_I8(src2); src2 += 8; pref_ab(src2);\ + I8X8 d3 = VLD1_I8(src3); src3 += 8; pref_ab(src3);\ + I16X8X3 tm1;\ + tm1.val[0] = VMOVL_I8(d1);\ + tm1.val[1] = VMOVL_I8(d2);\ + tm1.val[2] = VMOVL_I8(d3);\ + VST3Q_LANE_I16(dst_ptr, tm1, 0);\ + VST3Q_LANE_I16(dst_ptr + inc, tm1, 1);\ + VST3Q_LANE_I16(dst_ptr + inc * 2, tm1, 2);\ + VST3Q_LANE_I16(dst_ptr + inc * 3, tm1, 3);\ + VST3Q_LANE_I16(dst_ptr + inc * 4, tm1, 4);\ + VST3Q_LANE_I16(dst_ptr + inc * 5, tm1, 5);\ + VST3Q_LANE_I16(dst_ptr + inc * 6, tm1, 6);\ + VST3Q_LANE_I16(dst_ptr + inc * 7, tm1, 7);\ + dst_ptr += inc * 8;\ + } + +#define NCOPY_UNROLL_12 {\ + I16 *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_LOOP_K8_UNROLL4(12, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_LOOP_K8_UNROLL4(12, dst_h1, src5, src6, src7, src8)\ + dst_h1 = dst1 + 8;\ + NCOPY_LOOP_K8_UNROLL4(12, dst_h1, src9, src10, src11, src12)\ + dst1 = dst_h1 - 8;\ + NCOPY_STD(12)\ +} + +#define NCOPY_UNROLL_8 {\ + I16 *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_LOOP_K8_UNROLL4(8, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_LOOP_K8_UNROLL4(8, dst_h1, src5, src6, src7, src8)\ + dst1 = dst_h1 - 4;\ + NCOPY_STD(8)\ +} + +#define NCOPY_UNROLL_6 {\ + I16 *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_LOOP_K8_UNROLL3(6, dst_h1, src1, src2, src3)\ + dst_h1 = dst1 + 3;\ + NCOPY_LOOP_K8_UNROLL3(6, dst_h1, src4, src5, src6)\ + dst1 = dst_h1 - 3;\ + NCOPY_STD(6)\ +} + +#define NCOPY_UNROLL_4 {\ + uint32_t dim1_cache = dim1_count;\ + NCOPY_LOOP_K8_UNROLL4(4, dst1, src1, src2, src3, src4)\ + NCOPY_STD(4)\ +} + +#define NCOPY_UNROLL_2 NCOPY_STD(2) +#define NCOPY_UNROLL_1 NCOPY_STD(1) + +#ifdef GEMM_UNSIGNED_INT +#define NCOPY_uint8_t_uint16_t(unroll) NCOPY_UNROLL_##unroll +#else +#define NCOPY_int8_t_int16_t(unroll) NCOPY_UNROLL_##unroll +#endif + +#define TCOPY_UNIT_1(src_ptr, dst_ptr, dst_offset) \ + TCOPY_UNIT_STD(src_ptr, dst_ptr, dst_offset, 1) + +#define TCOPY_UNIT_2(src_ptr, dst_ptr, dst_offset) \ + TCOPY_UNIT_STD(src_ptr, dst_ptr, dst_offset, 2) + +static inline I16X4 vld1_i16_i8(const I8 *src) { +#if __aarch64__ + I16X4 ret; + __asm__("ldr %s0,[%1]; "ISHLL" %0.8h,%0.8b,#0\n\t" + :"=w"(ret):"r"(src):"memory","cc"); + return ret; +#else + I16X8 ret; + __asm__("vld1.32 {d0[0]},[%1]; "ASM_VMOVL_I8" %q0,d0\n\t" + :"=w"(ret):"r"(src):"memory","cc","d0"); + return VGET_LOW_I16(ret); +#endif +} + +static inline I16X8 vld1q_i16_i8(const I8 *src) { + return VMOVL_I8(VLD1_I8(src)); +} + +#define TCOPY_UNIT_4(src_ptr, dst_ptr, dst_offset) {\ + I16X4 tmp = vld1_i16_i8(src_ptr);\ + VST1_I16(dst_ptr + dst_offset, tmp);\ +} + +#define TCOPY_UNIT_6(src_ptr, dst_ptr, dst_offset) {\ + I16X4 tmp = vld1_i16_i8(src_ptr);\ + I16 t5 = src_ptr[4];\ + I16 t6 = src_ptr[5];\ + pref_ab(src_ptr + 6);\ + VST1_I16(dst_ptr + dst_offset, tmp);\ + dst_ptr[dst_offset + 4] = t5;\ + dst_ptr[dst_offset + 5] = t6;\ +} + +#define TCOPY_UNIT_8(src_ptr, dst_ptr, dst_offset) {\ + I16X8 tmp = vld1q_i16_i8(src_ptr);\ + pref_ab(src_ptr + 8);\ + VST1Q_I16(dst_ptr + dst_offset, tmp);\ +} + +#define TCOPY_UNIT_12(src_ptr, dst_ptr, dst_offset) {\ + I16X8 tmpq = vld1q_i16_i8(src_ptr);\ + I16X4 tmpd = vld1_i16_i8(src_ptr + 8);\ + pref_ab(src_ptr + 12);\ + VST1Q_I16(dst_ptr + dst_offset, tmpq);\ + VST1_I16(dst_ptr + dst_offset + 8, tmpd);\ +} + +#ifdef GEMM_UNSIGNED_INT +#define TCOPY_UNIT_uint8_t_uint16_t(src_ptr, dst_ptr, dst_offset, num_elements) \ + TCOPY_UNIT_##num_elements(src_ptr, dst_ptr, dst_offset) +#else +#define TCOPY_UNIT_int8_t_int16_t(src_ptr, dst_ptr, dst_offset, num_elements) \ + TCOPY_UNIT_##num_elements(src_ptr, dst_ptr, dst_offset) +#endif + +#endif diff --git a/include/arm_neon/NeonI8I32MlaGemmKernel.h b/include/arm_neon/NeonI8I32MlaGemmKernel.h new file mode 100644 index 0000000..5c9fcde --- /dev/null +++ b/include/arm_neon/NeonI8I32MlaGemmKernel.h @@ -0,0 +1,742 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: NeonI8I32MlaGemmKernel.h + * Description: Inline kernel function templates for NEON 8->32bit ingeter + * GEMM. This template can be used to generate signed and unsigned + * integer matmul functions. + * Names: "KERNEL_MxNy": code blocks that read panels from souce + * matrices and multiply them. + * "SAVE_MxNy": code blocks that store the multiply results + * to a region of output matrix. + *****************************************************************************/ + +#include "arm_neon/NeonIntOpSign.h" + +#ifndef INCLUDE_NEON_I8I32_KERNEL +#define INCLUDE_NEON_I8I32_KERNEL + +#define COMMON_KERNEL_HEADER(a_head, b_head) \ + const I16 *a_ptr = a_head;\ + const I16 *b_ptr = b_head;\ + uint32_t k_left = K; + +#define KERNEL_M1N1 \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, bd1;\ + I32X4 cq1 = VDUPQ_N_I32(0);\ + for (; k_left > 3; k_left -= 4) {\ + ad1 = VLD1_I16(a_ptr); a_ptr += 4;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_I16(cq1, ad1, bd1);\ + }\ + I32X2 cd1 = VADD_I32(VGET_LOW_I32(cq1), VGET_HIGH_I32(cq1));\ + I32 cs1 = VGET_LANE_I32(cd1, 0) + VGET_LANE_I32(cd1, 1);\ + for (; k_left > 0; k_left--) {\ + cs1 += (I32)(*a_ptr++) * (I32)(*b_ptr++);\ + } + +#define SAVE_M1N1 \ + cs1 += c_ptr[0] * beta; c_ptr[0] = cs1; + +#define KERNEL_M2N1_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, bd1;\ + I16X4X2 add1;\ + I32X4 cq1, cq2;\ + cq1 = cq2 = VDUPQ_N_I32(0);\ + for (; k_left > 3; k_left -= 4) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4); a_ptr += 8;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + add1 = VUZP_I16(ad1, ad2);\ + cq1 = VMLAL_I16(cq1, add1.val[0], bd1);\ + cq2 = VMLAL_I16(cq2, add1.val[1], bd1);\ + }\ + I32X2 cd1 = VADD_I32(VGET_LOW_I32(cq1), VGET_HIGH_I32(cq1));\ + I32X2 cd2 = VADD_I32(VGET_LOW_I32(cq2), VGET_HIGH_I32(cq2));\ + I32 cs1 = VGET_LANE_I32(cd1, 0) + VGET_LANE_I32(cd1, 1);\ + I32 cs2 = VGET_LANE_I32(cd2, 0) + VGET_LANE_I32(cd2, 1);\ + for (; k_left > 0; k_left--) {\ + I32 bs1 = *b_ptr++;\ + cs1 += (I32)a_ptr[0] * bs1;\ + cs2 += (I32)a_ptr[1] * bs1;\ + a_ptr += 2;\ + } + +#define KERNEL_M2N1 KERNEL_M2N1_UNIT(a_head, b_head) +#define KERNEL_M1N2 KERNEL_M2N1_UNIT(b_head, a_head) + +#define SAVE_M2N1 \ + cs1 += c_ptr[0] * beta; cs2 += c_ptr[1] * beta;\ + c_ptr[0] = cs1; c_ptr[1] = cs2; + +#define SAVE_M1N2 \ + cs1 += c_ptr[0] * beta; cs2 += c_ptr[ldc] * beta;\ + c_ptr[0] = cs1; c_ptr[ldc] = cs2; + +#define KERNEL_M2N2 \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, bd1, bd2;\ + I16X4X2 add1, bdd1;\ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + for (; k_left > 3; k_left -= 4) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4); a_ptr += 8;\ + bd1 = VLD1_I16(b_ptr); bd2 = VLD1_I16(b_ptr + 4); b_ptr += 8;\ + add1 = VUZP_I16(ad1, ad2); bdd1 = VUZP_I16(bd1, bd2);\ + cq1 = VMLAL_I16(cq1, add1.val[0], bdd1.val[0]);\ + cq2 = VMLAL_I16(cq2, add1.val[1], bdd1.val[0]);\ + cq3 = VMLAL_I16(cq3, add1.val[0], bdd1.val[1]);\ + cq4 = VMLAL_I16(cq4, add1.val[1], bdd1.val[1]);\ + }\ + I32X2 cd1 = VADD_I32(VGET_LOW_I32(cq1), VGET_HIGH_I32(cq1));\ + I32X2 cd2 = VADD_I32(VGET_LOW_I32(cq2), VGET_HIGH_I32(cq2));\ + I32X2 cd3 = VADD_I32(VGET_LOW_I32(cq3), VGET_HIGH_I32(cq3));\ + I32X2 cd4 = VADD_I32(VGET_LOW_I32(cq4), VGET_HIGH_I32(cq4));\ + I32 cs1 = VGET_LANE_I32(cd1, 0) + VGET_LANE_I32(cd1, 1);\ + I32 cs2 = VGET_LANE_I32(cd2, 0) + VGET_LANE_I32(cd2, 1);\ + I32 cs3 = VGET_LANE_I32(cd3, 0) + VGET_LANE_I32(cd3, 1);\ + I32 cs4 = VGET_LANE_I32(cd4, 0) + VGET_LANE_I32(cd4, 1);\ + for (; k_left > 0; k_left--) {\ + I32 as1 = a_ptr[0];\ + I32 as2 = a_ptr[1]; a_ptr += 2;\ + I32 bs1 = b_ptr[0];\ + I32 bs2 = b_ptr[1]; b_ptr += 2;\ + cs1 += as1 * bs1; cs2 += as2 * bs1;\ + cs3 += as1 * bs2; cs4 += as2 * bs2;\ + } + +#define SAVE_M2N2 \ + I32 *c_l1 = c_ptr + ldc;\ + cs1 += c_ptr[0] * beta; cs2 += c_ptr[1] * beta;\ + cs3 += c_l1[0] * beta; cs4 += c_l1[1] * beta;\ + c_ptr[0] = cs1; c_ptr[1] = cs2;\ + c_l1[0] = cs3; c_l1[1] = cs4; + +#define KERNEL_M4N1_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3, ad4, bd1;\ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + for (; k_left > 3; k_left -= 4) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); ad4 = VLD1_I16(a_ptr + 12); a_ptr += 16;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, ad2, bd1, 1);\ + cq3 = VMLAL_LANE_I16(cq3, ad3, bd1, 2);\ + cq4 = VMLAL_LANE_I16(cq4, ad4, bd1, 3);\ + }\ + cq1 = VADDQ_I32(cq1, cq3); cq2 = VADDQ_I32(cq2, cq4);\ + cq1 = VADDQ_I32(cq1, cq2);\ + for (; k_left > 0; k_left--) {\ + ad1 = VLD1_I16(a_ptr); a_ptr += 4;\ + I16 bs1 = *b_ptr++;\ + cq1 = VMLAL_N_I16(cq1, ad1, bs1);\ + } + +#define KERNEL_M4N1 KERNEL_M4N1_UNIT(a_head, b_head) +#define KERNEL_M1N4 KERNEL_M4N1_UNIT(b_head, a_head) + +#define SAVE_M4N1 \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_ptr), beta);\ + VST1Q_I32(c_ptr, cq1); + +#define UNIT_SAVE_M1N4(cq1) \ + c_tmp[0] = c_tmp[0] * beta + VGETQ_LANE_I32(cq1, 0);\ + c_tmp[ldc] = c_tmp[ldc] * beta + VGETQ_LANE_I32(cq1, 1);\ + c_tmp += ldc * 2;\ + c_tmp[0] = c_tmp[0] * beta + VGETQ_LANE_I32(cq1, 2);\ + c_tmp[ldc] = c_tmp[ldc] * beta + VGETQ_LANE_I32(cq1, 3);\ + c_tmp += ldc * 2; + +#define SAVE_M1N4 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M1N4(cq1) + +#define KERNEL_M4N2_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, bd1;\ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + for (; k_left > 1; k_left -= 2) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4); a_ptr += 8;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, ad1, bd1, 1);\ + cq3 = VMLAL_LANE_I16(cq3, ad2, bd1, 2);\ + cq4 = VMLAL_LANE_I16(cq4, ad2, bd1, 3);\ + }\ + cq1 = VADDQ_I32(cq1, cq3); cq2 = VADDQ_I32(cq2, cq4);\ + for (; k_left > 0; k_left--) {\ + ad1 = VLD1_I16(a_ptr); a_ptr += 4;\ + I16 bs1 = b_ptr[0];\ + I16 bs2 = b_ptr[1]; b_ptr += 2;\ + cq1 = VMLAL_N_I16(cq1, ad1, bs1); cq2 = VMLAL_N_I16(cq2, ad1, bs2);\ + } + +#define KERNEL_M4N2 KERNEL_M4N2_UNIT(a_head, b_head) +#define KERNEL_M2N4 KERNEL_M4N2_UNIT(b_head, a_head) + +#define UNIT_SAVE_M4N2(cq1, cq2) \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_tmp), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_tmp + ldc), beta);\ + VST1Q_I32(c_tmp, cq1); VST1Q_I32(c_tmp + ldc, cq2);\ + c_tmp += ldc * 2; + +#define SAVE_M4N2 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M4N2(cq1, cq2) + +#define UNIT_SAVE_M2N4(cq1, cq2) {\ + I32X4X2 tm1 = VZIPQ_I32(cq1, cq2);\ + I32X2 l1 = VMLA_N_I32(VGET_LOW_I32(tm1.val[0]),\ + VLD1_I32(c_tmp), beta);\ + I32X2 l2 = VMLA_N_I32(VGET_HIGH_I32(tm1.val[0]),\ + VLD1_I32(c_tmp + ldc), beta);\ + VST1_I32(c_tmp, l1); VST1_I32(c_tmp + ldc, l2); c_tmp += ldc * 2;\ + I32X2 l3 = VMLA_N_I32(VGET_LOW_I32(tm1.val[1]),\ + VLD1_I32(c_tmp), beta);\ + I32X2 l4 = VMLA_N_I32(VGET_HIGH_I32(tm1.val[1]),\ + VLD1_I32(c_tmp + ldc), beta);\ + VST1_I32(c_tmp, l3); VST1_I32(c_tmp + ldc, l4); c_tmp += ldc * 2;\ +} + +#define SAVE_M2N4 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M2N4(cq1, cq2) + +#define KERNEL_M4N4 \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, bd1;\ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + for (; k_left > 0; k_left--) {\ + ad1 = VLD1_I16(a_ptr); a_ptr += 4;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, ad1, bd1, 1);\ + cq3 = VMLAL_LANE_I16(cq3, ad1, bd1, 2);\ + cq4 = VMLAL_LANE_I16(cq4, ad1, bd1, 3);\ + } + +#define SAVE_M4N4 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M4N2(cq1, cq2) UNIT_SAVE_M4N2(cq3, cq4) + +#define KERNEL_M8N1_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3, ad4, ad5, ad6, ad7, ad8, bd1;\ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + for (; k_left > 3; k_left -= 4) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); ad4 = VLD1_I16(a_ptr + 12);\ + ad5 = VLD1_I16(a_ptr + 16); ad6 = VLD1_I16(a_ptr + 20);\ + ad7 = VLD1_I16(a_ptr + 24); ad8 = VLD1_I16(a_ptr + 28); a_ptr += 32;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, ad2, bd1, 0);\ + cq3 = VMLAL_LANE_I16(cq3, ad3, bd1, 1);\ + cq4 = VMLAL_LANE_I16(cq4, ad4, bd1, 1);\ + cq1 = VMLAL_LANE_I16(cq1, ad5, bd1, 2);\ + cq2 = VMLAL_LANE_I16(cq2, ad6, bd1, 2);\ + cq3 = VMLAL_LANE_I16(cq3, ad7, bd1, 3);\ + cq4 = VMLAL_LANE_I16(cq4, ad8, bd1, 3);\ + }\ + cq1 = VADDQ_I32(cq1, cq3); cq2 = VADDQ_I32(cq2, cq4);\ + for (; k_left > 0; k_left--) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4); a_ptr += 8;\ + I16 bs1 = *b_ptr++;\ + cq1 = VMLAL_N_I16(cq1, ad1, bs1); cq2 = VMLAL_N_I16(cq2, ad2, bs1);\ + } + +#define KERNEL_M8N1 KERNEL_M8N1_UNIT(a_head, b_head) +#define KERNEL_M1N8 KERNEL_M8N1_UNIT(b_head, a_head) + +#define SAVE_M8N1 \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_ptr), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_ptr + 4), beta);\ + VST1Q_I32(c_ptr, cq1); VST1Q_I32(c_ptr + 4, cq2); + +#define SAVE_M1N8 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M1N4(cq1) UNIT_SAVE_M1N4(cq2) + +#define KERNEL_M8N2_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3, ad4, bd1;\ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + for (; k_left > 1; k_left -= 2) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); ad4 = VLD1_I16(a_ptr + 12); a_ptr += 16;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, ad2, bd1, 0);\ + cq3 = VMLAL_LANE_I16(cq3, ad1, bd1, 1);\ + cq4 = VMLAL_LANE_I16(cq4, ad2, bd1, 1);\ + cq1 = VMLAL_LANE_I16(cq1, ad3, bd1, 2);\ + cq2 = VMLAL_LANE_I16(cq2, ad4, bd1, 2);\ + cq3 = VMLAL_LANE_I16(cq3, ad3, bd1, 3);\ + cq4 = VMLAL_LANE_I16(cq4, ad4, bd1, 3);\ + }\ + if (k_left > 0) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4); a_ptr += 8;\ + I16 bs1 = b_ptr[0];\ + I16 bs2 = b_ptr[1]; b_ptr += 2;\ + cq1 = VMLAL_N_I16(cq1, ad1, bs1); cq2 = VMLAL_N_I16(cq2, ad2, bs1);\ + cq3 = VMLAL_N_I16(cq3, ad1, bs2); cq4 = VMLAL_N_I16(cq4, ad2, bs2);\ + } + +#define KERNEL_M8N2 KERNEL_M8N2_UNIT(a_head, b_head) +#define KERNEL_M2N8 KERNEL_M8N2_UNIT(b_head, a_head) + +#define UNIT_SAVE_M8N2(cq1, cq2, cq3, cq4) \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_tmp), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_tmp + 4), beta);\ + cq3 = VMLAQ_N_I32(cq3, VLD1Q_I32(c_tmp + ldc), beta);\ + cq4 = VMLAQ_N_I32(cq4, VLD1Q_I32(c_tmp + ldc + 4), beta);\ + VST1Q_I32(c_tmp, cq1); VST1Q_I32(c_tmp + 4, cq2);\ + VST1Q_I32(c_tmp + ldc, cq3); VST1Q_I32(c_tmp + ldc + 4, cq4);\ + c_tmp += ldc * 2; + +#define SAVE_M8N2 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M8N2(cq1, cq2, cq3, cq4) + +#define SAVE_M2N8 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M2N4(cq1, cq3) UNIT_SAVE_M2N4(cq2, cq4) + +#define KERNEL_M8N4_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, bd1;\ + I32X4 cq1, cq2, cq3, cq4, cq5, cq6, cq7, cq8;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = cq7 = cq8 = VDUPQ_N_I32(0);\ + for (; k_left > 0; k_left--) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4); a_ptr += 8;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, ad2, bd1, 0);\ + cq3 = VMLAL_LANE_I16(cq3, ad1, bd1, 1);\ + cq4 = VMLAL_LANE_I16(cq4, ad2, bd1, 1);\ + cq5 = VMLAL_LANE_I16(cq5, ad1, bd1, 2);\ + cq6 = VMLAL_LANE_I16(cq6, ad2, bd1, 2);\ + cq7 = VMLAL_LANE_I16(cq7, ad1, bd1, 3);\ + cq8 = VMLAL_LANE_I16(cq8, ad2, bd1, 3);\ + } + +#define KERNEL_M8N4 KERNEL_M8N4_UNIT(a_head, b_head) +#define KERNEL_M4N8 KERNEL_M8N4_UNIT(b_head, a_head) + +#define SAVE_M8N4 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M8N2(cq1, cq2, cq3, cq4) UNIT_SAVE_M8N2(cq5, cq6, cq7, cq8) + +#define UNIT_SAVE_M4N4_TRANS(cq1, cq2, cq3, cq4) {\ + I32X4 l1 = VLD1Q_I32(c_tmp);\ + I32X4 l2 = VLD1Q_I32(c_tmp + ldc);\ + I32X4 l3 = VLD1Q_I32(c_tmp + ldc * 2);\ + I32X4 l4 = VLD1Q_I32(c_tmp + ldc * 3);\ + I32X4X2 tm1 = VZIPQ_I32(cq1, cq2);\ + I32X4X2 tm2 = VZIPQ_I32(cq3, cq4);\ + cq1 = VCOMBINE_I32(VGET_LOW_I32(tm1.val[0]), VGET_LOW_I32(tm2.val[0]));\ + cq2 = VCOMBINE_I32(VGET_HIGH_I32(tm1.val[0]), VGET_HIGH_I32(tm2.val[0]));\ + cq3 = VCOMBINE_I32(VGET_LOW_I32(tm1.val[1]), VGET_LOW_I32(tm2.val[1]));\ + cq4 = VCOMBINE_I32(VGET_HIGH_I32(tm1.val[1]), VGET_HIGH_I32(tm2.val[1]));\ + cq1 = VMLAQ_N_I32(cq1, l1, beta); cq2 = VMLAQ_N_I32(cq2, l2, beta);\ + cq3 = VMLAQ_N_I32(cq3, l3, beta); cq4 = VMLAQ_N_I32(cq4, l4, beta);\ + VST1Q_I32(c_tmp, cq1); VST1Q_I32(c_tmp + ldc, cq2);\ + VST1Q_I32(c_tmp + ldc * 2, cq3); VST1Q_I32(c_tmp + ldc * 3, cq4);\ + c_tmp += ldc * 4;\ +} + +#define SAVE_M4N8 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M4N4_TRANS(cq1, cq3, cq5, cq7)\ + UNIT_SAVE_M4N4_TRANS(cq2, cq4, cq6, cq8) + +#define KERNEL_M8N8 \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, bd1, bd2;\ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + I32X4 cq09, cq10, cq11, cq12, cq13, cq14, cq15, cq16;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = cq07 = cq08 = VDUPQ_N_I32(0);\ + cq09 = cq10 = cq11 = cq12 = cq13 = cq14 = cq15 = cq16 = VDUPQ_N_I32(0);\ + for (; k_left > 0; k_left--) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4); a_ptr += 8;\ + bd1 = VLD1_I16(b_ptr); bd2 = VLD1_I16(b_ptr + 4); b_ptr += 8;\ + cq01 = VMLAL_LANE_I16(cq01, ad1, bd1, 0);\ + cq02 = VMLAL_LANE_I16(cq02, ad2, bd1, 0);\ + cq03 = VMLAL_LANE_I16(cq03, ad1, bd1, 1);\ + cq04 = VMLAL_LANE_I16(cq04, ad2, bd1, 1);\ + cq05 = VMLAL_LANE_I16(cq05, ad1, bd1, 2);\ + cq06 = VMLAL_LANE_I16(cq06, ad2, bd1, 2);\ + cq07 = VMLAL_LANE_I16(cq07, ad1, bd1, 3);\ + cq08 = VMLAL_LANE_I16(cq08, ad2, bd1, 3);\ + cq09 = VMLAL_LANE_I16(cq09, ad1, bd2, 0);\ + cq10 = VMLAL_LANE_I16(cq10, ad2, bd2, 0);\ + cq11 = VMLAL_LANE_I16(cq11, ad1, bd2, 1);\ + cq12 = VMLAL_LANE_I16(cq12, ad2, bd2, 1);\ + cq13 = VMLAL_LANE_I16(cq13, ad1, bd2, 2);\ + cq14 = VMLAL_LANE_I16(cq14, ad2, bd2, 2);\ + cq15 = VMLAL_LANE_I16(cq15, ad1, bd2, 3);\ + cq16 = VMLAL_LANE_I16(cq16, ad2, bd2, 3);\ + } + +#define SAVE_M8N8 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M8N2(cq01, cq02, cq03, cq04)\ + UNIT_SAVE_M8N2(cq05, cq06, cq07, cq08)\ + UNIT_SAVE_M8N2(cq09, cq10, cq11, cq12)\ + UNIT_SAVE_M8N2(cq13, cq14, cq15, cq16) + +#define KERNEL_M12N1_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3;\ + I16 bs1;\ + I32X4 cq1, cq2, cq3;\ + cq1 = cq2 = cq3 = VDUPQ_N_I32(0);\ + for (; k_left > 0; k_left--) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); a_ptr += 12;\ + bs1 = *b_ptr++;\ + cq1 = VMLAL_N_I16(cq1, ad1, bs1);\ + cq2 = VMLAL_N_I16(cq2, ad2, bs1);\ + cq3 = VMLAL_N_I16(cq3, ad3, bs1);\ + } + +#define KERNEL_M12N1 KERNEL_M12N1_UNIT(a_head, b_head) +#define KERNEL_M1N12 KERNEL_M12N1_UNIT(b_head, a_head) + +#define SAVE_M12N1 \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_ptr), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_ptr + 4), beta);\ + cq3 = VMLAQ_N_I32(cq3, VLD1Q_I32(c_ptr + 8), beta);\ + VST1Q_I32(c_ptr, cq1); VST1Q_I32(c_ptr + 4, cq2); VST1Q_I32(c_ptr + 8, cq3); + +#define SAVE_M1N12 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M1N4(cq1)\ + UNIT_SAVE_M1N4(cq2) UNIT_SAVE_M1N4(cq3) + +#define KERNEL_M12N2_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3, ad4, ad5, ad6, bd1;\ + I32X4 cq1, cq2, cq3, cq4, cq5, cq6;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = VDUPQ_N_I32(0);\ + for (; k_left > 1; k_left -= 2) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); ad4 = VLD1_I16(a_ptr + 12);\ + ad5 = VLD1_I16(a_ptr + 16); ad6 = VLD1_I16(a_ptr + 20); a_ptr += 24;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, ad2, bd1, 0);\ + cq3 = VMLAL_LANE_I16(cq3, ad3, bd1, 0);\ + cq4 = VMLAL_LANE_I16(cq4, ad1, bd1, 1);\ + cq5 = VMLAL_LANE_I16(cq5, ad2, bd1, 1);\ + cq6 = VMLAL_LANE_I16(cq6, ad3, bd1, 1);\ + cq1 = VMLAL_LANE_I16(cq1, ad4, bd1, 2);\ + cq2 = VMLAL_LANE_I16(cq2, ad5, bd1, 2);\ + cq3 = VMLAL_LANE_I16(cq3, ad6, bd1, 2);\ + cq4 = VMLAL_LANE_I16(cq4, ad4, bd1, 3);\ + cq5 = VMLAL_LANE_I16(cq5, ad5, bd1, 3);\ + cq6 = VMLAL_LANE_I16(cq6, ad6, bd1, 3);\ + }\ + if (k_left > 0) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); a_ptr += 12;\ + I16 bs1 = b_ptr[0];\ + I16 bs2 = b_ptr[1]; b_ptr += 2;\ + cq1 = VMLAL_N_I16(cq1, ad1, bs1);\ + cq2 = VMLAL_N_I16(cq2, ad2, bs1);\ + cq3 = VMLAL_N_I16(cq3, ad3, bs1);\ + cq4 = VMLAL_N_I16(cq4, ad1, bs2);\ + cq5 = VMLAL_N_I16(cq5, ad2, bs2);\ + cq6 = VMLAL_N_I16(cq6, ad3, bs2);\ + } + +#define KERNEL_M12N2 KERNEL_M12N2_UNIT(a_head, b_head) +#define KERNEL_M2N12 KERNEL_M12N2_UNIT(b_head, a_head) + +#define UNIT_SAVE_M12N2(cq1, cq2, cq3, cq4, cq5, cq6) \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_tmp), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_tmp + 4), beta);\ + cq3 = VMLAQ_N_I32(cq3, VLD1Q_I32(c_tmp + 8), beta);\ + cq4 = VMLAQ_N_I32(cq4, VLD1Q_I32(c_tmp + ldc), beta);\ + cq5 = VMLAQ_N_I32(cq5, VLD1Q_I32(c_tmp + ldc + 4), beta);\ + cq6 = VMLAQ_N_I32(cq6, VLD1Q_I32(c_tmp + ldc + 8), beta);\ + VST1Q_I32(c_tmp, cq1); VST1Q_I32(c_tmp + 4, cq2);\ + VST1Q_I32(c_tmp + 8, cq3); VST1Q_I32(c_tmp + ldc, cq4);\ + VST1Q_I32(c_tmp + ldc + 4, cq5); VST1Q_I32(c_tmp + ldc + 8, cq6);\ + c_tmp += ldc * 2; + +#define SAVE_M12N2 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M12N2(cq1, cq2, cq3, cq4, cq5, cq6) + +#define SAVE_M2N12 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M2N4(cq1, cq4) UNIT_SAVE_M2N4(cq2, cq5)\ + UNIT_SAVE_M2N4(cq3, cq6) + +#define KERNEL_M12N4_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3, bd1;\ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06;\ + I32X4 cq07, cq08, cq09, cq10, cq11, cq12;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = VDUPQ_N_I32(0);\ + cq07 = cq08 = cq09 = cq10 = cq11 = cq12 = VDUPQ_N_I32(0);\ + for (; k_left > 0; k_left--) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); a_ptr += 12;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq01 = VMLAL_LANE_I16(cq01, ad1, bd1, 0);\ + cq02 = VMLAL_LANE_I16(cq02, ad2, bd1, 0);\ + cq03 = VMLAL_LANE_I16(cq03, ad3, bd1, 0);\ + cq04 = VMLAL_LANE_I16(cq04, ad1, bd1, 1);\ + cq05 = VMLAL_LANE_I16(cq05, ad2, bd1, 1);\ + cq06 = VMLAL_LANE_I16(cq06, ad3, bd1, 1);\ + cq07 = VMLAL_LANE_I16(cq07, ad1, bd1, 2);\ + cq08 = VMLAL_LANE_I16(cq08, ad2, bd1, 2);\ + cq09 = VMLAL_LANE_I16(cq09, ad3, bd1, 2);\ + cq10 = VMLAL_LANE_I16(cq10, ad1, bd1, 3);\ + cq11 = VMLAL_LANE_I16(cq11, ad2, bd1, 3);\ + cq12 = VMLAL_LANE_I16(cq12, ad3, bd1, 3);\ + } + +#define KERNEL_M12N4 KERNEL_M12N4_UNIT(a_head, b_head) +#define KERNEL_M4N12 KERNEL_M12N4_UNIT(b_head, a_head) + +#define SAVE_M12N4 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M12N2(cq01, cq02, cq03, cq04, cq05, cq06)\ + UNIT_SAVE_M12N2(cq07, cq08, cq09, cq10, cq11, cq12) + +#define SAVE_M4N12 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M4N4_TRANS(cq01, cq04, cq07, cq10)\ + UNIT_SAVE_M4N4_TRANS(cq02, cq05, cq08, cq11)\ + UNIT_SAVE_M4N4_TRANS(cq03, cq06, cq09, cq12) + +#define KERNEL_M6N1_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3, ad4, ad5, ad6, bd1;\ + I16X4X2 add1;\ + I32X4 cq1, cq2, cq3, cq4, cq5, cq6;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = VDUPQ_N_I32(0);\ + for (; k_left > 3; k_left -= 4) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); ad4 = VLD1_I16(a_ptr + 12);\ + ad5 = VLD1_I16(a_ptr + 16); ad6 = VLD1_I16(a_ptr + 20); a_ptr += 24;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + add1 = VUZP_I16(ad2, ad5);\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_I16(cq2, add1.val[0], bd1);\ + cq3 = VMLAL_LANE_I16(cq3, ad3, bd1, 1);\ + cq4 = VMLAL_LANE_I16(cq4, ad4, bd1, 2);\ + cq5 = VMLAL_I16(cq5, add1.val[1], bd1);\ + cq6 = VMLAL_LANE_I16(cq6, ad6, bd1, 3);\ + }\ + cq1 = VADDQ_I32(cq1, cq4); cq3 = VADDQ_I32(cq3, cq6);\ + cq4 = VCOMBINE_I32(VGET_LOW_I32(cq2), VGET_LOW_I32(cq5));\ + cq6 = VCOMBINE_I32(VGET_HIGH_I32(cq2), VGET_HIGH_I32(cq5));\ + cq2 = VADDQ_I32(cq4, cq6);\ + I32 cs1 = VGETQ_LANE_I32(cq1, 0) + VGETQ_LANE_I32(cq2, 1);\ + I32 cs2 = VGETQ_LANE_I32(cq1, 1) + VGETQ_LANE_I32(cq2, 3);\ + I32 cs3 = VGETQ_LANE_I32(cq1, 2) + VGETQ_LANE_I32(cq3, 0);\ + I32 cs4 = VGETQ_LANE_I32(cq1, 3) + VGETQ_LANE_I32(cq3, 1);\ + I32 cs5 = VGETQ_LANE_I32(cq2, 0) + VGETQ_LANE_I32(cq3, 2);\ + I32 cs6 = VGETQ_LANE_I32(cq2, 2) + VGETQ_LANE_I32(cq3, 3);\ + for (; k_left > 0; k_left--) {\ + I32 bs1 = *b_ptr++;\ + cs1 += bs1 * (I32)a_ptr[0];\ + cs2 += bs1 * (I32)a_ptr[1];\ + cs3 += bs1 * (I32)a_ptr[2];\ + cs4 += bs1 * (I32)a_ptr[3];\ + cs5 += bs1 * (I32)a_ptr[4];\ + cs6 += bs1 * (I32)a_ptr[5];\ + a_ptr += 6;\ + } + +#define KERNEL_M6N1 KERNEL_M6N1_UNIT(a_head, b_head) +#define KERNEL_M1N6 KERNEL_M6N1_UNIT(b_head, a_head) + +#define SAVE_M6N1 \ + cs1 += c_ptr[0] * beta; cs2 += c_ptr[1] * beta;\ + cs3 += c_ptr[2] * beta; cs4 += c_ptr[3] * beta;\ + cs5 += c_ptr[4] * beta; cs6 += c_ptr[5] * beta;\ + c_ptr[0] = cs1; c_ptr[1] = cs2; c_ptr[2] = cs3;\ + c_ptr[3] = cs4; c_ptr[4] = cs5; c_ptr[5] = cs6; + +#define SAVE_M1N6 \ + I32 *c_tmp = c_ptr;\ + cs1 += c_tmp[0] * beta; cs2 += c_tmp[ldc] * beta;\ + c_tmp[0] = cs1; c_tmp[ldc] = cs2; c_tmp += ldc * 2;\ + cs3 += c_tmp[0] * beta; cs4 += c_tmp[ldc] * beta;\ + c_tmp[0] = cs3; c_tmp[ldc] = cs4; c_tmp += ldc * 2;\ + cs5 += c_tmp[0] * beta; cs6 += c_tmp[ldc] * beta;\ + c_tmp[0] = cs5; c_tmp[ldc] = cs6; + +#define KERNEL_M6N2_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3, bd1;\ + I16X4X2 bdd1;\ + I32X4 cq1, cq2, cq3, cq4, cq5, cq6;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = VDUPQ_N_I32(0);\ + for (; k_left > 1; k_left -= 2) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); a_ptr += 12;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + bdd1 = VTRN_I16(bd1, bd1);\ + cq1 = VMLAL_LANE_I16(cq1, ad1, bd1, 0);\ + cq2 = VMLAL_I16(cq2, ad2, bdd1.val[0]);\ + cq3 = VMLAL_LANE_I16(cq3, ad3, bd1, 2);\ + cq4 = VMLAL_LANE_I16(cq4, ad1, bd1, 1);\ + cq5 = VMLAL_I16(cq5, ad2, bdd1.val[1]);\ + cq6 = VMLAL_LANE_I16(cq6, ad3, bd1, 3);\ + }\ + I32X2 cd1 = VADD_I32(VGET_LOW_I32(cq1), VGET_HIGH_I32(cq2));\ + I32X2 cd2 = VADD_I32(VGET_HIGH_I32(cq1), VGET_LOW_I32(cq3));\ + I32X2 cd3 = VADD_I32(VGET_LOW_I32(cq2), VGET_HIGH_I32(cq3));\ + I32X2 cd4 = VADD_I32(VGET_LOW_I32(cq4), VGET_HIGH_I32(cq5));\ + I32X2 cd5 = VADD_I32(VGET_HIGH_I32(cq4), VGET_LOW_I32(cq6));\ + I32X2 cd6 = VADD_I32(VGET_LOW_I32(cq5), VGET_HIGH_I32(cq6));\ + cq1 = VCOMBINE_I32(cd1, cd2); cq2 = VCOMBINE_I32(cd4, cd5);\ + I32 cs1 = VGET_LANE_I32(cd3, 0);\ + I32 cs2 = VGET_LANE_I32(cd3, 1);\ + I32 cs3 = VGET_LANE_I32(cd6, 0);\ + I32 cs4 = VGET_LANE_I32(cd6, 1);\ + if (k_left > 0) {\ + ad1 = VLD1_I16(a_ptr);\ + I32 as1 = a_ptr[4];\ + I32 as2 = a_ptr[5]; a_ptr += 6;\ + I32 bs1 = b_ptr[0];\ + I32 bs2 = b_ptr[1]; b_ptr += 2;\ + cq1 = VMLAL_N_I16(cq1, ad1, bs1);\ + cq2 = VMLAL_N_I16(cq2, ad1, bs2);\ + cs1 += as1 * bs1; cs2 += as2 * bs1;\ + cs3 += as1 * bs2; cs4 += as2 * bs2;\ + } + +#define KERNEL_M6N2 KERNEL_M6N2_UNIT(a_head, b_head) +#define KERNEL_M2N6 KERNEL_M6N2_UNIT(b_head, a_head) + +#define SAVE_M6N2 \ + I32 *c_l1 = c_ptr + ldc;\ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_ptr), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_l1), beta);\ + cs1 += c_ptr[4] * beta; cs2 += c_ptr[5] * beta;\ + cs3 += c_l1[4] * beta; cs4 += c_l1[5] * beta;\ + VST1Q_I32(c_ptr, cq1); VST1Q_I32(c_l1, cq2);\ + c_ptr[4] = cs1; c_ptr[5] = cs2;\ + c_l1[4] = cs3; c_l1[5] = cs4; + +#define SAVE_M2N6 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M2N4(cq1, cq2)\ + cs1 += c_tmp[0] * beta; cs3 += c_tmp[1] * beta;\ + c_tmp[0] = cs1; c_tmp[1] = cs3; c_tmp += ldc;\ + cs2 += c_tmp[0] * beta; cs4 += c_tmp[1] * beta;\ + c_tmp[0] = cs2; c_tmp[1] = cs4; + +#define KERNEL_M6N4_UNIT(a_head, b_head) \ + COMMON_KERNEL_HEADER(a_head, b_head) \ + I16X4 ad1, ad2, ad3, bd1, bd2;\ + I32X4 cq1, cq2, cq3, cq4, cq5, cq6;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = VDUPQ_N_I32(0);\ + for (; k_left > 1; k_left -= 2) {\ + ad1 = VLD1_I16(a_ptr); ad2 = VLD1_I16(a_ptr + 4);\ + ad3 = VLD1_I16(a_ptr + 8); a_ptr += 12;\ + bd1 = VLD1_I16(b_ptr); bd2 = VLD1_I16(b_ptr + 4); b_ptr += 8;\ + cq1 = VMLAL_LANE_I16(cq1, bd1, ad1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, bd1, ad1, 1);\ + cq3 = VMLAL_LANE_I16(cq3, bd1, ad1, 2);\ + cq4 = VMLAL_LANE_I16(cq4, bd1, ad1, 3);\ + cq5 = VMLAL_LANE_I16(cq5, bd1, ad2, 0);\ + cq6 = VMLAL_LANE_I16(cq6, bd1, ad2, 1);\ + cq1 = VMLAL_LANE_I16(cq1, bd2, ad2, 2);\ + cq2 = VMLAL_LANE_I16(cq2, bd2, ad2, 3);\ + cq3 = VMLAL_LANE_I16(cq3, bd2, ad3, 0);\ + cq4 = VMLAL_LANE_I16(cq4, bd2, ad3, 1);\ + cq5 = VMLAL_LANE_I16(cq5, bd2, ad3, 2);\ + cq6 = VMLAL_LANE_I16(cq6, bd2, ad3, 3);\ + }\ + if (k_left > 0) {\ + ad1 = VLD1_I16(a_ptr);\ + I32 as1 = a_ptr[4];\ + I32 as2 = a_ptr[5]; a_ptr += 6;\ + bd1 = VLD1_I16(b_ptr); b_ptr += 4;\ + cq1 = VMLAL_LANE_I16(cq1, bd1, ad1, 0);\ + cq2 = VMLAL_LANE_I16(cq2, bd1, ad1, 1);\ + cq3 = VMLAL_LANE_I16(cq3, bd1, ad1, 2);\ + cq4 = VMLAL_LANE_I16(cq4, bd1, ad1, 3);\ + cq5 = VMLAL_N_I16(cq5, bd1, as1);\ + cq6 = VMLAL_N_I16(cq6, bd1, as2);\ + } + +#define KERNEL_M6N4 KERNEL_M6N4_UNIT(a_head, b_head) +#define KERNEL_M4N6 KERNEL_M6N4_UNIT(b_head, a_head) + +#define UNIT_SAVE_M6N4(cq1, cq2, cq3, cq4, cq5, cq6) \ + UNIT_SAVE_M4N4_TRANS(cq1, cq2, cq3, cq4)\ + c_tmp -= 4 * ldc;\ + c_tmp += 4;\ + UNIT_SAVE_M2N4(cq5, cq6)\ + c_tmp -= 4; + +#define SAVE_M6N4 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M6N4(cq1, cq2, cq3, cq4, cq5, cq6) + +#define SAVE_M4N6 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M4N2(cq1, cq2) UNIT_SAVE_M4N2(cq3, cq4) UNIT_SAVE_M4N2(cq5, cq6) + +#define NEON_IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(mdim, ndim, srcint, dstint) \ +static inline void\ + inline_dualpack_gemm_a##srcint##_b##srcint##_c##dstint##_m##mdim##_n##ndim(\ + const srcint *a_head, const srcint *b_head, dstint *c_ptr,\ + uint32_t K, dstint beta, uint32_t ldc) {\ + KERNEL_M##mdim##N##ndim\ + SAVE_M##mdim##N##ndim\ +} + +#define IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(mdim, ndim, srcint, dstint)\ + NEON_IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(mdim, ndim, srcint, dstint) + +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 1, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 2, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 1, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 2, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 4, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 4, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 1, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 2, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 4, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 8, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 8, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 8, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 1, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 2, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 4, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 8, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 6, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 6, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 6, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(6, 1, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(6, 2, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(6, 4, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 12, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 12, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 12, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 1, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 2, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 4, I16, I32) + +#endif diff --git a/include/arm_neon/NeonI8I32MlaGemmSkinnyDot.h b/include/arm_neon/NeonI8I32MlaGemmSkinnyDot.h new file mode 100644 index 0000000..d35a1fa --- /dev/null +++ b/include/arm_neon/NeonI8I32MlaGemmSkinnyDot.h @@ -0,0 +1,200 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: NeonI8I32MlaGemmSkinnyDot.h + * Description: Source code template for NEON mlal 8->32bit GEMM skinny dot + * kernels. + *****************************************************************************/ + +#include "common/CommonSkinnyDot.h" +#include "arm_neon/NeonIntOpSign.h" + +#ifndef INCLUDE_I8I32_MLA_SKINNYDOT +#define INCLUDE_I8I32_MLA_SKINNYDOT + +typedef I8 I8I32MLAGEMM_SKINNYDOT_ASCALAR; +typedef I8 I8I32MLAGEMM_SKINNYDOT_BSCALAR; +typedef I32 I8I32MLAGEMM_SKINNYDOT_CSCALAR; + +typedef I16 I8I32MLAGEMM_SKINNYDOT_AVEC1; +typedef I16 I8I32MLAGEMM_SKINNYDOT_BVEC1; +typedef I32 I8I32MLAGEMM_SKINNYDOT_CVEC1; + +typedef I32X2 I8I32MLAGEMM_SKINNYDOT_AVEC2; +typedef I32X2 I8I32MLAGEMM_SKINNYDOT_BVEC2; +typedef I32X2 I8I32MLAGEMM_SKINNYDOT_CVEC2; + +typedef I8X8 I8I32MLAGEMM_SKINNYDOT_AVEC4; +typedef I8X8 I8I32MLAGEMM_SKINNYDOT_BVEC4; +typedef I32X2 I8I32MLAGEMM_SKINNYDOT_CVEC4; + +typedef I8X8 I8I32MLAGEMM_SKINNYDOT_AVEC8; +typedef I8X8 I8I32MLAGEMM_SKINNYDOT_BVEC8; +typedef I32X4 I8I32MLAGEMM_SKINNYDOT_CVEC8; + +typedef I8X16 I8I32MLAGEMM_SKINNYDOT_AVEC16; +typedef I8X16 I8I32MLAGEMM_SKINNYDOT_BVEC16; +typedef I32X4X2 I8I32MLAGEMM_SKINNYDOT_CVEC16; + +#define GEMM_SKINNY_DOT_UNIT_DEDUCE(type, ...)\ + GEMM_SKINNY_DOT_##type##_UNIT(__VA_ARGS__) + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32MLAGEMM, 16) { + I16X8 low_product = VMULL_I8(VGET_LOW_I8(a_vec), VGET_LOW_I8(b_vec)); +#if __aarch64__ + I16X8 high_product = VMULL_HIGH_I8(a_vec, b_vec); +#else + I16X8 high_product = VMULL_I8(VGET_HIGH_I8(a_vec), VGET_HIGH_I8(b_vec)); +#endif + I32X4X2 ret; + ret.val[0] = VPADALQ_I16(c_vec.val[0], low_product); + ret.val[1] = VPADALQ_I16(c_vec.val[1], high_product); + return ret; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32MLAGEMM, 8) { + I16X8 product = VMULL_I8(a_vec, b_vec); + return VPADALQ_I16(c_vec, product); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32MLAGEMM, 4) { + I16X8 product = VMULL_I8(a_vec, b_vec); + return VPADAL_I16(c_vec, VGET_LOW_I16(product)); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32MLAGEMM, 2) { + return VMLA_I32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(CALC, I8I32MLAGEMM, 1) { + return c_vec + a_vec * b_vec; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32MLAGEMM, 16) { +#if __aarch64__ + __asm__("prfm pldl1keep,[%0,#80]"::"r"(a_ptr):); +#else + __asm__("pld [%0,#80]"::"r"(a_ptr):); +#endif + return VLD1Q_I8(a_ptr); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32MLAGEMM, 8) { +#if __aarch64__ + __asm__("prfm pldl1keep,[%0,#72]"::"r"(a_ptr):); +#else + __asm__("pld [%0,#72]"::"r"(a_ptr):); +#endif + return VLD1_I8(a_ptr); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32MLAGEMM, 4) { +#if __aarch64__ + I8X8 ret; /* higher 4 elements not used */ + __asm__("ldr %s0,[%1]; prfm pldl1keep,[%1,#72]":"=w"(ret):"r"(a_ptr):"memory"); + return ret; +#else + register I8X16 ret __asm("q0"); /* higher 12 elements not used */ + __asm__("vld1.32 {%e0[0]},[%1]; pld [%1,#72]":"=w"(ret):"r"(a_ptr):"memory"); + return VGET_LOW_I8(ret); +#endif +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32MLAGEMM, 2) { + I32 lo = a_ptr[0]; + I32 hi = a_ptr[1]; + return (I32X2){lo, hi}; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADA, I8I32MLAGEMM, 1) { + return *a_ptr; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32MLAGEMM, 16) { + return VLD1Q_I8(b_ptr); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32MLAGEMM, 8) { + return VLD1_I8(b_ptr); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32MLAGEMM, 4) { +#if __aarch64__ + I8X8 ret; /* higher 4 elements not used */ + __asm__("ldr %s0,[%1]":"=w"(ret):"r"(b_ptr):"memory"); + return ret; +#else +/* armeabi-gcc is always buggy. It always put a 64-bit wide + * neon variable into s* register ! + * here to use 128-bit wide neon variable to avoid this bug */ + register I8X16 ret __asm("q0"); /* higher 12 elements not used */ + __asm__("vld1.32 {%e0[0]},[%1]":"=w"(ret):"r"(b_ptr):"memory"); + return VGET_LOW_I8(ret); +#endif +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32MLAGEMM, 2) { + I32 lo = b_ptr[0]; + I32 hi = b_ptr[1]; + return (I32X2){lo, hi}; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(LOADB, I8I32MLAGEMM, 1) { + return *b_ptr; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(REDUC, I8I32MLAGEMM, 16, 8) { + return VADDQ_I32(c_vec.val[0], c_vec.val[1]); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(REDUC, I8I32MLAGEMM, 8, 4) { + return VADD_I32(VGET_LOW_I32(c_vec), VGET_HIGH_I32(c_vec)); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(REDUC, I8I32MLAGEMM, 4, 2) { + return c_vec; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(REDUC, I8I32MLAGEMM, 2, 1) { + return VGET_LANE_I32(c_vec, 0) + VGET_LANE_I32(c_vec, 1); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32MLAGEMM, 16) { + I32X4X2 ret; + ret.val[0] = VDUPQ_N_I32(0); + ret.val[1] = VDUPQ_N_I32(0); + return ret; +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32MLAGEMM, 8) { + return VDUPQ_N_I32(0); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32MLAGEMM, 4) { + return VDUP_N_I32(0); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32MLAGEMM, 2) { + return VDUP_N_I32(0); +} + +GEMM_SKINNY_DOT_UNIT_DEDUCE(INITC, I8I32MLAGEMM, 1) { + return 0; +} + +#endif \ No newline at end of file diff --git a/include/arm_neon/NeonI8I32MlaGemmSkinnyGer.h b/include/arm_neon/NeonI8I32MlaGemmSkinnyGer.h new file mode 100644 index 0000000..dd21ed6 --- /dev/null +++ b/include/arm_neon/NeonI8I32MlaGemmSkinnyGer.h @@ -0,0 +1,317 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: NeonI8I32MlaGemmSkinnyGer.h + * Description: Source code template for NEON mlal 8->32bit GEMM skinny ger + * kernels. + *****************************************************************************/ + +#include "common/CommonSkinnyGer.h" +#include "arm_neon/NeonIntOpSign.h" + +#ifndef INCLUDE_I8I32_MLA_SKINNYGER +#define INCLUDE_I8I32_MLA_SKINNYGER + +typedef I8 I8I32MLAGEMM_SKINNYGER_ASCALAR; +typedef I8 I8I32MLAGEMM_SKINNYGER_BSCALAR; +typedef I32 I8I32MLAGEMM_SKINNYGER_CSCALAR; + +typedef I16 I8I32MLAGEMM_SKINNYGER_AVEC1; +typedef I16 I8I32MLAGEMM_SKINNYGER_BVEC1; +typedef I32 I8I32MLAGEMM_SKINNYGER_CVEC1; + +typedef I16X4 I8I32MLAGEMM_SKINNYGER_AVEC4; +typedef I16X4 I8I32MLAGEMM_SKINNYGER_BVEC4; +typedef I32X4 I8I32MLAGEMM_SKINNYGER_CVEC4; + +typedef I16X8 I8I32MLAGEMM_SKINNYGER_AVEC8; +typedef I32X4X2 I8I32MLAGEMM_SKINNYGER_CVEC8; + +typedef I16X8X2 I8I32MLAGEMM_SKINNYGER_AVEC16; +typedef I32X4X4 I8I32MLAGEMM_SKINNYGER_CVEC16; + +#if !__aarch64__ +#ifdef VMLAL_HIGH_LANE_I16 +#undef VMLAL_HIGH_LANE_I16 +#endif +#ifdef VMLAL_HIGH_N_I16 +#undef VMLAL_HIGH_N_I16 +#endif +#define VMLAL_HIGH_LANE_I16(c, a, b, v) VMLAL_LANE_I16(c, VGET_HIGH_I16(a), b, v) +#define VMLAL_HIGH_N_I16(c, a, b) VMLAL_N_I16(c, VGET_HIGH_I16(a), b) +#endif +#define VMLAL_LOW_LANE_I16(c, a, b, v) VMLAL_LANE_I16(c, VGET_LOW_I16(a), b, v) +#define VMLAL_LOW_N_I16(c, a, b) VMLAL_N_I16(c, VGET_LOW_I16(a), b) + +#define GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(type, a, b, c)\ + GEMM_SKINNY_GER_CALC_UNIT(type, a, b, c) + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 16, 4, 1) { + I32X4X4 ret; + ret.val[0] = VMLAL_LOW_LANE_I16(c_vec.val[0], a_vec.val[0], b_vec, 0); + ret.val[1] = VMLAL_HIGH_LANE_I16(c_vec.val[1], a_vec.val[0], b_vec, 0); + ret.val[2] = VMLAL_LOW_LANE_I16(c_vec.val[2], a_vec.val[1], b_vec, 0); + ret.val[3] = VMLAL_HIGH_LANE_I16(c_vec.val[3], a_vec.val[1], b_vec, 0); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 16, 4, 2) { + I32X4X4 ret; + ret.val[0] = VMLAL_LOW_LANE_I16(c_vec.val[0], a_vec.val[0], b_vec, 1); + ret.val[1] = VMLAL_HIGH_LANE_I16(c_vec.val[1], a_vec.val[0], b_vec, 1); + ret.val[2] = VMLAL_LOW_LANE_I16(c_vec.val[2], a_vec.val[1], b_vec, 1); + ret.val[3] = VMLAL_HIGH_LANE_I16(c_vec.val[3], a_vec.val[1], b_vec, 1); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 16, 4, 3) { + I32X4X4 ret; + ret.val[0] = VMLAL_LOW_LANE_I16(c_vec.val[0], a_vec.val[0], b_vec, 2); + ret.val[1] = VMLAL_HIGH_LANE_I16(c_vec.val[1], a_vec.val[0], b_vec, 2); + ret.val[2] = VMLAL_LOW_LANE_I16(c_vec.val[2], a_vec.val[1], b_vec, 2); + ret.val[3] = VMLAL_HIGH_LANE_I16(c_vec.val[3], a_vec.val[1], b_vec, 2); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 16, 4, 4) { + I32X4X4 ret; + ret.val[0] = VMLAL_LOW_LANE_I16(c_vec.val[0], a_vec.val[0], b_vec, 3); + ret.val[1] = VMLAL_HIGH_LANE_I16(c_vec.val[1], a_vec.val[0], b_vec, 3); + ret.val[2] = VMLAL_LOW_LANE_I16(c_vec.val[2], a_vec.val[1], b_vec, 3); + ret.val[3] = VMLAL_HIGH_LANE_I16(c_vec.val[3], a_vec.val[1], b_vec, 3); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 16, 1, 1) { + I32X4X4 ret; + ret.val[0] = VMLAL_LOW_N_I16(c_vec.val[0], a_vec.val[0], b_vec); + ret.val[1] = VMLAL_HIGH_N_I16(c_vec.val[1], a_vec.val[0], b_vec); + ret.val[2] = VMLAL_LOW_N_I16(c_vec.val[2], a_vec.val[1], b_vec); + ret.val[3] = VMLAL_HIGH_N_I16(c_vec.val[3], a_vec.val[1], b_vec); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 8, 4, 1) { + I32X4X2 ret; + ret.val[0] = VMLAL_LOW_LANE_I16(c_vec.val[0], a_vec, b_vec, 0); + ret.val[1] = VMLAL_HIGH_LANE_I16(c_vec.val[1], a_vec, b_vec, 0); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 8, 4, 2) { + I32X4X2 ret; + ret.val[0] = VMLAL_LOW_LANE_I16(c_vec.val[0], a_vec, b_vec, 1); + ret.val[1] = VMLAL_HIGH_LANE_I16(c_vec.val[1], a_vec, b_vec, 1); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 8, 4, 3) { + I32X4X2 ret; + ret.val[0] = VMLAL_LOW_LANE_I16(c_vec.val[0], a_vec, b_vec, 2); + ret.val[1] = VMLAL_HIGH_LANE_I16(c_vec.val[1], a_vec, b_vec, 2); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 8, 4, 4) { + I32X4X2 ret; + ret.val[0] = VMLAL_LOW_LANE_I16(c_vec.val[0], a_vec, b_vec, 3); + ret.val[1] = VMLAL_HIGH_LANE_I16(c_vec.val[1], a_vec, b_vec, 3); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 8, 1, 1) { + I32X4X2 ret; + ret.val[0] = VMLAL_LOW_N_I16(c_vec.val[0], a_vec, b_vec); + ret.val[1] = VMLAL_HIGH_N_I16(c_vec.val[1], a_vec, b_vec); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 4, 4, 1) { + return VMLAL_LANE_I16(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 4, 4, 2) { + return VMLAL_LANE_I16(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 4, 4, 3) { + return VMLAL_LANE_I16(c_vec, a_vec, b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 4, 4, 4) { + return VMLAL_LANE_I16(c_vec, a_vec, b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 4, 1, 1) { + return VMLAL_N_I16(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 1, 4, 1) { + return c_vec + a_vec * VGET_LANE_I16(b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 1, 4, 2) { + return c_vec + a_vec * VGET_LANE_I16(b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 1, 4, 3) { + return c_vec + a_vec * VGET_LANE_I16(b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 1, 4, 4) { + return c_vec + a_vec * VGET_LANE_I16(b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT_DEDUCE(I8I32MLAGEMM, 1, 1, 1) { + return c_vec + a_vec * b_vec; +} + +#define GEMM_SKINNY_GER_LOADA_UNIT_DEDUCE(type, a)\ + GEMM_SKINNY_GER_LOADA_UNIT(type, a) + +GEMM_SKINNY_GER_LOADA_UNIT_DEDUCE(I8I32MLAGEMM, 16) { + I8X16 ld = VLD1Q_I8(a_ptr); + I16X8X2 ret; + ret.val[0] = VMOVL_I8(VGET_LOW_I8(ld)); +#if __aarch64__ + ret.val[1] = VMOVL_HIGH_I8(ld); + __asm__("prfm pldl1keep,[%0,#80]"::"r"(a_ptr):); +#else + ret.val[1] = VMOVL_I8(VGET_HIGH_I8(ld)); + __asm__("pld [%0,#80]"::"r"(a_ptr):); +#endif + return ret; +} + +GEMM_SKINNY_GER_LOADA_UNIT_DEDUCE(I8I32MLAGEMM, 8) { + I8X8 t1 = VLD1_I8(a_ptr); +#if __aarch64__ + __asm__("prfm pldl1keep,[%0,#72]"::"r"(a_ptr):); +#else + __asm__("pld [%0,#72]"::"r"(a_ptr):); +#endif + return VMOVL_I8(t1); +} + +GEMM_SKINNY_GER_LOADA_UNIT_DEDUCE(I8I32MLAGEMM, 4) { +#if __aarch64__ + I16X4 ret; + __asm__("ldr %s0,[%1]; "ISHLL" %0.8h,%0.8b,#0; prfm pldl1keep,[%1,#72]\n\t" + :"=w"(ret):"r"(a_ptr):"memory","cc"); + return ret; +#else + I16X8 ret; + __asm__("vld1.32 {d0[0]},[%1]; "ASM_VMOVL_I8" %q0,d0; pld [%1,#68]\n\t" + :"=w"(ret):"r"(a_ptr):"memory","cc","d0"); + return VGET_LOW_I16(ret); +#endif +} + +GEMM_SKINNY_GER_LOADA_UNIT_DEDUCE(I8I32MLAGEMM, 1) { + return *a_ptr; +} + +#define GEMM_SKINNY_GER_LOADC_UNIT_DEDUCE(type, a)\ + GEMM_SKINNY_GER_LOADC_UNIT(type, a) + +GEMM_SKINNY_GER_LOADC_UNIT_DEDUCE(I8I32MLAGEMM, 16) { + I32X4X4 ret; + ret.val[0] = VLD1Q_I32(c_ptr); + ret.val[1] = VLD1Q_I32(c_ptr + 4); + ret.val[2] = VLD1Q_I32(c_ptr + 8); + ret.val[3] = VLD1Q_I32(c_ptr + 12); + return ret; +} + +GEMM_SKINNY_GER_LOADC_UNIT_DEDUCE(I8I32MLAGEMM, 8) { + I32X4X2 ret; + ret.val[0] = VLD1Q_I32(c_ptr); + ret.val[1] = VLD1Q_I32(c_ptr + 4); + return ret; +} + +GEMM_SKINNY_GER_LOADC_UNIT_DEDUCE(I8I32MLAGEMM, 4) { + return VLD1Q_I32(c_ptr); +} + +GEMM_SKINNY_GER_LOADC_UNIT_DEDUCE(I8I32MLAGEMM, 1) { + return *c_ptr; +} + +#define GEMM_SKINNY_GER_STOREC_UNIT_DEDUCE(type, c)\ + GEMM_SKINNY_GER_STOREC_UNIT(type, c) + +GEMM_SKINNY_GER_STOREC_UNIT_DEDUCE(I8I32MLAGEMM, 16) { + VST1Q_I32(c_ptr, c_vec.val[0]); + VST1Q_I32(c_ptr + 4, c_vec.val[1]); + VST1Q_I32(c_ptr + 8, c_vec.val[2]); + VST1Q_I32(c_ptr + 12, c_vec.val[3]); +} + +GEMM_SKINNY_GER_STOREC_UNIT_DEDUCE(I8I32MLAGEMM, 8) { + VST1Q_I32(c_ptr, c_vec.val[0]); + VST1Q_I32(c_ptr + 4, c_vec.val[1]); +} + +GEMM_SKINNY_GER_STOREC_UNIT_DEDUCE(I8I32MLAGEMM, 4) { + VST1Q_I32(c_ptr, c_vec); +} + +GEMM_SKINNY_GER_STOREC_UNIT_DEDUCE(I8I32MLAGEMM, 1) { + *c_ptr = c_vec; +} + +#define GEMM_SKINNY_GER_LOADB_UNIT_DEDUCE(mode, type, b)\ + GEMM_SKINNY_GER_LOADB_UNIT_##mode(type, b) + +GEMM_SKINNY_GER_LOADB_UNIT_DEDUCE(BROWMAJOR, I8I32MLAGEMM, 4) { + I16X4 ret = VDUP_N_I16(0); + I16 r1 = *b_ptr; b_ptr += ldb; + I16 r2 = *b_ptr; b_ptr += ldb; + I16 r3 = *b_ptr; b_ptr += ldb; + I16 r4 = *b_ptr; + ret = VSET_LANE_I16(r1, ret, 0); + ret = VSET_LANE_I16(r2, ret, 1); + ret = VSET_LANE_I16(r3, ret, 2); + ret = VSET_LANE_I16(r4, ret, 3); + return ret; +} + +GEMM_SKINNY_GER_LOADB_UNIT_DEDUCE(BROWMAJOR, I8I32MLAGEMM, 1) { + return *b_ptr; +} + +GEMM_SKINNY_GER_LOADB_UNIT_DEDUCE(BCOLMAJOR, I8I32MLAGEMM, 4) { +#if __aarch64__ + I16X4 ret; + __asm__("ldr %s0,[%1]; "ISHLL" %0.8h,%0.8b,#0\n\t" + :"=w"(ret):"r"(b_ptr):"memory","cc"); + return ret; +#else + I16X8 ret; + __asm__("vld1.32 {d0[0]},[%1]; "ASM_VMOVL_I8" %q0,d0\n\t" + :"=w"(ret):"r"(b_ptr):"memory","cc","d0"); + return VGET_LOW_I16(ret); +#endif +} + +GEMM_SKINNY_GER_LOADB_UNIT_DEDUCE(BCOLMAJOR, I8I32MLAGEMM, 1) { + return *b_ptr; +} + +#endif diff --git a/include/arm_neon/NeonIntOpSign.h b/include/arm_neon/NeonIntOpSign.h new file mode 100644 index 0000000..62339a6 --- /dev/null +++ b/include/arm_neon/NeonIntOpSign.h @@ -0,0 +1,441 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: NeonIntOpSign.h + * Description: Sign-irrelevant representations of NEON intrinsics and + * ASMs for integer operations involved in 8->32bit GEMM. + * With a macro as (signed/unsigned) switch, these + * representations are converted to the corresponding + * type of intrinsics/ASMs. + *****************************************************************************/ + +#include +#include + +#ifndef INCLUDE_NEON_INTEGER_SIGN +#define INCLUDE_NEON_INTEGER_SIGN + +#ifdef GEMM_UNSIGNED_INT +#define I8 uint8_t +#define I16 uint16_t +#define I32 uint32_t +#define I8X8 uint8x8_t +#define I8X16 uint8x16_t +#define I16X4 uint16x4_t +#define I16X8 uint16x8_t +#define I32X2 uint32x2_t +#define I32X4 uint32x4_t +#define I64X2 uint64x2_t +#define I8X8X2 uint8x8x2_t +#define I8X16X2 uint8x16x2_t +#define I16X4X2 uint16x4x2_t +#define I16X8X2 uint16x8x2_t +#define I32X2X2 uint32x2x2_t +#define I32X4X2 uint32x4x2_t +#define I8X8X3 uint8x8x3_t +#define I8X16X3 uint8x16x3_t +#define I16X4X3 uint16x4x3_t +#define I16X8X3 uint16x8x3_t +#define I32X2X3 uint32x2x3_t +#define I32X4X3 uint32x4x3_t +#define I8X8X4 uint8x8x4_t +#define I8X16X4 uint8x16x4_t +#define I16X4X4 uint16x4x4_t +#define I16X8X4 uint16x8x4_t +#define I32X2X4 uint32x2x4_t +#define I32X4X4 uint32x4x4_t +#else +#define I8 int8_t +#define I16 int16_t +#define I32 int32_t +#define I8X8 int8x8_t +#define I8X16 int8x16_t +#define I16X4 int16x4_t +#define I16X8 int16x8_t +#define I32X2 int32x2_t +#define I32X4 int32x4_t +#define I64X2 int64x2_t +#define I8X8X2 int8x8x2_t +#define I8X16X2 int8x16x2_t +#define I16X4X2 int16x4x2_t +#define I16X8X2 int16x8x2_t +#define I32X2X2 int32x2x2_t +#define I32X4X2 int32x4x2_t +#define I8X8X3 int8x8x3_t +#define I8X16X3 int8x16x3_t +#define I16X4X3 int16x4x3_t +#define I16X8X3 int16x8x3_t +#define I32X2X3 int32x2x3_t +#define I32X4X3 int32x4x3_t +#define I8X8X4 int8x8x4_t +#define I8X16X4 int8x16x4_t +#define I16X4X4 int16x4x4_t +#define I16X8X4 int16x8x4_t +#define I32X2X4 int32x2x4_t +#define I32X4X4 int32x4x4_t +#endif + +/* asm instruction switch */ +#if __aarch64__ +#ifdef GEMM_UNSIGNED_INT +#define IMLAL "umlal" +#define IMLAL2 "umlal2" +#define ISHLL "ushll" +#define ISHLL2 "ushll2" +#define IXTL "uxtl" +#define IADALP "uadalp" +#define IMULL "umull" +#define IMULL2 "umull2" +#define IDOT "udot" +#else +#define IMLAL "smlal" +#define IMLAL2 "smlal2" +#define ISHLL "sshll" +#define ISHLL2 "sshll2" +#define IXTL "sxtl" +#define IADALP "sadalp" +#define IMULL "smull" +#define IMULL2 "smull2" +#define IDOT "sdot" +#endif +#else //armv7a +#ifdef GEMM_UNSIGNED_INT +#define ASM_VMLAL_I16 "vmlal.u16" +#define ASM_VMOVL_I8 "vmovl.u8" +#define ASM_VPADAL_I16 "vpadal.u16" +#define ASM_VMULL_I8 "vmull.u8" +#else +#define ASM_VMLAL_I16 "vmlal.s16" +#define ASM_VMOVL_I8 "vmovl.s8" +#define ASM_VPADAL_I16 "vpadal.s16" +#define ASM_VMULL_I8 "vmull.s8" +#endif +#endif + +/* intrinsic function switch */ +#ifdef GEMM_UNSIGNED_INT +#define VMLAQ_N_I32(a, b, c) vmlaq_n_u32(a, b, c) +#define VMLA_N_I32(a, b, c) vmla_n_u32(a, b, c) +#define VMLA_I32(a, b, c) vmla_u32(a, b, c) +#define VMLA_LANE_I32(a, b, c, d) vmla_lane_u32(a, b, c, d) +#define VLD1_I8(a) vld1_u8(a) +#define VLD1Q_I8(a) vld1q_u8(a) +#define VLD1_I16(a) vld1_u16(a) +#define VLD1Q_I16(a) vld1q_u16(a) +#define VLD1_I32(a) vld1_u32(a) +#define VLD1Q_I32(a) vld1q_u32(a) +#define VLD3Q_I32(a) vld3q_u32(a) +#define VMOVL_I8(a) vmovl_u8(a) +#define VMOVL_HIGH_I8(a) vmovl_high_u8(a) +#define VST1_I32(a, b) vst1_u32(a, b) +#define VST2_I32(a, b) vst2_u32(a, b) +#define VST3_I32(a, b) vst3_u32(a, b) +#define VST4_I32(a, b) vst4_u32(a, b) +#define VST1Q_I32(a, b) vst1q_u32(a, b) +#define VST2Q_I32(a, b) vst2q_u32(a, b) +#define VST3Q_I32(a, b) vst3q_u32(a, b) +#define VST4Q_I32(a, b) vst4q_u32(a, b) +#define VST1_LANE_I32(a, b, c) vst1_lane_u32(a, b, c) +#define VST1Q_LANE_I32(a, b, c) vst1q_lane_u32(a, b, c) +#define VST2_LANE_I32(a, b, c) vst2_lane_u32(a, b, c) +#define VST2Q_LANE_I32(a, b, c) vst2q_lane_u32(a, b, c) +#define VST3_LANE_I32(a, b, c) vst3_lane_u32(a, b, c) +#define VST3Q_LANE_I32(a, b, c) vst3q_lane_u32(a, b, c) +#define VST4_LANE_I32(a, b, c) vst4_lane_u32(a, b, c) +#define VST4Q_LANE_I32(a, b, c) vst4q_lane_u32(a, b, c) +#define VST1_I16(a, b) vst1_u16(a, b) +#define VST1Q_I16(a, b) vst1q_u16(a, b) +#define VST1_LANE_I16(a, b, c) vst1_lane_u16(a, b, c) +#define VST1Q_LANE_I16(a, b, c) vst1q_lane_u16(a, b, c) +#define VST2_LANE_I16(a, b, c) vst2_lane_u16(a, b, c) +#define VST2Q_LANE_I16(a, b, c) vst2q_lane_u16(a, b, c) +#define VST3_LANE_I16(a, b, c) vst3_lane_u16(a, b, c) +#define VST3Q_LANE_I16(a, b, c) vst3q_lane_u16(a, b, c) +#define VST4_LANE_I16(a, b, c) vst4_lane_u16(a, b, c) +#define VST4Q_LANE_I16(a, b, c) vst4q_lane_u16(a, b, c) +#define VMLAL_LANE_I16(a, b, c, d) vmlal_lane_u16(a, b, c, d) +#define VMLAL_HIGH_LANE_I16(a, b, c, d) vmlal_high_lane_u16(a, b, c, d) +#define VMLAL_N_I16(a, b, c) vmlal_n_u16(a, b, c) +#define VMLAL_HIGH_N_I16(a, b, c) vmlal_high_n_u16(a, b, c) +#define VMLAL_I16(a, b, c) vmlal_u16(a, b, c) +#define VPADAL_I16(a, b) vpadal_u16(a, b) +#define VPADALQ_I16(a, b) vpadalq_u16(a, b) +#define VPADD_I32(a, b) vpadd_u32(a, b) +#define VMULL_I8(a, b) vmull_u8(a, b) +#define VMULL_HIGH_I8(a, b) vmull_high_u8(a, b) +#define VGET_LOW_I8(a) vget_low_u8(a) +#define VGET_HIGH_I8(a) vget_high_u8(a) +#define VGET_LOW_I16(a) vget_low_u16(a) +#define VGET_HIGH_I16(a) vget_high_u16(a) +#define VGET_LANE_I16(a, b) vget_lane_u16(a, b) +#define VGETQ_LANE_I16(a, b) vgetq_lane_u16(a, b) +#define VGET_LOW_I32(a) vget_low_u32(a) +#define VGET_HIGH_I32(a) vget_high_u32(a) +#define VGET_LANE_I32(a, b) vget_lane_u32(a, b) +#define VGETQ_LANE_I32(a, b) vgetq_lane_u32(a, b) +#define VDUP_N_I32(a) vdup_n_u32(a) +#define VDUPQ_N_I32(a) vdupq_n_u32(a) +#define VDUP_N_I16(a) vdup_n_u16(a) +#define VDUPQ_N_I16(a) vdupq_n_u16(a) +#define VDUP_N_I8(a) vdup_n_u8(a) +#define VDUPQ_N_I8(a) vdupq_n_u8(a) +#define VSET_LANE_I32(a, b, c) vset_lane_u32(a, b, c) +#define VSETQ_LANE_I32(a, b, c) vsetq_lane_u32(a, b, c) +#define VSET_LANE_I16(a, b, c) vset_lane_u16(a, b, c) +#define VSETQ_LANE_I16(a, b, c) vsetq_lane_u16(a, b, c) +#define VZIP_I16(a, b) vzip_u16(a, b) +#define VUZP_I16(a, b) vuzp_u16(a, b) +#define VZIPQ_I32(a, b) vzipq_u32(a, b) +#define VZIP1_I32(a, b) vzip1_u32(a, b) +#define VZIP2_I32(a, b) vzip2_u32(a, b) +#define VZIP1Q_I32(a, b) vzip1q_u32(a, b) +#define VZIP2Q_I32(a, b) vzip2q_u32(a, b) +#define VZIP1Q_I64(a, b) vzip1q_u64(a, b) +#define VZIP2Q_I64(a, b) vzip2q_u64(a, b) +#define VTRN_I16(a, b) vtrn_u16(a, b) +#define VADD_I32(a, b) vadd_u32(a, b) +#define VADDQ_I32(a, b) vaddq_u32(a, b) +#define VCOMBINE_I32(a, b) vcombine_u32(a, b) +#define VDOT_I32(a, b, c) vdot_u32(a, b, c) +#define VDOT_LANE_I32(a, b, c, d) vdot_lane_u32(a, b, c, d) +#define VDOTQ_I32(a, b, c) vdotq_u32(a, b, c) +#define VDOTQ_LANE_I32(a, b, c, d) vdotq_lane_u32(a, b, c, d) +#define VDOTQ_LANEQ_I32(a, b, c, d) vdotq_laneq_u32(a, b, c, d) +#define VREINTERPRETQ_I32_I64(a) vreinterpretq_u32_u64(a) +#define VREINTERPRETQ_I64_I32(a) vreinterpretq_u64_u32(a) +#define VREINTERPRET_I8_I32(a) vreinterpret_u8_u32(a) +#define VREINTERPRETQ_I8_I32(a) vreinterpretq_u8_u32(a) +#define VREINTERPRET_I32_I8(a) vreinterpret_u32_u8(a) +#define VREINTERPRETQ_I32_I8(a) vreinterpretq_u32_u8(a) +#else +#define VMLAQ_N_I32(a, b, c) vmlaq_n_s32(a, b, c) +#define VMLA_N_I32(a, b, c) vmla_n_s32(a, b, c) +#define VMLA_I32(a, b, c) vmla_s32(a, b, c) +#define VMLA_LANE_I32(a, b, c, d) vmla_lane_s32(a, b, c, d) +#define VLD1_I8(a) vld1_s8(a) +#define VLD1Q_I8(a) vld1q_s8(a) +#define VLD1_I16(a) vld1_s16(a) +#define VLD1Q_I16(a) vld1q_s16(a) +#define VLD1_I32(a) vld1_s32(a) +#define VLD1Q_I32(a) vld1q_s32(a) +#define VLD3Q_I32(a) vld3q_s32(a) +#define VMOVL_I8(a) vmovl_s8(a) +#define VMOVL_HIGH_I8(a) vmovl_high_s8(a) +#define VST1_I32(a, b) vst1_s32(a, b) +#define VST2_I32(a, b) vst2_s32(a, b) +#define VST3_I32(a, b) vst3_s32(a, b) +#define VST4_I32(a, b) vst4_s32(a, b) +#define VST1Q_I32(a, b) vst1q_s32(a, b) +#define VST2Q_I32(a, b) vst2q_s32(a, b) +#define VST3Q_I32(a, b) vst3q_s32(a, b) +#define VST4Q_I32(a, b) vst4q_s32(a, b) +#define VST1_LANE_I32(a, b, c) vst1_lane_s32(a, b, c) +#define VST1Q_LANE_I32(a, b, c) vst1q_lane_s32(a, b, c) +#define VST2_LANE_I32(a, b, c) vst2_lane_s32(a, b, c) +#define VST2Q_LANE_I32(a, b, c) vst2q_lane_s32(a, b, c) +#define VST3_LANE_I32(a, b, c) vst3_lane_s32(a, b, c) +#define VST3Q_LANE_I32(a, b, c) vst3q_lane_s32(a, b, c) +#define VST4_LANE_I32(a, b, c) vst4_lane_s32(a, b, c) +#define VST4Q_LANE_I32(a, b, c) vst4q_lane_s32(a, b, c) +#define VST1_I16(a, b) vst1_s16(a, b) +#define VST1Q_I16(a, b) vst1q_s16(a, b) +#define VST1_LANE_I16(a, b, c) vst1_lane_s16(a, b, c) +#define VST1Q_LANE_I16(a, b, c) vst1q_lane_s16(a, b, c) +#define VST2_LANE_I16(a, b, c) vst2_lane_s16(a, b, c) +#define VST2Q_LANE_I16(a, b, c) vst2q_lane_s16(a, b, c) +#define VST3_LANE_I16(a, b, c) vst3_lane_s16(a, b, c) +#define VST3Q_LANE_I16(a, b, c) vst3q_lane_s16(a, b, c) +#define VST4_LANE_I16(a, b, c) vst4_lane_s16(a, b, c) +#define VST4Q_LANE_I16(a, b, c) vst4q_lane_s16(a, b, c) +#define VMLAL_LANE_I16(a, b, c, d) vmlal_lane_s16(a, b, c, d) +#define VMLAL_HIGH_LANE_I16(a, b, c, d) vmlal_high_lane_s16(a, b, c, d) +#define VMLAL_N_I16(a, b, c) vmlal_n_s16(a, b, c) +#define VMLAL_HIGH_N_I16(a, b, c) vmlal_high_n_s16(a, b, c) +#define VMLAL_I16(a, b, c) vmlal_s16(a, b, c) +#define VPADAL_I16(a, b) vpadal_s16(a, b) +#define VPADALQ_I16(a, b) vpadalq_s16(a, b) +#define VPADD_I32(a, b) vpadd_s32(a, b) +#define VMULL_I8(a, b) vmull_s8(a, b) +#define VMULL_HIGH_I8(a, b) vmull_high_s8(a, b) +#define VGET_LOW_I8(a) vget_low_s8(a) +#define VGET_HIGH_I8(a) vget_high_s8(a) +#define VGET_LOW_I16(a) vget_low_s16(a) +#define VGET_HIGH_I16(a) vget_high_s16(a) +#define VGET_LANE_I16(a, b) vget_lane_s16(a, b) +#define VGETQ_LANE_I16(a, b) vgetq_lane_s16(a, b) +#define VGET_LOW_I32(a) vget_low_s32(a) +#define VGET_HIGH_I32(a) vget_high_s32(a) +#define VGET_LANE_I32(a, b) vget_lane_s32(a, b) +#define VGETQ_LANE_I32(a, b) vgetq_lane_s32(a, b) +#define VDUP_N_I32(a) vdup_n_s32(a) +#define VDUPQ_N_I32(a) vdupq_n_s32(a) +#define VDUP_N_I16(a) vdup_n_s16(a) +#define VDUPQ_N_I16(a) vdupq_n_s16(a) +#define VDUP_N_I8(a) vdup_n_s8(a) +#define VDUPQ_N_I8(a) vdupq_n_s8(a) +#define VSET_LANE_I32(a, b, c) vset_lane_s32(a, b, c) +#define VSETQ_LANE_I32(a, b, c) vsetq_lane_s32(a, b, c) +#define VSET_LANE_I16(a, b, c) vset_lane_s16(a, b, c) +#define VSETQ_LANE_I16(a, b, c) vsetq_lane_s16(a, b, c) +#define VZIP_I16(a, b) vzip_s16(a, b) +#define VUZP_I16(a, b) vuzp_s16(a, b) +#define VZIPQ_I32(a, b) vzipq_s32(a, b) +#define VZIP1_I32(a, b) vzip1_s32(a, b) +#define VZIP2_I32(a, b) vzip2_s32(a, b) +#define VZIP1Q_I32(a, b) vzip1q_s32(a, b) +#define VZIP2Q_I32(a, b) vzip2q_s32(a, b) +#define VZIP1Q_I64(a, b) vzip1q_s64(a, b) +#define VZIP2Q_I64(a, b) vzip2q_s64(a, b) +#define VTRN_I16(a, b) vtrn_s16(a, b) +#define VADD_I32(a, b) vadd_s32(a, b) +#define VADDQ_I32(a, b) vaddq_s32(a, b) +#define VCOMBINE_I32(a, b) vcombine_s32(a, b) +#define VDOT_I32(a, b, c) vdot_s32(a, b, c) +#define VDOT_LANE_I32(a, b, c, d) vdot_lane_s32(a, b, c, d) +#define VDOTQ_I32(a, b, c) vdotq_s32(a, b, c) +#define VDOTQ_LANE_I32(a, b, c, d) vdotq_lane_s32(a, b, c, d) +#define VDOTQ_LANEQ_I32(a, b, c, d) vdotq_laneq_s32(a, b, c, d) +#define VREINTERPRETQ_I32_I64(a) vreinterpretq_s32_s64(a) +#define VREINTERPRETQ_I64_I32(a) vreinterpretq_s64_s32(a) +#define VREINTERPRET_I8_I32(a) vreinterpret_s8_s32(a) +#define VREINTERPRETQ_I8_I32(a) vreinterpretq_s8_s32(a) +#define VREINTERPRET_I32_I8(a) vreinterpret_s32_s8(a) +#define VREINTERPRETQ_I32_I8(a) vreinterpretq_s32_s8(a) +#endif + +#ifndef GEMM_UNSIGNED_INT +#define I8I32MLAGEMM s8s32mlagemm +#define I8I32MLAGEMM_SKINNYGER_ASCALAR s8s32mlagemm_skinnyger_ascalar +#define I8I32MLAGEMM_SKINNYGER_BSCALAR s8s32mlagemm_skinnyger_bscalar +#define I8I32MLAGEMM_SKINNYGER_CSCALAR s8s32mlagemm_skinnyger_cscalar +#define I8I32MLAGEMM_SKINNYGER_AVEC1 s8s32mlagemm_skinnyger_avec1 +#define I8I32MLAGEMM_SKINNYGER_BVEC1 s8s32mlagemm_skinnyger_bvec1 +#define I8I32MLAGEMM_SKINNYGER_CVEC1 s8s32mlagemm_skinnyger_cvec1 +#define I8I32MLAGEMM_SKINNYGER_AVEC2 s8s32mlagemm_skinnyger_avec2 +#define I8I32MLAGEMM_SKINNYGER_BVEC2 s8s32mlagemm_skinnyger_bvec2 +#define I8I32MLAGEMM_SKINNYGER_CVEC2 s8s32mlagemm_skinnyger_cvec2 +#define I8I32MLAGEMM_SKINNYGER_AVEC4 s8s32mlagemm_skinnyger_avec4 +#define I8I32MLAGEMM_SKINNYGER_BVEC4 s8s32mlagemm_skinnyger_bvec4 +#define I8I32MLAGEMM_SKINNYGER_CVEC4 s8s32mlagemm_skinnyger_cvec4 +#define I8I32MLAGEMM_SKINNYGER_AVEC8 s8s32mlagemm_skinnyger_avec8 +#define I8I32MLAGEMM_SKINNYGER_BVEC8 s8s32mlagemm_skinnyger_bvec8 +#define I8I32MLAGEMM_SKINNYGER_CVEC8 s8s32mlagemm_skinnyger_cvec8 +#define I8I32MLAGEMM_SKINNYGER_AVEC16 s8s32mlagemm_skinnyger_avec16 +#define I8I32MLAGEMM_SKINNYGER_BVEC16 s8s32mlagemm_skinnyger_bvec16 +#define I8I32MLAGEMM_SKINNYGER_CVEC16 s8s32mlagemm_skinnyger_cvec16 +#define I8I32MLAGEMM_SKINNYDOT_ASCALAR s8s32mlagemm_skinnydot_ascalar +#define I8I32MLAGEMM_SKINNYDOT_BSCALAR s8s32mlagemm_skinnydot_bscalar +#define I8I32MLAGEMM_SKINNYDOT_CSCALAR s8s32mlagemm_skinnydot_cscalar +#define I8I32MLAGEMM_SKINNYDOT_AVEC1 s8s32mlagemm_skinnydot_avec1 +#define I8I32MLAGEMM_SKINNYDOT_BVEC1 s8s32mlagemm_skinnydot_bvec1 +#define I8I32MLAGEMM_SKINNYDOT_CVEC1 s8s32mlagemm_skinnydot_cvec1 +#define I8I32MLAGEMM_SKINNYDOT_AVEC2 s8s32mlagemm_skinnydot_avec2 +#define I8I32MLAGEMM_SKINNYDOT_BVEC2 s8s32mlagemm_skinnydot_bvec2 +#define I8I32MLAGEMM_SKINNYDOT_CVEC2 s8s32mlagemm_skinnydot_cvec2 +#define I8I32MLAGEMM_SKINNYDOT_AVEC4 s8s32mlagemm_skinnydot_avec4 +#define I8I32MLAGEMM_SKINNYDOT_BVEC4 s8s32mlagemm_skinnydot_bvec4 +#define I8I32MLAGEMM_SKINNYDOT_CVEC4 s8s32mlagemm_skinnydot_cvec4 +#define I8I32MLAGEMM_SKINNYDOT_AVEC8 s8s32mlagemm_skinnydot_avec8 +#define I8I32MLAGEMM_SKINNYDOT_BVEC8 s8s32mlagemm_skinnydot_bvec8 +#define I8I32MLAGEMM_SKINNYDOT_CVEC8 s8s32mlagemm_skinnydot_cvec8 +#define I8I32MLAGEMM_SKINNYDOT_AVEC16 s8s32mlagemm_skinnydot_avec16 +#define I8I32MLAGEMM_SKINNYDOT_BVEC16 s8s32mlagemm_skinnydot_bvec16 +#define I8I32MLAGEMM_SKINNYDOT_CVEC16 s8s32mlagemm_skinnydot_cvec16 +#else +#define I8I32MLAGEMM u8u32mlagemm +#define I8I32MLAGEMM_SKINNYGER_ASCALAR u8u32mlagemm_skinnyger_ascalar +#define I8I32MLAGEMM_SKINNYGER_BSCALAR u8u32mlagemm_skinnyger_bscalar +#define I8I32MLAGEMM_SKINNYGER_CSCALAR u8u32mlagemm_skinnyger_cscalar +#define I8I32MLAGEMM_SKINNYGER_AVEC1 u8u32mlagemm_skinnyger_avec1 +#define I8I32MLAGEMM_SKINNYGER_BVEC1 u8u32mlagemm_skinnyger_bvec1 +#define I8I32MLAGEMM_SKINNYGER_CVEC1 u8u32mlagemm_skinnyger_cvec1 +#define I8I32MLAGEMM_SKINNYGER_AVEC2 u8u32mlagemm_skinnyger_avec2 +#define I8I32MLAGEMM_SKINNYGER_BVEC2 u8u32mlagemm_skinnyger_bvec2 +#define I8I32MLAGEMM_SKINNYGER_CVEC2 u8u32mlagemm_skinnyger_cvec2 +#define I8I32MLAGEMM_SKINNYGER_AVEC4 u8u32mlagemm_skinnyger_avec4 +#define I8I32MLAGEMM_SKINNYGER_BVEC4 u8u32mlagemm_skinnyger_bvec4 +#define I8I32MLAGEMM_SKINNYGER_CVEC4 u8u32mlagemm_skinnyger_cvec4 +#define I8I32MLAGEMM_SKINNYGER_AVEC8 u8u32mlagemm_skinnyger_avec8 +#define I8I32MLAGEMM_SKINNYGER_BVEC8 u8u32mlagemm_skinnyger_bvec8 +#define I8I32MLAGEMM_SKINNYGER_CVEC8 u8u32mlagemm_skinnyger_cvec8 +#define I8I32MLAGEMM_SKINNYGER_AVEC16 u8u32mlagemm_skinnyger_avec16 +#define I8I32MLAGEMM_SKINNYGER_BVEC16 u8u32mlagemm_skinnyger_bvec16 +#define I8I32MLAGEMM_SKINNYGER_CVEC16 u8u32mlagemm_skinnyger_cvec16 +#define I8I32MLAGEMM_SKINNYDOT_ASCALAR u8u32mlagemm_skinnydot_ascalar +#define I8I32MLAGEMM_SKINNYDOT_BSCALAR u8u32mlagemm_skinnydot_bscalar +#define I8I32MLAGEMM_SKINNYDOT_CSCALAR u8u32mlagemm_skinnydot_cscalar +#define I8I32MLAGEMM_SKINNYDOT_AVEC1 u8u32mlagemm_skinnydot_avec1 +#define I8I32MLAGEMM_SKINNYDOT_BVEC1 u8u32mlagemm_skinnydot_bvec1 +#define I8I32MLAGEMM_SKINNYDOT_CVEC1 u8u32mlagemm_skinnydot_cvec1 +#define I8I32MLAGEMM_SKINNYDOT_AVEC2 u8u32mlagemm_skinnydot_avec2 +#define I8I32MLAGEMM_SKINNYDOT_BVEC2 u8u32mlagemm_skinnydot_bvec2 +#define I8I32MLAGEMM_SKINNYDOT_CVEC2 u8u32mlagemm_skinnydot_cvec2 +#define I8I32MLAGEMM_SKINNYDOT_AVEC4 u8u32mlagemm_skinnydot_avec4 +#define I8I32MLAGEMM_SKINNYDOT_BVEC4 u8u32mlagemm_skinnydot_bvec4 +#define I8I32MLAGEMM_SKINNYDOT_CVEC4 u8u32mlagemm_skinnydot_cvec4 +#define I8I32MLAGEMM_SKINNYDOT_AVEC8 u8u32mlagemm_skinnydot_avec8 +#define I8I32MLAGEMM_SKINNYDOT_BVEC8 u8u32mlagemm_skinnydot_bvec8 +#define I8I32MLAGEMM_SKINNYDOT_CVEC8 u8u32mlagemm_skinnydot_cvec8 +#define I8I32MLAGEMM_SKINNYDOT_AVEC16 u8u32mlagemm_skinnydot_avec16 +#define I8I32MLAGEMM_SKINNYDOT_BVEC16 u8u32mlagemm_skinnydot_bvec16 +#define I8I32MLAGEMM_SKINNYDOT_CVEC16 u8u32mlagemm_skinnydot_cvec16 +#endif + +#ifndef GEMM_UNSIGNED_INT +#define I8I32DOTGEMM s8s32dotgemm +#define I8I32DOTGEMM_SKINNYDOT_ASCALAR s8s32dotgemm_skinnydot_ascalar +#define I8I32DOTGEMM_SKINNYDOT_BSCALAR s8s32dotgemm_skinnydot_bscalar +#define I8I32DOTGEMM_SKINNYDOT_CSCALAR s8s32dotgemm_skinnydot_cscalar +#define I8I32DOTGEMM_SKINNYDOT_AVEC1 s8s32dotgemm_skinnydot_avec1 +#define I8I32DOTGEMM_SKINNYDOT_BVEC1 s8s32dotgemm_skinnydot_bvec1 +#define I8I32DOTGEMM_SKINNYDOT_CVEC1 s8s32dotgemm_skinnydot_cvec1 +#define I8I32DOTGEMM_SKINNYDOT_AVEC2 s8s32dotgemm_skinnydot_avec2 +#define I8I32DOTGEMM_SKINNYDOT_BVEC2 s8s32dotgemm_skinnydot_bvec2 +#define I8I32DOTGEMM_SKINNYDOT_CVEC2 s8s32dotgemm_skinnydot_cvec2 +#define I8I32DOTGEMM_SKINNYDOT_AVEC4 s8s32dotgemm_skinnydot_avec4 +#define I8I32DOTGEMM_SKINNYDOT_BVEC4 s8s32dotgemm_skinnydot_bvec4 +#define I8I32DOTGEMM_SKINNYDOT_CVEC4 s8s32dotgemm_skinnydot_cvec4 +#define I8I32DOTGEMM_SKINNYDOT_AVEC8 s8s32dotgemm_skinnydot_avec8 +#define I8I32DOTGEMM_SKINNYDOT_BVEC8 s8s32dotgemm_skinnydot_bvec8 +#define I8I32DOTGEMM_SKINNYDOT_CVEC8 s8s32dotgemm_skinnydot_cvec8 +#define I8I32DOTGEMM_SKINNYDOT_AVEC16 s8s32dotgemm_skinnydot_avec16 +#define I8I32DOTGEMM_SKINNYDOT_BVEC16 s8s32dotgemm_skinnydot_bvec16 +#define I8I32DOTGEMM_SKINNYDOT_CVEC16 s8s32dotgemm_skinnydot_cvec16 +#else +#define I8I32DOTGEMM u8u32dotgemm +#define I8I32DOTGEMM_SKINNYDOT_ASCALAR u8u32dotgemm_skinnydot_ascalar +#define I8I32DOTGEMM_SKINNYDOT_BSCALAR u8u32dotgemm_skinnydot_bscalar +#define I8I32DOTGEMM_SKINNYDOT_CSCALAR u8u32dotgemm_skinnydot_cscalar +#define I8I32DOTGEMM_SKINNYDOT_AVEC1 u8u32dotgemm_skinnydot_avec1 +#define I8I32DOTGEMM_SKINNYDOT_BVEC1 u8u32dotgemm_skinnydot_bvec1 +#define I8I32DOTGEMM_SKINNYDOT_CVEC1 u8u32dotgemm_skinnydot_cvec1 +#define I8I32DOTGEMM_SKINNYDOT_AVEC2 u8u32dotgemm_skinnydot_avec2 +#define I8I32DOTGEMM_SKINNYDOT_BVEC2 u8u32dotgemm_skinnydot_bvec2 +#define I8I32DOTGEMM_SKINNYDOT_CVEC2 u8u32dotgemm_skinnydot_cvec2 +#define I8I32DOTGEMM_SKINNYDOT_AVEC4 u8u32dotgemm_skinnydot_avec4 +#define I8I32DOTGEMM_SKINNYDOT_BVEC4 u8u32dotgemm_skinnydot_bvec4 +#define I8I32DOTGEMM_SKINNYDOT_CVEC4 u8u32dotgemm_skinnydot_cvec4 +#define I8I32DOTGEMM_SKINNYDOT_AVEC8 u8u32dotgemm_skinnydot_avec8 +#define I8I32DOTGEMM_SKINNYDOT_BVEC8 u8u32dotgemm_skinnydot_bvec8 +#define I8I32DOTGEMM_SKINNYDOT_CVEC8 u8u32dotgemm_skinnydot_cvec8 +#define I8I32DOTGEMM_SKINNYDOT_AVEC16 u8u32dotgemm_skinnydot_avec16 +#define I8I32DOTGEMM_SKINNYDOT_BVEC16 u8u32dotgemm_skinnydot_bvec16 +#define I8I32DOTGEMM_SKINNYDOT_CVEC16 u8u32dotgemm_skinnydot_cvec16 +#endif + +#endif diff --git a/include/arm_neon/NeonQuant.h b/include/arm_neon/NeonQuant.h new file mode 100644 index 0000000..dfb84b8 --- /dev/null +++ b/include/arm_neon/NeonQuant.h @@ -0,0 +1,814 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: NeonQuant.h + * Description: Source code template for NEON quantization kernels. + *****************************************************************************/ + +#include "arm_neon/NeonExtreme.h" + +#ifndef INCLUDE_NEON_QUANT +#define INCLUDE_NEON_QUANT + +static inline void inline_dequant_cvt_f32_s32( + float *dst, const int32_t *src, float scale, uint32_t size) { + + const float32x4_t sc4 = vdupq_n_f32(scale); + const float32x2_t sc2 = vdup_n_f32(scale); + for (; size >= 16; size -= 16) { + int32x4_t v1 = vld1q_s32(src); + int32x4_t v2 = vld1q_s32(src + 4); + int32x4_t v3 = vld1q_s32(src + 8); + int32x4_t v4 = vld1q_s32(src + 12); src += 16; + float32x4_t q1 = vcvtq_f32_s32(v1); + float32x4_t q2 = vcvtq_f32_s32(v2); + float32x4_t q3 = vcvtq_f32_s32(v3); + float32x4_t q4 = vcvtq_f32_s32(v4); + q1 = vmulq_f32(q1, sc4); + q2 = vmulq_f32(q2, sc4); + q3 = vmulq_f32(q3, sc4); + q4 = vmulq_f32(q4, sc4); + vst1q_f32(dst, q1); + vst1q_f32(dst + 4, q2); + vst1q_f32(dst + 8, q3); + vst1q_f32(dst + 12, q4); dst += 16; + } + if (size >= 8) { + int32x4_t v1 = vld1q_s32(src); + int32x4_t v2 = vld1q_s32(src + 4); src += 8; + float32x4_t q1 = vcvtq_f32_s32(v1); + float32x4_t q2 = vcvtq_f32_s32(v2); + q1 = vmulq_f32(q1, sc4); + q2 = vmulq_f32(q2, sc4); + vst1q_f32(dst, q1); + vst1q_f32(dst + 4, q2); dst += 8; + size -= 8; + } + if (size >= 4) { + int32x4_t v1 = vld1q_s32(src); src += 4; + float32x4_t q1 = vcvtq_f32_s32(v1); + q1 = vmulq_f32(q1, sc4); + vst1q_f32(dst, q1); dst += 4; + size -= 4; + } + if (size >= 2) { + int32x2_t v1 = vld1_s32(src); src += 2; + float32x2_t d1 = vcvt_f32_s32(v1); + d1 = vmul_f32(d1, sc2); + vst1_f32(dst, d1); dst += 2; + size -= 2; + } + if (size >= 1) { + *dst = (float)(*src) * scale; + } +} + +static inline void inline_quant_asym_u8_from_f32( + const float32_t *src, uint8_t *dst, + uint32_t size, uint8_t zero_point, float32_t scale) { + + if (scale <= 0) return; + if (size == 0) return; + const float32_t add_zero_s = (float32_t)zero_point + 0.5f; + const float32x4_t add_zero_q = vdupq_n_f32(add_zero_s); + const float32_t mult_s = 1.0f / scale; + const float32x4_t mult_q = vdupq_n_f32(mult_s); + + for (; size >= 16; size -= 16) { + float32x4_t f1 = vld1q_f32(src); + float32x4_t f2 = vld1q_f32(src + 4); + float32x4_t f3 = vld1q_f32(src + 8); + float32x4_t f4 = vld1q_f32(src + 12); src += 16; + f1 = vmlaq_f32(add_zero_q, f1, mult_q); + f2 = vmlaq_f32(add_zero_q, f2, mult_q); + f3 = vmlaq_f32(add_zero_q, f3, mult_q); + f4 = vmlaq_f32(add_zero_q, f4, mult_q); + uint32x4_t u1 = vcvtq_u32_f32(f1); + uint32x4_t u2 = vcvtq_u32_f32(f2); + uint32x4_t u3 = vcvtq_u32_f32(f3); + uint32x4_t u4 = vcvtq_u32_f32(f4); + uint16x4_t t1 = vqmovn_u32(u1); + uint16x4_t t2 = vqmovn_u32(u2); + uint16x4_t t3 = vqmovn_u32(u3); + uint16x4_t t4 = vqmovn_u32(u4); + uint8x8_t d1 = vqmovn_u16(vcombine_u16(t1, t2)); + uint8x8_t d2 = vqmovn_u16(vcombine_u16(t3, t4)); + vst1_u8(dst, d1); + vst1_u8(dst + 8, d2); dst += 16; + } + if (size >= 8) { + float32x4_t f1 = vld1q_f32(src); + float32x4_t f2 = vld1q_f32(src + 4); src += 8; + f1 = vmlaq_f32(add_zero_q, f1, mult_q); + f2 = vmlaq_f32(add_zero_q, f2, mult_q); + uint32x4_t u1 = vcvtq_u32_f32(f1); + uint32x4_t u2 = vcvtq_u32_f32(f2); + uint16x4_t t1 = vqmovn_u32(u1); + uint16x4_t t2 = vqmovn_u32(u2); + uint8x8_t d1 = vqmovn_u16(vcombine_u16(t1, t2)); + vst1_u8(dst, d1); dst += 8; + size -= 8; + } + if (size >= 4) { + float32x4_t f1 = vld1q_f32(src); src += 4; + f1 = vmlaq_f32(add_zero_q, f1, mult_q); + uint32x4_t u1 = vcvtq_u32_f32(f1); + uint16x4_t t1 = vqmovn_u32(u1); + uint16x4_t z1 = vdup_n_u16(0); + uint8x8_t d1 = vqmovn_u16(vcombine_u16(t1, z1)); + vst1_lane_u8(dst, d1, 0); + vst1_lane_u8(dst + 1, d1, 1); + vst1_lane_u8(dst + 2, d1, 2); + vst1_lane_u8(dst + 3, d1, 3); + dst += 4; + size -= 4; + } + for (; size > 0; size--) { + float32_t f1 = *src++; + f1 = f1 * mult_s + add_zero_s; + f1 = f1 < 0 ? 0.0 : f1; + f1 = f1 > 255 ? 255.0 : f1; + uint32_t u1 = (uint32_t)f1; + uint8_t s1 = u1 >= 256 ? 255 : u1; + *dst = s1; dst++; + } +} + +static inline void inline_quant_asym_u16_from_f32( + const float32_t *src, uint16_t *dst, + uint32_t size, uint16_t zero_point, float32_t scale) { + + if (scale <= 0) return; + if (size == 0) return; + const float32_t add_zero_s = (float32_t)zero_point + 0.5f; + const float32x4_t add_zero_q = vdupq_n_f32(add_zero_s); + const float32_t mult_s = 1.0f / scale; + const float32x4_t mult_q = vdupq_n_f32(mult_s); + + for (; size >= 16; size -= 16) { + float32x4_t f1 = vld1q_f32(src); + float32x4_t f2 = vld1q_f32(src + 4); + float32x4_t f3 = vld1q_f32(src + 8); + float32x4_t f4 = vld1q_f32(src + 12); src += 16; + f1 = vmlaq_f32(add_zero_q, f1, mult_q); + f2 = vmlaq_f32(add_zero_q, f2, mult_q); + f3 = vmlaq_f32(add_zero_q, f3, mult_q); + f4 = vmlaq_f32(add_zero_q, f4, mult_q); + uint32x4_t u1 = vcvtq_u32_f32(f1); + uint32x4_t u2 = vcvtq_u32_f32(f2); + uint32x4_t u3 = vcvtq_u32_f32(f3); + uint32x4_t u4 = vcvtq_u32_f32(f4); + uint16x4_t t1 = vqmovn_u32(u1); + uint16x4_t t2 = vqmovn_u32(u2); + uint16x4_t t3 = vqmovn_u32(u3); + uint16x4_t t4 = vqmovn_u32(u4); + vst1_u16(dst, t1); + vst1_u16(dst + 4, t2); + vst1_u16(dst + 8, t3); + vst1_u16(dst + 12, t4); dst += 16; + } + if (size >= 8) { + float32x4_t f1 = vld1q_f32(src); + float32x4_t f2 = vld1q_f32(src + 4); src += 8; + f1 = vmlaq_f32(add_zero_q, f1, mult_q); + f2 = vmlaq_f32(add_zero_q, f2, mult_q); + uint32x4_t u1 = vcvtq_u32_f32(f1); + uint32x4_t u2 = vcvtq_u32_f32(f2); + uint16x4_t t1 = vqmovn_u32(u1); + uint16x4_t t2 = vqmovn_u32(u2); + vst1_u16(dst, t1); + vst1_u16(dst + 4, t2); dst += 8; + size -= 8; + } + if (size >= 4) { + float32x4_t f1 = vld1q_f32(src); src += 4; + f1 = vmlaq_f32(add_zero_q, f1, mult_q); + uint32x4_t u1 = vcvtq_u32_f32(f1); + uint16x4_t t1 = vqmovn_u32(u1); + vst1_u16(dst, t1); dst += 4; + size -= 4; + } + if (size > 0) { + float32x4_t f1 = vdupq_n_f32(0); + f1 = vsetq_lane_f32(src[0], f1, 0); + if (size > 1) f1 = vsetq_lane_f32(src[1], f1, 1); + if (size > 2) f1 = vsetq_lane_f32(src[2], f1, 2); + f1 = vmlaq_f32(add_zero_q, f1, mult_q); + uint32x4_t u1 = vcvtq_u32_f32(f1); + uint16x4_t t1 = vqmovn_u32(u1); + vst1_lane_u16(dst, t1, 0); + if (size > 1) vst1_lane_u16(dst + 1, t1, 1); + if (size > 2) vst1_lane_u16(dst + 2, t1, 2); + } +} + +#if !__aarch64__ +static inline int32x4_t vcvtaq_s32_f32(float32x4_t src) { + const static float32x4_t cvt_positive_offset = {0.5f, 0.5f, 0.5f, 0.5f}; + const static float32x4_t cvt_negative_offset = {-0.5f, -0.5f, -0.5f, -0.5f}; + const static float32x4_t cmp_ref = {0.0f, 0.0f, 0.0f, 0.0f}; + uint32x4_t mask = vcgtq_f32(src, cmp_ref); //src big, set 1 + float32x4_t offset = vbslq_f32(mask, cvt_positive_offset, cvt_negative_offset); + src = vaddq_f32(src, offset); + return vcvtq_s32_f32(src); +} +#endif + +static inline void inline_quant_sym_s8_from_f32( + const float32_t *src, int8_t *dst, + uint32_t size, float32_t scale) { + + if (scale <= 0) return; + if (size == 0) return; + const float32_t mult_s = 1.0f / scale; + const float32x4_t mult_q = vdupq_n_f32(mult_s); + + for (; size >= 16; size -= 16) { + float32x4_t f1 = vld1q_f32(src); + float32x4_t f2 = vld1q_f32(src + 4); + float32x4_t f3 = vld1q_f32(src + 8); + float32x4_t f4 = vld1q_f32(src + 12); src += 16; + f1 = vmulq_f32(f1, mult_q); + f2 = vmulq_f32(f2, mult_q); + f3 = vmulq_f32(f3, mult_q); + f4 = vmulq_f32(f4, mult_q); + int32x4_t i1 = vcvtaq_s32_f32(f1); + int32x4_t i2 = vcvtaq_s32_f32(f2); + int32x4_t i3 = vcvtaq_s32_f32(f3); + int32x4_t i4 = vcvtaq_s32_f32(f4); + int16x4_t v1 = vqmovn_s32(i1); + int16x4_t v2 = vqmovn_s32(i2); + int16x4_t v3 = vqmovn_s32(i3); + int16x4_t v4 = vqmovn_s32(i4); + int8x8_t w1 = vqmovn_s16(vcombine_s16(v1, v2)); + int8x8_t w2 = vqmovn_s16(vcombine_s16(v3, v4)); + vst1_s8(dst, w1); + vst1_s8(dst + 8, w2); dst += 16; + } + if (size >= 8) { + float32x4_t f1 = vld1q_f32(src); + float32x4_t f2 = vld1q_f32(src + 4); src += 8; + f1 = vmulq_f32(f1, mult_q); + f2 = vmulq_f32(f2, mult_q); + int32x4_t i1 = vcvtaq_s32_f32(f1); + int32x4_t i2 = vcvtaq_s32_f32(f2); + int16x4_t v1 = vqmovn_s32(i1); + int16x4_t v2 = vqmovn_s32(i2); + int8x8_t w1 = vqmovn_s16(vcombine_s16(v1, v2)); + vst1_s8(dst, w1); dst += 8; + size -= 8; + } + if (size >= 4) { + float32x4_t f1 = vld1q_f32(src); src += 4; + f1 = vmulq_f32(f1, mult_q); + int32x4_t i1 = vcvtaq_s32_f32(f1); + int16x4_t v1 = vqmovn_s32(i1); + int16x4_t z1 = vdup_n_s16(0); + int8x8_t w1 = vqmovn_s16(vcombine_s16(v1, z1)); + vst1_lane_s8(dst, w1, 0); + vst1_lane_s8(dst + 1, w1, 1); + vst1_lane_s8(dst + 2, w1, 2); + vst1_lane_s8(dst + 3, w1, 3); dst += 4; + size -= 4; + } + for (; size > 0; size--) { + float32_t f1 = *src++; + f1 *= mult_s; + f1 += f1 > 0 ? 0.5f : -0.5f; + f1 = f1 < -128 ? -128.0 : f1; + f1 = f1 > 127 ? 127.0 : f1; + int8_t s1 = f1; + *dst = s1; dst++; + } +} + +static inline void inline_quant_sym_s16_from_f32( + const float32_t *src, int16_t *dst, + uint32_t size, float32_t scale) { + + if (scale <= 0) return; + if (size == 0) return; + const float32_t mult_s = 1.0f / scale; + const float32x4_t mult_q = vdupq_n_f32(mult_s); + + for (; size >= 16; size -= 16) { + float32x4_t f1 = vld1q_f32(src); + float32x4_t f2 = vld1q_f32(src + 4); + float32x4_t f3 = vld1q_f32(src + 8); + float32x4_t f4 = vld1q_f32(src + 12); src += 16; + f1 = vmulq_f32(f1, mult_q); + f2 = vmulq_f32(f2, mult_q); + f3 = vmulq_f32(f3, mult_q); + f4 = vmulq_f32(f4, mult_q); + int32x4_t i1 = vcvtaq_s32_f32(f1); + int32x4_t i2 = vcvtaq_s32_f32(f2); + int32x4_t i3 = vcvtaq_s32_f32(f3); + int32x4_t i4 = vcvtaq_s32_f32(f4); + int16x4_t v1 = vqmovn_s32(i1); + int16x4_t v2 = vqmovn_s32(i2); + int16x4_t v3 = vqmovn_s32(i3); + int16x4_t v4 = vqmovn_s32(i4); + vst1_s16(dst, v1); + vst1_s16(dst + 4, v2); + vst1_s16(dst + 8, v3); + vst1_s16(dst + 12, v4); dst += 16; + } + if (size >= 8) { + float32x4_t f1 = vld1q_f32(src); + float32x4_t f2 = vld1q_f32(src + 4); src += 8; + f1 = vmulq_f32(f1, mult_q); + f2 = vmulq_f32(f2, mult_q); + int32x4_t i1 = vcvtaq_s32_f32(f1); + int32x4_t i2 = vcvtaq_s32_f32(f2); + int16x4_t v1 = vqmovn_s32(i1); + int16x4_t v2 = vqmovn_s32(i2); + vst1_s16(dst, v1); + vst1_s16(dst + 4, v2); dst += 8; + size -= 8; + } + if (size >= 4) { + float32x4_t f1 = vld1q_f32(src); src += 4; + f1 = vmulq_f32(f1, mult_q); + int32x4_t i1 = vcvtaq_s32_f32(f1); + int16x4_t v1 = vqmovn_s32(i1); + vst1_s16(dst, v1); dst += 4; + size -= 4; + } + if (size > 0) { + float32x4_t f1 = vdupq_n_f32(0); + f1 = vsetq_lane_f32(src[0], f1, 0); + if (size > 1) f1 = vsetq_lane_f32(src[1], f1, 1); + if (size > 2) f1 = vsetq_lane_f32(src[2], f1, 2); + f1 = vmulq_f32(f1, mult_q); + int32x4_t i1 = vcvtaq_s32_f32(f1); + int16x4_t v1 = vqmovn_s32(i1); + vst1_lane_s16(dst, v1, 0); + if (size > 1) vst1_lane_s16(dst + 1, v1, 1); + if (size > 2) vst1_lane_s16(dst + 2, v1, 2); + } +} + +static inline void inline_requant_asym_u8_from_s32_mulhi(const int32_t *src, + uint8_t *dst, uint32_t size, uint8_t src_lshift, + int32_t mult_factor_22redun, uint8_t zero_point) { + + if (size == 0) return; + const int32x4_t src_sh4 = vdupq_n_s32(src_lshift); + const int32x4_t mult_v4 = vdupq_n_s32(mult_factor_22redun); + const int16x4_t add_z4 = vdup_n_s16((int16_t)zero_point << 6); + + for (; size > 15; size -= 16) { + int32x4_t l1 = vld1q_s32(src); + int32x4_t l2 = vld1q_s32(src + 4); + int32x4_t l3 = vld1q_s32(src + 8); + int32x4_t l4 = vld1q_s32(src + 12); src += 16; + l1 = vqrshlq_s32(l1, src_sh4); + l2 = vqrshlq_s32(l2, src_sh4); + l3 = vqrshlq_s32(l3, src_sh4); + l4 = vqrshlq_s32(l4, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l2 = vqrdmulhq_s32(l2, mult_v4); + l3 = vqrdmulhq_s32(l3, mult_v4); + l4 = vqrdmulhq_s32(l4, mult_v4); + int16x4_t m1 = vrshrn_n_s32(l1, 16); + int16x4_t m2 = vrshrn_n_s32(l2, 16); + int16x4_t m3 = vrshrn_n_s32(l3, 16); + int16x4_t m4 = vrshrn_n_s32(l4, 16); + m1 = vadd_s16(m1, add_z4); + m2 = vadd_s16(m2, add_z4); + m3 = vadd_s16(m3, add_z4); + m4 = vadd_s16(m4, add_z4); + uint8x8_t u1 = vqrshrun_n_s16(vcombine_s16(m1, m2), 6); + uint8x8_t u2 = vqrshrun_n_s16(vcombine_s16(m3, m4), 6); + vst1_u8(dst, u1); + vst1_u8(dst + 8, u2); dst += 16; + } + if (size > 7) { + int32x4_t l1 = vld1q_s32(src); + int32x4_t l2 = vld1q_s32(src + 4); src += 8; + l1 = vqrshlq_s32(l1, src_sh4); + l2 = vqrshlq_s32(l2, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l2 = vqrdmulhq_s32(l2, mult_v4); + int16x4_t m1 = vrshrn_n_s32(l1, 16); + int16x4_t m2 = vrshrn_n_s32(l2, 16); + m1 = vadd_s16(m1, add_z4); + m2 = vadd_s16(m2, add_z4); + uint8x8_t u1 = vqrshrun_n_s16(vcombine_s16(m1, m2), 6); + vst1_u8(dst, u1); dst += 8; + size -= 8; + } + if (size > 3) { + int32x4_t l1 = vld1q_s32(src); src += 4; + l1 = vqrshlq_s32(l1, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + int16x4_t m1 = vrshrn_n_s32(l1, 16); + m1 = vadd_s16(m1, add_z4); + uint8x8_t u1 = vqrshrun_n_s16(vcombine_s16(m1, m1), 6); + vst1_lane_u8(dst, u1, 0); + vst1_lane_u8(dst + 1, u1, 1); + vst1_lane_u8(dst + 2, u1, 2); + vst1_lane_u8(dst + 3, u1, 3); dst += 4; + size -= 4; + } + if (size > 0) { + int32x4_t l1 = vdupq_n_s32(0); + l1 = vsetq_lane_s32(src[0], l1, 0); + if (size > 1) l1 = vsetq_lane_s32(src[1], l1, 1); + if (size > 2) l1 = vsetq_lane_s32(src[2], l1, 2); + l1 = vqrshlq_s32(l1, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + int16x4_t m1 = vrshrn_n_s32(l1, 16); + m1 = vadd_s16(m1, add_z4); + uint8x8_t u1 = vqrshrun_n_s16(vcombine_s16(m1, m1), 6); + vst1_lane_u8(dst, u1, 0); + if (size > 1) vst1_lane_u8(dst + 1, u1, 1); + if (size > 2) vst1_lane_u8(dst + 2, u1, 2); + } +} + +static inline void inline_requant_sym_s8_from_s32_mulhi(const int32_t *src, + int8_t *dst, uint32_t size, + uint8_t src_lshift, int32_t mult_factor_22redun) { + + if (size == 0) return; + const int32x4_t src_sh4 = vdupq_n_s32(src_lshift); + const int32x4_t mult_v4 = vdupq_n_s32(mult_factor_22redun); + + for (; size > 15; size -= 16) { + int32x4_t l1 = vld1q_s32(src); + int32x4_t l2 = vld1q_s32(src + 4); + int32x4_t l3 = vld1q_s32(src + 8); + int32x4_t l4 = vld1q_s32(src + 12); src += 16; + l1 = vqrshlq_s32(l1, src_sh4); + l2 = vqrshlq_s32(l2, src_sh4); + l3 = vqrshlq_s32(l3, src_sh4); + l4 = vqrshlq_s32(l4, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l2 = vqrdmulhq_s32(l2, mult_v4); + l3 = vqrdmulhq_s32(l3, mult_v4); + l4 = vqrdmulhq_s32(l4, mult_v4); + int16x4_t m1 = vrshrn_n_s32(l1, 16); + int16x4_t m2 = vrshrn_n_s32(l2, 16); + int16x4_t m3 = vrshrn_n_s32(l3, 16); + int16x4_t m4 = vrshrn_n_s32(l4, 16); + int8x8_t s1 = vqrshrn_n_s16(vcombine_s16(m1, m2), 7); + int8x8_t s2 = vqrshrn_n_s16(vcombine_s16(m3, m4), 7); + vst1_s8(dst, s1); + vst1_s8(dst + 8, s2); dst += 16; + } + if (size > 7) { + int32x4_t l1 = vld1q_s32(src); + int32x4_t l2 = vld1q_s32(src + 4); src += 8; + l1 = vqrshlq_s32(l1, src_sh4); + l2 = vqrshlq_s32(l2, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l2 = vqrdmulhq_s32(l2, mult_v4); + int16x4_t m1 = vrshrn_n_s32(l1, 16); + int16x4_t m2 = vrshrn_n_s32(l2, 16); + int8x8_t s1 = vqrshrn_n_s16(vcombine_s16(m1, m2), 7); + vst1_s8(dst, s1); dst += 8; + size -= 8; + } + if (size > 3) { + int32x4_t l1 = vld1q_s32(src); src += 4; + l1 = vqrshlq_s32(l1, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + int16x4_t m1 = vrshrn_n_s32(l1, 16); + int8x8_t s1 = vqrshrn_n_s16(vcombine_s16(m1, m1), 7); + vst1_lane_s8(dst, s1, 0); + vst1_lane_s8(dst + 1, s1, 1); + vst1_lane_s8(dst + 2, s1, 2); + vst1_lane_s8(dst + 3, s1, 3); dst += 4; + size -= 4; + } + if (size > 0) { + int32x4_t l1 = vdupq_n_s32(0); + l1 = vsetq_lane_s32(src[0], l1, 0); + if (size > 1) l1 = vsetq_lane_s32(src[1], l1, 1); + if (size > 2) l1 = vsetq_lane_s32(src[2], l1, 2); + l1 = vqrshlq_s32(l1, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + int16x4_t m1 = vrshrn_n_s32(l1, 16); + int8x8_t s1 = vqrshrn_n_s16(vcombine_s16(m1, m1), 7); + vst1_lane_s8(dst, s1, 0); + if (size > 1) vst1_lane_s8(dst + 1, s1, 1); + if (size > 2) vst1_lane_s8(dst + 2, s1, 2); + } +} + +static inline void inline_requant_asym_u16_from_s32_mulhi(const int32_t *src, + uint16_t *dst, uint32_t size, uint8_t src_lshift, + int32_t mult_factor, uint16_t zero_point) { + + if (size == 0) return; + const int32x4_t src_sh4 = vdupq_n_s32(src_lshift); + const int32x4_t mult_v4 = vdupq_n_s32(mult_factor); + const int32x4_t add_z4 = vdupq_n_s32((int32_t)zero_point << 14); + + for (; size > 15; size -= 16) { + int32x4_t l1 = vld1q_s32(src); + int32x4_t l2 = vld1q_s32(src + 4); + int32x4_t l3 = vld1q_s32(src + 8); + int32x4_t l4 = vld1q_s32(src + 12); src += 16; + l1 = vqrshlq_s32(l1, src_sh4); + l2 = vqrshlq_s32(l2, src_sh4); + l3 = vqrshlq_s32(l3, src_sh4); + l4 = vqrshlq_s32(l4, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l2 = vqrdmulhq_s32(l2, mult_v4); + l3 = vqrdmulhq_s32(l3, mult_v4); + l4 = vqrdmulhq_s32(l4, mult_v4); + l1 = vqaddq_s32(l1, add_z4); + l2 = vqaddq_s32(l2, add_z4); + l3 = vqaddq_s32(l3, add_z4); + l4 = vqaddq_s32(l4, add_z4); + uint16x4_t m1 = vqrshrun_n_s32(l1, 14); + uint16x4_t m2 = vqrshrun_n_s32(l2, 14); + uint16x4_t m3 = vqrshrun_n_s32(l3, 14); + uint16x4_t m4 = vqrshrun_n_s32(l4, 14); + vst1_u16(dst, m1); + vst1_u16(dst + 4, m2); + vst1_u16(dst + 8, m3); + vst1_u16(dst + 12, m4); dst += 16; + } + for (; size > 7; size -= 8) { + int32x4_t l1 = vld1q_s32(src); + int32x4_t l2 = vld1q_s32(src + 4); src += 8; + l1 = vqrshlq_s32(l1, src_sh4); + l2 = vqrshlq_s32(l2, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l2 = vqrdmulhq_s32(l2, mult_v4); + l1 = vqaddq_s32(l1, add_z4); + l2 = vqaddq_s32(l2, add_z4); + uint16x4_t m1 = vqrshrun_n_s32(l1, 14); + uint16x4_t m2 = vqrshrun_n_s32(l2, 14); + vst1_u16(dst, m1); + vst1_u16(dst + 4, m2); dst += 8; + } + for (; size > 3; size -= 4) { + int32x4_t l1 = vld1q_s32(src); src += 4; + l1 = vqrshlq_s32(l1, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l1 = vqaddq_s32(l1, add_z4); + uint16x4_t m1 = vqrshrun_n_s32(l1, 14); + vst1_u16(dst, m1); dst += 4; + } + if (size > 0) { + int32x4_t l1 = vdupq_n_s32(0); + l1 = vsetq_lane_s32(src[0], l1, 0); + if (size > 1) l1 = vsetq_lane_s32(src[1], l1, 1); + if (size > 2) l1 = vsetq_lane_s32(src[2], l1, 2); + l1 = vqrshlq_s32(l1, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l1 = vqaddq_s32(l1, add_z4); + uint16x4_t m1 = vqrshrun_n_s32(l1, 14); + vst1_lane_u16(dst, m1, 0); + if (size > 1) vst1_lane_u16(dst + 1, m1, 1); + if (size > 2) vst1_lane_u16(dst + 2, m1, 2); + } +} + +static inline void inline_requant_sym_s16_from_s32_mulhi(const int32_t *src, + int16_t *dst, uint32_t size, + uint8_t src_lshift, int32_t mult_factor) { + + if (size == 0) return; + const int32x4_t src_sh4 = vdupq_n_s32(src_lshift); + const int32x4_t mult_v4 = vdupq_n_s32(mult_factor); + + for (; size > 15; size -= 16) { + int32x4_t l1 = vld1q_s32(src); + int32x4_t l2 = vld1q_s32(src + 4); + int32x4_t l3 = vld1q_s32(src + 8); + int32x4_t l4 = vld1q_s32(src + 12); src += 16; + l1 = vqrshlq_s32(l1, src_sh4); + l2 = vqrshlq_s32(l2, src_sh4); + l3 = vqrshlq_s32(l3, src_sh4); + l4 = vqrshlq_s32(l4, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l2 = vqrdmulhq_s32(l2, mult_v4); + l3 = vqrdmulhq_s32(l3, mult_v4); + l4 = vqrdmulhq_s32(l4, mult_v4); + int16x4_t m1 = vqrshrn_n_s32(l1, 15); + int16x4_t m2 = vqrshrn_n_s32(l2, 15); + int16x4_t m3 = vqrshrn_n_s32(l3, 15); + int16x4_t m4 = vqrshrn_n_s32(l4, 15); + vst1_s16(dst, m1); + vst1_s16(dst + 4, m2); + vst1_s16(dst + 8, m3); + vst1_s16(dst + 12, m4); dst += 16; + } + if (size > 7) { + int32x4_t l1 = vld1q_s32(src); + int32x4_t l2 = vld1q_s32(src + 4); src += 8; + l1 = vqrshlq_s32(l1, src_sh4); + l2 = vqrshlq_s32(l2, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + l2 = vqrdmulhq_s32(l2, mult_v4); + int16x4_t m1 = vqrshrn_n_s32(l1, 15); + int16x4_t m2 = vqrshrn_n_s32(l2, 15); + vst1_s16(dst, m1); + vst1_s16(dst + 4, m2); dst += 8; + size -= 8; + } + if (size > 3) { + int32x4_t l1 = vld1q_s32(src); src += 4; + l1 = vqrshlq_s32(l1, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + int16x4_t m1 = vqrshrn_n_s32(l1, 15); + vst1_s16(dst, m1); dst += 4; + size -= 4; + } + if (size > 0) { + int32x4_t l1 = vdupq_n_s32(0); + l1 = vsetq_lane_s32(src[0], l1, 0); + if (size > 1) l1 = vsetq_lane_s32(src[1], l1, 1); + if (size > 2) l1 = vsetq_lane_s32(src[2], l1, 2); + l1 = vqrshlq_s32(l1, src_sh4); + l1 = vqrdmulhq_s32(l1, mult_v4); + int16x4_t m1 = vqrshrn_n_s32(l1, 15); + vst1_lane_s16(dst, m1, 0); + if (size > 1) vst1_lane_s16(dst + 1, m1, 1); + if (size > 2) vst1_lane_s16(dst + 2, m1, 2); + } +} + +static inline void inline_requant_asym_u8_from_s16_mulhi(const int16_t *src, + uint8_t *dst, uint32_t size, uint8_t src_lshift, + int16_t mult_factor, uint8_t zero_point) { + + if (size == 0) return; + const int16x8_t src_sh8 = vdupq_n_s16(src_lshift); + const int16x8_t mult_v8 = vdupq_n_s16(mult_factor); + const int16x8_t add_z8 = vdupq_n_s16((int16_t)zero_point << 6); + + for (; size > 31; size -= 32) { + int16x8_t l1 = vld1q_s16(src); + int16x8_t l2 = vld1q_s16(src + 8); + int16x8_t l3 = vld1q_s16(src + 16); + int16x8_t l4 = vld1q_s16(src + 24); src += 32; + l1 = vqrshlq_s16(l1, src_sh8); + l2 = vqrshlq_s16(l2, src_sh8); + l3 = vqrshlq_s16(l3, src_sh8); + l4 = vqrshlq_s16(l4, src_sh8); + l1 = vqrdmulhq_s16(l1, mult_v8); + l2 = vqrdmulhq_s16(l2, mult_v8); + l3 = vqrdmulhq_s16(l3, mult_v8); + l4 = vqrdmulhq_s16(l4, mult_v8); + l1 = vqaddq_s16(l1, add_z8); + l2 = vqaddq_s16(l2, add_z8); + l3 = vqaddq_s16(l3, add_z8); + l4 = vqaddq_s16(l4, add_z8); + uint8x8_t m1 = vqrshrun_n_s16(l1, 6); + uint8x8_t m2 = vqrshrun_n_s16(l2, 6); + uint8x8_t m3 = vqrshrun_n_s16(l3, 6); + uint8x8_t m4 = vqrshrun_n_s16(l4, 6); + vst1_u8(dst, m1); + vst1_u8(dst + 8, m2); + vst1_u8(dst + 16, m3); + vst1_u8(dst + 24, m4); dst += 32; + } + if (size > 15) { + int16x8_t l1 = vld1q_s16(src); + int16x8_t l2 = vld1q_s16(src + 8); src += 16; + l1 = vqrshlq_s16(l1, src_sh8); + l2 = vqrshlq_s16(l2, src_sh8); + l1 = vqrdmulhq_s16(l1, mult_v8); + l2 = vqrdmulhq_s16(l2, mult_v8); + l1 = vqaddq_s16(l1, add_z8); + l2 = vqaddq_s16(l2, add_z8); + uint8x8_t m1 = vqrshrun_n_s16(l1, 6); + uint8x8_t m2 = vqrshrun_n_s16(l2, 6); + vst1_u8(dst, m1); + vst1_u8(dst + 8, m2); dst += 16; + size -= 16; + } + if (size > 7) { + int16x8_t l1 = vld1q_s16(src); src += 8; + l1 = vqrshlq_s16(l1, src_sh8); + l1 = vqrdmulhq_s16(l1, mult_v8); + l1 = vqaddq_s16(l1, add_z8); + uint8x8_t m1 = vqrshrun_n_s16(l1, 6); + vst1_u8(dst, m1); dst += 8; + size -= 8; + } + if (size > 3) { + int16x4_t l1 = vld1_s16(src); src += 4; + l1 = vqrshl_s16(l1, vget_low_s16(src_sh8)); + l1 = vqrdmulh_s16(l1, vget_low_s16(mult_v8)); + l1 = vqadd_s16(l1, vget_low_s16(add_z8)); + uint8x8_t m1 = vqrshrun_n_s16(vcombine_s16(l1, vdup_n_s16(0)), 6); + vst1_lane_u8(dst, m1, 0); + vst1_lane_u8(dst + 1, m1, 1); + vst1_lane_u8(dst + 2, m1, 2); + vst1_lane_u8(dst + 3, m1, 3); dst += 4; + size -= 4; + } + if (size > 0) { + int16x4_t l1 = vdup_n_s16(0); + l1 = vset_lane_s16(src[0], l1, 0); + if (size > 1) l1 = vset_lane_s16(src[1], l1, 1); + if (size > 2) l1 = vset_lane_s16(src[2], l1, 2); + l1 = vqrshl_s16(l1, vget_low_s16(src_sh8)); + l1 = vqrdmulh_s16(l1, vget_low_s16(mult_v8)); + l1 = vqadd_s16(l1, vget_low_s16(add_z8)); + uint8x8_t m1 = vqrshrun_n_s16(vcombine_s16(l1, vdup_n_s16(0)), 6); + vst1_lane_u8(dst, m1, 0); + if (size > 1) vst1_lane_u8(dst + 1, m1, 1); + if (size > 2) vst1_lane_u8(dst + 2, m1, 2); + } +} + +static inline void inline_requant_sym_s8_from_s16_mulhi(const int16_t *src, + int8_t *dst, uint32_t size, + uint8_t src_lshift, int16_t mult_factor) { + + if (size == 0) return; + const int16x8_t src_sh8 = vdupq_n_s16(src_lshift); + const int16x8_t mult_v8 = vdupq_n_s16(mult_factor); + + for (; size > 31; size -= 32) { + int16x8_t l1 = vld1q_s16(src); + int16x8_t l2 = vld1q_s16(src + 8); + int16x8_t l3 = vld1q_s16(src + 16); + int16x8_t l4 = vld1q_s16(src + 24); src += 32; + l1 = vqrshlq_s16(l1, src_sh8); + l2 = vqrshlq_s16(l2, src_sh8); + l3 = vqrshlq_s16(l3, src_sh8); + l4 = vqrshlq_s16(l4, src_sh8); + l1 = vqrdmulhq_s16(l1, mult_v8); + l2 = vqrdmulhq_s16(l2, mult_v8); + l3 = vqrdmulhq_s16(l3, mult_v8); + l4 = vqrdmulhq_s16(l4, mult_v8); + int8x8_t m1 = vqrshrn_n_s16(l1, 7); + int8x8_t m2 = vqrshrn_n_s16(l2, 7); + int8x8_t m3 = vqrshrn_n_s16(l3, 7); + int8x8_t m4 = vqrshrn_n_s16(l4, 7); + vst1_s8(dst, m1); + vst1_s8(dst + 8, m2); + vst1_s8(dst + 16, m3); + vst1_s8(dst + 24, m4); dst += 32; + } + if (size > 15) { + int16x8_t l1 = vld1q_s16(src); + int16x8_t l2 = vld1q_s16(src + 8); src += 16; + l1 = vqrshlq_s16(l1, src_sh8); + l2 = vqrshlq_s16(l2, src_sh8); + l1 = vqrdmulhq_s16(l1, mult_v8); + l2 = vqrdmulhq_s16(l2, mult_v8); + int8x8_t m1 = vqrshrn_n_s16(l1, 7); + int8x8_t m2 = vqrshrn_n_s16(l2, 7); + vst1_s8(dst, m1); + vst1_s8(dst + 8, m2); dst += 16; + size -= 16; + } + if (size > 7) { + int16x8_t l1 = vld1q_s16(src); src += 8; + l1 = vqrshlq_s16(l1, src_sh8); + l1 = vqrdmulhq_s16(l1, mult_v8); + int8x8_t m1 = vqrshrn_n_s16(l1, 7); + vst1_s8(dst, m1); dst += 8; + size -= 8; + } + if (size > 3) { + int16x4_t l1 = vld1_s16(src); src += 4; + l1 = vqrshl_s16(l1, vget_low_s16(src_sh8)); + l1 = vqrdmulh_s16(l1, vget_low_s16(mult_v8)); + int8x8_t m1 = vqrshrn_n_s16(vcombine_s16(l1, vdup_n_s16(0)), 7); + vst1_lane_s8(dst, m1, 0); + vst1_lane_s8(dst + 1, m1, 1); + vst1_lane_s8(dst + 2, m1, 2); + vst1_lane_s8(dst + 3, m1, 3); dst += 4; + size -= 4; + } + if (size > 0) { + int16x4_t l1 = vdup_n_s16(0); + l1 = vset_lane_s16(src[0], l1, 0); + if (size > 1) l1 = vset_lane_s16(src[1], l1, 1); + if (size > 2) l1 = vset_lane_s16(src[2], l1, 2); + l1 = vqrshl_s16(l1, vget_low_s16(src_sh8)); + l1 = vqrdmulh_s16(l1, vget_low_s16(mult_v8)); + int8x8_t m1 = vqrshrn_n_s16(vcombine_s16(l1, vdup_n_s16(0)), 7); + vst1_lane_s8(dst, m1, 0); + if (size > 1) vst1_lane_s8(dst + 1, m1, 1); + if (size > 2) vst1_lane_s8(dst + 2, m1, 2); + } +} + +#endif diff --git a/include/arm_neon/NeonSgemmCopy.h b/include/arm_neon/NeonSgemmCopy.h new file mode 100644 index 0000000..736ba52 --- /dev/null +++ b/include/arm_neon/NeonSgemmCopy.h @@ -0,0 +1,217 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: NeonSgemmCopy.h + * Description: Code templates for NEON SGEMM packing functions. + *****************************************************************************/ + +#include + +#ifndef INCLUDE_NEON_SGEMM_COPY +#define INCLUDE_NEON_SGEMM_COPY + +#if __aarch64__ +static inline void pref_ab(const float *dat) { + __asm__ ("prfm pldl1keep,[%0,#64]\n\t"::"r"(dat):); +} +#else +static inline void pref_ab(const float *dat) { + __asm__ ("pld [%0,#64]\n\t"::"r"(dat):); +} +#endif + +#define NCOPY_NEON_LOOP_K8_UNROLL4(inc, dst_ptr, src1, src2, src3, src4) \ + for (dim1_count = dim1_cache; dim1_count > 7; dim1_count -= 8) {\ + t1.val[0] = vld1q_f32(src1); t2.val[0] = vld1q_f32(src1 + 4);\ + src1 += 8; pref_ab(src1);\ + t1.val[1] = vld1q_f32(src2); t2.val[1] = vld1q_f32(src2 + 4);\ + src2 += 8; pref_ab(src2);\ + t1.val[2] = vld1q_f32(src3); t2.val[2] = vld1q_f32(src3 + 4);\ + src3 += 8; pref_ab(src3);\ + t1.val[3] = vld1q_f32(src4); t2.val[3] = vld1q_f32(src4 + 4);\ + src4 += 8; pref_ab(src4);\ + vst4q_lane_f32(dst_ptr, t1, 0);\ + vst4q_lane_f32(dst_ptr + inc, t1, 1);\ + vst4q_lane_f32(dst_ptr + inc * 2, t1, 2);\ + vst4q_lane_f32(dst_ptr + inc * 3, t1, 3);\ + vst4q_lane_f32(dst_ptr + inc * 4, t2, 0);\ + vst4q_lane_f32(dst_ptr + inc * 5, t2, 1);\ + vst4q_lane_f32(dst_ptr + inc * 6, t2, 2);\ + vst4q_lane_f32(dst_ptr + inc * 7, t2, 3);\ + dst_ptr += inc * 8;\ + }\ + + +#define NCOPY_UNROLL_24 {\ + float32x4x4_t t1, t2;\ + float *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_NEON_LOOP_K8_UNROLL4(24, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_NEON_LOOP_K8_UNROLL4(24, dst_h1, src5, src6, src7, src8)\ + dst_h1 = dst1 + 8;\ + NCOPY_NEON_LOOP_K8_UNROLL4(24, dst_h1, src9, src10, src11, src12)\ + dst_h1 = dst1 + 12;\ + NCOPY_NEON_LOOP_K8_UNROLL4(24, dst_h1, src13, src14, src15, src16)\ + dst_h1 = dst1 + 16;\ + NCOPY_NEON_LOOP_K8_UNROLL4(24, dst_h1, src17, src18, src19, src20)\ + dst_h1 = dst1 + 20;\ + NCOPY_NEON_LOOP_K8_UNROLL4(24, dst_h1, src21, src22, src23, src24)\ + dst1 = dst_h1 - 20;\ + NCOPY_STD(24)\ +} + +#define NCOPY_UNROLL_12 {\ + float32x4x4_t t1, t2;\ + float *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_NEON_LOOP_K8_UNROLL4(12, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_NEON_LOOP_K8_UNROLL4(12, dst_h1, src5, src6, src7, src8)\ + dst_h1 = dst1 + 8;\ + NCOPY_NEON_LOOP_K8_UNROLL4(12, dst_h1, src9, src10, src11, src12)\ + dst1 = dst_h1 - 8;\ + NCOPY_STD(12)\ +} + +#define NCOPY_UNROLL_8 {\ + float32x4x4_t t1, t2;\ + float *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_NEON_LOOP_K8_UNROLL4(8, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_NEON_LOOP_K8_UNROLL4(8, dst_h1, src5, src6, src7, src8)\ + dst1 = dst_h1 - 4;\ + NCOPY_STD(8)\ +} + +#define NCOPY_UNROLL_6 {\ + float32x4x3_t t1, t2;\ + float *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + for (; dim1_count > 7; dim1_count -= 8) {\ + t1.val[0] = vld1q_f32(src1); t2.val[0] = vld1q_f32(src1 + 4);\ + src1 += 8; pref_ab(src1);\ + t1.val[1] = vld1q_f32(src2); t2.val[1] = vld1q_f32(src2 + 4);\ + src2 += 8; pref_ab(src2);\ + t1.val[2] = vld1q_f32(src3); t2.val[2] = vld1q_f32(src3 + 4);\ + src3 += 8; pref_ab(src3);\ + vst3q_lane_f32(dst_h1, t1, 0);\ + vst3q_lane_f32(dst_h1 + 6, t1, 1);\ + vst3q_lane_f32(dst_h1 + 12, t1, 2);\ + vst3q_lane_f32(dst_h1 + 18, t1, 3);\ + vst3q_lane_f32(dst_h1 + 24, t2, 0);\ + vst3q_lane_f32(dst_h1 + 30, t2, 1);\ + vst3q_lane_f32(dst_h1 + 36, t2, 2);\ + vst3q_lane_f32(dst_h1 + 42, t2, 3);\ + dst_h1 += 48;\ + }\ + float *dst_h2 = dst1 + 3;\ + for (dim1_count = dim1_cache; dim1_count > 7; dim1_count -= 8) {\ + t1.val[0] = vld1q_f32(src4); t2.val[0] = vld1q_f32(src4 + 4);\ + src4 += 8; pref_ab(src4);\ + t1.val[1] = vld1q_f32(src5); t2.val[1] = vld1q_f32(src5 + 4);\ + src5 += 8; pref_ab(src5);\ + t1.val[2] = vld1q_f32(src6); t2.val[2] = vld1q_f32(src6 + 4);\ + src6 += 8; pref_ab(src6);\ + vst3q_lane_f32(dst_h2, t1, 0);\ + vst3q_lane_f32(dst_h2 + 6, t1, 1);\ + vst3q_lane_f32(dst_h2 + 12, t1, 2);\ + vst3q_lane_f32(dst_h2 + 18, t1, 3);\ + vst3q_lane_f32(dst_h2 + 24, t2, 0);\ + vst3q_lane_f32(dst_h2 + 30, t2, 1);\ + vst3q_lane_f32(dst_h2 + 36, t2, 2);\ + vst3q_lane_f32(dst_h2 + 42, t2, 3);\ + dst_h2 += 48;\ + }\ + dst1 = dst_h1;\ + NCOPY_STD(6)\ +} + +#define NCOPY_UNROLL_4 {\ + float32x4x4_t t1;\ + for (; dim1_count > 3; dim1_count -= 4) {\ + t1.val[0] = vld1q_f32(src1); src1 += 4; pref_ab(src1);\ + t1.val[1] = vld1q_f32(src2); src2 += 4; pref_ab(src2);\ + t1.val[2] = vld1q_f32(src3); src3 += 4; pref_ab(src3);\ + t1.val[3] = vld1q_f32(src4); src4 += 4; pref_ab(src4);\ + vst4q_f32(dst1,t1); dst1 += 16;\ + }\ + NCOPY_STD(4)\ +} + +#define NCOPY_UNROLL_2 NCOPY_STD(2) +#define NCOPY_UNROLL_1 NCOPY_STD(1) + +//#define NCOPY_a(unroll) NCOPY_UNROLL_##unroll +//#define NCOPY_b(unroll) NCOPY_UNROLL_##unroll + +#define TCOPY_UNIT_1(src_ptr, dst_ptr, dst_offset) \ + dst_ptr[dst_offset] = *src_ptr; + +#define TCOPY_UNIT_2(src_ptr, dst_ptr, dst_offset) {\ + float32x2_t tmp = vld1_f32(src_ptr);\ + vst1_f32(dst_ptr + dst_offset, tmp);\ +} + +#define TCOPY_UNIT_4(src_ptr, dst_ptr, dst_offset) {\ + float32x4_t tmp = vld1q_f32(src_ptr); pref_ab(src_ptr + 4);\ + vst1q_f32(dst_ptr + dst_offset, tmp);\ +} + +#define TCOPY_UNIT_6(src_ptr, dst_ptr, dst_offset) {\ + float32x4_t tmpq = vld1q_f32(src_ptr);\ + float32x2_t tmpd = vld1_f32(src_ptr + 4); pref_ab(src_ptr + 6);\ + vst1q_f32(dst_ptr + dst_offset, tmpq);\ + vst1_f32(dst_ptr + dst_offset + 4, tmpd);\ +} + +#define TCOPY_UNIT_8(src_ptr, dst_ptr, dst_offset) {\ + float32x4_t tmp1 = vld1q_f32(src_ptr);\ + float32x4_t tmp2 = vld1q_f32(src_ptr + 4); pref_ab(src_ptr + 8);\ + vst1q_f32(dst_ptr + dst_offset, tmp1);\ + vst1q_f32(dst_ptr + dst_offset + 4, tmp2);\ +} + +#define TCOPY_UNIT_12(src_ptr, dst_ptr, dst_offset) {\ + float32x4_t tmp1 = vld1q_f32(src_ptr);\ + float32x4_t tmp2 = vld1q_f32(src_ptr + 4);\ + float32x4_t tmp3 = vld1q_f32(src_ptr + 8); pref_ab(src_ptr + 12);\ + vst1q_f32(dst_ptr + dst_offset, tmp1);\ + vst1q_f32(dst_ptr + dst_offset + 4, tmp2);\ + vst1q_f32(dst_ptr + dst_offset + 8, tmp3);\ +} + +#define TCOPY_UNIT_24(src_ptr, dst_ptr, dst_offset) {\ + float32x4_t tmp1 = vld1q_f32(src_ptr);\ + float32x4_t tmp2 = vld1q_f32(src_ptr + 4);\ + float32x4_t tmp3 = vld1q_f32(src_ptr + 8);\ + float32x4_t tmp4 = vld1q_f32(src_ptr + 12);\ + float32x4_t tmp5 = vld1q_f32(src_ptr + 16); pref_ab(src_ptr + 24);\ + float32x4_t tmp6 = vld1q_f32(src_ptr + 20); pref_ab(src_ptr + 40);\ + vst1q_f32(dst_ptr + dst_offset, tmp1);\ + vst1q_f32(dst_ptr + dst_offset + 4, tmp2);\ + vst1q_f32(dst_ptr + dst_offset + 8, tmp3);\ + vst1q_f32(dst_ptr + dst_offset + 12, tmp4);\ + vst1q_f32(dst_ptr + dst_offset + 16, tmp5);\ + vst1q_f32(dst_ptr + dst_offset + 20, tmp6);\ +} + +//#define TCOPY_UNIT_a(src_ptr, dst_ptr, dst_offset, num_elements) \ + TCOPY_UNIT_##num_elements(src_ptr, dst_ptr, dst_offset) + +//#define TCOPY_UNIT_b(src_ptr, dst_ptr, dst_offset, num_elements) \ + TCOPY_UNIT_##num_elements(src_ptr, dst_ptr, dst_offset) + +#endif diff --git a/include/arm_neon/NeonSgemmKernel.h b/include/arm_neon/NeonSgemmKernel.h new file mode 100644 index 0000000..3f00c3c --- /dev/null +++ b/include/arm_neon/NeonSgemmKernel.h @@ -0,0 +1,973 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: NeonSgemmKernel.h + * Description: Common building blocks for NEON SGEMM kernel functions. + *****************************************************************************/ + +#include +#include + +#ifndef INCLUDE_NEON_SGEMM_KERNEL +#define INCLUDE_NEON_SGEMM_KERNEL + +#if __aarch64__ + +static inline void pref_c(float *dat) { + __asm__ ("prfm pstl1keep,[%0]\n\t"::"r"(dat):); +} + +#else + +static inline void pref_c(float *dat) { + __asm__ ("pld [%0]\n\t"::"r"(dat):); +} + +#define vfmaq_lane_f32(c1,a1,b1,id) vmlaq_lane_f32(c1,a1,b1,id) +#define vfma_lane_f32(c1,a1,b1,id) vmla_lane_f32(c1,a1,b1,id) +#define vmlaq_laneq0_f32(c1,a1,b1) vmlaq_lane_f32(c1,a1,vget_low_f32(b1),0) +#define vmlaq_laneq1_f32(c1,a1,b1) vmlaq_lane_f32(c1,a1,vget_low_f32(b1),1) +#define vmlaq_laneq2_f32(c1,a1,b1) vmlaq_lane_f32(c1,a1,vget_high_f32(b1),0) +#define vmlaq_laneq3_f32(c1,a1,b1) vmlaq_lane_f32(c1,a1,vget_high_f32(b1),1) +#define vfmaq_laneq_f32(c1,a1,b1,laneid) vmlaq_laneq##laneid##_f32(c1,a1,b1) +#define vmla_laneq0_f32(c1,a1,b1) vmla_lane_f32(c1,a1,vget_low_f32(b1),0) +#define vmla_laneq1_f32(c1,a1,b1) vmla_lane_f32(c1,a1,vget_low_f32(b1),1) +#define vmla_laneq2_f32(c1,a1,b1) vmla_lane_f32(c1,a1,vget_high_f32(b1),0) +#define vmla_laneq3_f32(c1,a1,b1) vmla_lane_f32(c1,a1,vget_high_f32(b1),1) +#define vfma_laneq_f32(c1,a1,b1,laneid) vmla_laneq##laneid##_f32(c1,a1,b1) +#define vfma_n_f32(c1,a1,b1) vmla_n_f32(c1,a1,b1) +#define vfmaq_n_f32(c1,a1,b1) vmlaq_n_f32(c1,a1,b1) +#define vfma_f32(c1,a1,b1) vmla_f32(c1,a1,b1) +#define vfmaq_f32(c1,a1,b1) vmlaq_f32(c1,a1,b1) + +#endif + +#define NEON_SGEMM_KERNEL_M1N1 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + float32x2_t ad01, bd01;\ + float32x2_t cd01 = vdup_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + ad01 = vld1_f32(a_ptr); a_ptr += 2;\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + for (; k_left > 3; k_left-=2) {\ + cd01 = vfma_f32(cd01, ad01, bd01);\ + ad01 = vld1_f32(a_ptr); a_ptr += 2;\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if (k_left > 1) {\ + cd01 = vfma_f32(cd01, ad01, bd01); k_left -= 2;\ + }\ + float cs01 = vget_lane_f32(cd01, 0) + vget_lane_f32(cd01, 1);\ + if (k_left > 0) {\ + cs01 += (*a_ptr) * (*b_ptr1); a_ptr++;\ + } + +#define NEON_SGEMM_SAVE_M1N1 \ + cs01 += beta * (*c_ptr);\ + *c_ptr = cs01; + +#define NEON_SGEMM_KERNEL_M2N1_UNIT(a_ptr1, b_ptr1) \ + float32x2_t ad01, ad02, bd01, cd01, cd02;\ + cd01 = cd02 = vdup_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + ad01 = vld1_f32(a_ptr1); ad02 = vld1_f32(a_ptr1 + 2); a_ptr1 += 4;\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0); ad01 = vld1_f32(a_ptr1);\ + cd02 = vfma_lane_f32(cd02, ad02, bd01, 1); ad02 = vld1_f32(a_ptr1 + 2);\ + a_ptr1 += 4; bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if(k_left > 1) {\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0);\ + cd02 = vfma_lane_f32(cd02, ad02, bd01, 1); k_left -= 2;\ + }\ + cd01 = vadd_f32(cd01, cd02);\ + if(k_left > 0) {\ + ad01 = vld1_f32(a_ptr1); a_ptr1 += 2;\ + cd01 = vfma_n_f32(cd01, ad01, *b_ptr1); b_ptr1++;\ + } + +#define NEON_SGEMM_KERNEL_M2N1 \ + const float *b_ptr1 = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M2N1_UNIT(a_ptr, b_ptr1) + +#define NEON_SGEMM_KERNEL_M1N2 \ + const float *b_ptr1 = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M2N1_UNIT(b_ptr1, a_ptr) + +#define NEON_SGEMM_SAVE_M2N1 \ + float32x2_t ct1 = vld1_f32(c_ptr);\ + cd01 = vfma_n_f32(cd01, ct1, beta);\ + vst1_f32(c_ptr, cd01); + +#define NEON_SGEMM_SAVE_M1N2_UNIT(cd01) \ + c_tmp[0] = c_tmp[0] * beta + vget_lane_f32(cd01, 0);\ + c_tmp[ldc] = c_tmp[ldc] * beta + vget_lane_f32(cd01, 1);\ + c_tmp += ldc * 2; + +#define NEON_SGEMM_SAVE_M1N2 float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M1N2_UNIT(cd01) + +#define NEON_SGEMM_KERNEL_M2N2 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + float32x2_t ad01, ad02, bd01, bd02;\ + float32x2_t cd01, cd02, cd03, cd04;\ + cd01 = cd02 = cd03 = cd04 = vdup_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + ad01 = vld1_f32(a_ptr); ad02 = vld1_f32(a_ptr + 2); a_ptr += 4;\ + bd01 = vld1_f32(b_ptr1); bd02 = vld1_f32(b_ptr1 + 2); b_ptr1 += 4;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0);\ + cd02 = vfma_lane_f32(cd02, ad01, bd01, 1);\ + ad01 = vld1_f32(a_ptr); bd01 = vld1_f32(b_ptr1);\ + cd03 = vfma_lane_f32(cd03, ad02, bd02, 0);\ + cd04 = vfma_lane_f32(cd04, ad02, bd02, 1);\ + ad02 = vld1_f32(a_ptr + 2); a_ptr += 4;\ + bd02 = vld1_f32(b_ptr1 + 2); b_ptr1 += 4;\ + }\ + if (k_left > 1) {\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0);\ + cd02 = vfma_lane_f32(cd02, ad01, bd01, 1);\ + cd03 = vfma_lane_f32(cd03, ad02, bd02, 0);\ + cd04 = vfma_lane_f32(cd04, ad02, bd02, 1); k_left -= 2;\ + }\ + cd01 = vadd_f32(cd01, cd03);\ + cd02 = vadd_f32(cd02, cd04);\ + if (k_left > 0) {\ + ad01 = vld1_f32(a_ptr); a_ptr += 2;\ + bd01 = vld1_f32(b_ptr1);\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0);\ + cd02 = vfma_lane_f32(cd02, ad01, bd01, 1);\ + } + +#define NEON_SGEMM_SAVE_M2N2_UNIT(cd01, cd02) \ + ct1 = vld1_f32(c_tmp);\ + ct2 = vld1_f32(c_tmp + ldc);\ + cd01 = vfma_n_f32(cd01, ct1, beta);\ + cd02 = vfma_n_f32(cd02, ct2, beta);\ + vst1_f32(c_tmp, cd01);\ + vst1_f32(c_tmp + ldc, cd02); c_tmp += ldc * 2; + +#define NEON_SGEMM_SAVE_M2N2 \ + float *c_tmp = c_ptr;\ + float32x2_t ct1, ct2; NEON_SGEMM_SAVE_M2N2_UNIT(cd01, cd02) + +#define NEON_SGEMM_KERNEL_M4N1_UNIT(a_ptr1, b_ptr1) \ + uint32_t k_left = K;\ + float32x4_t aq01, aq02, cq01, cq02;\ + float32x2_t bd01;\ + cq01 = cq02 = vdupq_n_f32(0.0f);\ + if (k_left > 1) {\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4); a_ptr1 += 8;\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0); aq01 = vld1q_f32(a_ptr1);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 1); aq02 = vld1q_f32(a_ptr1 + 4);\ + a_ptr1 += 8; bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if (k_left > 1) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 1);\ + k_left -= 2;\ + }\ + cq01 = vaddq_f32(cq01, cq02);\ + if (k_left > 0) {\ + aq01 = vld1q_f32(a_ptr1); a_ptr1 += 4;\ + cq01 = vfmaq_n_f32(cq01, aq01, *b_ptr1); b_ptr1++;\ + } + +#define NEON_SGEMM_KERNEL_M4N1 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M4N1_UNIT(a_ptr, b_ptr1) + +#define NEON_SGEMM_KERNEL_M1N4 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M4N1_UNIT(b_ptr1, a_ptr) + +#define NEON_SGEMM_SAVE_M4N1 \ + float32x4_t ct1 = vld1q_f32(c_ptr);\ + cq01 = vfmaq_n_f32(cq01, ct1, beta);\ + vst1q_f32(c_ptr, cq01); + +#define NEON_SGEMM_SAVE_M1N4_UNIT(cq01) \ + c_tmp[0] = c_tmp[0] * beta + vgetq_lane_f32(cq01, 0);\ + c_tmp[ldc] = c_tmp[ldc] * beta + vgetq_lane_f32(cq01, 1);\ + c_tmp += ldc * 2;\ + c_tmp[0] = c_tmp[0] * beta + vgetq_lane_f32(cq01, 2);\ + c_tmp[ldc] = c_tmp[ldc] * beta + vgetq_lane_f32(cq01, 3);\ + c_tmp += ldc * 2; + +#define NEON_SGEMM_SAVE_M1N4 \ + float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M1N4_UNIT(cq01) + +#define NEON_SGEMM_KERNEL_M4N2_UNIT(a_ptr1, b_ptr1) \ + float32x4_t aq01, aq02, cq01, cq02, cq03, cq04;\ + float32x2_t bd01, bd02;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4); a_ptr1 += 8;\ + bd01 = vld1_f32(b_ptr1); bd02 = vld1_f32(b_ptr1 + 2); b_ptr1 += 4;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f32(cq02, aq01, bd01, 1);\ + aq01 = vld1q_f32(a_ptr1); bd01 = vld1_f32(b_ptr1);\ + cq03 = vfmaq_lane_f32(cq03, aq02, bd02, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq02, bd02, 1);\ + aq02 = vld1q_f32(a_ptr1 + 4); a_ptr1 += 8;\ + bd02 = vld1_f32(b_ptr1 + 2); b_ptr1 += 4;\ + }\ + if (k_left > 1) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f32(cq02, aq01, bd01, 1);\ + cq03 = vfmaq_lane_f32(cq03, aq02, bd02, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq02, bd02, 1); k_left -= 2;\ + }\ + cq01 = vaddq_f32(cq01, cq03);\ + cq02 = vaddq_f32(cq02, cq04);\ + if (k_left > 0) {\ + aq01 = vld1q_f32(a_ptr1); a_ptr1 += 4;\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f32(cq02, aq01, bd01, 1);\ + } + +#define NEON_SGEMM_KERNEL_M4N2 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M4N2_UNIT(a_ptr, b_ptr1) + +#define NEON_SGEMM_KERNEL_M2N4 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M4N2_UNIT(b_ptr1, a_ptr) + +#define NEON_SGEMM_SAVE_M4N2_UNIT(cq01, cq02) \ + ct1 = vld1q_f32(c_tmp); ct2 = vld1q_f32(c_tmp + ldc);\ + cq01 = vfmaq_n_f32(cq01, ct1, beta);\ + cq02 = vfmaq_n_f32(cq02, ct2, beta);\ + vst1q_f32(c_tmp, cq01);\ + vst1q_f32(c_tmp + ldc, cq02);\ + c_tmp += ldc * 2; + +#define NEON_SGEMM_SAVE_M4N2 \ + float32x4_t ct1, ct2;\ + float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M4N2_UNIT(cq01, cq02) + +#define NEON_SGEMM_SAVE_M2N4_UNIT(cq01, cq02) \ + ctd1 = vzipq_f32(cq01, cq02);\ + cd1 = vget_low_f32(ctd1.val[0]);\ + cd2 = vget_high_f32(ctd1.val[0]);\ + cd3 = vget_low_f32(ctd1.val[1]);\ + cd4 = vget_high_f32(ctd1.val[1]);\ + cd1 = vfma_n_f32(cd1, vld1_f32(c_tmp), beta);\ + cd2 = vfma_n_f32(cd2, vld1_f32(c_tmp + ldc), beta);\ + cd3 = vfma_n_f32(cd3, vld1_f32(c_tmp + ldc * 2), beta);\ + cd4 = vfma_n_f32(cd4, vld1_f32(c_tmp + ldc * 3), beta);\ + vst1_f32(c_tmp, cd1);\ + vst1_f32(c_tmp + ldc, cd2);\ + vst1_f32(c_tmp + ldc * 2, cd3);\ + vst1_f32(c_tmp + ldc * 3, cd4);\ + c_tmp += ldc * 4; + +#define NEON_SGEMM_SAVE_M2N4 \ + float32x4x2_t ctd1; float32x2_t cd1, cd2, cd3, cd4;\ + float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M2N4_UNIT(cq01, cq02) + +#define NEON_SGEMM_KERNEL_M4N4 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + float32x4_t aq01, cq01, cq02, cq03, cq04;\ + float32x2_t bd01, bd02;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 0) {\ + aq01 = vld1q_f32(a_ptr); a_ptr += 4;\ + bd01 = vld1_f32(b_ptr1); bd02 = vld1_f32(b_ptr1 + 2); b_ptr1 += 4;\ + }\ + for (; k_left > 1; k_left--) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f32(cq02, aq01, bd01, 1); bd01 = vld1_f32(b_ptr1);\ + cq03 = vfmaq_lane_f32(cq03, aq01, bd02, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq01, bd02, 1); bd02 = vld1_f32(b_ptr1 + 2);\ + b_ptr1 += 4; aq01 = vld1q_f32(a_ptr); a_ptr += 4;\ + }\ + if (k_left > 0) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f32(cq02, aq01, bd01, 1);\ + cq03 = vfmaq_lane_f32(cq03, aq01, bd02, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq01, bd02, 1);\ + } + +#define NEON_SGEMM_SAVE_M4N4_UNIT(cq01, cq02, cq03, cq04) \ + ct1 = vld1q_f32(c_tmp);\ + ct2 = vld1q_f32(c_tmp + ldc);\ + ct3 = vld1q_f32(c_tmp + ldc * 2);\ + ct4 = vld1q_f32(c_tmp + ldc * 3);\ + cq01 = vfmaq_n_f32(cq01, ct1, beta);\ + cq02 = vfmaq_n_f32(cq02, ct2, beta);\ + cq03 = vfmaq_n_f32(cq03, ct3, beta);\ + cq04 = vfmaq_n_f32(cq04, ct4, beta);\ + vst1q_f32(c_tmp, cq01);\ + vst1q_f32(c_tmp + ldc, cq02);\ + vst1q_f32(c_tmp + ldc * 2, cq03);\ + vst1q_f32(c_tmp + ldc * 3, cq04); c_tmp += ldc * 4; + +#define NEON_SGEMM_SAVE_M4N4 \ + float32x4_t ct1, ct2, ct3, ct4;\ + float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M4N4_UNIT(cq01, cq02, cq03, cq04) + +#define NEON_SGEMM_KERNEL_M8N1_UNIT(a_ptr1, b_ptr1) \ + float32x4_t aq01, aq02, aq03, aq04, cq01, cq02, cq03, cq04;\ + float32x2_t bd01;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4);\ + aq03 = vld1q_f32(a_ptr1 + 8); aq04 = vld1q_f32(a_ptr1 + 12); a_ptr1 += 16;\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0); aq01 = vld1q_f32(a_ptr1);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 0); aq02 = vld1q_f32(a_ptr1 + 4);\ + cq03 = vfmaq_lane_f32(cq03, aq03, bd01, 1); aq03 = vld1q_f32(a_ptr1 + 8);\ + cq04 = vfmaq_lane_f32(cq04, aq04, bd01, 1); aq04 = vld1q_f32(a_ptr1 + 12);\ + a_ptr1 += 16; bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if (k_left > 1) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 0);\ + cq03 = vfmaq_lane_f32(cq03, aq03, bd01, 1);\ + cq04 = vfmaq_lane_f32(cq04, aq04, bd01, 1); k_left -= 2;\ + }\ + cq01 = vaddq_f32(cq01, cq03);\ + cq02 = vaddq_f32(cq02, cq04);\ + if (k_left > 0) {\ + float bs1 = *b_ptr1; b_ptr1++;\ + aq01 = vld1q_f32(a_ptr1);\ + aq02 = vld1q_f32(a_ptr1 + 4); a_ptr1 += 8;\ + cq01 = vfmaq_n_f32(cq01, aq01, bs1);\ + cq02 = vfmaq_n_f32(cq02, aq02, bs1);\ + } + +#define NEON_SGEMM_KERNEL_M8N1 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M8N1_UNIT(a_ptr, b_ptr1) + +#define NEON_SGEMM_KERNEL_M1N8 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M8N1_UNIT(b_ptr1, a_ptr) + +#define NEON_SGEMM_SAVE_M8N1 \ + float32x4_t ct1, ct2;\ + ct1 = vld1q_f32(c_ptr); ct2 = vld1q_f32(c_ptr + 4);\ + cq01 = vfmaq_n_f32(cq01, ct1, beta);\ + cq02 = vfmaq_n_f32(cq02, ct2, beta);\ + vst1q_f32(c_ptr, cq01);\ + vst1q_f32(c_ptr + 4, cq02); + +#define NEON_SGEMM_SAVE_M1N8 \ + float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M1N4_UNIT(cq01) NEON_SGEMM_SAVE_M1N4_UNIT(cq02) + +#define NEON_SGEMM_KERNEL_M8N2_UNIT(a_ptr1, b_ptr1) \ + float32x4_t aq01, aq02, cq01, cq02, cq03, cq04;\ + float32x2_t bd01;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 0) {\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4); a_ptr1 += 8;\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + for (; k_left > 1; k_left--) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq03 = vfmaq_lane_f32(cq03, aq01, bd01, 1); aq01 = vld1q_f32(a_ptr1);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq02, bd01, 1); aq02 = vld1q_f32(a_ptr1 + 4);\ + a_ptr1 += 8; bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if (k_left > 0) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq03 = vfmaq_lane_f32(cq03, aq01, bd01, 1);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq02, bd01, 1);\ + } + +#define NEON_SGEMM_KERNEL_M8N2 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M8N2_UNIT(a_ptr, b_ptr1) + +#define NEON_SGEMM_KERNEL_M2N8 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M8N2_UNIT(b_ptr1, a_ptr) + +#define NEON_SGEMM_SAVE_M8N2_UNIT(cq01, cq02, cq03, cq04) \ + ct1 = vld1q_f32(c_tmp);\ + ct2 = vld1q_f32(c_tmp + 4);\ + ct3 = vld1q_f32(c_tmp + ldc);\ + ct4 = vld1q_f32(c_tmp + ldc + 4);\ + cq01 = vfmaq_n_f32(cq01, ct1, beta);\ + cq02 = vfmaq_n_f32(cq02, ct2, beta);\ + cq03 = vfmaq_n_f32(cq03, ct3, beta);\ + cq04 = vfmaq_n_f32(cq04, ct4, beta);\ + vst1q_f32(c_tmp, cq01);\ + vst1q_f32(c_tmp + 4, cq02);\ + vst1q_f32(c_tmp + ldc, cq03);\ + vst1q_f32(c_tmp + ldc + 4, cq04); c_tmp += 2 * ldc; + +#define NEON_SGEMM_SAVE_M8N2 \ + float *c_tmp = c_ptr;\ + float32x4_t ct1, ct2, ct3, ct4;\ + NEON_SGEMM_SAVE_M8N2_UNIT(cq01, cq02, cq03, cq04) + +#define NEON_SGEMM_SAVE_M2N8 \ + float32x4x2_t ctd1; float32x2_t cd1, cd2, cd3, cd4;\ + float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M2N4_UNIT(cq01, cq03) NEON_SGEMM_SAVE_M2N4_UNIT(cq02, cq04) + +#define NEON_SGEMM_KERNEL_M8N4_UNIT(a_ptr1, b_ptr1) \ + float32x4_t aq01, aq02, bq01, cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = cq07 = cq08 = vdupq_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 0) {\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4); a_ptr1 += 8;\ + bq01 = vld1q_f32(b_ptr1); b_ptr1 += 4;\ + }\ + for (; k_left > 1; k_left--) {\ + cq01 = vfmaq_laneq_f32(cq01, aq01, bq01, 0);\ + cq03 = vfmaq_laneq_f32(cq03, aq01, bq01, 1);\ + cq05 = vfmaq_laneq_f32(cq05, aq01, bq01, 2);\ + cq07 = vfmaq_laneq_f32(cq07, aq01, bq01, 3);\ + aq01 = vld1q_f32(a_ptr1);\ + cq02 = vfmaq_laneq_f32(cq02, aq02, bq01, 0);\ + cq04 = vfmaq_laneq_f32(cq04, aq02, bq01, 1);\ + cq06 = vfmaq_laneq_f32(cq06, aq02, bq01, 2);\ + cq08 = vfmaq_laneq_f32(cq08, aq02, bq01, 3);\ + aq02 = vld1q_f32(a_ptr1 + 4); a_ptr1 += 8;\ + bq01 = vld1q_f32(b_ptr1); b_ptr1 += 4;\ + }\ + if (k_left > 0) {\ + cq01 = vfmaq_laneq_f32(cq01, aq01, bq01, 0);\ + cq03 = vfmaq_laneq_f32(cq03, aq01, bq01, 1);\ + cq05 = vfmaq_laneq_f32(cq05, aq01, bq01, 2);\ + cq07 = vfmaq_laneq_f32(cq07, aq01, bq01, 3);\ + cq02 = vfmaq_laneq_f32(cq02, aq02, bq01, 0);\ + cq04 = vfmaq_laneq_f32(cq04, aq02, bq01, 1);\ + cq06 = vfmaq_laneq_f32(cq06, aq02, bq01, 2);\ + cq08 = vfmaq_laneq_f32(cq08, aq02, bq01, 3);\ + } + +#define NEON_SGEMM_KERNEL_M8N4 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M8N4_UNIT(a_ptr, b_ptr1) + +#define NEON_SGEMM_KERNEL_M4N8 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + NEON_SGEMM_KERNEL_M8N4_UNIT(b_ptr1, a_ptr) + +#define NEON_SGEMM_SAVE_M8N4 \ + float *c_tmp = c_ptr;\ + float32x4_t ct1, ct2, ct3, ct4;\ + NEON_SGEMM_SAVE_M8N2_UNIT(cq01, cq02, cq03, cq04)\ + NEON_SGEMM_SAVE_M8N2_UNIT(cq05, cq06, cq07, cq08) + +#define TRANSPOSE_4x4(cq1, cq2, cq3, cq4) {\ + float32x4x2_t ctd1 = vzipq_f32(cq1, cq2);\ + float32x4x2_t ctd2 = vzipq_f32(cq3, cq4);\ + cq1 = vcombine_f32(vget_low_f32(ctd1.val[0]), vget_low_f32(ctd2.val[0]));\ + cq2 = vcombine_f32(vget_high_f32(ctd1.val[0]), vget_high_f32(ctd2.val[0]));\ + cq3 = vcombine_f32(vget_low_f32(ctd1.val[1]), vget_low_f32(ctd2.val[1]));\ + cq4 = vcombine_f32(vget_high_f32(ctd1.val[1]), vget_high_f32(ctd2.val[1]));\ +} + +#define NEON_SGEMM_SAVE_M4N8 \ + float *c_tmp = c_ptr;\ + float32x4_t ct1, ct2, ct3, ct4;\ + TRANSPOSE_4x4(cq01, cq03, cq05, cq07)\ + TRANSPOSE_4x4(cq02, cq04, cq06, cq08)\ + NEON_SGEMM_SAVE_M4N4_UNIT(cq01, cq03, cq05, cq07)\ + NEON_SGEMM_SAVE_M4N4_UNIT(cq02, cq04, cq06, cq08) + +#define NEON_SGEMM_KERNEL_M8N8 \ + const float *a_ptr = a_head;\ + const float *b_ptr1 = b_head;\ + float32x4_t aq01, aq02, bq01, bq02;\ + float32x4_t cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + float32x4_t cq09, cq10, cq11, cq12, cq13, cq14, cq15, cq16;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = cq07 = cq08 = vdupq_n_f32(0.0f);\ + cq09 = cq10 = cq11 = cq12 = cq13 = cq14 = cq15 = cq16 = vdupq_n_f32(0.0f);\ + uint32_t k_left = K;\ + if (k_left > 0) {\ + aq01 = vld1q_f32(a_ptr); aq02 = vld1q_f32(a_ptr + 4); a_ptr += 8;\ + bq01 = vld1q_f32(b_ptr1); bq02 = vld1q_f32(b_ptr1 + 4); b_ptr1 += 8;\ + }\ + for (; k_left > 1; k_left--) {\ + cq01 = vfmaq_laneq_f32(cq01, aq01, bq01, 0);\ + cq03 = vfmaq_laneq_f32(cq03, aq01, bq01, 1);\ + cq05 = vfmaq_laneq_f32(cq05, aq01, bq01, 2);\ + cq07 = vfmaq_laneq_f32(cq07, aq01, bq01, 3);\ + cq02 = vfmaq_laneq_f32(cq02, aq02, bq01, 0);\ + cq04 = vfmaq_laneq_f32(cq04, aq02, bq01, 1);\ + cq06 = vfmaq_laneq_f32(cq06, aq02, bq01, 2);\ + cq08 = vfmaq_laneq_f32(cq08, aq02, bq01, 3);\ + bq01 = vld1q_f32(b_ptr1);\ + cq09 = vfmaq_laneq_f32(cq09, aq01, bq02, 0);\ + cq11 = vfmaq_laneq_f32(cq11, aq01, bq02, 1);\ + cq13 = vfmaq_laneq_f32(cq13, aq01, bq02, 2);\ + cq15 = vfmaq_laneq_f32(cq15, aq01, bq02, 3);\ + aq01 = vld1q_f32(a_ptr);\ + cq10 = vfmaq_laneq_f32(cq10, aq02, bq02, 0);\ + cq12 = vfmaq_laneq_f32(cq12, aq02, bq02, 1);\ + cq14 = vfmaq_laneq_f32(cq14, aq02, bq02, 2);\ + cq16 = vfmaq_laneq_f32(cq16, aq02, bq02, 3);\ + aq02 = vld1q_f32(a_ptr + 4); a_ptr += 8;\ + bq02 = vld1q_f32(b_ptr1 + 4); b_ptr1 += 8;\ + }\ + if (k_left > 0) {\ + cq01 = vfmaq_laneq_f32(cq01, aq01, bq01, 0);\ + cq03 = vfmaq_laneq_f32(cq03, aq01, bq01, 1);\ + cq05 = vfmaq_laneq_f32(cq05, aq01, bq01, 2);\ + cq07 = vfmaq_laneq_f32(cq07, aq01, bq01, 3);\ + cq02 = vfmaq_laneq_f32(cq02, aq02, bq01, 0);\ + cq04 = vfmaq_laneq_f32(cq04, aq02, bq01, 1);\ + cq06 = vfmaq_laneq_f32(cq06, aq02, bq01, 2);\ + cq08 = vfmaq_laneq_f32(cq08, aq02, bq01, 3);\ + cq09 = vfmaq_laneq_f32(cq09, aq01, bq02, 0);\ + cq11 = vfmaq_laneq_f32(cq11, aq01, bq02, 1);\ + cq13 = vfmaq_laneq_f32(cq13, aq01, bq02, 2);\ + cq15 = vfmaq_laneq_f32(cq15, aq01, bq02, 3);\ + cq10 = vfmaq_laneq_f32(cq10, aq02, bq02, 0);\ + cq12 = vfmaq_laneq_f32(cq12, aq02, bq02, 1);\ + cq14 = vfmaq_laneq_f32(cq14, aq02, bq02, 2);\ + cq16 = vfmaq_laneq_f32(cq16, aq02, bq02, 3);\ + } + +#define NEON_SGEMM_SAVE_M8N8 \ + float *c_tmp = c_ptr;\ + float32x4_t ct1, ct2, ct3, ct4;\ + NEON_SGEMM_SAVE_M8N2_UNIT(cq01, cq02, cq03, cq04)\ + NEON_SGEMM_SAVE_M8N2_UNIT(cq05, cq06, cq07, cq08)\ + NEON_SGEMM_SAVE_M8N2_UNIT(cq09, cq10, cq11, cq12)\ + NEON_SGEMM_SAVE_M8N2_UNIT(cq13, cq14, cq15, cq16) + +#define NEON_SGEMM_KERNEL_M6N1_UNIT(a_ptr1, b_ptr1) \ + uint32_t k_left = K;\ + float32x2_t cd01, cd02, cd03, cd04, cd05, cd06;\ + float32x2_t ad01, ad02, ad03, ad04, ad05, ad06, bd01;\ + cd01 = cd02 = cd03 = cd04 = cd05 = cd06 = vdup_n_f32(0.0f);\ + if (k_left > 1) {\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + ad01 = vld1_f32(a_ptr1); ad02 = vld1_f32(a_ptr1 + 2);\ + ad03 = vld1_f32(a_ptr1 + 4); ad04 = vld1_f32(a_ptr1 + 6);\ + ad05 = vld1_f32(a_ptr1 + 8); ad06 = vld1_f32(a_ptr1 + 10);\ + a_ptr1 += 12;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0); ad01 = vld1_f32(a_ptr1);\ + cd02 = vfma_lane_f32(cd02, ad02, bd01, 0); ad02 = vld1_f32(a_ptr1 + 2);\ + cd03 = vfma_lane_f32(cd03, ad03, bd01, 0); ad03 = vld1_f32(a_ptr1 + 4);\ + cd04 = vfma_lane_f32(cd04, ad04, bd01, 1); ad04 = vld1_f32(a_ptr1 + 6);\ + cd05 = vfma_lane_f32(cd05, ad05, bd01, 1); ad05 = vld1_f32(a_ptr1 + 8);\ + cd06 = vfma_lane_f32(cd06, ad06, bd01, 1); ad06 = vld1_f32(a_ptr1 + 10);\ + a_ptr1 += 12; bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if (k_left > 1) {\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0);\ + cd02 = vfma_lane_f32(cd02, ad02, bd01, 0);\ + cd03 = vfma_lane_f32(cd03, ad03, bd01, 0);\ + cd04 = vfma_lane_f32(cd04, ad04, bd01, 1);\ + cd05 = vfma_lane_f32(cd05, ad05, bd01, 1);\ + cd06 = vfma_lane_f32(cd06, ad06, bd01, 1); k_left -= 2;\ + }\ + cd01 = vadd_f32(cd01, cd04);\ + cd02 = vadd_f32(cd02, cd05);\ + cd03 = vadd_f32(cd03, cd06);\ + if (k_left > 0) {\ + float bs1 = *b_ptr1; b_ptr1++;\ + ad01 = vld1_f32(a_ptr1);\ + ad02 = vld1_f32(a_ptr1 + 2);\ + ad03 = vld1_f32(a_ptr1 + 4); a_ptr1 += 6;\ + cd01 = vfma_n_f32(cd01, ad01, bs1);\ + cd02 = vfma_n_f32(cd02, ad02, bs1);\ + cd03 = vfma_n_f32(cd03, ad03, bs1);\ + } + +#define NEON_SGEMM_KERNEL_M6N1 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M6N1_UNIT(a_ptr, b_ptr) + +#define NEON_SGEMM_KERNEL_M1N6 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M6N1_UNIT(b_ptr, a_ptr) + +#define NEON_SGEMM_SAVE_M6N1 \ + float32x2_t ct1, ct2, ct3;\ + ct1 = vld1_f32(c_ptr); ct2 = vld1_f32(c_ptr + 2); ct3 = vld1_f32(c_ptr + 4);\ + cd01 = vfma_n_f32(cd01, ct1, beta);\ + cd02 = vfma_n_f32(cd02, ct2, beta);\ + cd03 = vfma_n_f32(cd03, ct3, beta);\ + vst1_f32(c_ptr, cd01); vst1_f32(c_ptr + 2, cd02); vst1_f32(c_ptr + 4, cd03); + +#define NEON_SGEMM_SAVE_M1N6 \ + float *c_tmp = c_ptr;\ + NEON_SGEMM_SAVE_M1N2_UNIT(cd01) NEON_SGEMM_SAVE_M1N2_UNIT(cd02) NEON_SGEMM_SAVE_M1N2_UNIT(cd03) + +#define NEON_SGEMM_KERNEL_M6N2_UNIT(a_ptr1, b_ptr1) \ + uint32_t k_left = K;\ + float32x2_t cd01, cd02, cd03, cd04, cd05, cd06;\ + float32x2_t ad01, ad02, ad03, bd01;\ + cd01 = cd02 = cd03 = cd04 = cd05 = cd06 = vdup_n_f32(0.0f);\ + if (k_left > 0) {\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + ad01 = vld1_f32(a_ptr1); ad02 = vld1_f32(a_ptr1 + 2);\ + ad03 = vld1_f32(a_ptr1 + 4); a_ptr1 += 6;\ + }\ + for (; k_left > 1; k_left--) {\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0);\ + cd04 = vfma_lane_f32(cd04, ad01, bd01, 1); ad01 = vld1_f32(a_ptr1);\ + cd02 = vfma_lane_f32(cd02, ad02, bd01, 0);\ + cd05 = vfma_lane_f32(cd05, ad02, bd01, 1); ad02 = vld1_f32(a_ptr1 + 2);\ + cd03 = vfma_lane_f32(cd03, ad03, bd01, 0);\ + cd06 = vfma_lane_f32(cd06, ad03, bd01, 1); ad03 = vld1_f32(a_ptr1 + 4);\ + a_ptr1 += 6; bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if (k_left > 0) {\ + cd01 = vfma_lane_f32(cd01, ad01, bd01, 0);\ + cd04 = vfma_lane_f32(cd04, ad01, bd01, 1);\ + cd02 = vfma_lane_f32(cd02, ad02, bd01, 0);\ + cd05 = vfma_lane_f32(cd05, ad02, bd01, 1);\ + cd03 = vfma_lane_f32(cd03, ad03, bd01, 0);\ + cd06 = vfma_lane_f32(cd06, ad03, bd01, 1);\ + } + +#define NEON_SGEMM_KERNEL_M6N2 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M6N2_UNIT(a_ptr, b_ptr) + +#define NEON_SGEMM_KERNEL_M2N6 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M6N2_UNIT(b_ptr, a_ptr) + +#define TRANS_M2N2(cd01, cd02) \ + cdd1 = vzip_f32(cd01, cd02); cd01 = cdd1.val[0]; cd02 = cdd1.val[1]; + +#define NEON_SGEMM_SAVE_M6N2 \ + float32x2_t ct1, ct2; float *c_tmp = c_ptr;\ + NEON_SGEMM_SAVE_M2N2_UNIT(cd01, cd04) c_tmp = c_ptr + 2;\ + NEON_SGEMM_SAVE_M2N2_UNIT(cd02, cd05) c_tmp = c_ptr + 4;\ + NEON_SGEMM_SAVE_M2N2_UNIT(cd03, cd06) + +#define NEON_SGEMM_SAVE_M2N6 \ + float32x2x2_t cdd1; float32x2_t ct1, ct2; float *c_tmp = c_ptr;\ + TRANS_M2N2(cd01, cd04) TRANS_M2N2(cd02, cd05) TRANS_M2N2(cd03, cd06)\ + NEON_SGEMM_SAVE_M2N2_UNIT(cd01, cd04)\ + NEON_SGEMM_SAVE_M2N2_UNIT(cd02, cd05)\ + NEON_SGEMM_SAVE_M2N2_UNIT(cd03, cd06) + +#define NEON_SGEMM_KERNEL_M6N4_UNIT(a_ptr1, b_ptr1) \ + uint32_t k_left = K;\ + float32x4_t cq01, cq02, cq03, cq04, cq05, cq06;\ + float32x4_t bq01; float32x2_t ad01, ad02, ad03;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = vdupq_n_f32(0.0f);\ + if (k_left > 0) {\ + bq01 = vld1q_f32(b_ptr1); b_ptr1 += 4;\ + ad01 = vld1_f32(a_ptr1); ad02 = vld1_f32(a_ptr1 + 2);\ + ad03 = vld1_f32(a_ptr1 + 4); a_ptr1 += 6;\ + }\ + for (; k_left > 1; k_left--) {\ + cq01 = vfmaq_lane_f32(cq01, bq01, ad01, 0);\ + cq02 = vfmaq_lane_f32(cq02, bq01, ad01, 1); ad01 = vld1_f32(a_ptr1);\ + cq03 = vfmaq_lane_f32(cq03, bq01, ad02, 0);\ + cq04 = vfmaq_lane_f32(cq04, bq01, ad02, 1); ad02 = vld1_f32(a_ptr1 + 2);\ + cq05 = vfmaq_lane_f32(cq05, bq01, ad03, 0);\ + cq06 = vfmaq_lane_f32(cq06, bq01, ad03, 1); ad03 = vld1_f32(a_ptr1 + 4);\ + a_ptr1 += 6; bq01 = vld1q_f32(b_ptr1); b_ptr1 += 4;\ + }\ + if (k_left > 0) {\ + cq01 = vfmaq_lane_f32(cq01, bq01, ad01, 0);\ + cq02 = vfmaq_lane_f32(cq02, bq01, ad01, 1);\ + cq03 = vfmaq_lane_f32(cq03, bq01, ad02, 0);\ + cq04 = vfmaq_lane_f32(cq04, bq01, ad02, 1);\ + cq05 = vfmaq_lane_f32(cq05, bq01, ad03, 0);\ + cq06 = vfmaq_lane_f32(cq06, bq01, ad03, 1);\ + } + +#define NEON_SGEMM_KERNEL_M6N4 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M6N4_UNIT(a_ptr, b_ptr) + +#define NEON_SGEMM_KERNEL_M4N6 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M6N4_UNIT(b_ptr, a_ptr) + +#define NEON_SGEMM_SAVE_M6N4 \ + float32x4x2_t ctd1; float32x2_t cd1, cd2, cd3, cd4; float *c_tmp = c_ptr;\ + NEON_SGEMM_SAVE_M2N4_UNIT(cq01, cq02) c_tmp = c_ptr + 2;\ + NEON_SGEMM_SAVE_M2N4_UNIT(cq03, cq04) c_tmp = c_ptr + 4;\ + NEON_SGEMM_SAVE_M2N4_UNIT(cq05, cq06) + +#define NEON_SGEMM_SAVE_M4N6 \ + float32x4_t ct1, ct2; float *c_tmp = c_ptr;\ + NEON_SGEMM_SAVE_M4N2_UNIT(cq01, cq02) NEON_SGEMM_SAVE_M4N2_UNIT(cq03, cq04)\ + NEON_SGEMM_SAVE_M4N2_UNIT(cq05, cq06) + +#define NEON_SGEMM_KERNEL_M12N1_UNIT(a_ptr1, b_ptr1) \ + uint32_t k_left = K;\ + float32x4_t cq01, cq02, cq03, cq04, cq05, cq06;\ + float32x4_t aq01, aq02, aq03, aq04, aq05, aq06;\ + float32x2_t bd01;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = vdupq_n_f32(0.0f);\ + if (k_left > 1) {\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4);\ + aq03 = vld1q_f32(a_ptr1 + 8); aq04 = vld1q_f32(a_ptr1 + 12);\ + aq05 = vld1q_f32(a_ptr1 + 16); aq06 = vld1q_f32(a_ptr1 + 20);\ + a_ptr1 += 24;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0); aq01 = vld1q_f32(a_ptr1);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 0); aq02 = vld1q_f32(a_ptr1 + 4);\ + cq03 = vfmaq_lane_f32(cq03, aq03, bd01, 0); aq03 = vld1q_f32(a_ptr1 + 8);\ + cq04 = vfmaq_lane_f32(cq04, aq04, bd01, 1); aq04 = vld1q_f32(a_ptr1 + 12);\ + cq05 = vfmaq_lane_f32(cq05, aq05, bd01, 1); aq05 = vld1q_f32(a_ptr1 + 16);\ + cq06 = vfmaq_lane_f32(cq06, aq06, bd01, 1); aq06 = vld1q_f32(a_ptr1 + 20);\ + a_ptr1 += 24; bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if (k_left > 1) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 0);\ + cq03 = vfmaq_lane_f32(cq03, aq03, bd01, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq04, bd01, 1);\ + cq05 = vfmaq_lane_f32(cq05, aq05, bd01, 1);\ + cq06 = vfmaq_lane_f32(cq06, aq06, bd01, 1); k_left -= 2;\ + }\ + cq01 = vaddq_f32(cq01, cq04);\ + cq02 = vaddq_f32(cq02, cq05);\ + cq03 = vaddq_f32(cq03, cq06);\ + if (k_left > 0) {\ + float bs1 = *b_ptr1; b_ptr1++;\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4);\ + aq03 = vld1q_f32(a_ptr1 + 8); a_ptr1 += 12;\ + cq01 = vfmaq_n_f32(cq01, aq01, bs1);\ + cq02 = vfmaq_n_f32(cq02, aq02, bs1);\ + cq03 = vfmaq_n_f32(cq03, aq03, bs1);\ + } + +#define NEON_SGEMM_KERNEL_M12N1 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M12N1_UNIT(a_ptr, b_ptr) + +#define NEON_SGEMM_KERNEL_M1N12 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M12N1_UNIT(b_ptr, a_ptr) + +#define NEON_SGEMM_SAVE_M12N1 \ + float32x4_t ct1, ct2, ct3;\ + ct1 = vld1q_f32(c_ptr); ct2 = vld1q_f32(c_ptr + 4); ct3 = vld1q_f32(c_ptr + 8);\ + cq01 = vfmaq_n_f32(cq01, ct1, beta);\ + cq02 = vfmaq_n_f32(cq02, ct2, beta);\ + cq03 = vfmaq_n_f32(cq03, ct3, beta);\ + vst1q_f32(c_ptr, cq01); vst1q_f32(c_ptr + 4, cq02); vst1q_f32(c_ptr + 8, cq03); + +#define NEON_SGEMM_SAVE_M1N12 \ + float *c_tmp = c_ptr;\ + NEON_SGEMM_SAVE_M1N4_UNIT(cq01) NEON_SGEMM_SAVE_M1N4_UNIT(cq02) NEON_SGEMM_SAVE_M1N4_UNIT(cq03) + +#define NEON_SGEMM_KERNEL_M12N2_UNIT(a_ptr1, b_ptr1) \ + uint32_t k_left = K;\ + float32x4_t cq01, cq02, cq03, cq04, cq05, cq06;\ + float32x4_t aq01, aq02, aq03; float32x2_t bd01;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = vdupq_n_f32(0.0f);\ + if (k_left > 0) {\ + bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4);\ + aq03 = vld1q_f32(a_ptr1 + 8); a_ptr1 += 12;\ + }\ + for (; k_left > 1; k_left--) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq01, bd01, 1); aq01 = vld1q_f32(a_ptr1);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 0);\ + cq05 = vfmaq_lane_f32(cq05, aq02, bd01, 1); aq02 = vld1q_f32(a_ptr1 + 4);\ + cq03 = vfmaq_lane_f32(cq03, aq03, bd01, 0);\ + cq06 = vfmaq_lane_f32(cq06, aq03, bd01, 1); aq03 = vld1q_f32(a_ptr1 + 8);\ + a_ptr1 += 12; bd01 = vld1_f32(b_ptr1); b_ptr1 += 2;\ + }\ + if (k_left > 0) {\ + cq01 = vfmaq_lane_f32(cq01, aq01, bd01, 0);\ + cq04 = vfmaq_lane_f32(cq04, aq01, bd01, 1);\ + cq02 = vfmaq_lane_f32(cq02, aq02, bd01, 0);\ + cq05 = vfmaq_lane_f32(cq05, aq02, bd01, 1);\ + cq03 = vfmaq_lane_f32(cq03, aq03, bd01, 0);\ + cq06 = vfmaq_lane_f32(cq06, aq03, bd01, 1);\ + } + +#define NEON_SGEMM_KERNEL_M12N2 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M12N2_UNIT(a_ptr, b_ptr) + +#define NEON_SGEMM_KERNEL_M2N12 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M12N2_UNIT(b_ptr, a_ptr) + +#define NEON_SGEMM_SAVE_M12N2 \ + float32x4_t ct1, ct2; float *c_tmp = c_ptr;\ + NEON_SGEMM_SAVE_M4N2_UNIT(cq01, cq04) c_tmp = c_ptr + 4;\ + NEON_SGEMM_SAVE_M4N2_UNIT(cq02, cq05) c_tmp = c_ptr + 8;\ + NEON_SGEMM_SAVE_M4N2_UNIT(cq03, cq06) + +#define NEON_SGEMM_SAVE_M2N12 \ + float32x4x2_t ctd1; float32x2_t cd1, cd2, cd3, cd4;\ + float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M2N4_UNIT(cq01, cq04)\ + NEON_SGEMM_SAVE_M2N4_UNIT(cq02, cq05) NEON_SGEMM_SAVE_M2N4_UNIT(cq03, cq06) + +#define NEON_SGEMM_KERNEL_M12N4_UNIT(a_ptr1, b_ptr1) \ + uint32_t k_left = K;\ + float32x4_t cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + float32x4_t cq09, cq10, cq11, cq12, aq01, aq02, aq03, bq01;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = vdupq_n_f32(0.0f);\ + cq07 = cq08 = cq09 = cq10 = cq11 = cq12 = vdupq_n_f32(0.0f);\ + if (k_left > 0) {\ + bq01 = vld1q_f32(b_ptr1); b_ptr1 += 4;\ + aq01 = vld1q_f32(a_ptr1); aq02 = vld1q_f32(a_ptr1 + 4);\ + aq03 = vld1q_f32(a_ptr1 + 8); a_ptr1 += 12;\ + }\ + for (; k_left > 1; k_left--) {\ + cq01 = vfmaq_laneq_f32(cq01, aq01, bq01, 0);\ + cq04 = vfmaq_laneq_f32(cq04, aq01, bq01, 1);\ + cq07 = vfmaq_laneq_f32(cq07, aq01, bq01, 2);\ + cq10 = vfmaq_laneq_f32(cq10, aq01, bq01, 3);\ + aq01 = vld1q_f32(a_ptr1);\ + cq02 = vfmaq_laneq_f32(cq02, aq02, bq01, 0);\ + cq05 = vfmaq_laneq_f32(cq05, aq02, bq01, 1);\ + cq08 = vfmaq_laneq_f32(cq08, aq02, bq01, 2);\ + cq11 = vfmaq_laneq_f32(cq11, aq02, bq01, 3);\ + aq02 = vld1q_f32(a_ptr1 + 4);\ + cq03 = vfmaq_laneq_f32(cq03, aq03, bq01, 0);\ + cq06 = vfmaq_laneq_f32(cq06, aq03, bq01, 1);\ + cq09 = vfmaq_laneq_f32(cq09, aq03, bq01, 2);\ + cq12 = vfmaq_laneq_f32(cq12, aq03, bq01, 3);\ + aq03 = vld1q_f32(a_ptr1 + 8); a_ptr1 += 12;\ + bq01 = vld1q_f32(b_ptr1); b_ptr1 += 4;\ + }\ + if (k_left > 0) {\ + cq01 = vfmaq_laneq_f32(cq01, aq01, bq01, 0);\ + cq04 = vfmaq_laneq_f32(cq04, aq01, bq01, 1);\ + cq07 = vfmaq_laneq_f32(cq07, aq01, bq01, 2);\ + cq10 = vfmaq_laneq_f32(cq10, aq01, bq01, 3);\ + cq02 = vfmaq_laneq_f32(cq02, aq02, bq01, 0);\ + cq05 = vfmaq_laneq_f32(cq05, aq02, bq01, 1);\ + cq08 = vfmaq_laneq_f32(cq08, aq02, bq01, 2);\ + cq11 = vfmaq_laneq_f32(cq11, aq02, bq01, 3);\ + cq03 = vfmaq_laneq_f32(cq03, aq03, bq01, 0);\ + cq06 = vfmaq_laneq_f32(cq06, aq03, bq01, 1);\ + cq09 = vfmaq_laneq_f32(cq09, aq03, bq01, 2);\ + cq12 = vfmaq_laneq_f32(cq12, aq03, bq01, 3);\ + } + +#define NEON_SGEMM_KERNEL_M12N4 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M12N4_UNIT(a_ptr, b_ptr) + +#define NEON_SGEMM_KERNEL_M4N12 \ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + NEON_SGEMM_KERNEL_M12N4_UNIT(b_ptr, a_ptr) + +#define NEON_SGEMM_SAVE_M12N4 \ + float32x4_t ct1, ct2, ct3, ct4;\ + float *c_tmp = c_ptr; NEON_SGEMM_SAVE_M4N4_UNIT(cq01, cq04, cq07, cq10)\ + c_tmp = c_ptr + 4; NEON_SGEMM_SAVE_M4N4_UNIT(cq02, cq05, cq08, cq11)\ + c_tmp = c_ptr + 8; NEON_SGEMM_SAVE_M4N4_UNIT(cq03, cq06, cq09, cq12) + +#define NEON_SGEMM_SAVE_M4N12 \ + float *c_tmp = c_ptr;\ + float32x4_t ct1, ct2, ct3, ct4;\ + TRANSPOSE_4x4(cq01, cq04, cq07, cq10)\ + TRANSPOSE_4x4(cq02, cq05, cq08, cq11)\ + TRANSPOSE_4x4(cq03, cq06, cq09, cq12)\ + NEON_SGEMM_SAVE_M4N4_UNIT(cq01, cq04, cq07, cq10)\ + NEON_SGEMM_SAVE_M4N4_UNIT(cq02, cq05, cq08, cq11)\ + NEON_SGEMM_SAVE_M4N4_UNIT(cq03, cq06, cq09, cq12) + +#define NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(mdim, ndim) \ +static inline void inline_dualpack_gemm_afloat_bfloat_cfloat_m##mdim##_n##ndim(\ + const float *a_head, const float *b_head, float *c_ptr,\ + uint32_t K, float beta, uint32_t ldc) {\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim\ + NEON_SGEMM_SAVE_M##mdim##N##ndim\ +} + +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 1) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 2) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 1) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 2) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 4) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 4) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 1) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 2) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 4) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 8) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 8) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 8) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 1) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 2) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 4) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 8) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 6) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 6) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 6) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(6, 1) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(6, 2) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(6, 4) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 12) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 12) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 12) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 1) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 2) +NEON_SGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 4) + +#endif + diff --git a/include/arm_neon/NeonSum.h b/include/arm_neon/NeonSum.h new file mode 100644 index 0000000..1794109 --- /dev/null +++ b/include/arm_neon/NeonSum.h @@ -0,0 +1,394 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/***************************************************************************** + * File: NeonSum.h + * Description: Sum functions based on ARM NEON instructions. + ****************************************************************************/ + +#include +#include +#include + +#ifndef INCLUDE_NEON_SUM +#define INCLUDE_NEON_SUM + +static inline int16x8_t vaddl_low_s8(int8x16_t v1, int8x16_t v2) { + return vaddl_s8(vget_low_s8(v1), vget_low_s8(v2)); +} + +static inline int32x4_t vaddl_low_s16(int16x8_t v1, int16x8_t v2) { + return vaddl_s16(vget_low_s16(v1), vget_low_s16(v2)); +} + +static inline uint16x8_t vaddl_low_u8(uint8x16_t v1, uint8x16_t v2) { + return vaddl_u8(vget_low_u8(v1), vget_low_u8(v2)); +} + +static inline uint32x4_t vaddl_low_u16(uint16x8_t v1, uint16x8_t v2) { + return vaddl_u16(vget_low_u16(v1), vget_low_u16(v2)); +} + +#if !__aarch64__ +static inline int16x8_t vaddl_high_s8(int8x16_t v1, int8x16_t v2) { + return vaddl_s8(vget_high_s8(v1), vget_high_s8(v2)); +} + +static inline int32x4_t vaddl_high_s16(int16x8_t v1, int16x8_t v2) { + return vaddl_s16(vget_high_s16(v1), vget_high_s16(v2)); +} + +static inline uint16x8_t vaddl_high_u8(uint8x16_t v1, uint8x16_t v2) { + return vaddl_u8(vget_high_u8(v1), vget_high_u8(v2)); +} + +static inline uint32x4_t vaddl_high_u16(uint16x8_t v1, uint16x8_t v2) { + return vaddl_u16(vget_high_u16(v1), vget_high_u16(v2)); +} + +static inline int32x4_t vaddw_high_s16(int32x4_t qv, int16x8_t dv) { + return vaddw_s16(qv, vget_high_s16(dv)); +} + +static inline uint32x4_t vaddw_high_u16(uint32x4_t qv, uint16x8_t dv) { + return vaddw_u16(qv, vget_high_u16(dv)); +} +#endif + +static inline void pref_src(const void *dat) { +#if __aarch64__ + __asm__("prfm pldl1keep,[%0,#64]\n\t"::"r"(dat):); +#else + __asm__("pld [%0,#64]\n\t"::"r"(dat):); +#endif +} + +/***************************************************************************** + * Template: NEON_I8I32_SUM + * Description: Function template for NEON-based summing operation of a matrix. + * Template Parameters: sign_short: the integer sign char in the name of + * NEON intrinsics. Please use 's' for signed int + * and 'u' for unsigned int. + * sign_scalar: the string showing integer sign in the + * name of integer type. Please use "int" for + * signed int and "uint" for unsigned int. + * Function Parameters: src: the address of input matrix. + * dst: the address of output vector. + * dim1: the length of major dimension of input matrix. + * dim2: the length of minor dimension of input matrix. + * (the major dimension is the vertical one for column- + * major matrix, or the horizontal one for row-major + * matrix) + * direction: the direction of summing + * 0: sum along the minor dimension, + * output_vector_size == dim1; + * 1: sum along the major dimension, + * output_vector_size == dim2. + ****************************************************************************/ +#define NEON_I8I32_SUM(sign_short, sign_scalar) \ +void sign_short##8##sign_short##32##_sum(const sign_scalar##8_t *src,\ + sign_scalar##32_t *dst, uint32_t dim1, uint32_t dim2, uint8_t direction) {\ +\ + if (direction == 0) {/* output_size = dim1 */\ + /* first zero output */\ + const sign_scalar##32x4_t z1 = vdupq_n_##sign_short##32(0);\ + uint32_t dim1_left = dim1;\ + sign_scalar##32_t *dst1 = dst;\ + for (; dim1_left > 3; dim1_left -= 4) {\ + vst1q_##sign_short##32(dst1, z1); dst1 += 4;\ + }\ + for (; dim1_left > 0; dim1_left--) {\ + *dst1 = 0; dst1++;\ + }\ + /* then accumulate */\ + const sign_scalar##8_t *src1 = src;\ + uint32_t dim2_left = dim2;\ + for (; dim2_left > 3; dim2_left -= 4) {\ + const sign_scalar##8_t *src_l1 = src1;\ + const sign_scalar##8_t *src_l2 = src1 + dim1;\ + const sign_scalar##8_t *src_l3 = src1 + dim1 * 2;\ + const sign_scalar##8_t *src_l4 = src_l2 + dim1 * 2;\ + src1 = src_l3 + dim1 * 2;\ + sign_scalar##32_t *dst1 = dst;\ + dim1_left = dim1;\ + for (; dim1_left > 15; dim1_left -= 16) {\ + sign_scalar##8x16_t q1 = vld1q_##sign_short##8(src_l1);\ + src_l1 += 16; pref_src(src_l1);\ + sign_scalar##8x16_t q2 = vld1q_##sign_short##8(src_l2);\ + src_l2 += 16; pref_src(src_l2);\ + sign_scalar##8x16_t q3 = vld1q_##sign_short##8(src_l3);\ + src_l3 += 16; pref_src(src_l3);\ + sign_scalar##8x16_t q4 = vld1q_##sign_short##8(src_l4);\ + src_l4 += 16; pref_src(src_l4);\ + sign_scalar##16x8_t m1 = vaddl_low_##sign_short##8(q1, q2);\ + sign_scalar##16x8_t m2 = vaddl_high_##sign_short##8(q1, q2);\ + sign_scalar##16x8_t m3 = vaddl_low_##sign_short##8(q3, q4);\ + sign_scalar##16x8_t m4 = vaddl_high_##sign_short##8(q3, q4);\ + sign_scalar##32x4_t c1 = vld1q_##sign_short##32(dst1);\ + sign_scalar##32x4_t c2 = vld1q_##sign_short##32(dst1 + 4);\ + sign_scalar##32x4_t c3 = vld1q_##sign_short##32(dst1 + 8);\ + sign_scalar##32x4_t c4 = vld1q_##sign_short##32(dst1 + 12);\ + m1 = vaddq_##sign_short##16(m1, m3);\ + m2 = vaddq_##sign_short##16(m2, m4);\ + c1 = vaddw_##sign_short##16(c1, vget_low_##sign_short##16(m1));\ + c2 = vaddw_high_##sign_short##16(c2, m1);\ + c3 = vaddw_##sign_short##16(c3, vget_low_##sign_short##16(m2));\ + c4 = vaddw_high_##sign_short##16(c4, m2);\ + vst1q_##sign_short##32(dst1, c1);\ + vst1q_##sign_short##32(dst1 + 4, c2);\ + vst1q_##sign_short##32(dst1 + 8, c3);\ + vst1q_##sign_short##32(dst1 + 12, c4); dst1 += 16;\ + }\ + if (dim1_left > 7) {\ + sign_scalar##8x8_t d1 = vld1_##sign_short##8(src_l1); src_l1 += 8;\ + sign_scalar##8x8_t d2 = vld1_##sign_short##8(src_l2); src_l2 += 8;\ + sign_scalar##8x8_t d3 = vld1_##sign_short##8(src_l3); src_l3 += 8;\ + sign_scalar##8x8_t d4 = vld1_##sign_short##8(src_l4); src_l4 += 8;\ + sign_scalar##32x4_t c1 = vld1q_##sign_short##32(dst1);\ + sign_scalar##32x4_t c2 = vld1q_##sign_short##32(dst1 + 4);\ + sign_scalar##16x8_t m1 = vaddl_##sign_short##8(d1, d2);\ + sign_scalar##16x8_t m2 = vaddl_##sign_short##8(d3, d4);\ + m1 = vaddq_##sign_short##16(m1, m2);\ + c1 = vaddw_##sign_short##16(c1, vget_low_##sign_short##16(m1));\ + c2 = vaddw_high_##sign_short##16(c2, m1);\ + vst1q_##sign_short##32(dst1, c1);\ + vst1q_##sign_short##32(dst1 + 4, c2); dst1 += 8;\ + dim1_left -= 8;\ + }\ + for (; dim1_left > 0; dim1_left--) {\ + sign_scalar##16_t s1 = *src_l1++;\ + sign_scalar##16_t s2 = *src_l2++;\ + sign_scalar##16_t s3 = *src_l3++;\ + sign_scalar##16_t s4 = *src_l4++;\ + sign_scalar##32_t cs1 = *dst1;\ + s1 += s2; s3 += s4; s1 += s3; cs1 += s1;\ + *dst1 = cs1; dst1++;\ + }\ + }\ + for (; dim2_left > 0; dim2_left--) {\ + sign_scalar##32_t *dst1 = dst;\ + dim1_left = dim1;\ + for (; dim1_left > 15; dim1_left -= 16) {\ + sign_scalar##8x8_t d1 = vld1_##sign_short##8(src1);\ + sign_scalar##8x8_t d2 = vld1_##sign_short##8(src1 + 8); src1 += 16;\ + sign_scalar##16x8_t q1 = vmovl_##sign_short##8(d1);\ + sign_scalar##16x8_t q2 = vmovl_##sign_short##8(d2);\ + sign_scalar##32x4_t c1 = vld1q_##sign_short##32(dst1);\ + sign_scalar##32x4_t c2 = vld1q_##sign_short##32(dst1 + 4);\ + sign_scalar##32x4_t c3 = vld1q_##sign_short##32(dst1 + 8);\ + sign_scalar##32x4_t c4 = vld1q_##sign_short##32(dst1 + 12);\ + c1 = vaddw_##sign_short##16(c1, vget_low_##sign_short##16(q1));\ + c2 = vaddw_high_##sign_short##16(c2, q1);\ + c3 = vaddw_##sign_short##16(c3, vget_low_##sign_short##16(q2));\ + c4 = vaddw_high_##sign_short##16(c4, q2);\ + vst1q_##sign_short##32(dst1, c1);\ + vst1q_##sign_short##32(dst1 + 4, c2);\ + vst1q_##sign_short##32(dst1 + 8, c3);\ + vst1q_##sign_short##32(dst1 + 12, c4);\ + dst1 += 16;\ + }\ + if (dim1_left > 7) {\ + sign_scalar##8x8_t d1 = vld1_##sign_short##8(src1); src1 += 8;\ + sign_scalar##16x8_t q1 = vmovl_##sign_short##8(d1);\ + sign_scalar##32x4_t c1 = vld1q_##sign_short##32(dst1);\ + sign_scalar##32x4_t c2 = vld1q_##sign_short##32(dst1 + 4);\ + c1 = vaddw_##sign_short##16(c1, vget_low_##sign_short##16(q1));\ + c2 = vaddw_high_##sign_short##16(c2, q1);\ + vst1q_##sign_short##32(dst1, c1);\ + vst1q_##sign_short##32(dst1 + 4, c2);\ + dst1 += 8; dim1_left -= 8;\ + }\ + for (; dim1_left > 0; dim1_left--) {\ + *dst1 += *src1; src1++; dst1++;\ + }\ + }\ + } else {/* output size = dim2 */\ + const sign_scalar##8_t *src1 = src;\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; dim2_pos++) {\ + sign_scalar##32x4_t cq1 = vdupq_n_##sign_short##32(0);\ + uint32_t dim1_left = dim1;\ + for (; dim1_left > 15; dim1_left -= 16) {\ + sign_scalar##8x16_t aq1 = vld1q_##sign_short##8(src1); src1 += 16;\ + sign_scalar##16x8_t tq1 = vpaddlq_##sign_short##8(aq1);\ + cq1 = vpadalq_##sign_short##16(cq1, tq1);\ + }\ + sign_scalar##32x2_t cd1 = vadd_##sign_short##32(\ + vget_low_##sign_short##32(cq1), vget_high_##sign_short##32(cq1));\ + if (dim1_left > 7) {\ + sign_scalar##8x8_t ad1 = vld1_##sign_short##8(src1); src1 += 8;\ + sign_scalar##16x4_t td1 = vpaddl_##sign_short##8(ad1);\ + cd1 = vpadal_##sign_short##16(cd1, td1);\ + dim1_left -= 8;\ + }\ + sign_scalar##32_t cs1 = vget_lane_##sign_short##32(\ + vpadd_##sign_short##32(cd1, cd1), 0);\ + for (; dim1_left > 0; dim1_left--) {\ + cs1 += *src1; src1++;\ + }\ + dst[dim2_pos] = cs1;\ + }\ + }\ +} + +static inline int32x4_t vmull_low_s16(int16x8_t a, int16x8_t b) { + return vmull_s16(vget_low_s16(a), vget_low_s16(b)); +} + +static inline uint32x4_t vmull_low_u16(uint16x8_t a, uint16x8_t b) { + return vmull_u16(vget_low_u16(a), vget_low_u16(b)); +} + +#if !__aarch64__ +static inline int32x4_t vmull_high_s16(int16x8_t a, int16x8_t b) { + return vmull_s16(vget_high_s16(a), vget_high_s16(b)); +} + +static inline uint32x4_t vmull_high_u16(uint16x8_t a, uint16x8_t b) { + return vmull_u16(vget_high_u16(a), vget_high_u16(b)); +} +#endif + +#define NEON_I16_SUMSQUARE(sign_short, sign_scalar) \ +void sign_short##16_sumsquare(const sign_scalar##16_t *dat,\ + sign_scalar##32_t *sum, sign_scalar##64_t *sumsquare, uint32_t size) {\ +\ + sign_scalar##32x4_t sum1 = vdupq_n_##sign_short##32(0);\ + sign_scalar##32x4_t sum2 = vdupq_n_##sign_short##32(0);\ + sign_scalar##64x2_t sumsq1 = vdupq_n_##sign_short##64(0);\ + sign_scalar##64x2_t sumsq2 = vdupq_n_##sign_short##64(0);\ + sign_scalar##64x2_t sumsq3 = vdupq_n_##sign_short##64(0);\ + sign_scalar##64x2_t sumsq4 = vdupq_n_##sign_short##64(0);\ +\ + if (!sumsquare) {\ + if (sum) {\ + for (; size > 15; size -= 16) {\ + sign_scalar##16x8_t l1 = vld1q_##sign_short##16(dat);\ + sign_scalar##16x8_t l2 = vld1q_##sign_short##16(dat + 8); dat += 16;\ + sum1 = vpadalq_##sign_short##16(sum1, l1);\ + sum2 = vpadalq_##sign_short##16(sum2, l2);\ + }\ + sum1 = vaddq_##sign_short##32(sum1, sum2);\ + if (size > 7) {\ + sign_scalar##16x8_t l1 = vld1q_##sign_short##16(dat); dat += 8;\ + sum1 = vpadalq_##sign_short##16(sum1, l1);\ + size -= 8;\ + }\ + if (size > 3) {\ + sign_scalar##16x4_t l1 = vld1_##sign_short##16(dat); dat += 4;\ + sum1 = vaddw_##sign_short##16(sum1, l1);\ + size -= 4;\ + }\ + sign_scalar##32x2_t sumd = vadd_##sign_short##32(\ + vget_low_##sign_short##32(sum1), vget_high_##sign_short##32(sum1));\ + sign_scalar##32_t sums = vget_lane_##sign_short##32(sumd, 0) + \ + vget_lane_##sign_short##32(sumd, 1);\ + for (; size > 0; size--) {\ + sign_scalar##32_t l1 = *dat++;\ + sums += l1;\ + }\ + *sum = sums;\ + }\ + } else if (!sum) {\ + for (; size > 15; size -= 16) {\ + sign_scalar##16x8_t l1 = vld1q_##sign_short##16(dat);\ + sign_scalar##16x8_t l2 = vld1q_##sign_short##16(dat + 8); dat += 16;\ + sign_scalar##32x4_t sq1 = vmull_low_##sign_short##16(l1, l1);\ + sign_scalar##32x4_t sq2 = vmull_high_##sign_short##16(l1, l1);\ + sign_scalar##32x4_t sq3 = vmull_low_##sign_short##16(l2, l2);\ + sign_scalar##32x4_t sq4 = vmull_high_##sign_short##16(l2, l2);\ + sumsq1 = vpadalq_##sign_short##32(sumsq1, sq1);\ + sumsq2 = vpadalq_##sign_short##32(sumsq2, sq2);\ + sumsq3 = vpadalq_##sign_short##32(sumsq3, sq3);\ + sumsq4 = vpadalq_##sign_short##32(sumsq4, sq4);\ + }\ + sumsq1 = vaddq_##sign_short##64(sumsq1, sumsq3);\ + sumsq2 = vaddq_##sign_short##64(sumsq2, sumsq4);\ + if (size > 7) {\ + sign_scalar##16x8_t l1 = vld1q_##sign_short##16(dat); dat += 8;\ + sign_scalar##32x4_t sq1 = vmull_low_##sign_short##16(l1, l1);\ + sign_scalar##32x4_t sq2 = vmull_high_##sign_short##16(l1, l1);\ + sumsq1 = vpadalq_##sign_short##32(sumsq1, sq1);\ + sumsq2 = vpadalq_##sign_short##32(sumsq2, sq2);\ + size -= 8;\ + }\ + sumsq1 = vaddq_##sign_short##64(sumsq1, sumsq2);\ + if (size > 3) {\ + sign_scalar##16x4_t l1 = vld1_##sign_short##16(dat); dat += 4;\ + sign_scalar##32x4_t sq1 = vmull_##sign_short##16(l1, l1);\ + sumsq1 = vpadalq_##sign_short##32(sumsq1, sq1);\ + size -= 4;\ + }\ + sign_scalar##64_t sumsqs = vgetq_lane_##sign_short##64(sumsq1, 0) + \ + vgetq_lane_##sign_short##64(sumsq1, 1);\ + for (; size > 0; size--) {\ + sign_scalar##32_t l1 = *dat++;\ + sumsqs += l1 * l1;\ + }\ + *sumsquare = sumsqs;\ + } else {\ + for (; size > 15; size -= 16) {\ + sign_scalar##16x8_t l1 = vld1q_##sign_short##16(dat);\ + sign_scalar##16x8_t l2 = vld1q_##sign_short##16(dat + 8); dat += 16;\ + sum1 = vpadalq_##sign_short##16(sum1, l1);\ + sum2 = vpadalq_##sign_short##16(sum2, l2);\ + sign_scalar##32x4_t sq1 = vmull_low_##sign_short##16(l1, l1);\ + sign_scalar##32x4_t sq2 = vmull_high_##sign_short##16(l1, l1);\ + sign_scalar##32x4_t sq3 = vmull_low_##sign_short##16(l2, l2);\ + sign_scalar##32x4_t sq4 = vmull_high_##sign_short##16(l2, l2);\ + sumsq1 = vpadalq_##sign_short##32(sumsq1, sq1);\ + sumsq2 = vpadalq_##sign_short##32(sumsq2, sq2);\ + sumsq3 = vpadalq_##sign_short##32(sumsq3, sq3);\ + sumsq4 = vpadalq_##sign_short##32(sumsq4, sq4);\ + }\ + sum1 = vaddq_##sign_short##32(sum1, sum2);\ + sumsq1 = vaddq_##sign_short##64(sumsq1, sumsq3);\ + sumsq2 = vaddq_##sign_short##64(sumsq2, sumsq4);\ + if (size > 7) {\ + sign_scalar##16x8_t l1 = vld1q_##sign_short##16(dat); dat += 8;\ + sum1 = vpadalq_##sign_short##16(sum1, l1);\ + sign_scalar##32x4_t sq1 = vmull_low_##sign_short##16(l1, l1);\ + sign_scalar##32x4_t sq2 = vmull_high_##sign_short##16(l1, l1);\ + sumsq1 = vpadalq_##sign_short##32(sumsq1, sq1);\ + sumsq2 = vpadalq_##sign_short##32(sumsq2, sq2);\ + size -= 8;\ + }\ + sumsq1 = vaddq_##sign_short##64(sumsq1, sumsq2);\ + if (size > 3) {\ + sign_scalar##16x4_t l1 = vld1_##sign_short##16(dat); dat += 4;\ + sum1 = vaddw_##sign_short##16(sum1, l1);\ + sign_scalar##32x4_t sq1 = vmull_##sign_short##16(l1, l1);\ + sumsq1 = vpadalq_##sign_short##32(sumsq1, sq1);\ + size -= 4;\ + }\ + sign_scalar##32x2_t sumd = vadd_##sign_short##32(\ + vget_low_##sign_short##32(sum1), vget_high_##sign_short##32(sum1));\ + sign_scalar##32_t sums = vget_lane_##sign_short##32(sumd, 0) + \ + vget_lane_##sign_short##32(sumd, 1);\ + sign_scalar##64_t sumsqs = vgetq_lane_##sign_short##64(sumsq1, 0) + \ + vgetq_lane_##sign_short##64(sumsq1, 1);\ + for (; size > 0; size--) {\ + sign_scalar##32_t l1 = *dat++;\ + sums += l1;\ + sumsqs += l1 * l1;\ + }\ + *sum = sums;\ + *sumsquare = sumsqs;\ + }\ +} + +#endif + diff --git a/include/common/CommonCopy.h b/include/common/CommonCopy.h new file mode 100644 index 0000000..779dee9 --- /dev/null +++ b/include/common/CommonCopy.h @@ -0,0 +1,121 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: CommonCopy.h + * Description: Common building blocks for packing functions in GEMM operation + * Terms: "ncopy": pack from K-major source matrix + * "tcopy": pack from K-minor source matrix + *****************************************************************************/ + +#include "ExpandMacro.h" +#include + +#ifndef INCLUDE_COMMON_COPY +#define INCLUDE_COMMON_COPY + +#define NCOPY_INIT_SRC_PTR_ITEM(n, type) \ + const type *src##n = src0 + (n - 1) * ld_dim; +#define NCOPY_INIT_SRC_PTR(n, type) \ + MACRO_EXPANSION_##n(VOID_BASE, NCOPY_INIT_SRC_PTR_ITEM, type) + +#define NCOPY_COPY_1(n) \ + dst1[n - 1] = *src##n; src##n ++; +#define NCOPY_COPY(n) \ + MACRO_EXPANSION_##n(VOID_BASE, NCOPY_COPY_1) dst1 += n; + +/* a standard-C fallback for NCOPY_ */ +#define NCOPY_STD(unroll) \ + for (; dim1_count > 0; dim1_count--) {\ + NCOPY_COPY(unroll)\ + } + +/* the macro NCOPY__(unroll) is architecture dependant, + * * which should be defined in the source file including this header */ +#define NCOPY_LOOP(unroll, type, stype) \ + for (; dim2_count >= unroll; dim2_count -= unroll) {\ + uint32_t dim1_count = dim1;\ + NCOPY_INIT_SRC_PTR(unroll, type)\ + NCOPY_##type##_##stype(unroll)\ + src0 += ld_dim * unroll;\ + } +#define NCOPY(max_unroll, side, type) \ + MACRO_EXP_E_##max_unroll(NCOPY_LOOP, side, type) + +#define GENERIC_NCOPY_FUNC(gemmtype, type, stype, max_unroll) \ +void gemmtype##_##type##_##stype##_ncopy_unroll##max_unroll(\ + const type * __restrict__ src, stype * __restrict__ dst,\ + uint32_t ld_dim, uint32_t dim1, uint32_t dim2) {\ + const type *src0 = src;\ + stype *dst1 = dst;\ + uint32_t dim2_count = dim2;\ + NCOPY(max_unroll, type, stype)\ +} + + +/* this macro is the fallback for TCOPY_UNIT__ */ +#define TCOPY_UNIT_STD(src_ptr, dst_ptr, dst_offset, num_elements) \ + _Pragma("omp simd")\ + for (int i = 0; i < num_elements; ++i) \ + dst_ptr[dst_offset + i] = src_ptr[i]; + +/* the macro + * TCOPY_UNIT__(src_ptr, dst_ptr, dst_offset, num_elements) + * is architecture dependant, + * which should be defined in source file including this header */ + +#define TCOPY_LINE_1(n, unroll, type, stype) \ + TCOPY_UNIT_##type##_##stype(src##n, dst1, ((n-1)*unroll), unroll)\ + src##n += unroll; +#define TCOPY_LINES(n, unroll, type, stype) \ + MACRO_EXPANSION_##n(VOID_BASE, TCOPY_LINE_1, unroll, type, stype) + +#define TCOPY_LOOP(unroll, type, stype, read_width) \ + dst1 = dst + (dim1 - dim1_count) * dim2 + (dim2 - dim2_count) * unroll;\ + for (; dim1_count >= unroll; dim1_count -= unroll) {\ + TCOPY_LINES(read_width, unroll, type, stype)\ + dst1 += dim2 * unroll;\ + } +#define TCOPY(max_unroll, type, stype, read_width) \ + MACRO_EXPANSION_E_##max_unroll(TCOPY_LOOP, type, stype, read_width) + +#define GENERIC_TCOPY_FUNC(gemmtype, type, stype, max_unroll) \ +void gemmtype##_##type##_##stype##_tcopy_unroll##max_unroll(\ + const type * __restrict__ src, stype * __restrict__ dst,\ + uint32_t ld_dim, uint32_t dim1, uint32_t dim2) {\ + uint32_t dim2_count = dim2;\ + const type *src0 = src;\ + for (; dim2_count > 3; dim2_count -= 4) {\ + const type *src1 = src0;\ + const type *src2 = src0 + ld_dim;\ + const type *src3 = src0 + ld_dim * 2;\ + const type *src4 = src2 + ld_dim * 2;\ + stype *dst1;\ + uint32_t dim1_count = dim1;\ + TCOPY(max_unroll, type, stype, 4)\ + src0 += ld_dim * 4;\ + }\ + for (; dim2_count > 0; dim2_count--) {\ + const type *src1 = src0;\ + stype *dst1;\ + uint32_t dim1_count = dim1;\ + TCOPY(max_unroll, type, stype, 1)\ + src0 += ld_dim;\ + }\ +} + +#endif diff --git a/include/common/CommonDriver.h b/include/common/CommonDriver.h new file mode 100644 index 0000000..22bd3f2 --- /dev/null +++ b/include/common/CommonDriver.h @@ -0,0 +1,497 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: CommonDriver.h + * Description: Common driver functions for GEMM operation. A driver function + * does blocking and calls packing/kernel/skinny_kernel functions + * to perform efficient matrix multiplication + *****************************************************************************/ + +#include "ExpandMacro.h" +#include "CommonSched.h" +#include +#include +#include +#include +#ifndef EMLL_SERIAL_ONLY +#include +#endif + +#ifndef INCLUDE_COMMON_DRIVER +#define INCLUDE_COMMON_DRIVER + +#define SKINNY_FUNC_BASE(FUNCTITLE, VOIDTITLE, ...) VOIDTITLE + +#define SKINNY_FUNC_LIST_ITEM(NO, FUNCTITLE, VOIDTITLE, ...)\ + ,FUNCTITLE##NO##__VA_ARGS__ + +#define SKINNY_GEMM_FUNC_LIST(NUM, FUNCTITLE, VOIDTITLE, ...)\ + MACRO_EXP_##NUM(SKINNY_FUNC_BASE, SKINNY_FUNC_LIST_ITEM,\ + FUNCTITLE, VOIDTITLE, ##__VA_ARGS__) + +/* blocking parameters */ + +#ifndef GEMM_R_MN +#define GEMM_R_MN 1024 +#endif + +#ifndef GEMM_D_MN +#define GEMM_D_MN 192 +#endif + +#ifndef GEMM_D_K +#define GEMM_D_K 192 +#endif + +/* GEMM_D_K * GEMM_UNROLL_M or GEMM_D_K * GEMM_UNROLL_N fit in L1 cache */ +/* GEMM_D_K * GEMM_D_MN fit in L2 cache */ +/* GEMM_R_MN is the last to optimize, not crucial to performance */ + +#if GEMM_D_MN > GEMM_R_MN +#define GEMM_S_MN GEMM_D_MN +#else +#define GEMM_S_MN GEMM_R_MN +#endif + +#ifndef SCRATCH_K_CORD +#define SCRATCH_K_CORD(k) (k) +#endif + +#define SCRATCH_GEMM_D_K (SCRATCH_K_CORD(GEMM_D_K - 1) + 1) + +#define GEMM_STATIC_BUFFER(gemmtype, sbtype, satype) \ +__thread __attribute__((aligned(4096))) satype\ + blas_##gemmtype##_sa[GEMM_S_MN * SCRATCH_GEMM_D_K];\ +__thread __attribute__((aligned(4096))) sbtype\ + blas_##gemmtype##_sb[GEMM_S_MN * SCRATCH_GEMM_D_K]; + +/* serial driver function with packing both source matrices, + * loop order: N { M { K } } */ +#define GEMM_SERIAL_FUNC_LM(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_m, unroll_n) \ +static void gemmtype##_serial_lm_m##unroll_m##n##unroll_n(\ + int a_rowmajor, int b_rowmajor,\ + const atype *A, const btype *B, ctype *C,\ + uint32_t M, uint32_t N, uint32_t K, ctype beta_inp) {\ +\ + satype * const sa = blas_##gemmtype##_sa;\ + sbtype * const sb = blas_##gemmtype##_sb;\ +\ + uint32_t m_pos, n_pos, k_pos, m_inc, n_inc, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= (GEMM_D_K << 1)) k_inc = GEMM_D_K;\ + else if (k_inc > GEMM_D_K) k_inc >>= 1;\ + ctype beta = (k_pos == 0) ? beta_inp : 1;\ + for (n_pos = 0; n_pos < N; n_pos += n_inc) {\ + n_inc = N - n_pos;\ + if (n_inc >= (GEMM_R_MN << 1)) n_inc = GEMM_R_MN;\ + else if (n_inc > GEMM_R_MN) n_inc >>= 1;\ + if (b_rowmajor) {\ + gemmtype##_##btype##_##sbtype##_tcopy_unroll##unroll_n(\ + B + k_pos * N + n_pos, sb, N, n_inc, k_inc);\ + } else {\ + gemmtype##_##btype##_##sbtype##_ncopy_unroll##unroll_n(\ + B + n_pos * K + k_pos, sb, K, k_inc, n_inc);\ + }\ + for (m_pos = 0; m_pos < M; m_pos += m_inc) {\ + m_inc = M - m_pos;\ + if (m_inc > GEMM_D_MN) m_inc = GEMM_D_MN;\ + if (a_rowmajor) {\ + gemmtype##_##atype##_##satype##_ncopy_unroll##unroll_m(\ + A + m_pos * K + k_pos, sa, K, k_inc, m_inc);\ + } else {\ + gemmtype##_##atype##_##satype##_tcopy_unroll##unroll_m(\ + A + k_pos * M + m_pos, sa, M, m_inc, k_inc);\ + }\ + uint32_t scratch_k_inc = (k_inc == 0) ? 0 :\ + SCRATCH_K_CORD(k_inc - 1) + 1;\ + gemmtype##_kernel_lm_m##unroll_m##n##unroll_n(m_inc, n_inc,\ + scratch_k_inc,\ + beta, sa, sb, C + n_pos * M + m_pos, M);\ + }\ + }\ + }\ +} + +/* serial driver function with packing both source matrices, + * loop order: M { N { K } } */ +#define GEMM_SERIAL_FUNC_LN(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_m, unroll_n) \ +static void gemmtype##_serial_ln_m##unroll_m##n##unroll_n(\ + int a_rowmajor, int b_rowmajor,\ + const atype *A, const btype *B, ctype *C,\ + uint32_t M, uint32_t N, uint32_t K, ctype beta_inp) {\ +\ + satype * const sa = blas_##gemmtype##_sa;\ + sbtype * const sb = blas_##gemmtype##_sb;\ +\ + uint32_t m_pos, n_pos, k_pos, m_inc, n_inc, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= (GEMM_D_K << 1)) k_inc = GEMM_D_K;\ + else if (k_inc > GEMM_D_K) k_inc >>= 1;\ + ctype beta = (k_pos == 0) ? beta_inp : 1;\ + for (m_pos = 0; m_pos < M; m_pos += m_inc) {\ + m_inc = M - m_pos;\ + if (m_inc >= (GEMM_R_MN << 1)) m_inc = GEMM_R_MN;\ + else if (m_inc > GEMM_R_MN) m_inc >>= 1;\ + if (a_rowmajor) {\ + gemmtype##_##atype##_##satype##_ncopy_unroll##unroll_m(\ + A + m_pos * K + k_pos, sa, K, k_inc, m_inc);\ + } else {\ + gemmtype##_##atype##_##satype##_tcopy_unroll##unroll_m(\ + A + k_pos * M + m_pos, sa, M, m_inc, k_inc);\ + }\ + for (n_pos = 0; n_pos < N; n_pos += n_inc) {\ + n_inc = N - n_pos;\ + if (n_inc > GEMM_D_MN) n_inc = GEMM_D_MN;\ + if (b_rowmajor) {\ + gemmtype##_##btype##_##sbtype##_tcopy_unroll##unroll_n(\ + B + k_pos * N + n_pos, sb, N, n_inc, k_inc);\ + } else {\ + gemmtype##_##btype##_##sbtype##_ncopy_unroll##unroll_n(\ + B + n_pos * K + k_pos, sb, K, k_inc, n_inc);\ + }\ + uint32_t scratch_k_inc = (k_inc == 0) ? 0 :\ + SCRATCH_K_CORD(k_inc - 1) + 1;\ + gemmtype##_kernel_ln_m##unroll_m##n##unroll_n(m_inc, n_inc,\ + scratch_k_inc,\ + beta, sa, sb, C + n_pos * M + m_pos, M);\ + }\ + }\ + }\ +} + +/* inline function to check arguments */ +static inline bool inline_gemm_par_valid(const void *A, const void *B, + void *C, uint32_t M, uint32_t N, uint32_t K) { + + bool a_valid = A || (M == 0 || K == 0); + bool b_valid = B || (N == 0 || K == 0); + bool c_valid = C || (M == 0 || N == 0); + + return a_valid && b_valid && c_valid; +} + +/* serial GEMM driver function */ +#define GEMM_SERIAL_FUNC(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_l2, unroll_l1, skin1_maxm, skin1_maxn, skin2_maxm, skin2_maxn, ...)\ +\ +GEMM_STATIC_BUFFER(gemmtype, sbtype, satype)\ +\ +GEMM_SERIAL_FUNC_LM(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_l2, unroll_l1)\ +\ +GEMM_SERIAL_FUNC_LN(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_l1, unroll_l2)\ +\ +static void arowmajor_bskinny_void(\ + const atype *A_mat, const btype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order, ctype beta_inp) { return; }\ +\ +static void bcolmajor_askinny_void(\ + const btype *A_mat, const atype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order, ctype beta_inp) { return; }\ +\ +static void acolmajor_bskinny_void(\ + const atype *A_mat, const btype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order, ctype beta_inp) { return; }\ +\ +static void browmajor_askinny_void(\ + const btype *A_mat, const atype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order, ctype beta_inp) { return; }\ +\ +static void (* gemmtype##_bskinny1[]) (\ + const atype *A_mat, const btype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order, ctype beta_inp) = {\ + SKINNY_GEMM_FUNC_LIST(skin1_maxn,\ + gemmtype##_arowmajor_bskinny_a##atype##_b##btype##_n,\ + arowmajor_bskinny_void) };\ +\ +static void (* gemmtype##_askinny1[]) (\ + const btype *A_mat, const atype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order, ctype beta_inp) = {\ + SKINNY_GEMM_FUNC_LIST(skin1_maxm,\ + gemmtype##_arowmajor_bskinny_a##btype##_b##atype##_n,\ + bcolmajor_askinny_void) };\ +\ +static void (* gemmtype##_bskinny2[]) (\ + const atype *A_mat, const btype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order, ctype beta_inp) = {\ + SKINNY_GEMM_FUNC_LIST(skin2_maxn,\ + gemmtype##_acolmajor_bskinny_a##atype##_b##btype##_n,\ + acolmajor_bskinny_void) };\ +\ +static void (* gemmtype##_askinny2[]) (\ + const btype *A_mat, const atype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order, ctype beta_inp) = {\ + SKINNY_GEMM_FUNC_LIST(skin2_maxm,\ + gemmtype##_acolmajor_bskinny_a##btype##_b##atype##_n,\ + browmajor_askinny_void) };\ +\ +int gemmtype##_serial(int a_rowmajor, int b_rowmajor,\ + const atype *A, const btype *B, ctype *C,\ + uint32_t M, uint32_t N, uint32_t K, ctype beta_inp) {\ +\ + if (!inline_gemm_par_valid(A, B, C, M, N, K)) return 1;\ + if (0 __VA_ARGS__) return 2;\ +\ + if (K == 0) {\ + if (beta_inp != (ctype)1.0) {\ + const uint64_t MN = (uint64_t)M * (uint64_t)N;\ + for (uint64_t pos = 0; pos < MN; ++pos) {\ + C[pos] *= beta_inp;\ + }\ + }\ + return 0;\ + }\ +\ + if (N <= skin1_maxn && a_rowmajor) {\ + (* gemmtype##_bskinny1[N])(A, B, C, M, K, b_rowmajor ? 1 : 0, beta_inp);\ + return 0;\ + }\ + if (M <= skin1_maxm && !b_rowmajor) {\ + (* gemmtype##_askinny1[M])(B, A, C, N, K, a_rowmajor ? 2 : 3, beta_inp);\ + return 0;\ + }\ + if (N <= skin2_maxn && !a_rowmajor) {\ + (* gemmtype##_bskinny2[N])(A, B, C, M, K, b_rowmajor ? 1 : 0, beta_inp);\ + return 0;\ + }\ + if (M <= skin2_maxm && b_rowmajor) {\ + (* gemmtype##_askinny2[M])(B, A, C, N, K, a_rowmajor ? 2 : 3, beta_inp);\ + return 0;\ + }\ +\ + if ((N >> 1) > M) {\ + gemmtype##_serial_ln_m##unroll_l1##n##unroll_l2(\ + a_rowmajor, b_rowmajor, A, B, C, M, N, K, beta_inp);\ + } else {\ + gemmtype##_serial_lm_m##unroll_l2##n##unroll_l1(\ + a_rowmajor, b_rowmajor, A, B, C, M, N, K, beta_inp);\ + }\ + return 0;\ +} + +#ifdef EMLL_SERIAL_ONLY + +#define GEMM_PARALLEL_FUNC(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_l2, unroll_l1, skin1_maxm, skin1_maxn, skin2_maxm, skin2_maxn, ...) \ +\ +GEMM_SERIAL_FUNC(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_l2, unroll_l1, skin1_maxm, skin1_maxn, skin2_maxm, skin2_maxn, ##__VA_ARGS__)\ +int gemmtype(int a_rowmajor, int b_rowmajor,\ + const atype *A, const btype *B, ctype *C,\ + uint32_t M, uint32_t N, uint32_t K, ctype beta_inp, uint32_t num_threads) {\ +\ + return gemmtype##_serial(a_rowmajor, b_rowmajor, A, B, C, M, N, K, beta_inp);\ +} + +#else + +/* OpenMP GEMM driver function */ +#define GEMM_PARALLEL_FUNC(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_l2, unroll_l1, skin1_maxm, skin1_maxn, skin2_maxm, skin2_maxn, ...) \ +\ +GEMM_SERIAL_FUNC(gemmtype, atype, satype, btype, sbtype, ctype,\ + unroll_l2, unroll_l1, skin1_maxm, skin1_maxn, skin2_maxm, skin2_maxn, ##__VA_ARGS__)\ +\ +static void arowmajor_bskinny_void_omp(\ + const atype *A_mat, const btype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + ctype beta_inp, uint32_t num_threads) { return; }\ +\ +static void bcolmajor_askinny_void_omp(\ + const btype *A_mat, const atype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + ctype beta_inp, uint32_t num_threads) { return; }\ +\ +static void acolmajor_bskinny_void_omp(\ + const atype *A_mat, const btype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + ctype beta_inp, uint32_t num_threads) { return; }\ +\ +static void browmajor_askinny_void_omp(\ + const btype *A_mat, const atype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + ctype beta_inp, uint32_t num_threads) { return; }\ +\ +static void (* gemmtype##_bskinny1_omp[]) (\ + const atype *A_mat, const btype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + ctype beta_inp, uint32_t num_threads) = {\ + SKINNY_GEMM_FUNC_LIST(skin1_maxn,\ + gemmtype##_arowmajor_bskinny_a##atype##_b##btype##_n,\ + arowmajor_bskinny_void_omp, _omp) };\ +\ +static void (* gemmtype##_askinny1_omp[]) (\ + const btype *A_mat, const atype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + ctype beta_inp, uint32_t num_threads) = {\ + SKINNY_GEMM_FUNC_LIST(skin1_maxm,\ + gemmtype##_arowmajor_bskinny_a##btype##_b##atype##_n,\ + bcolmajor_askinny_void_omp, _omp) };\ +\ +static void (* gemmtype##_bskinny2_omp[]) (\ + const atype *A_mat, const btype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + ctype beta_inp, uint32_t num_threads) = {\ + SKINNY_GEMM_FUNC_LIST(skin2_maxn,\ + gemmtype##_acolmajor_bskinny_a##atype##_b##btype##_n,\ + acolmajor_bskinny_void_omp, _omp) };\ +\ +static void (* gemmtype##_askinny2_omp[]) (\ + const btype *A_mat, const atype *B_skin, ctype *C_skin,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + ctype beta_inp, uint32_t num_threads) = {\ + SKINNY_GEMM_FUNC_LIST(skin2_maxm,\ + gemmtype##_acolmajor_bskinny_a##btype##_b##atype##_n,\ + browmajor_askinny_void_omp, _omp) };\ +\ +int gemmtype(int a_rowmajor, int b_rowmajor,\ + const atype *A, const btype *B, ctype *C,\ + uint32_t M, uint32_t N, uint32_t K, ctype beta_inp, uint32_t num_threads) {\ +\ + uint32_t rec_threads = (uint64_t)M * (uint64_t)N * (uint64_t)K \ + / (GEMM_D_K * GEMM_D_MN * unroll_l1) + 1;\ + if (num_threads > rec_threads) num_threads = rec_threads;\ + if (!num_threads) num_threads = rec_threads;\ + uint32_t max_threads = omp_get_max_threads();\ + if (num_threads > max_threads) num_threads = max_threads;\ +\ + if (num_threads == 1 || K == 0) {\ + return gemmtype##_serial(a_rowmajor, b_rowmajor, A, B, C, M, N, K, beta_inp);\ + }\ +\ + if (!inline_gemm_par_valid(A, B, C, M, N, K)) return 1;\ + if (0 __VA_ARGS__) return 2;\ +\ + omp_set_num_threads(num_threads);\ + if (N <= skin1_maxn && a_rowmajor) {\ + (* gemmtype##_bskinny1_omp[N])(\ + A, B, C, M, K, b_rowmajor ? 1 : 0, beta_inp, num_threads);\ + return 0;\ + }\ + if (M <= skin1_maxm && !b_rowmajor) {\ + (* gemmtype##_askinny1_omp[M])(\ + B, A, C, N, K, a_rowmajor ? 2 : 3, beta_inp, num_threads);\ + return 0;\ + }\ + if (N <= skin2_maxn && !a_rowmajor) {\ + (* gemmtype##_bskinny2_omp[N])(\ + A, B, C, M, K, b_rowmajor ? 1 : 0, beta_inp, num_threads);\ + return 0;\ + }\ + if (M <= skin2_maxm && b_rowmajor) {\ + (* gemmtype##_askinny2_omp[M])(\ + B, A, C, N, K, a_rowmajor ? 2 : 3, beta_inp, num_threads);\ + return 0;\ + }\ +\ + satype * const blas_master_sa = blas_##gemmtype##_sa;\ + sbtype * const blas_master_sb = blas_##gemmtype##_sb;\ + uint32_t acopy_dim_left, bcopy_dim_left;\ + uint64_t mn_task_end;\ +\ + _Pragma("omp parallel")\ + {\ + const uint32_t tid = omp_get_thread_num();\ + uint32_t k_pos, m_pos, n_pos, k_inc, m_inc, n_inc;\ + uint32_t bcopy_dim_start, bcopy_dim_end, acopy_dim_start, acopy_dim_end;\ + uint32_t gemm_mstart, gemm_mend, gemm_nstart, gemm_nend;\ + uint64_t gemm_mn_max;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= (GEMM_D_K << 1)) k_inc = GEMM_D_K;\ + else if (k_inc > GEMM_D_K) k_inc >>= 1;\ + const uint32_t scratch_k_inc = (k_inc == 0) ? 0 :\ + SCRATCH_K_CORD(k_inc - 1) + 1;\ + const ctype beta = (k_pos == 0) ? beta_inp : 1;\ + for (n_pos = 0; n_pos < N; n_pos += n_inc) {\ + n_inc = N - n_pos;\ + if (n_inc >= (GEMM_R_MN << 1)) n_inc = GEMM_R_MN;\ + else if (n_inc > GEMM_R_MN) n_inc >>= 1;\ +\ + if (!tid) bcopy_dim_left = n_inc;\ + _Pragma("omp barrier");\ +\ + while (get_copy_task(&bcopy_dim_left, unroll_l1 << 3,\ + &bcopy_dim_start, &bcopy_dim_end)) {\ + if (b_rowmajor) {\ + gemmtype##_##btype##_##sbtype##_tcopy_unroll##unroll_l1(\ + B + k_pos * N + n_pos + bcopy_dim_start,\ + blas_master_sb + bcopy_dim_start * scratch_k_inc,\ + N, bcopy_dim_end - bcopy_dim_start, k_inc);\ + } else {\ + gemmtype##_##btype##_##sbtype##_ncopy_unroll##unroll_l1(\ + B + K * (n_pos + bcopy_dim_start) + k_pos,\ + blas_master_sb + bcopy_dim_start * scratch_k_inc,\ + K, k_inc, bcopy_dim_end - bcopy_dim_start);\ + }\ + }\ +\ + for (m_pos = 0; m_pos < M; m_pos += m_inc) {\ + m_inc = M - m_pos;\ + if (m_inc >= (GEMM_R_MN << 1)) m_inc = GEMM_R_MN;\ + else if (m_inc > GEMM_R_MN) m_inc >>= 1;\ +\ + if (!tid) acopy_dim_left = m_inc;\ + _Pragma("omp barrier");\ +\ + while (get_copy_task(&acopy_dim_left, unroll_l2 << 3,\ + &acopy_dim_start, &acopy_dim_end)) {\ + if (a_rowmajor) {\ + gemmtype##_##atype##_##satype##_ncopy_unroll##unroll_l2(\ + A + K * (m_pos + acopy_dim_start) + k_pos,\ + blas_master_sa + acopy_dim_start * scratch_k_inc,\ + K, k_inc, acopy_dim_end - acopy_dim_start);\ + } else {\ + gemmtype##_##atype##_##satype##_tcopy_unroll##unroll_l2(\ + A + M * k_pos + m_pos + acopy_dim_start,\ + blas_master_sa + acopy_dim_start * scratch_k_inc,\ + M, acopy_dim_end - acopy_dim_start, k_inc);\ + }\ + }\ +\ + if (!tid) mn_task_end = (uint64_t)n_pos << 32 | (uint64_t)m_pos;\ + gemm_mn_max = ((uint64_t)(n_pos + n_inc) << 32)\ + | (uint64_t)(m_pos + m_inc);\ + _Pragma("omp barrier");\ +\ + while (get_mn_task(&mn_task_end,\ + &gemm_mstart, &gemm_nstart, &gemm_mend, &gemm_nend,\ + ((uint64_t)unroll_l1 << 32) | ((GEMM_D_MN >> 2) / unroll_l2 * unroll_l2),\ + GEMM_D_MN, n_pos, gemm_mn_max, num_threads)) {\ +\ + gemmtype##_kernel_lm_m##unroll_l2##n##unroll_l1(\ + gemm_mend - gemm_mstart, gemm_nend - gemm_nstart, scratch_k_inc, beta,\ + blas_master_sa + (gemm_mstart - m_pos) * scratch_k_inc,\ + blas_master_sb + (gemm_nstart - n_pos) * scratch_k_inc,\ + C + M * gemm_nstart + gemm_mstart, M);\ + }\ + }\ + }\ + }\ + }\ + return 0;\ +} + +#endif + +#endif diff --git a/include/common/CommonKernel.h b/include/common/CommonKernel.h new file mode 100644 index 0000000..04dfbc3 --- /dev/null +++ b/include/common/CommonKernel.h @@ -0,0 +1,190 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: CommonKernel.h + * Description: The common skeleton of regular GEMM kernel functions with both + * source matrices packed before computation + * Extention: For supporting a new CPU arch, the following steps are needed + * in addition to including this header: + * (1) implement a collection of inline GEMM functions, each with + * fixed M & N but variable K(as input param), for the + * multiplication of column-major matrix A with row-major + * matrix B and update the results to column-major matrix C. + * A SGEMM inline function with M = 2 and N = 4 can be + * implemented like this: + * static inline void + * inline_dualpack_gemm_afloat_bfloat_cfloat_m2_n4( + * const float *a_head, const float *b_head, + * float *c_ptr, uint32_t K, float beta, uint32_t ldc) { + * float c0, c1, c2, c3, c4, c5, c6, c7; + * c0 = c1 = c2 = c3 = c4 = c5 = c6 = c7 = 0.0f; + * for (; K > 0; K--) { + * float a0 = a_head[0]; + * float a1 = a_head[1]; a_head += 2; + * float b0 = b_head[0]; + * float b1 = b_head[1]; + * float b2 = b_head[2]; + * float b3 = b_head[3]; b_head += 4; + * c0 += a0 * b0; c1 += a1 * b0; + * c2 += a0 * b1; c3 += a1 * b1; + * c4 += a0 * b2; c5 += a1 * b2; + * c6 += a0 * b3; c7 += a1 * b3; + * } + * c_ptr[0] = c_ptr[0] * beta + c0; + * c_ptr[1] = c_ptr[1] * beta + c1; + * c_ptr += ldc; + * c_ptr[0] = c_ptr[0] * beta + c2; + * c_ptr[1] = c_ptr[1] * beta + c3; + * c_ptr += ldc; + * c_ptr[0] = c_ptr[0] * beta + c4; + * c_ptr[1] = c_ptr[1] * beta + c5; + * c_ptr += ldc; + * c_ptr[0] = c_ptr[0] * beta + c6; + * c_ptr[1] = c_ptr[1] * beta + c7; + * } + * (2) Construct kernel functions with the aid of macros. + * Please refer to src/neon_armv7a/SgemmKernel.c for example. + *****************************************************************************/ + +#include "ExpandMacro.h" +#include + +#ifndef INCLUDE_COMMON_KERNEL +#define INCLUDE_COMMON_KERNEL + +/* the macros COMPUTE_MmNn are architecture dependant, + * which should be defined in the source file including this header */ + +#define COMPUTE_STD_INIT_SLICE(n_pos, mdim, ctype) \ + ctype c_reg##n_pos[mdim];\ + _Pragma("omp simd")\ + for (int j = 0; j < mdim; ++j) {\ + c_reg##n_pos[j] = 0;\ + } + +#define COMPUTE_STD_ACC_SLICE(n_pos, mdim, ndim, k_off) \ + _Pragma("omp simd")\ + for (int j = 0; j < mdim; ++j) {\ + c_reg##n_pos[j] += a_ptr[j + k_off * mdim] *\ + b_ptr[n_pos - 1 + k_off * ndim];\ + } + +#define COMPUTE_STD_SAVE_SLICE(n_pos, mdim, c_str) \ + _Pragma("omp simd")\ + for (int j = 0; j < mdim; ++j) {\ + c_str[j] = c_str[j] * beta + c_reg##n_pos[j];\ + }\ + c_str += ldc; + +#define COMPUTE_STD(mdim, ndim, atype, btype, ctype) \ +static inline void\ + inline_dualpack_gemm_a##atype##_b##btype##_c##ctype##_m##mdim##_n##ndim(\ + const atype *a_head, const btype *b_head, ctype *c_ptr,\ + uint32_t K, ctype beta, uint32_t ldc) {\ + MACRO_EXP_##ndim(VOID_BASE, COMPUTE_STD_INIT_SLICE, mdim, ctype)\ + const atype * a_ptr = a_head;\ + const btype * b_ptr = b_head;\ + uint32_t k_left = K;\ + for (; k_left > 3; k_left -= 4) {\ + MACRO_EXP_##ndim(VOID_BASE, COMPUTE_STD_ACC_SLICE, mdim, ndim, 0)\ + MACRO_EXP_##ndim(VOID_BASE, COMPUTE_STD_ACC_SLICE, mdim, ndim, 1)\ + MACRO_EXP_##ndim(VOID_BASE, COMPUTE_STD_ACC_SLICE, mdim, ndim, 2)\ + MACRO_EXP_##ndim(VOID_BASE, COMPUTE_STD_ACC_SLICE, mdim, ndim, 3)\ + a_ptr += mdim * 4;\ + b_ptr += ndim * 4;\ + }\ + for (; k_left > 0; k_left--) {\ + MACRO_EXP_##ndim(VOID_BASE, COMPUTE_STD_ACC_SLICE, mdim, ndim, 0)\ + a_ptr += mdim;\ + b_ptr += ndim;\ + }\ + ctype *c_str = c_ptr;\ + MACRO_EXP_##ndim(VOID_BASE, COMPUTE_STD_SAVE_SLICE, mdim, c_str)\ +} + +#define MICRO_COMPUTE_LM_LOOP(mdim, ndim, atype, btype, ctype) \ + for (; m_left >= mdim; m_left -= mdim) {\ + inline_dualpack_gemm_a##atype##_b##btype##_c##ctype##_m##mdim##_n##ndim(\ + a_head, b_head, c_ptr, K, beta, ldc);\ + a_head += mdim * K;\ + c_ptr += mdim;\ + } + +#define MICRO_COMPUTE_LN_LOOP(ndim, mdim, atype, btype, ctype) \ + for (; n_left >= ndim; n_left -= ndim) {\ + inline_dualpack_gemm_a##atype##_b##btype##_c##ctype##_m##mdim##_n##ndim(\ + a_head, b_head, c_ptr, K, beta, ldc);\ + b_head += ndim * K;\ + c_ptr += ndim * ldc;\ + } + +#define MICRO_COMPUTE_LM(mdim, ndim, atype, btype, ctype) \ + MACRO_EXPANSION_E_##mdim(MICRO_COMPUTE_LM_LOOP, ndim, atype, btype, ctype) + +#define MICRO_COMPUTE_LN(mdim, ndim, atype, btype, ctype) \ + MACRO_EXPANSION_E_##ndim(MICRO_COMPUTE_LN_LOOP, mdim, atype, btype, ctype) + +#define DUALPACK_COMPUTE_LM(ndim, satype, sbtype, ctype, block_m_max) \ + for (; n_left >= ndim; n_left -= ndim) {\ + const satype *a_head = sa;\ + ctype *c_ptr = c_head;\ + uint32_t m_left = M;\ + MICRO_COMPUTE_LM(block_m_max, ndim, satype, sbtype, ctype)\ + b_head += K * ndim;\ + c_head += ldc * ndim;\ + } + +#define DUALPACK_COMPUTE_LN(mdim, satype, sbtype, ctype, block_n_max) \ + for (; m_left >= mdim; m_left -= mdim) {\ + const sbtype *b_head = sb;\ + ctype *c_ptr = c_head;\ + uint32_t n_left = N;\ + MICRO_COMPUTE_LN(mdim, block_n_max, satype, sbtype, ctype)\ + a_head += K * mdim;\ + c_head += mdim;\ + } + +#define ASSEMBLE_DUALPACK_COMPUTE_LM(ndim, satype, sbtype, ctype, block_m_max) \ + MACRO_EXP_E_##ndim(DUALPACK_COMPUTE_LM, satype, sbtype, ctype, block_m_max) + +#define ASSEMBLE_DUALPACK_COMPUTE_LN(mdim, satype, sbtype, ctype, block_n_max) \ + MACRO_EXP_E_##mdim(DUALPACK_COMPUTE_LN, satype, sbtype, ctype, block_n_max) + +#define DUALPACK_KERNEL_FUNC_LM(gemmtype, satype, sbtype, ctype, block_m_max, block_n_max) \ +void gemmtype##_kernel_lm_m##block_m_max##n##block_n_max(\ + uint32_t M, uint32_t N, uint32_t K, ctype beta,\ + const satype * __restrict__ sa, const sbtype * __restrict__ sb,\ + ctype * __restrict__ C, uint32_t ldc) {\ + uint32_t n_left = N;\ + const sbtype *b_head = sb;\ + ctype *c_head = C;\ + ASSEMBLE_DUALPACK_COMPUTE_LM(block_n_max, satype, sbtype, ctype, block_m_max)\ +} + +#define DUALPACK_KERNEL_FUNC_LN(gemmtype, satype, sbtype, ctype, block_m_max, block_n_max) \ +void gemmtype##_kernel_ln_m##block_m_max##n##block_n_max(\ + uint32_t M, uint32_t N, uint32_t K, ctype beta,\ + const satype * __restrict__ sa, const sbtype * __restrict__ sb,\ + ctype * __restrict__ C, uint32_t ldc) {\ + uint32_t m_left = M;\ + const satype *a_head = sa;\ + ctype *c_head = C;\ + ASSEMBLE_DUALPACK_COMPUTE_LN(block_m_max, satype, sbtype, ctype, block_n_max)\ +} + +#endif diff --git a/include/common/CommonLayer.h b/include/common/CommonLayer.h new file mode 100644 index 0000000..cbb7b63 --- /dev/null +++ b/include/common/CommonLayer.h @@ -0,0 +1,90 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: CommonLayer.h + * Description: Function templates for operations in neural network layers + *****************************************************************************/ + +#include +#include + +#ifndef INCLUDE_COMMON_LAYER +#define INCLUDE_COMMON_LAYER + +/* function template for fully-connected layer, serial & OpenMP */ +#define SIMPLE_FC_FUNC(gemmtype, wtype, itype, otype, ...) \ +int fc##__VA_ARGS__(const itype *src, const wtype *weight,\ + const otype *bias, otype *output, int M, int K, int N,\ + int trans_src, int trans_weight, int num_threads) {\ +\ + int status = gemmtype(trans_weight, trans_src,\ + weight, src, output, N, M, K, 0, num_threads);\ + if (status) return status;\ + bias_##otype(output, 0.0, bias, 1.0, NULL, 0.0, N, M);\ + return status;\ +} + +/* function template for bias layer */ +#define STD_BIAS_FUNC(type) \ +void bias_##type(type *C, type bias_dim0, const type *bias_dim1,\ + type bias_dim1_scale, const type *bias_dim2, type bias_dim2_scale,\ + uint32_t dim1, uint32_t dim2) {\ +\ + if (!C) return;\ +\ + bool do_bias_0 = (bias_dim0 != 0);\ + bool do_bias_1 = bias_dim1 && (bias_dim1_scale != 0);\ + bool do_bias_2 = bias_dim2 && (bias_dim2_scale != 0);\ +\ + if (!do_bias_0 && !do_bias_1 && !do_bias_2) return;\ +\ + if (!do_bias_1 && (do_bias_0 || do_bias_2)) {\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + type *c_ptr = C + dim2_pos * dim1;\ + const type bs = bias_dim0 + \ + (bias_dim2 ? bias_dim2[dim2_pos] * bias_dim2_scale : 0);\ + _Pragma("omp simd")\ + for (uint32_t dim1_pos = 0; dim1_pos < dim1; ++dim1_pos) {\ + c_ptr[dim1_pos] += bs;\ + }\ + }\ + } else if (do_bias_1 && !do_bias_0 && !do_bias_2) {\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + type *c_ptr = C + dim2_pos * dim1;\ + const type *bias_ptr = bias_dim1;\ + _Pragma("omp simd")\ + for (uint32_t dim1_pos = 0; dim1_pos < dim1; ++dim1_pos) {\ + c_ptr[dim1_pos] += bias_ptr[dim1_pos] * bias_dim1_scale;\ + }\ + }\ + } else {\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + type *c_ptr = C + dim2_pos * dim1;\ + const type bs = bias_dim0 + \ + (bias_dim2 ? bias_dim2[dim2_pos] * bias_dim2_scale : 0);\ + const type *bias_ptr = bias_dim1;\ + _Pragma("omp simd")\ + for (uint32_t dim1_pos = 0; dim1_pos < dim1; ++dim1_pos) {\ + c_ptr[dim1_pos] += bs +\ + bias_ptr[dim1_pos] * bias_dim1_scale;\ + }\ + }\ + }\ +} + +#endif diff --git a/include/common/CommonQuant.h b/include/common/CommonQuant.h new file mode 100644 index 0000000..700b160 --- /dev/null +++ b/include/common/CommonQuant.h @@ -0,0 +1,311 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: CommonQuant.h + * Description: Function templates for quant/dequant/requant functions. + *****************************************************************************/ + +#include +#include + +#ifndef INCLUDE_COMMON_QUANT +#define INCLUDE_COMMON_QUANT + +/* function template for asymmetric quantization fp -> uint */ +#define QUANTIZE_ASYMMETRIC(inbits, outbits) \ +void quantize_asymmetric_f##inbits##_u##outbits(\ + const float##inbits##_t *input, uint##outbits##_t *output,\ + uint##outbits##_t *zero_point, float##inbits##_t *scale, uint32_t size,\ + float##inbits##_t input_min, float##inbits##_t input_max) {\ +\ + if (size == 0) return;\ + float##inbits##_t min, max;\ + if (input_min <= input_max) {\ + min = input_min;\ + max = input_max;\ + } else {\ + inline_find_extreme_float##inbits##_t(input, size, &min, &max);\ + }\ +\ + if (min > 0) min = 0.0;\ + if (max < 0) max = 0.0;\ + const float##inbits##_t max_diff = max - min;\ + if (max_diff == 0.0) {\ + memset(output, 0, size * (outbits >> 3));\ + *zero_point = 0;\ + *scale = 1.0;\ + return;\ + }\ +\ + const float##inbits##_t sc = max_diff *\ + (float##inbits##_t)(1.0 / (uint##outbits##_t)-1);\ + *scale = sc;\ + unsigned long long z = ((float##inbits##_t)0.0 - min) / sc\ + + (float##inbits##_t)0.5;\ + const uint##outbits##_t zp = z > (uint##outbits##_t)-1 ?\ + (uint##outbits##_t)-1 : z;\ + *zero_point = zp;\ +\ + inline_quant_asym_u##outbits##_from_f##inbits(input, output, size, zp, sc);\ +} + +/* function template for symmetric quantization fp -> int */ +#define QUANTIZE_SYMMETRIC(inbits, outbits) \ +void quantize_symmetric_f##inbits##_s##outbits(\ + const float##inbits##_t *input, int##outbits##_t *output,\ + float##inbits##_t *scale, uint32_t size,\ + float##inbits##_t input_min, float##inbits##_t input_max) {\ +\ + if (size == 0) return;\ + float##inbits##_t min, max;\ + if (input_min <= input_max) {\ + min = input_min;\ + max = input_max;\ + } else {\ + inline_find_extreme_float##inbits##_t(input, size, &min, &max);\ + }\ +\ + const uint##outbits##_t out_abs_max = (uint##outbits##_t)-1 >> 1;\ + const float##inbits##_t sc_positive = max *\ + (float##inbits##_t)(1.0 / out_abs_max);\ + const float##inbits##_t sc_negative = min *\ + (float##inbits##_t)(-1.0 / (out_abs_max + 1));\ + const float##inbits##_t sc =\ + sc_positive > sc_negative ? sc_positive : sc_negative;\ + if (sc == 0.0) {\ + memset(output, 0, size * (outbits >> 3));\ + *scale = 1.0;\ + return;\ + }\ + *scale = sc;\ +\ + inline_quant_sym_s##outbits##_from_f##inbits(input, output, size, sc);\ +} + +/****************************************************************************** + * Template: REQUANTIZE_ASYMMETRIC_MULHI + * Description: Function template of asymmetric requantization + * based on "mulhi" operations. + * Basically, the requantization can be done like this: + * (1) determine the min and max of input integers + * if min > 0, min is set to 0 + * if max < 0, max is set to 0 + * (2) calculate scaling factor Sint on input integers: + * Sint = expression_range_of_output_uint / (max - min) + * (3) calculate zero point Z of output + * Z = -min * Si + * (4) inflate input integers {Ii} to output ints {Oi} + * for i in input index range + * Oi = Ii * Sint + Z + * (5) update scaling factor S + * S /= Sint + * The steps (1) - (4) are just identical to that in asymmetric + * quantization if the inputs are floating numbers. For integers + * the situation gets a bit more complicated. The scaling factor + * Sint need to be expressed by integer(s). For precision reasons + * the exponent and mantissa part of Sint should be stored in + * individual integers Bint and Eint: + * Sint = (2^exp + mantissa) * (2^-exp) = Bint * (2^-Eint) + * Bint = 2^exp + mantissa, Eint = exp + * Also, the multiplication Ii * Sint in step (4) changes to + * (Ii * Bint) >> Eint. + * + * For integer multiplications on CPU, there're 3 types of + * operations normally: + * (1) keep all bits of the product, so the length of result + * is twice of that of input + * (2) keep only lower half of the product, with output length + * unchanged: "mullo" operation + * (3) keep only higher half of the product, with output length + * unchanged: "mulhi" operation + * Among the 3 types of operations, type (2) is useful only when + * the inputs are small enough (sum of valid bits must be no more + * than input length). For type (1), keeping the lower half of + * product is not necessary if the input numbers are big enough + * (near expression limit). So we choose type (3) for precision + * and efficiency concerns. + * Generally, we determine a left-shift number L, a mult-factor + * M, a right-shift number R and a zero-point Z according to + * the min and max of input integers. Then the following steps + * are performed on each input integer Ii: + * (1) left-shift Ii by L, which can make the min or max number + * approach the expression limit of input integer type, + * so as to minimize the precision loss in subsequent + * "mulhi" operation. + * (2) perform "mulhi" operation of shifted Ii with mult-factor + * M to yield (rounded) higher-half product Pi. The value + * of M is also near the expression limit of its type. + * (3) right (saturated rounded) shift of Pi by R. + * The right shift is needed to fit results into + * the expression range of output type. + * (4) add shifted Pi with Z to get output integer Oi. + * Parameters: fp: the type of scale to update (float/double/float16_t/...) + * inbits: the number of bits of input integral type + * outbits: the number of bits of output integral type + * accbits: must be 2 * outbits + * Dependency: the following inline functions should be implemented prior + * to the introduction of this macro: + * (1) inline_find_extreme_int_t( + * const int_t *dat, uint32_t size, + * int_t *min, int_t *max) {...} + * This function determines the minimum (write to *min) + * and maximum (write to *max) value of input dat[] which + * has "size" elements. + * (2) inline_requant_asym_u_from_s_mulhi( + * const int_t *input, uint_t *output, + * uint32_t size, uint8_t L, int_t M, + * uint_t Z) {...} + * This function performs left-shift on input by L, then + * "mulhi" it with a mult-factor M, right-shift the + * product by R and add it with Z to get output, + * just as the 4 steps shown above. + * The right-shift value R is fixed to accbits-outbits-3 + * so it is not in the parameter list. + *****************************************************************************/ +#define REQUANTIZE_ASYMMETRIC_MULHI(fp, inbits, outbits, accbits) \ +void requantize_asymmetric_##inbits##to##outbits(\ + const int##inbits##_t *input, uint##outbits##_t *output,\ + fp *scale, uint##outbits##_t *zero_point, uint32_t size,\ + int##inbits##_t input_min, int##inbits##_t input_max) {\ +\ + if (size == 0) return;\ + const fp scale_org = *scale;\ + if (scale_org == 0.0) {\ + *zero_point = 0;\ + memset(output, 0, size * sizeof(uint##outbits##_t));\ + return;\ + }\ +\ + int##inbits##_t min, max;\ + if (input_min <= input_max) {\ + min = input_min;\ + max = input_max;\ + } else {\ + inline_find_extreme_int##inbits##_t(input, size, &min, &max);\ + }\ + max = max < 0 ? 0 : max;\ + min = min > 0 ? 0 : min;\ + if (min == max) {\ + *zero_point = 0;\ + memset(output, 0, size * sizeof(uint##outbits##_t));\ + return;\ + }\ +\ + int##inbits##_t abs_max = -min;\ + if (max > abs_max) abs_max = max;\ + unsigned int max_digits = 0;\ + for (; abs_max > 0; ++max_digits) abs_max >>= 1;\ +\ + const int src_lshift = inbits - 1 - max_digits;\ + const uint##inbits##_t range = (uint##inbits##_t)max - (uint##inbits##_t)min;\ +\ + uint##accbits##_t mult_par = \ + ((uint##accbits##_t)1 << (accbits - 3)) -\ + ((uint##accbits##_t)1 << (accbits - outbits - 3));\ +\ + int##accbits##_t lsh_range = (int##accbits##_t)range << src_lshift;\ + int##inbits##_t mult_factor = mult_par / lsh_range;\ + if (mult_par % lsh_range > lsh_range >> 1) {\ + mult_factor++;\ + }\ +\ + int##accbits##_t z_mid = (int##accbits##_t)((-min) << src_lshift) * \ + (int##accbits##_t)mult_factor;\ + int##inbits##_t z_mid2 = z_mid >> (accbits - outbits - 3);\ + if (z_mid & ((int##accbits##_t)1 << (accbits - outbits - 4))) z_mid2++;\ + uint##outbits##_t zp = z_mid2 < 0 ?\ + 0 : (z_mid2 > (uint##outbits##_t)-1 ? (uint##outbits##_t)-1 : z_mid2);\ + *zero_point = zp;\ +\ + *scale = (*scale) * (fp)range * ((fp)1 / (fp)((uint##outbits##_t)-1));\ + inline_requant_asym_u##outbits##_from_s##inbits##_mulhi(input, output,\ + size, src_lshift, mult_factor, zp);\ +} + +/****************************************************************************** + * Template: REQUANTIZE_SYMMETRIC_MULHI + * Description: Function template of symmetric requantization + * based on "mulhi" operations. + * Parameters: fp: the type of scale to update (float/double/float16_t/...) + * inbits: the number of bits of input integral type + * outbits: the number of bits of output integral type + * accbits: must be 2 * outbits + * Dependency: the following inline functions should be implemented prior + * to the introduction of this macro: + * (1) inline_find_extreme_int_t( + * const int_t *dat, uint32_t size, + * int_t *min, int_t *max) {...} + * This function determines the minimum (write to *min) + * and maximum (write to *max) value of input dat[] which + * has "size" elements. + * (2) inline_requant_sym_s_from_s_mulhi( + * const int_t *input, int_t *output, + * uint32_t size, uint8_t L, int_t M) {...} + * This function performs left-shift on input by L, then + * "mulhi" it with a mult-factor M, finally right-shift + * the product by R to get output, + * The right-shift value R is fixed to accbits-outbits-2 + * so it is not in the parameter list. + *****************************************************************************/ +#define REQUANTIZE_SYMMETRIC_MULHI(fp, inbits, outbits, accbits) \ +void requantize_symmetric_##inbits##to##outbits(\ + const int##inbits##_t *input, int##outbits##_t *output,\ + fp *scale, uint32_t size,\ + int##inbits##_t input_min, int##inbits##_t input_max) {\ +\ + if (size == 0) return;\ + const fp scale_org = *scale;\ + if (scale_org == 0.0) {\ + memset(output, 0, size * sizeof(uint##outbits##_t));\ + return;\ + }\ +\ + int##inbits##_t min, max;\ + if (input_min <= input_max) {\ + min = input_min;\ + max = input_max;\ + } else {\ + inline_find_extreme_int##inbits##_t(input, size, &min, &max);\ + }\ + int##inbits##_t max_abs = max;\ + if (max_abs < -min) max_abs = -min;\ + if (max_abs == 0) {\ + memset(output, 0, size * sizeof(uint##outbits##_t));\ + return;\ + }\ +\ + int##inbits##_t tmp = max_abs;\ + unsigned int max_digits = 0;\ + for (; tmp > 0; ++max_digits) tmp >>= 1;\ +\ + const int src_lshift = inbits - 1 - max_digits;\ + uint##accbits##_t mult_par = \ + ((uint##accbits##_t)1 << (accbits - 3)) -\ + ((uint##accbits##_t)1 << (accbits - outbits - 2));\ + uint##accbits##_t lsh_max_abs = max_abs << src_lshift;\ + int##inbits##_t mult_factor = mult_par / lsh_max_abs;\ + if (mult_par % lsh_max_abs > lsh_max_abs >> 1) {\ + mult_factor++;\ + }\ +\ + *scale = (*scale) * (fp)max_abs * ((fp)1 / (fp)(((uint##outbits##_t)-1) >> 1));\ + inline_requant_sym_s##outbits##_from_s##inbits##_mulhi(input, output,\ + size, src_lshift, mult_factor);\ +} + +#endif diff --git a/include/common/CommonSched.h b/include/common/CommonSched.h new file mode 100644 index 0000000..4b21964 --- /dev/null +++ b/include/common/CommonSched.h @@ -0,0 +1,265 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/************************************************************************* + * File: CommonSched.h + * Description: Functions associated with task distribution and + * synchronization in parallelized calculations +*************************************************************************/ +#include + +#ifndef INCLUDE_COMMON_SCHEDULE +#define INCLUDE_COMMON_SCHEDULE + +/* The atomic compare-and-swap instructions are platform- + * specific, which need to be given elsewhere. + * If the compiler is GCC, the simplest way is to activate the following + * macro before including this file */ +#ifdef GCC_BUILTIN_SYNC +static uint32_t atomicCAS_U32(uint32_t comp, uint32_t write, uint32_t *dst) { + __atomic_compare_exchange_n(dst, &comp, write, + 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + return comp; +} + +static uint64_t atomicCAS_U64(uint64_t comp, uint64_t write, uint64_t *dst) { + __atomic_compare_exchange_n(dst, &comp, write, + 0, __ATOMIC_SEQ_CST, __ATOMIC_SEQ_CST); + return comp; +} +#endif + +#ifndef D_SCALE +#define D_SCALE 8 //dynamical scaling factor in 2D task distribution +#endif + +/****************************************************************************** + * Function: get_mn_task + * Description: Function for a running OpenMP thread to get GEMM task (m, n) + * from a (shared) task container atomically. + * Design: The task destribution in parallelized GEMM is done in (M, N) + * 2-dimension space. In parallel run, Several threads are + * deployed to compute a MxN output matrix, just like dividing + * a rectangular piece of paper with length = N and width = M. + * The MxN piece is cut into a number of tiny rectangular pieces + * that are distributed to threads. Instead of finishing all the + * cutting work before distribution, this function does cutting + * and distribution simultaneously. When a thread becomes idle + * in parallel zone, it will call this function to cut a new + * piece (task) from the remaining paper and take ownership of + * the newly-cut small piece(then work on it), or go to a barrier + * and wait till other threads finish their work when the paper + * has been used up. The size of the newly-get piece(task) is + * proportional to the area of the remaining paper. + * A typical cutting process is shown below: + * + * ____________ ______ + * | | | | + * | | first cut _____|m1 | + * M | | -----------> | n1 | + * | | m1xn1 | remain | + * |____________| |____________| + * N | + * second cut | m1xn2 + * (m must be m1) | + * V __ + * m1| | + * ____________ third cut _________| | + * | | <----------- | n1+n2 | + * |M-m1 | m1xn3 |M-m1 | + * |____________| n3==N-n1-n2 |____________| + * N (m must be m1) + * | + * fourth cut | m2xn4 + * | + * V + * _______ + * | | + * ____| | ---> ---> ---> repeat till nothing left + * |____________| + * + * Calls: atomicCAS_U64 (atomic compare and swap, make the cuts atomic) + * Input: uint64_t *task_end: the address of a long interger recording + * the start point of the next task. The + * long integer represents the coordinates + * (low-32bit = m, high-32bit = n) of the + * vortex right above the inward-pointing one + * in the remaining paper when it is a + * concave hexagon, or the upper left corner + * when the remaining is a rectangle. + * uint32_t n_pos_min: input the lower bound of N-direction. + * uint64_t m_n_pos_max: input the upper bounds + * along M and N axis + * low 32 bits: m_max + * high 32 bits: n_max. + * uint64_t m_n_task_min: the minimum task size of + * m (low 32bit) and n (high 32bit) + * uint32_t m_task_max: the maximum task size of m + * uint32_t num_threads: input the number of OpenMP threads. + * Output: uint32_t *m_start: to output the starting m of the new task. + * uint32_t *n_start: to output the starting n of the new task. + * uint32_t *m_end: to output the ending m of the new task. + * uint32_t *n_end: to output the ending n of the new task. + * Return: 0 if there's no task left, 1 if a new task has been acquired. + *****************************************************************************/ +static uint32_t get_mn_task(uint64_t *task_end, + uint32_t *m_start, uint32_t *n_start, uint32_t *m_end, uint32_t *n_end, + uint64_t m_n_task_min, uint32_t m_task_max, + uint32_t n_pos_min, uint64_t m_n_pos_max, uint32_t num_threads) { + + const uint32_t m_pos_max = m_n_pos_max & 0xFFFFFFFF; + const uint32_t n_pos_max = m_n_pos_max >> 32; + const uint32_t m_task_min_raw = m_n_task_min & 0xFFFFFFFF; + const uint32_t n_task_min_raw = m_n_task_min >> 32; + const uint32_t m_task_min = (m_task_min_raw) ? m_task_min_raw : 24; + const uint32_t n_task_min = (n_task_min_raw) ? n_task_min_raw : 8; + + if (n_pos_max <= n_pos_min) return 0; + + uint32_t mstart, nstart, mend, nend; + uint64_t task_end_read, task_end_load; + + do { + task_end_load = *task_end; + mstart = task_end_load & 0xFFFFFFFF; + nstart = task_end_load >> 32; + + /* if there is no task left, return 0 */ + if (mstart >= m_pos_max || nstart >= n_pos_max) return 0; + + /* determine how many tasks left in 2D space */ + const uint64_t mn_left = (uint64_t)(n_pos_max - n_pos_min) * + (uint64_t)(m_pos_max - mstart); + + /* determine the msize of the next task */ + /* msize should only depend on mstart, not affected by nstart */ + uint32_t msize = mn_left / (uint64_t)(num_threads * D_SCALE * n_task_min); + msize = msize / m_task_min * m_task_min; + if (msize > m_task_max) msize = m_task_max; + if (msize < m_task_min) msize = m_task_min; + if (msize > m_pos_max - mstart) msize = m_pos_max - mstart; + + /* determine the nsize of the next task */ + uint32_t n_inc = (nstart >= n_pos_min) ? nstart - n_pos_min : 0; + uint32_t nsize = (mn_left - (uint64_t)msize * (uint64_t)n_inc) / + (uint64_t)(num_threads * D_SCALE * msize); + nsize = nsize / n_task_min * n_task_min; + if (nsize < n_task_min) nsize = n_task_min; + if (nsize > n_pos_max - nstart) nsize = n_pos_max - nstart; + + nend = nstart + nsize; + mend = mstart + msize; + uint32_t nextm = mstart; + uint32_t nextn = nend; + if (nend == n_pos_max) { + nextm = mend; nextn = n_pos_min; + } + uint64_t task_end_write = ((uint64_t)nextn << 32) | (uint64_t)nextm; + task_end_read = atomicCAS_U64(task_end_load, task_end_write, task_end); + } while (task_end_read != task_end_load); + + /* write back task info */ + *m_start = mstart; + *n_start = nstart; + *m_end = mend; + *n_end = nend; + /* if a task has been successfully required, return 1 */ + return 1; +} + +/****************************************************************************** + * Function: get_copy_task + * Description: Function for a running thread to get GEMM copy task (m or n) + * from a (shared) task container atomically + * Calls: atomicCAS_U32 + * Input: uint32_t *dim_left: the address of an interger recording + * the amount of remaining work to do, + * which is shared among threads. + * uint32_t min_task: the default size of task to get. + * Output: uint32_t *dim_start: to output the starting position + * of the new task. + * uint32_t *dim_end: to output the ending position + * of the new task. + * Return: 0 if there's no task left, 1 if a new task has been acquired. + *****************************************************************************/ +static uint32_t get_copy_task(uint32_t *dim_left, uint32_t min_task, + uint32_t *dim_start, uint32_t *dim_end) { + + if (!min_task) min_task = 24; + + uint32_t dim_left_load, dim_left_read, dim_left_write, dim_get; + + do { + dim_left_load = *dim_left; + /* if no task left, return 0 */ + if (dim_left_load == 0) return 0; + + /* determine task size */ + dim_get = dim_left_load % min_task; + if (dim_get == 0) dim_get = min_task; + + dim_left_write = dim_left_load - dim_get; + dim_left_read = atomicCAS_U32(dim_left_load, dim_left_write, dim_left); + + } while (dim_left_read != dim_left_load); + + *dim_start = dim_left_write; + *dim_end = dim_left_load; + return 1; +} + +/****************************************************************************** + * Function: get_irreg_task + * Description: Function for a running thread to get 1D computing task + * from a (shared) task container atomically + * Calls: atomicCAS_U32 + * Input: uint32_t *dim_end: the address of an interger recording + * how much work has been done, + * which is shared among threads. + * uint32_t min_task: specify the default size of a task. + * uint32_t max_dim: input the amount of work to do. + * Output: uint32_t *task_start: to output the starting position + * of the new task. + * uint32_t *task_end: to output the ending position + * of the new task. + * Return: 0 if there's no task left, 1 if a new task has been acquired. + *****************************************************************************/ +static uint32_t get_irreg_task(uint32_t *dim_end, + uint32_t *task_start, uint32_t *task_end, + uint32_t min_task, uint32_t max_dim) { + + if (!min_task) min_task = 4; + uint32_t dim_end_load, dim_end_read, dim_end_write; + + do { + dim_end_load = *dim_end; + /* if no task left, return 0 */ + if (dim_end_load >= max_dim) return 0; + + dim_end_write = dim_end_load + min_task; + if (dim_end_write > max_dim) dim_end_write = max_dim; + + dim_end_read = atomicCAS_U32(dim_end_load, dim_end_write, dim_end); + } while (dim_end_read != dim_end_load); + + *task_start = dim_end_load; + *task_end = dim_end_write; + return 1; +} + +#endif + diff --git a/include/common/CommonSkinnyDot.h b/include/common/CommonSkinnyDot.h new file mode 100644 index 0000000..28e6175 --- /dev/null +++ b/include/common/CommonSkinnyDot.h @@ -0,0 +1,586 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** +* File: CommonSkinnyDot.h +* Description: Common building blocks for regular * skinny or skinny * regular +* matmul when the regular matrix is row-major in the former +* case or column-major in the latter case. These 2 kinds of matmul +* involving skinny matrices require a special efficient kernel +* different from that in regular * regular matmul. Specifically, +* The regular matrix is no longer reordered (packed) during +* calculation. Elements from the regular matrix are accessed +* sequentially and only once.The GEMM calculation is decomposed +* into sequential DOT operations.The skinny source matrix is +* accessed repeatedly in DOT operations so it is always packed +* in a scratch array. +* Extension: To support a new CPU architecture, the following tasks should +* be done in addition to including this header: +* (1) Use typedef to define _skinnydot_[a/b/c]scalar and +* _skinnydot_[a/b/c]vec[\d]. For example, when +* developing avx2 SGEMM regular*skinny kernels, the following +* lines should be added when the maximum vector length is 8 +* in K dimension: +* // scalar types in main memory +* typedef float sgemm_skinnydot_ascalar; +* typedef float sgemm_skinnydot_bscalar; +* typedef float sgemm_skinnydot_cscalar; +* // (converted) vector types in registers +* typedef float sgemm_skinnydot_avec1; +* typedef __m128 sgemm_skinnydot_avec4; +* typedef __m256 sgemm_skinnydot_avec8; +* typedef float sgemm_skinnydot_bvec1; +* typedef __m128 sgemm_skinnydot_bvec4; +* typedef __m256 sgemm_skinnydot_bvec8; +* typedef float sgemm_skinnydot_cvec1; +* typedef __m128 sgemm_skinnydot_cvec4; +* typedef __m256 sgemm_skinnydot_cvec8; +* (2) Implement inline functions for basic vector-vector +* multiply-add operations. Here are examples for +* inline functions of avx2 SGEMM with k_veclen = 8, 4 and 1. +* These functions multiplies each element in a_vec with the +* corresponding element in b_vec and add the result +* to the corresponding element in c_vec: +* GEMM_SKINNY_DOT_CALC_UNIT(sgemm, 8) { +* return _mm256_fmadd_ps(a_vec, b_vec, c_vec); +* } +* GEMM_SKINNY_DOT_CALC_UNIT(sgemm, 4) { +* return _mm_fmadd_ps(a_vec, b_vec, c_vec); +* } +* GEMM_SKINNY_DOT_CALC_UNIT(sgemm, 1) { +* return a_vec * b_vec + c_vec; +* } +* (3) Implement load and store inline functions for matrix +* a & b like this (each catagory 1 example (k_veclen = 8)): +* GEMM_SKINNY_DOT_LOADA_UNIT(sgemm, 8) { +* _mm_prefetch((char *)(a_ptr + 24), _MM_HINT_T0); +* return _mm256_loadu_ps(a_ptr); +* } +* GEMM_SKINNY_DOT_LOADB_UNIT(sgemm, 8) { +* return _mm256_loadu_ps(b_ptr); +* } +* (4) Implement inline vectorized reduction functions: +* // reduction from vec[8] to vec[4] +* GEMM_SKINNY_DOT_REDUC_UNIT(sgemm, 8, 4) { +* return _mm_add_ps(_mm256_extractf128_ps(c_vec, 0), +* _mm256_extractf128_ps(c_vec, 1)); +* } +* // reduction from vec[4] to vec[1] +* GEMM_SKINNY_DOT_REDUC_UNIT(sgemm, 4, 1) { +* __m128 z0 = _mm_setzero_ps(); +* c_vec = _mm_hadd_ps(c_vec, z0); +* c_vec = _mm_hadd_ps(c_vec, z0); +* return _mm_cvtss_f32(c_vec); +* } +* (5) Implement inline vector initialization functions. +* A function in this category returns a vector filled with +* zeros. +* GEMM_SKINNY_DOT_INITC_UNIT(sgemm, 8) { +* return _mm256_setzero_ps(); +* } +* GEMM_SKINNY_DOT_INITC_UNIT(sgemm, 4) { +* return _mm_setzero_ps(); +* } +* GEMM_SKINNY_DOT_INITC_UNIT(sgemm, 1) { +* return 0; +* } +* (5) Finally build kernel functions from inline functions +* defined above. For each kernel function only 1 line +* is needed. The following line defines regular*skinny +* kernel functions (serial and OpenMP) for the minimum +* dimension length = 2 with k_veclen = {1, 4, 8} and +* m_unroll = {1, 2, 4}: +* GEMM_SKINNY_DOT_PARALLEL_FUNC(sgemm, 2, 13, 7, 8192, +* float, float) +* The last 2 parameters in the macro are for function +* name mangling, providing the data type for regular +* and skinny matrix respectively. The last number in +* macro parameters (8192) specify the scratch size +* for skinny matrix which should be adjusted to the size +* of L1 cache. The second number (13) is the sum of all +* implemented k_veclen values. The third number (7) is +* the sum of all m_unroll values covered. +******************************************************************************/ + +#include "common/ExpandMacro.h" +#include "common/CommonSched.h" + +#include +#ifndef EMLL_SERIAL_ONLY +#include +#endif + +#ifndef INCLUDE_COMMON_SKINNY_DOT +#define INCLUDE_COMMON_SKINNY_DOT + +/* computation units basic in skinny_dot function */ +#define GEMM_SKINNY_DOT_CALC_UNIT(gemm, k_veclen) \ +static inline gemm##_skinnydot_cvec##k_veclen\ + inline_##gemm##_arowmajor_bskinny_fma_unit_m1n1k##k_veclen(\ + gemm##_skinnydot_cvec##k_veclen c_vec,\ + gemm##_skinnydot_avec##k_veclen a_vec,\ + gemm##_skinnydot_bvec##k_veclen b_vec) +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_DOT_CALC_UNIT(gemm, k_veclen) { + * gemm##_skinnydot_cvec##k_veclen ret; + * for (int i = 0; i < k_veclen; ++i) { + * ret[i] = a_vec[i] * b_vec[i] + c_vec[i]; + * } + * return ret; + * } + */ + +#define GEMM_SKINNY_DOT_LOADA_UNIT(gemm, k_veclen) \ +static inline gemm##_skinnydot_avec##k_veclen\ + inline_##gemm##_arowmajor_bskinny_loada_unit_k##k_veclen(\ + const gemm##_skinnydot_ascalar *a_ptr) +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_DOT_LOADA_UNIT(gemm, k_veclen) { + * gemm##_skinnydot_avec##k_veclen ret; + * for (int i = 0; i < k_veclen; ++i) { + * ret[i] = a_ptr[i]; + * } + * prefetch(a_ptr + pref_distance); + * return ret; + * } + */ + +#define GEMM_SKINNY_DOT_LOADB_UNIT(gemm, k_veclen) \ +static inline gemm##_skinnydot_bvec##k_veclen\ + inline_##gemm##_arowmajor_bskinny_loadb_unit_k##k_veclen(\ + const gemm##_skinnydot_bscalar *b_ptr) +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_DOT_LOADB_UNIT(gemm, k_veclen) { + * gemm##_skinnydot_bvec##k_veclen ret; + * for (int i = 0; i < k_veclen; ++i) { + * ret[i] = b_ptr[i]; + * } + * return ret; + * } + */ + +#define GEMM_SKINNY_DOT_REDUC_UNIT(gemm, old_k_vlen, new_k_vlen) \ +static inline gemm##_skinnydot_cvec##new_k_vlen\ + inline_##gemm##_arowmajor_bskinny_reduc_unit_##new_k_vlen##from##old_k_vlen(\ + gemm##_skinnydot_cvec##old_k_vlen c_vec) +/* The sum of all elements of the returned vector should be + * equal to that of the input c_vec, here's an example: + * GEMM_SKINNY_DOT_REDUC_UNIT(gemm, old_k_vlen, new_k_vlen) { + * gemm##_skinnydot_cvec##new_k_vlen ret; + * int i; + * for (i = 0; i < new_k_vlen; ++i) { + * ret[i] = c_vec[i]; + * } + * for (; i < old_k_vlen; ++i) { + * ret[i % new_k_vlen] += c_vec[i]; + * } + * return ret; + * } + */ + +#define GEMM_SKINNY_DOT_INITC_UNIT(gemm, k_veclen) \ +static inline gemm##_skinnydot_cvec##k_veclen\ + inline_##gemm##_arowmajor_bskinny_initc_unit_k##k_veclen() +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_DOT_INITC_UNIT(gemm, k_veclen) { + * gemm##_skinnydot_cvec##k_veclen ret = {0}; + * return ret; + * } + */ + +/* construct inline function from building blocks */ +#define GEMM_SKINNY_DOT_INIT_CVEC_ITEM(m_id, gemm, k_veclen, n_id) \ + gemm##_skinnydot_cvec##k_veclen c_##k_veclen##_##m_id##_##n_id =\ + inline_##gemm##_arowmajor_bskinny_initc_unit_k##k_veclen(); + +#define GEMM_SKINNY_DOT_INIT_CVEC_COL_ITEM(n_id, gemm, k_veclen, m_unroll) \ + MACRO_EXPANSION_##m_unroll(VOID_BASE, GEMM_SKINNY_DOT_INIT_CVEC_ITEM,\ + gemm, k_veclen, n_id) + +#define GEMM_SKINNY_DOT_INIT_CVEC(k_veclen, gemm, m_unroll, n_dim) \ + MACRO_EXP_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_INIT_CVEC_COL_ITEM,\ + gemm, k_veclen, m_unroll) + +#define GEMM_SKINNY_DOT_CALC_ITEM(m_id, gemm, k_veclen, n_id) \ + c_##k_veclen##_##m_id##_##n_id =\ + inline_##gemm##_arowmajor_bskinny_fma_unit_m1n1k##k_veclen(\ + c_##k_veclen##_##m_id##_##n_id, a_##k_veclen##_##m_id,\ + b_##k_veclen##_##n_id); + +#define GEMM_SKINNY_DOT_CALC_COL_ITEM_PACK(n_id, gemm, k_veclen, m_unroll) \ + const gemm##_skinnydot_bvec##k_veclen b_##k_veclen##_##n_id =\ + inline_##gemm##_arowmajor_bskinny_loadb_unit_k##k_veclen(\ + b_ptr + (n_id - 1) * k_veclen);\ + MACRO_EXPANSION_##m_unroll(VOID_BASE, GEMM_SKINNY_DOT_CALC_ITEM,\ + gemm, k_veclen, n_id) + +#define GEMM_SKINNY_DOT_LOADA_ITEM(m_id, gemm, k_veclen) \ + const gemm##_skinnydot_avec##k_veclen a_##k_veclen##_##m_id =\ + inline_##gemm##_arowmajor_bskinny_loada_unit_k##k_veclen(a_ptr##m_id);\ + a_ptr##m_id += k_veclen; + + +#define GEMM_SKINNY_DOT_CALC_LOOPITEM_PACK(k_veclen, gemm, m_unroll, n_dim) \ +for (; k_left >= k_veclen; k_left -= k_veclen) {\ + MACRO_EXP_##m_unroll(VOID_BASE, GEMM_SKINNY_DOT_LOADA_ITEM, gemm, k_veclen)\ + MACRO_EXP_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_CALC_COL_ITEM_PACK,\ + gemm, k_veclen, m_unroll)\ + b_ptr += n_dim * k_veclen;\ +} + +#define GEMM_SKINNY_DOT_REDUC_ITEM(m_id, old_kvlen, new_kvlen, gemm, n_id) \ + gemm##_skinnydot_cvec##new_kvlen c_##new_kvlen##_##m_id##_##n_id =\ + inline_##gemm##_arowmajor_bskinny_reduc_unit_##new_kvlen##from##old_kvlen(\ + c_##old_kvlen##_##m_id##_##n_id); + +#define GEMM_SKINNY_DOT_REDUC_COL_ITEM(n_id, m_unroll, gemm,\ + old_kvlen, new_kvlen) \ + MACRO_EXPANSION_##m_unroll(VOID_BASE, GEMM_SKINNY_DOT_REDUC_ITEM,\ + old_kvlen, new_kvlen, gemm, n_id) + +#define GEMM_SKINNY_DOT_REDUC_CROSS_ITEM(old_kvlen, new_kvlen, gemm,\ + m_unroll, n_dim)\ + MACRO_EXP_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_REDUC_COL_ITEM,\ + m_unroll, gemm, old_kvlen, new_kvlen) + +#define GEMM_SKINNY_DOT_INIT_APTR_ITEM(m_id, gemm) \ + const gemm##_skinnydot_ascalar *a_ptr##m_id = A + (m_id - 1) * LDK; + +#define GEMM_SKINNY_DOT_STOREC_ITEM_CC(m_id, n_id) \ + c_ptr[m_id - 1] = c_ptr[m_id - 1] * beta + c_1_##m_id##_##n_id; + +#define GEMM_SKINNY_DOT_STOREC_ITEM_CR(n_id, m_id) \ + c_ptr[n_id - 1] = c_ptr[n_id - 1] * beta + c_1_##m_id##_##n_id; + +#define GEMM_SKINNY_DOT_STOREC_CC_COL_ITEM(n_id, m_unroll) \ + MACRO_EXPANSION_##m_unroll(VOID_BASE, GEMM_SKINNY_DOT_STOREC_ITEM_CC, n_id)\ + c_ptr += LDM; + +#define GEMM_SKINNY_DOT_STOREC_CR_ROW_ITEM(m_id, n_dim) \ + MACRO_EXPANSION_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_STOREC_ITEM_CR, m_id)\ + c_ptr += n_dim; + +#define GEMM_SKINNY_DOT_INLINE_PACK_FUNC(gemm, m_unroll, n_dim, k_mask) \ +static inline void\ + inline_##gemm##_arowmajor_bskinny_m##m_unroll##n##n_dim(\ + const gemm##_skinnydot_ascalar *A, const gemm##_skinnydot_bscalar *b_ptr,\ + gemm##_skinnydot_cscalar *c_ptr, uint32_t k_left, uint32_t LDK, uint32_t LDM,\ + gemm##_skinnydot_cscalar beta, bool c_rowmajor) {\ +\ + MACRO_EXP_##m_unroll(VOID_BASE, GEMM_SKINNY_DOT_INIT_APTR_ITEM, gemm)\ + MACRO_EXPANSION_IMX_##k_mask(GEMM_SKINNY_DOT_INIT_CVEC,\ + GEMM_SKINNY_DOT_CALC_LOOPITEM_PACK,\ + GEMM_SKINNY_DOT_REDUC_CROSS_ITEM, gemm, m_unroll, n_dim)\ + if (c_rowmajor) {\ + MACRO_EXP_##m_unroll(VOID_BASE, GEMM_SKINNY_DOT_STOREC_CR_ROW_ITEM, n_dim)\ + } else {\ + MACRO_EXP_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_STOREC_CC_COL_ITEM, m_unroll)\ + }\ +} + +#define GEMM_SKINNY_DOT_INLINE_FUNC_ITEM(m_unroll, gemm, n_dim, k_mask) \ + GEMM_SKINNY_DOT_INLINE_PACK_FUNC(gemm, m_unroll, n_dim, k_mask) + +#define GEMM_SKINNY_DOT_PACKK_BC_ITEM(k_id, n_id) \ + sb_ptr[k_id - 1] = b_ptr##n_id[k_id - 1]; + +#define GEMM_SKINNY_DOT_PACKK_BC_COL_ITEM(n_id, k_veclen) \ + MACRO_EXPANSION_##k_veclen(VOID_BASE, GEMM_SKINNY_DOT_PACKK_BC_ITEM, n_id)\ + b_ptr##n_id += k_veclen; sb_ptr += k_veclen; + +#define GEMM_SKINNY_DOT_PACKK_BC_LOOP(k_veclen, n_dim) \ + for (; k_left >= k_veclen; k_left -= k_veclen) {\ + MACRO_EXP_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_PACKK_BC_COL_ITEM, k_veclen)\ + } + +#define GEMM_SKINNY_DOT_PACKK_BR_ITEM(n_id, k_id, k_veclen) \ + sb_ptr[(n_id - 1) * k_veclen + k_id - 1] = b_ptr[n_id - 1]; + +#define GEMM_SKINNY_DOT_PACKK_BR_ROW_ITEM(k_id, n_dim, k_veclen) \ + MACRO_EXPANSION_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_PACKK_BR_ITEM,\ + k_id, k_veclen)\ + b_ptr += n_dim; + +#define GEMM_SKINNY_DOT_PACKK_BR_LOOP(k_veclen, n_dim) \ + for (; k_left >= k_veclen; k_left -= k_veclen) {\ + MACRO_EXP_##k_veclen(VOID_BASE, GEMM_SKINNY_DOT_PACKK_BR_ROW_ITEM,\ + n_dim, k_veclen)\ + sb_ptr += n_dim * k_veclen;\ + } + +#define GEMM_SKINNY_DOT_PACKK_BC_INIT_BPTR_ITEM(n_id, gemm) \ + const gemm##_skinnydot_bscalar *b_ptr##n_id = b_ptr + (n_id - 1) * K; + +#define GEMM_SKINNY_DOT_INLINE_CALL_LOOP(m_unroll, gemm, n_dim) \ + if (unroll_m##m_unroll) {\ + for (; m_left >= m_unroll; m_left -= m_unroll) {\ + inline_##gemm##_arowmajor_bskinny_m##m_unroll##n##n_dim(\ + a_ptr, b_ptr, c_ptr, k_inc, K, M, beta, c_rowmajor);\ + a_ptr += K * m_unroll;\ + c_ptr += c_m_inc * m_unroll;\ + }\ + } + +#define GEMM_SKINNY_DOT_UNROLL_TEST(m_unroll, unroll_test, n_dim) \ + const bool unroll_m##m_unroll = unroll_test##_m##m_unroll##n##n_dim(M, K)\ + || (m_unroll == 1); + +#define GEMM_SKINNY_DOT_UNROLL_TEST_DEFAULT(m_unroll, n_dim) \ +static inline bool unroll_test_m##m_unroll##n##n_dim(uint32_t M, uint32_t K) {\ + return true;\ +} + +#define GEMM_SKINNY_DOT_SERIAL_FUNC_NOINCINLINE(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype, unroll_test) \ +__attribute__((aligned(4096))) static __thread gemm##_skinnydot_bscalar\ + blas_skinny_dot_b_scratch_##btype##n##n_dim[scratch_size];\ +void gemm##_arowmajor_bskinny_a##atype##_b##btype##_n##n_dim(\ + const gemm##_skinnydot_ascalar *A, const gemm##_skinnydot_bscalar *B,\ + gemm##_skinnydot_cscalar *C, uint32_t M, uint32_t K,\ + uint8_t b_c_order, gemm##_skinnydot_cscalar beta_inp) {\ +\ + if (K == 0) {\ + if (beta_inp != (gemm##_skinnydot_cscalar)1.0) {\ + uint64_t size = (uint64_t)M * n_dim;\ + for (uint64_t pos = 0; pos < size; ++pos) {\ + C[pos] *= beta_inp;\ + }\ + }\ + return;\ + }\ +\ + const bool b_rowmajor = b_c_order & 1;\ + const bool c_rowmajor = b_c_order & 2;\ + const uint32_t k_limit = ((scratch_size / n_dim) >> 5) << 5;\ + const uint32_t c_m_inc = c_rowmajor ? n_dim : 1;\ + MACRO_EXPANSION_M_##m_mask(GEMM_SKINNY_DOT_UNROLL_TEST, unroll_test, n_dim)\ +\ + uint32_t k_pos, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= (k_limit << 1)) k_inc = k_limit;\ + else if (k_inc > k_limit) k_inc >>= 1;\ +\ + const gemm##_skinnydot_cscalar beta = (k_pos == 0) ? beta_inp : 1;\ + if (n_dim == 1) {\ + const gemm##_skinnydot_ascalar *a_ptr = A + k_pos;\ + const gemm##_skinnydot_bscalar * const b_ptr = B + k_pos;\ + gemm##_skinnydot_cscalar *c_ptr = C;\ + uint32_t m_left = M;\ + MACRO_EXPANSION_M_##m_mask(GEMM_SKINNY_DOT_INLINE_CALL_LOOP, gemm, n_dim)\ + } else {\ + if (b_rowmajor) {\ + const gemm##_skinnydot_bscalar *b_ptr = B + k_pos * n_dim;\ + gemm##_skinnydot_bscalar *sb_ptr =\ + blas_skinny_dot_b_scratch_##btype##n##n_dim;\ + uint32_t k_left = k_inc;\ + MACRO_EXPANSION_M_##k_mask(GEMM_SKINNY_DOT_PACKK_BR_LOOP, n_dim)\ + } else {\ + const gemm##_skinnydot_bscalar *b_ptr = B + k_pos;\ + MACRO_EXP_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_PACKK_BC_INIT_BPTR_ITEM, gemm)\ + gemm##_skinnydot_bscalar *sb_ptr =\ + blas_skinny_dot_b_scratch_##btype##n##n_dim;\ + uint32_t k_left = k_inc;\ + MACRO_EXPANSION_M_##k_mask(GEMM_SKINNY_DOT_PACKK_BC_LOOP, n_dim)\ + }\ + const gemm##_skinnydot_ascalar *a_ptr = A + k_pos;\ + const gemm##_skinnydot_bscalar * const b_ptr =\ + blas_skinny_dot_b_scratch_##btype##n##n_dim;\ + gemm##_skinnydot_cscalar *c_ptr = C;\ + uint32_t m_left = M;\ + MACRO_EXPANSION_M_##m_mask(GEMM_SKINNY_DOT_INLINE_CALL_LOOP, gemm, n_dim)\ + }\ + }\ +} + +/****************************************************************************** + * Template: GEMM_SKINNY_DOT_SERIAL_FUNC + * Description: Construct serial dot-based "regular * skinny" GEMM function + * from the general algorithm. + * Parameters: gemm: The type of GEMM, e.g. sgemm, hgemm, u8u32gemm, ... + * n_dim: The width of skinny matrix that this function can handle. + * (Every such function can only process 1 width) + * k_mask: The sum of all supported accumulation vector width. + * For example, if inline calculation functions with + * k_veclen = 1, 4 and 8 are available, this parameter + * should be 1 + 4 + 8 = 13. Note that every k_veclen + * should be a power of 2. + * m_mask: The sum of all supported unroll factors of M. During the + * calculation of dot values, usually several rows are + * read concurrently from the regular matrix to improve + * the ratio of arith/load. But if too many rows are loaded + * at the same time, there will be no enough registers to + * hold dot values. So there's a balance. Let's say, if + * the optimal solution is to read 4 rows together in the + * bulk region and read one by one at the edge, this + * parameter can be set to 4 + 1 = 5. + * scratch_size: The size (number of elements) for the scratch + * array that holds rearranged (packed) block from + * the skinny source matrix. Because the skinny + * source is accessed repeatedly during calculations, + * it's better to rearrange it to make the access + * to its element fully sequential. This parameter + * should not exceed the capacity of level-2 cache. + * atype: The data type of regular source matrix. This parameter + * is only for naming the function properly so that it + * can be called correctly by driver. + * btype: The data type of skinny source matrix. This parameter + * is for naming the function only. + *****************************************************************************/ +#define GEMM_SKINNY_DOT_SERIAL_FUNC(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype) \ + MACRO_EXP_M_##m_mask(GEMM_SKINNY_DOT_INLINE_FUNC_ITEM, gemm, n_dim, k_mask)\ + MACRO_EXP_M_##m_mask(GEMM_SKINNY_DOT_UNROLL_TEST_DEFAULT, n_dim)\ + GEMM_SKINNY_DOT_SERIAL_FUNC_NOINCINLINE(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype, unroll_test) + +#ifdef EMLL_SERIAL_ONLY + +#define GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype, unroll_test) \ +GEMM_SKINNY_DOT_SERIAL_FUNC_NOINCINLINE(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype, unroll_test) \ +void gemm##_arowmajor_bskinny_a##atype##_b##btype##_n##n_dim##_omp(\ + const gemm##_skinnydot_ascalar *A, const gemm##_skinnydot_bscalar *B,\ + gemm##_skinnydot_cscalar *C, uint32_t M, uint32_t K,\ + uint8_t b_c_order, gemm##_skinnydot_cscalar beta_inp, uint32_t num_threads) {\ +\ + gemm##_arowmajor_bskinny_a##atype##_b##btype##_n##n_dim(A, B, C,\ + M, K, b_c_order, beta_inp);\ +} + +#else + +#define GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype, unroll_test) \ +GEMM_SKINNY_DOT_SERIAL_FUNC_NOINCINLINE(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype, unroll_test) \ +/* in ARMv7, the arguments when creating a thread is limited to a certain */\ +/* number, so some arguments need to be wrapped into a struct to pass */\ +struct gemm##_skinnydot_a##atype##_b##btype##_##n_dim##_matrix_info {\ + const gemm##_skinnydot_ascalar *m_A;\ + const gemm##_skinnydot_bscalar *m_B;\ + gemm##_skinnydot_cscalar *m_C;\ + uint32_t m_M;\ +};\ +void gemm##_arowmajor_bskinny_a##atype##_b##btype##_n##n_dim##_omp(\ + const gemm##_skinnydot_ascalar *A, const gemm##_skinnydot_bscalar *B,\ + gemm##_skinnydot_cscalar *C, uint32_t M, uint32_t K,\ + uint8_t b_c_order, gemm##_skinnydot_cscalar beta_inp, uint32_t num_threads) {\ +\ + if (num_threads <= 1 || K == 0) {\ + gemm##_arowmajor_bskinny_a##atype##_b##btype##_n##n_dim(A, B, C,\ + M, K, b_c_order, beta_inp);\ + return;\ + }\ +\ + struct gemm##_skinnydot_a##atype##_b##btype##_##n_dim##_matrix_info thread_args;\ + thread_args.m_A = A;\ + thread_args.m_B = B;\ + thread_args.m_C = C;\ + thread_args.m_M = M;\ + /* use the tls scratch of master thread for shared buffer */\ + gemm##_skinnydot_bscalar * const b_scratch_master =\ + blas_skinny_dot_b_scratch_##btype##n##n_dim;\ + const bool b_rowmajor = b_c_order & 1;\ + const bool c_rowmajor = b_c_order & 2;\ + const uint32_t k_limit = ((scratch_size / n_dim) >> 5) << 5;\ + const uint32_t c_m_inc = c_rowmajor ? n_dim : 1;\ + MACRO_EXPANSION_M_##m_mask(GEMM_SKINNY_DOT_UNROLL_TEST, unroll_test, n_dim)\ +\ + uint32_t k_pos, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= (k_limit << 1)) k_inc = k_limit;\ + else if (k_inc > k_limit) k_inc >>= 1;\ +\ + const gemm##_skinnydot_cscalar beta = (k_pos == 0) ? beta_inp : 1;\ + if (n_dim == 1) {\ + uint32_t m_done = 0;\ + omp_set_num_threads(num_threads);\ + _Pragma("omp parallel")\ + {\ + uint32_t m_start, m_end;\ + while(get_irreg_task(&m_done, &m_start, &m_end,\ + ((((M - m_done) / num_threads) >> 2) / MACRO_EXP_M_FIRSTITEM_##m_mask + 1)\ + * MACRO_EXP_M_FIRSTITEM_##m_mask, M)) {\ + const gemm##_skinnydot_ascalar *a_ptr = A + k_pos + m_start * K;\ + const gemm##_skinnydot_bscalar * const b_ptr = B + k_pos;\ + gemm##_skinnydot_cscalar *c_ptr = C + m_start;\ + uint32_t m_left = m_end - m_start;\ + MACRO_EXPANSION_M_##m_mask(GEMM_SKINNY_DOT_INLINE_CALL_LOOP, gemm, n_dim)\ + }\ + }\ + } else {\ + uint32_t m_done = 0;\ + uint32_t k_left_shared = k_inc;\ + omp_set_num_threads(num_threads);\ + _Pragma("omp parallel")\ + {\ + const gemm##_skinnydot_ascalar * const A = thread_args.m_A;\ + const gemm##_skinnydot_bscalar * const B = thread_args.m_B;\ + gemm##_skinnydot_cscalar * const C = thread_args.m_C;\ + const uint32_t M = thread_args.m_M;\ + uint32_t k_start, k_end;\ + while(get_copy_task(&k_left_shared, MACRO_EXP_M_FIRSTITEM_##k_mask << 3,\ + &k_start, &k_end)) {\ + if (b_rowmajor) {\ + const gemm##_skinnydot_bscalar *b_ptr = B + (k_pos + k_start) * n_dim;\ + gemm##_skinnydot_bscalar *sb_ptr = b_scratch_master + k_start * n_dim;\ + uint32_t k_left = k_end - k_start;\ + MACRO_EXPANSION_M_##k_mask(GEMM_SKINNY_DOT_PACKK_BR_LOOP, n_dim)\ + } else {\ + const gemm##_skinnydot_bscalar *b_ptr = B + k_pos + k_start;\ + MACRO_EXP_##n_dim(VOID_BASE, GEMM_SKINNY_DOT_PACKK_BC_INIT_BPTR_ITEM, gemm)\ + gemm##_skinnydot_bscalar *sb_ptr = b_scratch_master + k_start * n_dim;\ + uint32_t k_left = k_end - k_start;\ + MACRO_EXPANSION_M_##k_mask(GEMM_SKINNY_DOT_PACKK_BC_LOOP, n_dim)\ + }\ + }\ + _Pragma("omp barrier")\ + uint32_t m_start, m_end;\ + while(get_irreg_task(&m_done, &m_start, &m_end,\ + ((((M - m_done) / num_threads) >> 2) / MACRO_EXP_M_FIRSTITEM_##m_mask + 1)\ + * MACRO_EXP_M_FIRSTITEM_##m_mask, M)) {\ + const gemm##_skinnydot_ascalar *a_ptr = A + k_pos + m_start * K;\ + const gemm##_skinnydot_bscalar * const b_ptr = b_scratch_master;\ + gemm##_skinnydot_cscalar *c_ptr = C + c_m_inc * m_start;\ + uint32_t m_left = m_end - m_start;\ + MACRO_EXPANSION_M_##m_mask(GEMM_SKINNY_DOT_INLINE_CALL_LOOP,\ + gemm, n_dim)\ + }\ + }\ + }\ + }\ +} + +#endif + +/****************************************************************************** + * Template: GEMM_SKINNY_DOT_PARALLEL_FUNC + * Description: Construct dot-based "regular * skinny" GEMM function + * paralleled by OpenMP. + * Parameters: the same as in GEMM_SKINNY_DOT_SERIAL_FUNC + *****************************************************************************/ +#define GEMM_SKINNY_DOT_PARALLEL_FUNC(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype) \ + MACRO_EXP_M_##m_mask(GEMM_SKINNY_DOT_INLINE_FUNC_ITEM, gemm, n_dim, k_mask)\ + MACRO_EXP_M_##m_mask(GEMM_SKINNY_DOT_UNROLL_TEST_DEFAULT, n_dim)\ + GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(gemm, n_dim, k_mask, m_mask,\ + scratch_size, atype, btype, unroll_test) + +#endif diff --git a/include/common/CommonSkinnyGer.h b/include/common/CommonSkinnyGer.h new file mode 100644 index 0000000..b0af350 --- /dev/null +++ b/include/common/CommonSkinnyGer.h @@ -0,0 +1,526 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** +* File: CommonSkinnyGer.h +* Description: Common building blocks for regular * skinny or skinny * regular +* matmul when the regular matrix is column-major in the former +* case or row-major in the latter case. These 2 kinds of matmul +* involving skinny matrices require a special efficient kernel +* different from that in regular * regular matmul. Specifically, +* The regular matrix is no longer reordered (packed) during +* calculation. Elements from the regular matrix are accessed +* sequentially and only once.The GEMM calculation is decomposed +* into sequential GER operations rather than DOT ones.The +* output matrix is accessed repeatedly in GER operations so +* it is always packed in a scratch array. +* Extension: To support a new CPU architecture, the following tasks should +* be done in addition to including this header: +* (1) Use typedef to define _skinnyger_[a/b/c]scalar and +* _skinnyger_[a/b/c]vec[\d]. For example, when +* developing avx2 SGEMM regular*skinny kernels, the following +* lines should be added when the maximum vector length is 8 +* in M and 4 in K: +* typedef float sgemm_skinnyger_ascalar; +* typedef float sgemm_skinnyger_bscalar; +* typedef float sgemm_skinnyger_cscalar; +* // M vec length up to 8 +* typedef float sgemm_skinnyger_avec1; +* typedef __m128 sgemm_skinnyger_avec4; +* typedef __m256 sgemm_skinnyger_avec8; +* typedef float sgemm_skinnyger_cvec1; +* typedef __m128 sgemm_skinnyger_cvec4; +* typedef __m256 sgemm_skinnyger_cvec8; +* // K vec length up to 4 +* typedef float sgemm_skinnyger_bvec1; +* typedef __m128 sgemm_skinnyger_bvec4; +* (2) Implement inline functions for basic vector-scalar +* multiply-add operations. Here is an example for +* an inline function of avx2 SGEMM with +* m_veclen = 8, k_veclen = 4 and k_laneid = 3, +* which multiplies each element in a_vec with the +* element at lane 3 in b_vec and add the result +* to the corresponding element in c_vec: +* GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 3) { +* __m256 b_v0 = _mm256_broadcast_ss((float*)&b_vec + 2); +* return _mm256_fmadd_ps(a_vec, b_v0, c_vec); +* } +* For every combination of m_veclen and k_veclen, +* all related inline multiply-add functions +* with k_laneid from 1 to k_veclen should be implemented. +* (3) Implement load and store inline functions for matrix +* a/b/c like this (each catagory 1 example): +* // the 3 types of functions below should be written +* // for each m_veclen +* GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 8) { +* _mm_prefetch((char *)(a_ptr + 24), _MM_HINT_T0); +* return _mm256_loadu_ps(a_ptr); +* } +* GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 8) { +* return _mm256_loadu_ps(c_ptr); +* } +* GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 8) { +* _mm256_storeu_ps(c_ptr, c_vec); +* } +* // the 2 types of functions blow should be written +* // for each k_veclen +* GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(sgemm, 4) { +* float e0 = *b_ptr; b_ptr += ldb; +* float e1 = *b_ptr; b_ptr += ldb; +* float e2 = *b_ptr; b_ptr += ldb; +* float e3 = *b_ptr; +* return _mm_set_ps(e3, e2, e1, e0); +* } +* GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(sgemm, 4) { +* return _mm_loadu_ps(b_ptr); +* } +* (4) Finally build kernel functions from inline functions +* defined above. For each kernel function only 1 line +* is needed. The following line defines regular*skinny +* kernel functions (serial and OpenMP) for the minimum +* dimension length = 3 with m_veclen = {1, 4, 8} +* and k_veclen = {1, 4}: +* GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 3, 5, 13, 8192, +* float, float) +* The last 2 parameters in the macro are for function +* name mangling, providing the data type for regular +* and skinny matrix respectively. The last number in +* macro parameters (8192) specify the scratch size +* for output matrix which should be adjusted to the size +* of L1 cache. The second number (5) is the sum of all +* implemented k_veclen values. The third number (13) is +* the sum of all m_veclen values implemented. +******************************************************************************/ + +#define D_SCALE 2 //dynamic scaling factor in scheduling +#include "common/ExpandMacro.h" +#include "common/CommonSched.h" +#include +#include +#include +#include +#ifndef EMLL_SERIAL_ONLY +#include +#endif + +#ifndef INCLUDE_COMMON_SKINNY_GER +#define INCLUDE_COMMON_SKINNY_GER + +/* GEMM_SKINNY_GER_XXX_UNIT: computation units basic in skinny_ger function */ +/* below are only headers to this 6 functions */ +/* the function bodies should be provided according to CPU arch */ + +#define GEMM_SKINNY_GER_CALC_UNIT(gemm, m_vlen, k_vlen, k_id) \ +static inline gemm##_skinnyger_cvec##m_vlen\ + inline_##gemm##_acolmajor_bskinny_fma_unit_m##m_vlen##_kid##k_id##in##k_vlen(\ + gemm##_skinnyger_cvec##m_vlen c_vec,\ + gemm##_skinnyger_avec##m_vlen a_vec,\ + gemm##_skinnyger_bvec##k_vlen b_vec) +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_GER_CALC_UNIT(gemm, m_vlen, k_vlen, k_id) { + * gemm##_skinnyger_cvec##m_vlen ret; + * for (int i = 0; i < m_vlen; ++i) { + * ret[i] = c_vec[i] + a_vec[i] * b_vec[k_id - 1]; + * } + * return ret; + * } + */ + +#define GEMM_SKINNY_GER_LOADA_UNIT(gemm, m_vlen) \ +static inline gemm##_skinnyger_avec##m_vlen\ + inline_##gemm##_acolmajor_bskinny_loada_unit_m##m_vlen(\ + const gemm##_skinnyger_ascalar *a_ptr) +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_GER_LOADA_UNIT(gemm, m_vlen) { + * gemm##_skinnyger_avec##m_vlen ret; + * for (int i = 0; i < m_vlen; ++i) { + * ret[i] = a_ptr[i]; + * } + * prefetch(a_ptr + pref_distance); + * return ret; + * } + */ + +#define GEMM_SKINNY_GER_LOADC_UNIT(gemm, m_vlen) \ +static inline gemm##_skinnyger_cvec##m_vlen\ + inline_##gemm##_acolmajor_bskinny_loadc_unit_m##m_vlen(\ + const gemm##_skinnyger_cscalar *c_ptr) +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_GER_LOADC_UNIT(gemm, m_vlen) { + * gemm##_skinnyger_cvec##m_vlen ret; + * for (int i = 0; i < m_vlen; ++i) { + * ret[i] = c_ptr[i]; + * } + * return ret; + * } + */ + +#define GEMM_SKINNY_GER_STOREC_UNIT(gemm, m_vlen) \ +static inline void\ + inline_##gemm##_acolmajor_bskinny_storec_unit_m##m_vlen(\ + gemm##_skinnyger_cscalar *c_ptr,\ + gemm##_skinnyger_cvec##m_vlen c_vec) +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_GER_STOREC_UNIT(gemm, m_vlen) { + * for (int i = 0; i < m_vlen; ++i) { + * c_ptr[i] = c_vec[i]; + * } + * } + */ + +#define GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(gemm, k_vlen) \ +static inline gemm##_skinnyger_bvec##k_vlen\ + inline_##gemm##_acolmajor_bskinny_loadb_browmajor_unit_k##k_vlen(\ + const gemm##_skinnyger_bscalar *b_ptr, uint32_t ldb) +/* you should give optimized implementation equivalent to this: + * GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(gemm, k_vlen) { + * gemm##_skinnyger_bvec##k_vlen ret; + * for (int i = 0; i < m_vlen; ++i) { + * ret[i] = *b_ptr; b_ptr += ldb; + * } + * } + */ + +#define GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(gemm, k_vlen) \ +static inline gemm##_skinnyger_bvec##k_vlen\ + inline_##gemm##_acolmajor_bskinny_loadb_bcolmajor_unit_k##k_vlen(\ + const gemm##_skinnyger_bscalar *b_ptr) +/* you should give vectorized implementation equivalent to this: + * GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(gemm, k_vlen) { + * gemm##_skinnyger_bvec##k_vlen ret; + * for (int i = 0; i < m_vlen; ++i) { + * ret[i] = b_ptr[i]; + * } + * } + */ + + +/* construct skinny_ger function from computation units */ +#define GEMM_SKINNY_GER_CALC_UNIT_ITEM(n_id, gemm, m_vlen, k_vlen, k_id) \ + c##m_vlen##_##n_id =\ + inline_##gemm##_acolmajor_bskinny_fma_unit_m##m_vlen##_kid##k_id##in##k_vlen(\ + c##m_vlen##_##n_id, a##m_vlen##_##k_id, b##k_vlen##_##n_id); + +#define GEMM_SKINNY_GER_CALC_UNIT_K1(k_id, gemm, m_vlen, k_vlen, n_dim) \ + const gemm##_skinnyger_avec##m_vlen a##m_vlen##_##k_id =\ + inline_##gemm##_acolmajor_bskinny_loada_unit_m##m_vlen(a_ptr##k_id);\ + a_ptr##k_id += m_vlen;\ + MACRO_EXPANSION_##n_dim(VOID_BASE, GEMM_SKINNY_GER_CALC_UNIT_ITEM,\ + gemm, m_vlen, k_vlen, k_id) + +#define GEMM_SKINNY_GER_LOADC_ITEM(n_id, gemm, m_vlen) \ + gemm##_skinnyger_cvec##m_vlen c##m_vlen##_##n_id =\ + inline_##gemm##_acolmajor_bskinny_loadc_unit_m##m_vlen(\ + c_ptr + (n_id - 1) * m_vlen); + +#define GEMM_SKINNY_GER_STOREC_ITEM(n_id, gemm, m_vlen) \ + inline_##gemm##_acolmajor_bskinny_storec_unit_m##m_vlen(\ + c_ptr + (n_id - 1) * m_vlen, c##m_vlen##_##n_id); + +#define GEMM_SKINNY_GER_COMPUTE_BLOCK(gemm, m_vlen, k_vlen, n_dim) \ + MACRO_EXPANSION_##n_dim(VOID_BASE,\ + GEMM_SKINNY_GER_LOADC_ITEM, gemm, m_vlen)\ + MACRO_EXP_##k_vlen(VOID_BASE,\ + GEMM_SKINNY_GER_CALC_UNIT_K1, gemm, m_vlen, k_vlen, n_dim)\ + MACRO_EXPANSION_##n_dim(VOID_BASE,\ + GEMM_SKINNY_GER_STOREC_ITEM, gemm, m_vlen) + +#define GEMM_SKINNY_GER_COMPUTE_BLOCK_LOOP(\ + m_vlen, gemm, k_vlen, n_dim) \ + for (; m_left >= m_vlen; m_left -= m_vlen) {\ + GEMM_SKINNY_GER_COMPUTE_BLOCK(gemm, m_vlen, k_vlen, n_dim) \ + c_ptr += n_dim * m_vlen;\ + } + +#define GEMM_SKINNY_GER_DECLARE_B_ITEM(n_id, gemm, k_vlen) \ + gemm##_skinnyger_bvec##k_vlen b##k_vlen##_##n_id; + +#define GEMM_SKINNY_GER_LOADB_BROWMAJOR_ITEM(n_id, gemm, k_vlen) \ + b##k_vlen##_##n_id =\ + inline_##gemm##_acolmajor_bskinny_loadb_browmajor_unit_k##k_vlen(\ + b_ptr, LDB); b_ptr++; + +#define GEMM_SKINNY_GER_LOADB_BCOLMAJOR_ITEM(n_id, gemm, k_vlen) \ + b##k_vlen##_##n_id =\ + inline_##gemm##_acolmajor_bskinny_loadb_bcolmajor_unit_k##k_vlen(b_ptr);\ + b_ptr += LDB; + +#define GEMM_SKINNY_GER_INIT_APTR_ITEM(k_id, gemm) \ + const gemm##_skinnyger_ascalar *a_ptr##k_id = a_ptr + (k_id - 1) * LDA; + +/* define valid inline function */ +#define GEMM_SKINNY_GER_INLINE_FUNC(gemm, n_dim, k_vlen, m_mask) \ +static inline void inline_##gemm##_acolmajor_bskinny_k##k_vlen##n##n_dim(\ + const gemm##_skinnyger_ascalar *a_ptr,\ + const gemm##_skinnyger_bscalar *b_ptr,\ + gemm##_skinnyger_cscalar *c_ptr,\ + uint32_t m_left, uint32_t LDA, uint32_t LDB, bool b_rowmajor) {\ +\ + MACRO_EXP_##k_vlen(VOID_BASE, GEMM_SKINNY_GER_INIT_APTR_ITEM, gemm)\ + MACRO_EXP_##n_dim(VOID_BASE, GEMM_SKINNY_GER_DECLARE_B_ITEM, gemm, k_vlen)\ + if (b_rowmajor) {\ + MACRO_EXP_##n_dim(VOID_BASE,\ + GEMM_SKINNY_GER_LOADB_BROWMAJOR_ITEM, gemm, k_vlen)\ + } else {\ + MACRO_EXP_##n_dim(VOID_BASE,\ + GEMM_SKINNY_GER_LOADB_BCOLMAJOR_ITEM, gemm, k_vlen)\ + }\ +\ + MACRO_EXP_M_##m_mask(GEMM_SKINNY_GER_COMPUTE_BLOCK_LOOP,\ + gemm, k_vlen, n_dim)\ +} + +#define GEMM_SKINNY_GER_INLINE_FUNC_ITEM(k_vlen, gemm, n_dim, m_mask)\ + GEMM_SKINNY_GER_INLINE_FUNC(gemm, n_dim, k_vlen, m_mask) + +#define GEMM_SKINNY_GER_INLINE_FUNCS(gemm, n_dim, k_mask, m_mask) \ + MACRO_EXPANSION_M_##k_mask(GEMM_SKINNY_GER_INLINE_FUNC_ITEM, gemm, n_dim, m_mask) + +#define GEMM_SKINNY_GER_INLINE_CALL_LOOP(k_vlen, gemm, n_dim) \ + for (; k_left >= k_vlen; k_left -= k_vlen) {\ + inline_##gemm##_acolmajor_bskinny_k##k_vlen##n##n_dim(\ + a_ptr, b_ptr, c_scratch, m_inc, M, LDB, b_rowmajor);\ + a_ptr += k_vlen * M;\ + b_ptr += k_vlen * b_k_inc;\ + } + +#define GEMM_SKINNY_GER_BETA_FUNC(gemm, n_dim) \ +static inline void inline_##gemm##_acolmajor_bskinny_beta_##n_dim(\ + gemm##_skinnyger_cscalar *c_ptr, uint32_t M,\ + gemm##_skinnyger_cscalar beta) {\ +\ + if (beta == (gemm##_skinnyger_cscalar)1.0) {\ + return;\ + }\ +\ + uint64_t size = (uint64_t)M * n_dim;\ + for (; size > 7; size -= 8) {\ + c_ptr[0] *= beta; c_ptr[1] *= beta;\ + c_ptr[2] *= beta; c_ptr[3] *= beta;\ + c_ptr[4] *= beta; c_ptr[5] *= beta;\ + c_ptr[6] *= beta; c_ptr[7] *= beta;\ + c_ptr += 8;\ + }\ + for (; size > 0; size--) {\ + *c_ptr *= beta;\ + c_ptr++;\ + }\ +} + +/* params atype & btype here are for function name mangling only */ +#define GEMM_SKINNY_GER_SERIAL_FUNC(gemm, n_dim,\ + k_mask, m_mask, stack_size, atype, btype) \ +GEMM_SKINNY_GER_BETA_FUNC(gemm, n_dim)\ +GEMM_SKINNY_GER_INLINE_FUNCS(gemm, n_dim, k_mask, m_mask)\ +__attribute__((aligned(4096))) static __thread gemm##_skinnyger_cscalar\ + gemm##_acolmajor_bskinny_a##atype##_b##btype##_##n_dim##_cscratch[stack_size];\ +GEMM_SKINNY_GER_INLINE_DEPACK_FUNC(gemm, m_mask, n_dim)\ +void gemm##_acolmajor_bskinny_a##atype##_b##btype##_n##n_dim(\ + const gemm##_skinnyger_ascalar *A,\ + const gemm##_skinnyger_bscalar *B,\ + gemm##_skinnyger_cscalar *C,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + gemm##_skinnyger_cscalar beta_inp) {\ +\ + const bool b_rowmajor = b_c_order & 1;\ + const bool c_rowmajor = b_c_order & 2;\ + const uint32_t b_k_inc = b_rowmajor ? n_dim : 1;\ + const uint32_t LDB = b_rowmajor ? n_dim : K;\ +\ + if (n_dim == 1) {\ + uint32_t k_left = K;\ + const uint32_t m_inc = M;\ + const gemm##_skinnyger_ascalar *a_ptr = A;\ + const gemm##_skinnyger_bscalar *b_ptr = B;\ + gemm##_skinnyger_cscalar *c_scratch = C;\ + inline_##gemm##_acolmajor_bskinny_beta_##n_dim(c_scratch, M, beta_inp);\ + MACRO_EXP_M_##k_mask(GEMM_SKINNY_GER_INLINE_CALL_LOOP, gemm, 1)\ + return;\ + }\ +\ + const uint32_t m_limit = ((stack_size / n_dim) >> 5) << 5;\ + uint32_t m_pos, m_inc;\ + for (m_pos = 0; m_pos < M; m_pos += m_inc) {\ + m_inc = M - m_pos;\ + if (m_inc >= (m_limit << 1)) m_inc = m_limit;\ + else if (m_inc > m_limit) m_inc >>= 1;\ + uint32_t k_left = K;\ + const gemm##_skinnyger_ascalar *a_ptr = A + m_pos;\ + const gemm##_skinnyger_bscalar *b_ptr = B;\ + gemm##_skinnyger_cscalar *c_scratch =\ + gemm##_acolmajor_bskinny_a##atype##_b##btype##_##n_dim##_cscratch;\ + memset(c_scratch, 0, m_inc * n_dim * sizeof(gemm##_skinnyger_cscalar));\ + MACRO_EXP_M_##k_mask(GEMM_SKINNY_GER_INLINE_CALL_LOOP, gemm, n_dim)\ + inline_##gemm##_acolmajor_bskinny_depack_c_n##n_dim(c_rowmajor, C,\ + c_scratch, M, m_pos, m_inc, beta_inp);\ + }\ +} + +#ifdef EMLL_SERIAL_ONLY + +#define GEMM_SKINNY_GER_PARALLEL_FUNC(gemm, n_dim,\ + k_mask, m_mask, stack_size, atype, btype) \ +GEMM_SKINNY_GER_SERIAL_FUNC(gemm, n_dim, k_mask, m_mask, stack_size, atype, btype)\ +void gemm##_acolmajor_bskinny_a##atype##_b##btype##_n##n_dim##_omp(\ + const gemm##_skinnyger_ascalar *A,\ + const gemm##_skinnyger_bscalar *B,\ + gemm##_skinnyger_cscalar *C,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + gemm##_skinnyger_cscalar beta_inp, uint32_t num_threads) {\ +\ + gemm##_acolmajor_bskinny_a##atype##_b##btype##_n##n_dim(\ + A, B, C, M, K, b_c_order, beta_inp);\ +} + +#else + +/* params atype & btype here are for function name mangling only */ +#define GEMM_SKINNY_GER_PARALLEL_FUNC(gemm, n_dim,\ + k_mask, m_mask, stack_size, atype, btype) \ +struct gemm##_skinnyger_a##atype##_b##btype##_n##n_dim##_info {\ + const gemm##_skinnyger_ascalar *m_A;\ + const gemm##_skinnyger_bscalar *m_B;\ + gemm##_skinnyger_cscalar *m_C;\ + uint32_t m_M;\ +};\ +GEMM_SKINNY_GER_SERIAL_FUNC(gemm, n_dim, k_mask, m_mask, stack_size, atype, btype)\ +void gemm##_acolmajor_bskinny_a##atype##_b##btype##_n##n_dim##_omp(\ + const gemm##_skinnyger_ascalar *A,\ + const gemm##_skinnyger_bscalar *B,\ + gemm##_skinnyger_cscalar *C,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + gemm##_skinnyger_cscalar beta_inp, uint32_t num_threads) {\ +\ + if (num_threads <= 1) {\ + gemm##_acolmajor_bskinny_a##atype##_b##btype##_n##n_dim(\ + A, B, C, M, K, b_c_order, beta_inp);\ + return;\ + }\ +\ + inline_##gemm##_acolmajor_bskinny_beta_##n_dim(C, M, beta_inp);\ + const bool b_rowmajor = b_c_order & 1;\ + const bool c_rowmajor = b_c_order & 2;\ + const uint32_t b_k_inc = b_rowmajor ? n_dim : 1;\ + const uint32_t LDB = b_rowmajor ? n_dim : K;\ + const uint32_t m_limit = ((stack_size / n_dim) >> 5) << 5;\ + const uint32_t m_task_min = m_limit >= 256 ? 256 : m_limit;\ + const uint64_t m_k_task_min = (16ULL << 32) | (uint64_t)m_task_min;\ + const uint64_t m_k_pos_max = ((uint64_t)K << 32) | (uint64_t)M;\ + uint64_t task_end = 0;\ +\ + struct gemm##_skinnyger_a##atype##_b##btype##_n##n_dim##_info task_info;\ + task_info.m_A = A;\ + task_info.m_B = B;\ + task_info.m_C = C;\ + task_info.m_M = M;\ +\ + omp_set_num_threads(num_threads);\ + _Pragma("omp parallel")\ + {\ + const gemm##_skinnyger_ascalar * const A = task_info.m_A;\ + const gemm##_skinnyger_bscalar * const B = task_info.m_B;\ + gemm##_skinnyger_cscalar * const C = task_info.m_C;\ + const uint32_t M = task_info.m_M;\ + uint32_t m_start, k_start, m_end, k_end, m_start_old, m_inc_old;\ + m_start_old = M; m_inc_old = 0;\ + gemm##_skinnyger_cscalar * const c_scratch = \ + gemm##_acolmajor_bskinny_a##atype##_b##btype##_##n_dim##_cscratch;\ + while(get_mn_task(&task_end, &m_start, &k_start, &m_end, &k_end,\ + m_k_task_min, m_limit, 0, m_k_pos_max, num_threads)) {\ +\ + uint32_t k_left = k_end - k_start;\ + const uint32_t m_inc = m_end - m_start;\ + const gemm##_skinnyger_ascalar *a_ptr = A + k_start * M + m_start;\ + const gemm##_skinnyger_bscalar *b_ptr = B + k_start * b_k_inc;\ + if (m_start != m_start_old) {\ + if (m_inc_old > 0) {\ + _Pragma("omp critical")\ + {\ + inline_##gemm##_acolmajor_bskinny_depack_c_n##n_dim(c_rowmajor, C,\ + c_scratch, M, m_start_old, m_inc_old, 1);\ + }\ + }\ + memset(c_scratch, 0, m_inc * n_dim * sizeof(gemm##_skinnyger_cscalar));\ + }\ + MACRO_EXP_M_##k_mask(GEMM_SKINNY_GER_INLINE_CALL_LOOP, gemm, n_dim)\ + m_start_old = m_start; m_inc_old = m_inc;\ + }\ + if (m_inc_old > 0) {\ + _Pragma("omp critical")\ + {\ + inline_##gemm##_acolmajor_bskinny_depack_c_n##n_dim(c_rowmajor, C,\ + c_scratch, M, m_start_old, m_inc_old, 1);\ + }\ + }\ + }\ +} + +#endif + +#define GEMM_SKINNY_GER_DEPACK_CRM_LOW_ITEM(n_id, m_id, m_vlen, n_dim) \ + c_wt[(m_id - 1) * n_dim + n_id - 1] =\ + c_wt[(m_id - 1) * n_dim + n_id - 1] * beta +\ + c_rd[(m_id - 1) + (n_id - 1) * m_vlen]; + +#define GEMM_SKINNY_GER_DEPACK_CRM_MID_ITEM(m_id, m_vlen, n_dim) \ + MACRO_EXPANSION_##n_dim(VOID_BASE,\ + GEMM_SKINNY_GER_DEPACK_CRM_LOW_ITEM, m_id, m_vlen, n_dim) + +#define GEMM_SKINNY_GER_DEPACK_CRM_BLOCK_LOOP(m_vlen, gemm, n_dim) \ + for (; m_left >= m_vlen; m_left -= m_vlen) {\ + MACRO_EXP_##m_vlen(VOID_BASE,\ + GEMM_SKINNY_GER_DEPACK_CRM_MID_ITEM, m_vlen, n_dim)\ + c_wt += m_vlen * n_dim;\ + c_rd += m_vlen * n_dim;\ + } + +#define GEMM_SKINNY_GER_DEPACK_CCM_LOW_ITEM(m_id, n_id, m_vlen) \ + c_wt1[m_id - 1] = c_wt1[m_id - 1] * beta +\ + c_rd[(n_id - 1) * m_vlen + m_id - 1]; + +#define GEMM_SKINNY_GER_DEPACK_CCM_MID_ITEM(n_id, m_vlen) \ + MACRO_EXPANSION_##m_vlen(VOID_BASE,\ + GEMM_SKINNY_GER_DEPACK_CCM_LOW_ITEM, n_id, m_vlen)\ + c_wt1 += M; + +#define GEMM_SKINNY_GER_DEPACK_CCM_BLOCK_LOOP(m_vlen, gemm, n_dim) \ + for (; m_left >= m_vlen; m_left -= m_vlen) {\ + gemm##_skinnyger_cscalar *c_wt1 = c_wt;\ + MACRO_EXP_##n_dim(VOID_BASE,\ + GEMM_SKINNY_GER_DEPACK_CCM_MID_ITEM, m_vlen)\ + c_wt += m_vlen;\ + c_rd += m_vlen * n_dim;\ + } + +#define GEMM_SKINNY_GER_INLINE_DEPACK_FUNC(gemm, m_mask, n_dim) \ +static void inline_##gemm##_acolmajor_bskinny_depack_c_n##n_dim(\ + bool c_rowmajor, gemm##_skinnyger_cscalar * __restrict__ C,\ + const gemm##_skinnyger_cscalar * __restrict__ c_scratch,\ + uint32_t M, uint32_t m_pos, uint32_t m_left,\ + gemm##_skinnyger_cscalar beta) {\ +\ + const gemm##_skinnyger_cscalar *c_rd = c_scratch;\ + if (c_rowmajor) {\ + gemm##_skinnyger_cscalar *c_wt = C + m_pos * n_dim;\ + MACRO_EXP_M_##m_mask(GEMM_SKINNY_GER_DEPACK_CRM_BLOCK_LOOP, gemm, n_dim)\ + } else {\ + gemm##_skinnyger_cscalar *c_wt = C + m_pos;\ + MACRO_EXP_M_##m_mask(GEMM_SKINNY_GER_DEPACK_CCM_BLOCK_LOOP, gemm, n_dim)\ + }\ +} + +#endif diff --git a/include/common/CommonTest.h b/include/common/CommonTest.h new file mode 100644 index 0000000..d0ca181 --- /dev/null +++ b/include/common/CommonTest.h @@ -0,0 +1,620 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: CommonTest.h + * Description: Common test framework for GEMM/Bias/Quantization functions + * Usage: Include this header, then define test functions by macros, + * last call test functions in main function. Please refer to + * test/Test*.c for example. + *****************************************************************************/ + +#include +#include +#include +#include +#include +#include + +#ifndef INCLUDE_COMMON_TEST +#define INCLUDE_COMMON_TEST + +#define STD_GEMM_AR_BC_CC(atype, btype, ctype, A, B, C, M, N, K, beta) {\ + for (uint32_t n_pos = 0; n_pos < (N); ++n_pos) {\ + ctype *c_ptr = (C) + n_pos * (M);\ + const btype *b_ptr = (B) + n_pos * (K);\ + for (uint32_t m_pos = 0; m_pos < (M); ++m_pos) {\ + const atype *a_ptr = (A) + m_pos * (K);\ + ctype sum = (ctype)0.0f;\ + for (uint32_t k_pos = 0; k_pos < (K); ++k_pos) {\ + sum += (ctype)a_ptr[k_pos] * (ctype)b_ptr[k_pos];\ + }\ + c_ptr[m_pos] = c_ptr[m_pos] * beta + sum;\ + }\ + }\ +} + +/* src: row-major; dst: column-major */ +#define STD_TRANSPOSE(T, src, dst, src_rows, src_cols) {\ + for (uint32_t src_row_pos = 0; src_row_pos < src_rows; ++src_row_pos) {\ + const T *src_ptr = src + src_row_pos * src_cols;\ + T *dst_ptr = dst + src_row_pos;\ + for (uint32_t src_col_pos = 0; src_col_pos < src_cols; ++src_col_pos) {\ + dst_ptr[src_col_pos * src_rows] = src_ptr[src_col_pos];\ + }\ + }\ +} + +/* matrix C is column-major */ +#define STD_GEMM(gemmtype, atype, btype, ctype) \ +void std_##gemmtype(const atype *A, const btype *B, ctype *C,\ + uint32_t M, uint32_t N, uint32_t K,\ + bool a_rowmajor, bool b_colmajor, ctype beta) {\ + atype *A_mat = NULL; const atype *A_rd = A;\ + if (!a_rowmajor) {\ + A_mat = (atype *)malloc(M * K * sizeof(atype));\ + STD_TRANSPOSE(atype, A, A_mat, K, M)\ + A_rd = A_mat;\ + }\ + btype *B_mat = NULL; const btype *B_rd = B;\ + if (!b_colmajor) {\ + B_mat = (btype *)malloc(N * K * sizeof(btype));\ + STD_TRANSPOSE(btype, B, B_mat, K, N)\ + B_rd = B_mat;\ + }\ + STD_GEMM_AR_BC_CC(atype, btype, ctype, A_rd, B_rd, C, M, N, K, beta)\ + if (A_mat) free(A_mat);\ + if (B_mat) free(B_mat);\ +} + +/* produce a random number from a/b, a is a random number in [-c, c] */ +/* c = dividend_abs_max; b = divisor */ +#define STD_RAND(T, dat, size, dividend_abs_max, divisor) {\ + const int32_t abs_max_get = (dividend_abs_max) < 0 ? \ + -(dividend_abs_max) : (dividend_abs_max);\ + const int32_t offset_get = (dividend_abs_max) < 0 ? \ + 0 : (dividend_abs_max);\ + for (uint64_t pos = 0; pos < (size); ++pos) {\ + int32_t rand_i = rand() % (2 * abs_max_get + 1);\ + rand_i -= offset_get;\ + float rand_f = (float)rand_i / (float)(divisor);\ + *((dat) + pos) = (T)rand_f;\ + }\ +} + +#define STD_MAXDIFF(T, max, dat1, dat2, size) {\ + T tmp;\ + max = (T)0.0f;\ + for (uint64_t pos = 0; pos < (size); ++pos) {\ + tmp = (*((dat2) + pos)) - (*((dat1) + pos));\ + if (tmp < 0) tmp = (T)0.0f - tmp;\ + if (tmp > max) max = tmp;\ + }\ +} + +#define SRC_SIZE 160000000 + +#define STD_TEST(gemmtype, btype, atype, ctype, dividend_abs_max, divisor) \ +STD_GEMM(gemmtype, atype, btype, ctype)\ +typedef int (*TestFunc_##gemmtype)(int, int, const atype*, const btype*, ctype*,\ + uint32_t, uint32_t, uint32_t, ctype, uint32_t);\ +void std_test_##gemmtype(TestFunc_##gemmtype test_gemm,\ + uint32_t M, uint32_t N, uint32_t K, uint8_t transAB,\ + ctype beta, uint32_t num_threads) {\ +\ + const int b_rowmajor = transAB & 2;\ + const int a_rowmajor = transAB & 1;\ +\ + const uint64_t a_size = (uint64_t)M * (uint64_t)K;\ + const uint64_t b_size = (uint64_t)N * (uint64_t)K;\ + const uint64_t c_size = (uint64_t)M * (uint64_t)N;\ + const uint64_t iters = (uint64_t)SRC_SIZE / \ + (a_size * sizeof(atype) + b_size * sizeof(btype) + 1);\ + if (iters == 0) {\ + printf("Problem size too large. return.\n");\ + return;\ + }\ + atype * const A = (atype *)malloc(a_size * iters * sizeof(atype));\ + btype * const B = (btype *)malloc(b_size * iters * sizeof(btype));\ + ctype * const C_ref = (ctype *)malloc(c_size * sizeof(ctype));\ + ctype * const C_tst = (ctype *)malloc(c_size * sizeof(ctype));\ + if (A == NULL || B == NULL || C_ref == NULL || C_tst == NULL) {\ + printf("Memory allocation failed. return.\n");\ + free(A); free(B); free(C_ref); free(C_tst);\ + return;\ + }\ + srand(time(NULL));\ + STD_RAND(float, A, a_size, dividend_abs_max, divisor)\ + for (uint64_t pos = 1; pos < iters; ++pos) {\ + memcpy(A + pos * a_size, A, a_size * sizeof(atype));\ + }\ + STD_RAND(float, B, b_size, dividend_abs_max, divisor)\ + for (uint64_t pos = 1; pos < iters; ++pos) {\ + memcpy(B + pos * b_size, B, b_size * sizeof(btype));\ + }\ + STD_RAND(float, C_tst, c_size, dividend_abs_max, divisor)\ + memcpy(C_ref, C_tst, c_size * sizeof(ctype));\ + struct timespec st, et;\ + std_##gemmtype(A, B, C_ref, M, N, K, a_rowmajor, !b_rowmajor, beta);\ + clock_gettime(CLOCK_MONOTONIC, &st);\ + int ret_status = test_gemm(a_rowmajor, b_rowmajor, A, B, C_tst,\ + M, N, K, beta, num_threads);\ + clock_gettime(CLOCK_MONOTONIC, &et);\ + double nsec = (double)(et.tv_nsec - st.tv_nsec) + 1.0e9 * \ + (double)(et.tv_sec - st.tv_sec);\ + printf("Time elapsed for the first run: %.2e ns\n", nsec);\ + if (ret_status) {\ + printf("An error has occurred in the tested gemm, error code = %d\n",\ + ret_status);\ + return;\ + }\ + ctype max;\ + STD_MAXDIFF(float, max, C_ref, C_tst, c_size)\ + printf("Max diff. between test and std: %.2e\n", (double)max);\ +\ + if (iters > 1) {\ + clock_gettime(CLOCK_MONOTONIC, &st);\ + for (uint64_t pos = 1; pos < iters; ++pos) {\ + test_gemm(a_rowmajor, b_rowmajor, A + a_size * pos, B + b_size * pos, C_tst,\ + M, N, K, -1, num_threads);\ + }\ + clock_gettime(CLOCK_MONOTONIC, &et);\ + double nsec = (double)(et.tv_nsec - st.tv_nsec) + 1.0e9 * \ + (double)(et.tv_sec - st.tv_sec);\ + double ops = (double)M * (double)N * (double)(2 * K - 1) * \ + (double)(iters - 1);\ + printf("Averaged time for each run after warm-up: %.2e ns\n",\ + nsec / (double)(iters - 1));\ + printf("The performance of test: %.2e GFLOPS\n", ops / nsec);\ + }\ +\ + free(A); free(B); free(C_ref); free(C_tst);\ + return;\ +} + +#define TEST_1D_OPERATION_PERF(size, num_iters, FUNC_CALLER, ...) \ + struct timespec st, et;\ + clock_gettime(CLOCK_MONOTONIC, &st);\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + FUNC_CALLER(0, size, ##__VA_ARGS__)\ + }\ + clock_gettime(CLOCK_MONOTONIC, &et);\ + double nsec = (double)(et.tv_nsec - st.tv_nsec) + 1.0e9 * (double)\ + (et.tv_sec - st.tv_sec);\ + printf("Avg. Perf.(repeat on the same data): %.2e G elements per second\n",\ + (double)size * (double)(num_iters - 1) / nsec);\ + clock_gettime(CLOCK_MONOTONIC, &st);\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + FUNC_CALLER(pos, size, ##__VA_ARGS__)\ + }\ + clock_gettime(CLOCK_MONOTONIC, &et);\ + nsec = (double)(et.tv_nsec - st.tv_nsec) + 1.0e9 * (double)\ + (et.tv_sec - st.tv_sec);\ + printf("Avg. Perf.(no repeat of data region): %.2e G elements per second\n",\ + (double)size * (double)(num_iters - 1) / nsec); + +#define FUNC_CALLER_QUANT_UNSYM(pos, size, inbits, outbits,\ + src, tst_u, zero_addr, scale_addr) \ + quantize_asymmetric_f##inbits##_u##outbits(\ + src + pos * size, tst_u + pos * size,\ + zero_addr, scale_addr, size, 0, -1); + +#define TEST_QUANT_UNSYM(inbits, outbits) \ +static void test_quant_asym_f##inbits##_u##outbits(uint32_t size) {\ + if (size < 4) size = 4;\ + printf("Test unsymmetrical quantization fp"#inbits" -> uint"#outbits":\n");\ + printf("num_elements = %u\n", size);\ +\ + const uint32_t num_iters = 40000000 / size;\ + if (num_iters <= 2) {\ + printf("Problem size too large.\n");\ + return;\ + }\ +\ + uint##outbits##_t * const ref_u =\ + (uint##outbits##_t *)malloc(size * (outbits >> 3));\ + uint##outbits##_t * const tst_u =\ + (uint##outbits##_t *)malloc(num_iters * size * (outbits >> 3));\ + float##inbits##_t * const src =\ + (float##inbits##_t *)malloc(num_iters * size * (inbits >> 3));\ +\ + srand(time(NULL));\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + ref_u[pos] = rand();\ + }\ + uint32_t min_pos = rand() % size;\ + uint32_t max_pos = min_pos;\ + while (max_pos == min_pos) {\ + max_pos = rand() % size;\ + }\ + ref_u[min_pos] = 0;\ + ref_u[max_pos] = (uint##outbits##_t)-1;\ + const float##inbits##_t ref_scale =\ + (float##inbits##_t)(rand() + 1) / (float##inbits##_t)(RAND_MAX >> 2);\ + const uint##outbits##_t ref_zero = rand();\ + printf("Generate src data with ref_zero = %u and ref_scale = %.2e\n",\ + ref_zero, ref_scale);\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + float##inbits##_t fluc =\ + ((float##inbits##_t)rand() / RAND_MAX - (float##inbits##_t)0.5) *\ + (float##inbits##_t)0.9875;\ + if (pos == max_pos || pos == min_pos) fluc = 0.0;\ + else if (ref_u[pos] == (uint##outbits##_t)-1 && fluc > 0) fluc *= -1.0;\ + else if (ref_u[pos] == 0 && fluc < 0) fluc *= -1.0;\ + src[pos] = ((float##inbits##_t)((long)ref_u[pos] - (long)ref_zero) + fluc)\ + * ref_scale;\ + }\ + printf("First 4 elements of ref_u"#outbits"\n: %u, %u, %u, %u\n",\ + ref_u[0], ref_u[1], ref_u[2], ref_u[3]);\ + printf("First 4 elements of src_f"#inbits"\n: %.2e, %.2e, %.2e, %.2e\n",\ + src[0], src[1], src[2], src[3]);\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + memcpy(src + pos * size, src, size * (inbits >> 3));\ + }\ +\ + uint##outbits##_t tst_zero;\ + float##inbits##_t tst_scale;\ + quantize_asymmetric_f##inbits##_u##outbits(\ + src, tst_u, &tst_zero, &tst_scale, size, 0, -1);\ +\ + if (tst_zero != ref_zero) {\ + printf("tst_zero = %u, mismatch with ref_zero\n", tst_zero);\ + }\ + printf("relative difference between ref_scale and tst_scale: %.2e\n",\ + (tst_scale - ref_scale) / ref_scale);\ + int eql = 1;\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + if (eql != 0 && tst_u[pos] != ref_u[pos]) {\ + eql = 0;\ + printf("u"#outbits" results at pos %u are inconsistent: ref = %u, tst = %u\n",\ + pos, ref_u[pos], tst_u[pos]);\ + break;\ + }\ + }\ + if (eql != 0) {\ + printf("u"#outbits" results are equal\n");\ + TEST_1D_OPERATION_PERF(size, num_iters, FUNC_CALLER_QUANT_UNSYM,\ + inbits, outbits, src, tst_u, &tst_zero, &tst_scale)\ + }\ +\ + free(src);\ + free(ref_u);\ + free(tst_u);\ +} + +#define FUNC_CALLER_QUANT_SYM(pos, size, inbits, outbits, src, tst_s, scale_addr)\ + quantize_symmetric_f##inbits##_s##outbits(src + pos * size, tst_s + pos * size,\ + scale_addr, size, 0, -1); + +#define TEST_QUANT_SYM(inbits, outbits) \ +static void test_quant_sym_f##inbits##_s##outbits(uint32_t size) {\ + if (size < 4) size = 4;\ + printf("Test symmetrical quantization f"#inbits" -> s"#outbits":\n");\ + printf("num_elements = %u\n", size);\ +\ + const uint32_t num_iters = 40000000 / size;\ + if (num_iters <= 2) {\ + printf("Problem size too large.\n");\ + return;\ + }\ +\ + int##outbits##_t * const ref_s =\ + (int##outbits##_t *)malloc(size * (outbits >> 3));\ + int##outbits##_t * const tst_s =\ + (int##outbits##_t *)malloc(num_iters * size * (outbits >> 3));\ + float##inbits##_t * const src =\ + (float##inbits##_t *)malloc(num_iters * size * (inbits >> 3));\ +\ + const long sint_max = (uint##outbits##_t)-1 >> 1;\ + const long sint_min = (-sint_max) + (-1);\ + srand(time(NULL));\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + ref_s[pos] = (long)rand() % (2 * sint_max + 2) + sint_min;\ + }\ + const uint32_t extreme_pos = rand() % size;\ + ref_s[extreme_pos] = (rand() & 1) ? sint_min : sint_max;\ + const float##inbits##_t ref_scale =\ + (float##inbits##_t)(rand() + 1) / (RAND_MAX >> 2);\ + printf("Generate fp"#inbits" src data with ref_scale = %.2e\n", ref_scale);\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + float##inbits##_t fluc =\ + ((float##inbits##_t)rand() / RAND_MAX - (float##inbits##_t)0.5)\ + * (float##inbits##_t)0.9875;\ + if (pos == extreme_pos) fluc = 0.0;\ + else if (ref_s[pos] == sint_min && fluc < 0) fluc *= -1.0;\ + else if (ref_s[pos] == sint_max && fluc > 0) fluc *= -1.0;\ + src[pos] = ((float##inbits##_t)ref_s[pos] + fluc) * ref_scale;\ + }\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + memcpy(src + pos * size, src, size * (inbits >> 3));\ + }\ + printf("First 4 elements of fp"#inbits" src:\n%.2e, %.2e, %.2e, %.2e\n",\ + src[0], src[1], src[2], src[3]);\ + printf("First 4 elements of s"#outbits" ref_dst:\n%d, %d, %d, %d\n",\ + ref_s[0], ref_s[1], ref_s[2], ref_s[3]);\ +\ + float##inbits##_t tst_scale;\ + quantize_symmetric_f##inbits##_s##outbits(\ + src, tst_s, &tst_scale, size, 0, -1);\ +\ + printf("relative difference between ref_scale and tst_scale: %.2e\n",\ + (tst_scale - ref_scale) / ref_scale);\ + int eql = 1;\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + if (eql != 0 && tst_s[pos] != ref_s[pos]) {\ + eql = 0;\ + printf("s"#outbits" results at pos %u are inconsistent: ref = %d, tst = %d\n",\ + pos, ref_s[pos], tst_s[pos]);\ + break;\ + }\ + }\ + if (eql != 0) {\ + printf("s"#outbits" results are equal\n");\ + TEST_1D_OPERATION_PERF(size, num_iters, FUNC_CALLER_QUANT_SYM,\ + inbits, outbits, src, tst_s, &tst_scale)\ + }\ +\ + free(src);\ + free(ref_s);\ + free(tst_s);\ +} + +#define FUNC_CALLER_DEQUANT(pos, size, inbits, outbits, src, tst_f, scale) \ + dequantize_symmetric_f##outbits##_s##inbits(src + pos * size,\ + tst_f + pos * size, scale, size); + +#define TEST_DEQUANT_SYM(inbits, outbits) \ +static void test_dequant_sym_f##outbits##_s##inbits(uint32_t size) {\ + if (size < 4) size = 4;\ + printf("Test dequantization s"#inbits" -> f"#outbits":\n");\ + printf("num_elements = %u\n", size);\ +\ + const uint32_t num_iters = 40000000 / size;\ + if (num_iters <= 2) {\ + printf("Problem size too large.\n");\ + return;\ + }\ +\ + int##inbits##_t * const src =\ + (int##inbits##_t *)malloc(num_iters * size * (inbits >> 3));\ + float##outbits##_t * const ref_f =\ + (float##outbits##_t *)malloc(size * (outbits >> 3));\ + float##outbits##_t * const tst_f =\ + (float##outbits##_t *)malloc(num_iters * size * (outbits >> 3));\ +\ + srand(time(NULL));\ + const float##outbits##_t scale = (float##outbits##_t)rand() / RAND_MAX;\ + printf("Generate src with scale = %.2e\n", scale);\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + src[pos] = (long long)rand() - (long long)(RAND_MAX >> 1);\ + ref_f[pos] = scale * src[pos];\ + }\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + memcpy(src + pos * size, src, size * (inbits >> 3));\ + }\ + printf("First 4 elements of src:\n%d, %d, %d, %d\n",\ + src[0], src[1], src[2], src[3]);\ + printf("First 4 elements of ref:\n%.2e, %.2e, %.2e, %.2e\n",\ + ref_f[0], ref_f[1], ref_f[2], ref_f[3]);\ +\ + dequantize_symmetric_f##outbits##_s##inbits(src, tst_f, scale, size);\ +\ + float##outbits##_t max_diff = 0.0;\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + float##outbits##_t tmp = tst_f[pos] - ref_f[pos];\ + if (tmp < 0) tmp *= -1.0;\ + if (tmp > max_diff) max_diff = tmp;\ + }\ + printf("Max diff. between tst. and ref.: %.2e\n", max_diff);\ +\ + TEST_1D_OPERATION_PERF(size, num_iters, FUNC_CALLER_DEQUANT, inbits, outbits,\ + src, tst_f, scale)\ +\ + free(src);\ + free(ref_f);\ + free(tst_f);\ +} + +#define FUNC_CALLER_REQUANT_UNSYM(pos, size, inbits, fp, outbits,\ + src, dst, org_scale, zero_addr) \ + fp tmp_scale = org_scale;\ + requantize_asymmetric_##inbits##to##outbits(\ + src + pos * size, dst + pos * size, &tmp_scale, zero_addr, size, 0, -1); + +#define TEST_REQUANT_UNSYM(fp, inbits, outbits) \ +static void test_requant_int##inbits##_t_##fp##_uint##outbits##_t(\ + uint32_t size, int##inbits##_t min_src, int##inbits##_t max_src,\ + fp org_scale) {\ +\ + if (max_src < min_src) {\ + int##inbits##_t tmp = min_src;\ + min_src = max_src;\ + max_src = tmp;\ + }\ + if (size < 4) size = 4;\ + printf("Test unsymmetrical requantization int"#inbits"_t -> uint"#outbits"_t:\n");\ + printf("Range of src: %lld - %lld\n", (long long)min_src, (long long)max_src);\ + printf("original_scale = %.2e\n", org_scale);\ + printf("num_elements = %u\n", size);\ +\ + const uint32_t num_iters = 40000000 / size;\ + if (num_iters <= 2) {\ + printf("Problem size too large.\n");\ + return;\ + }\ +\ + int##inbits##_t * const src = (int##inbits##_t *)malloc(\ + num_iters * size * sizeof(int##inbits##_t));\ + uint##outbits##_t * const dst = (uint##outbits##_t *)malloc(\ + num_iters * size * sizeof(uint##outbits##_t));\ +\ + const double range = (long long)max_src - (long long)min_src;\ + srand(time(NULL));\ + if (range == 0) {\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + src[pos] = min_src;\ + }\ + } else {\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + double rv = (double)rand() / (double)RAND_MAX;\ + double v = rv * range + (double)min_src;\ + int##inbits##_t iv = v;\ + if (iv < min_src) iv = min_src;\ + if (iv > max_src) iv = max_src;\ + src[pos] = iv;\ + }\ + uint32_t min_pos = rand() % size;\ + uint32_t max_pos = rand() % size;\ + while(max_pos == min_pos) {\ + max_pos = rand() % size;\ + }\ + src[min_pos] = min_src;\ + src[max_pos] = max_src;\ + }\ + printf("First 4 src elements: %lld, %lld, %lld, %lld\n",\ + (long long)src[0], (long long)src[1], (long long)src[2], (long long)src[3]);\ + for (uint32_t it = 1; it < num_iters; ++it) {\ + memcpy(src + it * size, src, size * sizeof(int##inbits##_t));\ + }\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + dst[pos] = rand();\ + }\ +\ + const long long renorm_min_src = min_src > 0 ? 0 : min_src;\ + const long long renorm_max_src = max_src < 0 ? 0 : max_src;\ + const fp ref_scale = (double)org_scale * \ + (double)(renorm_max_src - renorm_min_src) / ((uint##outbits##_t)-1);\ + printf("ref_scale = %.2e\n", ref_scale);\ +\ + uint##outbits##_t zero_point;\ + fp new_scale = org_scale;\ + requantize_asymmetric_##inbits##to##outbits(src, dst,\ + &new_scale, &zero_point, size, 0, -1);\ +\ + printf("tst_zero = %u\n", zero_point);\ + printf("tst_scale - ref_scale = %.2e\n", new_scale - ref_scale);\ + long min_out, max_out;\ + double max_diff_out = 0.0;\ + min_out = max_out = dst[0];\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + long ld = dst[pos];\ + if (ld < min_out) min_out = ld;\ + if (ld > max_out) max_out = ld;\ + double curr_fp = src[pos] * (double)org_scale;\ + double curr_i8 = curr_fp / (double)new_scale;\ + double curr_u8 = curr_i8 + (double)zero_point;\ + double tmp_diff_out = (double)ld - curr_u8;\ + if (tmp_diff_out < 0) tmp_diff_out *= -1.0;\ + if (tmp_diff_out > max_diff_out) max_diff_out = tmp_diff_out;\ + }\ + printf("range of requant u"#outbits": [%ld, %ld]\n", min_out, max_out);\ + printf("max deviation of requant u"#outbits": %.2e\n", max_diff_out);\ +\ + TEST_1D_OPERATION_PERF(size, num_iters, FUNC_CALLER_REQUANT_UNSYM,\ + inbits, fp, outbits, src, dst, org_scale, &zero_point)\ +\ + free(src);\ + free(dst);\ +} + +#define FUNC_CALLER_REQUANT_SYM(pos, size, inbits, fp, outbits, src, dst, org_scale) \ + fp tmp_scale = org_scale;\ + requantize_symmetric_##inbits##to##outbits(\ + src + pos * size, dst + pos * size, &tmp_scale, size, 0, -1); + +#define TEST_REQUANT_SYM(fp, inbits, outbits) \ +static void test_requant_int##inbits##_t_##fp##_int##outbits##_t(\ + uint32_t size, int##inbits##_t max_abs, fp org_scale) {\ +\ + if (max_abs < 0) max_abs = -max_abs;\ + if (size < 4) size = 4;\ + printf("Test symmetrical requantization int"#inbits"_t -> int"#outbits"_t:\n");\ + printf("Range of src: %d - %d\n", -max_abs, max_abs);\ + printf("original_scale = %.2e\n", org_scale);\ + printf("num_elements = %u\n", size);\ +\ + const uint32_t num_iters = 40000000 / size;\ + if (num_iters <= 2) {\ + printf("Problem size too large.\n");\ + return;\ + }\ +\ + int##inbits##_t * const src = (int##inbits##_t *)malloc(\ + num_iters * size * sizeof(int##inbits##_t));\ + int##outbits##_t * const dst = (int##outbits##_t *)malloc(\ + num_iters * size * sizeof(int##outbits##_t));\ +\ + srand(time(NULL));\ + if (max_abs == 0) {\ + memset(src, 0, size * sizeof(int##inbits##_t));\ + } else {\ + const double rand_range = 2.0 * (double)max_abs + 1.0;\ + const double rand_offset = -1.0 * (double)max_abs;\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + double rv = (double)rand() / (double)RAND_MAX;\ + double ra = rv * rand_range + rand_offset;\ + int##inbits##_t ia = ra;\ + if (ia < -max_abs) ia = -max_abs;\ + if (ia > max_abs) ia = max_abs;\ + src[pos] = ia;\ + }\ + uint32_t max_rand_pos = rand() % size;\ + src[max_rand_pos] = (rand() & 1) ? max_abs : -max_abs;\ + }\ + printf("The first 4 elements of src: %lld, %lld, %lld, %lld\n",\ + (long long)src[0], (long long)src[1], (long long)src[2], (long long)src[3]);\ + for (uint32_t it = 1; it < num_iters; ++it) {\ + memcpy(src + it * size, src, size * sizeof(int##inbits##_t));\ + }\ +\ + const fp ref_scale = (double)org_scale * (double)max_abs / \ + (double)(((uint##outbits##_t)-1) >> 1);\ + printf("ref_scale = %.2e\n", ref_scale);\ + fp new_scale = org_scale;\ + requantize_symmetric_##inbits##to##outbits(src, dst, &new_scale, size, 0, -1);\ + printf("diff. between ref_scale and tst_scale: %.2e\n",\ + new_scale - ref_scale);\ +\ + int##outbits##_t max_out_abs = 0;\ + double max_out_dev = 0.0;\ + for (uint32_t pos = 0; pos < size; ++pos) {\ + int##outbits##_t l1 = dst[pos];\ + if (l1 > max_out_abs) max_out_abs = l1;\ + if (-l1 > max_out_abs) max_out_abs = -l1;\ + if (new_scale != 0.0) {\ + double expected = (double)src[pos] * (double)org_scale / \ + (double)new_scale;\ + double tmp_dev = expected - (double)dst[pos];\ + if (tmp_dev < 0) tmp_dev *= -1.0;\ + if (tmp_dev > max_out_dev) max_out_dev = tmp_dev;\ + }\ + }\ + printf("max abs of output int"#outbits": %d\n", max_out_abs);\ + if (new_scale == 0.0) {\ + printf("max deviation of output int"#outbits" not determined.\n");\ + } else {\ + printf("max deviation of output int"#outbits": %.2e\n", max_out_dev);\ + }\ +\ + TEST_1D_OPERATION_PERF(size, num_iters, FUNC_CALLER_REQUANT_SYM,\ + inbits, fp, outbits, src, dst, org_scale)\ +\ + free(src);\ + free(dst);\ +} + +#endif diff --git a/include/common/ExpandMacro.h b/include/common/ExpandMacro.h new file mode 100644 index 0000000..62403e7 --- /dev/null +++ b/include/common/ExpandMacro.h @@ -0,0 +1,932 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +/****************************************************************************** + * File: ExpandMacro.h + * Description: Smart macros for manual unroll of tiny loops + * Example: Original loop: + * INITIALIZATION(parm1, parm2) + * for (int i = 1; i <= 8; ++i) { + * LOOP_ITEM(i, parm1, parm2) + * } + * Using macros to manually unroll the loop: + * MACRO_EXP_8(INITIALIZATION, LOOP_ITEM, parm1, parm2) + * Which is identical to the original loop. + *****************************************************************************/ +#ifndef INCLUDE_EXPAND_MACRO +#define INCLUDE_EXPAND_MACRO + +#define VOID_BASE(...) /* */ + +#define MACRO_EXP_0(BASE, ADD_ITEM, ...) \ + BASE(__VA_ARGS__) + +#define MACRO_EXP_1(BASE, ADD_ITEM, ...) \ + BASE(__VA_ARGS__)\ + ADD_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXP_2(BASE, ADD_ITEM, ...) \ + MACRO_EXP_1(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(2, ##__VA_ARGS__) + +#define MACRO_EXP_3(BASE, ADD_ITEM, ...) \ + MACRO_EXP_2(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(3, ##__VA_ARGS__) + +#define MACRO_EXP_4(BASE, ADD_ITEM, ...) \ + MACRO_EXP_3(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(4, ##__VA_ARGS__) + +#define MACRO_EXP_5(BASE, ADD_ITEM, ...) \ + MACRO_EXP_4(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(5, ##__VA_ARGS__) + +#define MACRO_EXP_6(BASE, ADD_ITEM, ...) \ + MACRO_EXP_5(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(6, ##__VA_ARGS__) + +#define MACRO_EXP_7(BASE, ADD_ITEM, ...) \ + MACRO_EXP_6(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(7, ##__VA_ARGS__) + +#define MACRO_EXP_8(BASE, ADD_ITEM, ...) \ + MACRO_EXP_7(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(8, ##__VA_ARGS__) + +#define MACRO_EXP_9(BASE, ADD_ITEM, ...) \ + MACRO_EXP_8(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(9, ##__VA_ARGS__) + +#define MACRO_EXP_10(BASE, ADD_ITEM, ...) \ + MACRO_EXP_9(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(10, ##__VA_ARGS__) + +#define MACRO_EXP_11(BASE, ADD_ITEM, ...) \ + MACRO_EXP_10(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(11, ##__VA_ARGS__) + +#define MACRO_EXP_12(BASE, ADD_ITEM, ...) \ + MACRO_EXP_11(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(12, ##__VA_ARGS__) + +#define MACRO_EXP_13(BASE, ADD_ITEM, ...) \ + MACRO_EXP_12(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(13, ##__VA_ARGS__) + +#define MACRO_EXP_14(BASE, ADD_ITEM, ...) \ + MACRO_EXP_13(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(14, ##__VA_ARGS__) + +#define MACRO_EXP_15(BASE, ADD_ITEM, ...) \ + MACRO_EXP_14(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(15, ##__VA_ARGS__) + +#define MACRO_EXP_16(BASE, ADD_ITEM, ...) \ + MACRO_EXP_15(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(16, ##__VA_ARGS__) + +#define MACRO_EXP_17(BASE, ADD_ITEM, ...) \ + MACRO_EXP_16(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(17, ##__VA_ARGS__) + +#define MACRO_EXP_18(BASE, ADD_ITEM, ...) \ + MACRO_EXP_17(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(18, ##__VA_ARGS__) + +#define MACRO_EXP_19(BASE, ADD_ITEM, ...) \ + MACRO_EXP_18(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(19, ##__VA_ARGS__) + +#define MACRO_EXP_20(BASE, ADD_ITEM, ...) \ + MACRO_EXP_19(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(20, ##__VA_ARGS__) + +#define MACRO_EXP_21(BASE, ADD_ITEM, ...) \ + MACRO_EXP_20(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(21, ##__VA_ARGS__) + +#define MACRO_EXP_22(BASE, ADD_ITEM, ...) \ + MACRO_EXP_21(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(22, ##__VA_ARGS__) + +#define MACRO_EXP_23(BASE, ADD_ITEM, ...) \ + MACRO_EXP_22(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(23, ##__VA_ARGS__) + +#define MACRO_EXP_24(BASE, ADD_ITEM, ...) \ + MACRO_EXP_23(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(24, ##__VA_ARGS__) + +#define MACRO_EXP_25(BASE, ADD_ITEM, ...) \ + MACRO_EXP_24(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(25, ##__VA_ARGS__) + +#define MACRO_EXP_26(BASE, ADD_ITEM, ...) \ + MACRO_EXP_25(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(26, ##__VA_ARGS__) + +#define MACRO_EXP_27(BASE, ADD_ITEM, ...) \ + MACRO_EXP_26(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(27, ##__VA_ARGS__) + +#define MACRO_EXP_28(BASE, ADD_ITEM, ...) \ + MACRO_EXP_27(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(28, ##__VA_ARGS__) + +#define MACRO_EXP_29(BASE, ADD_ITEM, ...) \ + MACRO_EXP_28(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(29, ##__VA_ARGS__) + +#define MACRO_EXP_30(BASE, ADD_ITEM, ...) \ + MACRO_EXP_29(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(30, ##__VA_ARGS__) + +#define MACRO_EXP_31(BASE, ADD_ITEM, ...) \ + MACRO_EXP_30(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(31, ##__VA_ARGS__) + +#define MACRO_EXP_32(BASE, ADD_ITEM, ...) \ + MACRO_EXP_31(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(32, ##__VA_ARGS__) + +#define MACRO_EXP_33(BASE, ADD_ITEM, ...) \ + MACRO_EXP_32(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(33, ##__VA_ARGS__) + +#define MACRO_EXP_34(BASE, ADD_ITEM, ...) \ + MACRO_EXP_33(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(34, ##__VA_ARGS__) + +#define MACRO_EXP_35(BASE, ADD_ITEM, ...) \ + MACRO_EXP_34(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(35, ##__VA_ARGS__) + +#define MACRO_EXP_36(BASE, ADD_ITEM, ...) \ + MACRO_EXP_35(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(36, ##__VA_ARGS__) + +#define MACRO_EXP_37(BASE, ADD_ITEM, ...) \ + MACRO_EXP_36(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(37, ##__VA_ARGS__) + +#define MACRO_EXP_38(BASE, ADD_ITEM, ...) \ + MACRO_EXP_37(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(38, ##__VA_ARGS__) + +#define MACRO_EXP_39(BASE, ADD_ITEM, ...) \ + MACRO_EXP_38(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(39, ##__VA_ARGS__) + +#define MACRO_EXP_40(BASE, ADD_ITEM, ...) \ + MACRO_EXP_39(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(40, ##__VA_ARGS__) + +#define MACRO_EXP_41(BASE, ADD_ITEM, ...) \ + MACRO_EXP_40(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(41, ##__VA_ARGS__) + +#define MACRO_EXP_42(BASE, ADD_ITEM, ...) \ + MACRO_EXP_41(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(42, ##__VA_ARGS__) + +#define MACRO_EXP_43(BASE, ADD_ITEM, ...) \ + MACRO_EXP_42(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(43, ##__VA_ARGS__) + +#define MACRO_EXP_44(BASE, ADD_ITEM, ...) \ + MACRO_EXP_43(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(44, ##__VA_ARGS__) + +#define MACRO_EXP_45(BASE, ADD_ITEM, ...) \ + MACRO_EXP_44(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(45, ##__VA_ARGS__) + +#define MACRO_EXP_46(BASE, ADD_ITEM, ...) \ + MACRO_EXP_45(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(46, ##__VA_ARGS__) + +#define MACRO_EXP_47(BASE, ADD_ITEM, ...) \ + MACRO_EXP_46(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(47, ##__VA_ARGS__) + +#define MACRO_EXP_48(BASE, ADD_ITEM, ...) \ + MACRO_EXP_47(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(48, ##__VA_ARGS__) + +#define MACRO_EXP_49(BASE, ADD_ITEM, ...) \ + MACRO_EXP_48(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(49, ##__VA_ARGS__) + +#define MACRO_EXP_50(BASE, ADD_ITEM, ...) \ + MACRO_EXP_49(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(50, ##__VA_ARGS__) + +#define MACRO_EXP_51(BASE, ADD_ITEM, ...) \ + MACRO_EXP_50(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(51, ##__VA_ARGS__) + +#define MACRO_EXP_52(BASE, ADD_ITEM, ...) \ + MACRO_EXP_51(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(52, ##__VA_ARGS__) + +/* MACRO_EXPANSION_X does the same as MACRO_EXP_X. + * this macro is useful when 2D unrolling is needed. + * Example: + * to unroll LOOP_ITEM(x, y, parms) to x = 1-5 + * and y = 1-3, just write this: + * #define LOOP_X_UNROLL(y, params)\ + * MACRO_EXP_5(LOOP_ITEM, y, params) + * MACRO_EXPANSION_3(LOOP_X_UNROLL, params) + * //you can't use MACRO_EXP_3 here because + * //recursion cannot occur in macro expansion + */ +#define MACRO_EXPANSION_0(BASE, ADD_ITEM, ...) \ + BASE(__VA_ARGS__) + +#define MACRO_EXPANSION_1(BASE, ADD_ITEM, ...) \ + BASE(__VA_ARGS__)\ + ADD_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXPANSION_2(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_1(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(2, ##__VA_ARGS__) + +#define MACRO_EXPANSION_3(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_2(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(3, ##__VA_ARGS__) + +#define MACRO_EXPANSION_4(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_3(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(4, ##__VA_ARGS__) + +#define MACRO_EXPANSION_5(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_4(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(5, ##__VA_ARGS__) + +#define MACRO_EXPANSION_6(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_5(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(6, ##__VA_ARGS__) + +#define MACRO_EXPANSION_7(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_6(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(7, ##__VA_ARGS__) + +#define MACRO_EXPANSION_8(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_7(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(8, ##__VA_ARGS__) + +#define MACRO_EXPANSION_9(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_8(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(9, ##__VA_ARGS__) + +#define MACRO_EXPANSION_10(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_9(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(10, ##__VA_ARGS__) + +#define MACRO_EXPANSION_11(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_10(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(11, ##__VA_ARGS__) + +#define MACRO_EXPANSION_12(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_11(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(12, ##__VA_ARGS__) + +#define MACRO_EXPANSION_13(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_12(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(13, ##__VA_ARGS__) + +#define MACRO_EXPANSION_14(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_13(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(14, ##__VA_ARGS__) + +#define MACRO_EXPANSION_15(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_14(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(15, ##__VA_ARGS__) + +#define MACRO_EXPANSION_16(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_15(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(16, ##__VA_ARGS__) + +#define MACRO_EXPANSION_17(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_16(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(17, ##__VA_ARGS__) + +#define MACRO_EXPANSION_18(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_17(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(18, ##__VA_ARGS__) + +#define MACRO_EXPANSION_19(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_18(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(19, ##__VA_ARGS__) + +#define MACRO_EXPANSION_20(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_19(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(20, ##__VA_ARGS__) + +#define MACRO_EXPANSION_21(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_20(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(21, ##__VA_ARGS__) + +#define MACRO_EXPANSION_22(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_21(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(22, ##__VA_ARGS__) + +#define MACRO_EXPANSION_23(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_22(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(23, ##__VA_ARGS__) + +#define MACRO_EXPANSION_24(BASE, ADD_ITEM, ...) \ + MACRO_EXPANSION_23(BASE, ADD_ITEM, ##__VA_ARGS__)\ + ADD_ITEM(24, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_1(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + BASE(__VA_ARGS__)\ + ADD_ITEM_1(1, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_2(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_1(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_1(2, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_3(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_2(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_1(3, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_4(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + BASE(__VA_ARGS__)\ + ADD_ITEM_4(1, 2, 3, 4, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_5(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_4(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_1(5, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_6(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_5(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_1(6, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_7(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_6(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_1(7, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_8(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_4(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_4(5, 6, 7, 8, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_9(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_8(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_1(9, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_10(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_9(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_1(10, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_11(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_10(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_1(11, ##__VA_ARGS__) + +#define MACRO_EXPANSION_Q_12(BASE, ADD_ITEM_1, ADD_ITEM_4, ...) \ + MACRO_EXPANSION_Q_8(BASE, ADD_ITEM_1, ADD_ITEM_4, ##__VA_ARGS__)\ + ADD_ITEM_4(9, 10, 11, 12, ##__VA_ARGS__) + +#define MACRO_EXPANSION_E_1(LOOP_ITEM, ...) \ + LOOP_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXPANSION_E_2(LOOP_ITEM, ...) \ + LOOP_ITEM(2, ##__VA_ARGS__) MACRO_EXPANSION_E_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_E_4(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXPANSION_E_2(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_E_6(LOOP_ITEM, ...) \ + LOOP_ITEM(6, ##__VA_ARGS__) MACRO_EXPANSION_E_4(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_E_8(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_E_4(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_E_12(LOOP_ITEM, ...) \ + LOOP_ITEM(12, ##__VA_ARGS__) MACRO_EXPANSION_E_8(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_E_16(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_E_8(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_E_24(LOOP_ITEM, ...) \ + LOOP_ITEM(24, ##__VA_ARGS__) MACRO_EXPANSION_E_12(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_E_1(LOOP_ITEM, ...) \ + LOOP_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXP_E_2(LOOP_ITEM, ...) \ + LOOP_ITEM(2, ##__VA_ARGS__) MACRO_EXP_E_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_E_4(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXP_E_2(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_E_6(LOOP_ITEM, ...) \ + LOOP_ITEM(6, ##__VA_ARGS__) MACRO_EXP_E_4(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_E_8(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_E_4(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_E_12(LOOP_ITEM, ...) \ + LOOP_ITEM(12, ##__VA_ARGS__) MACRO_EXP_E_8(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_E_16(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_E_8(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_E_24(LOOP_ITEM, ...) \ + LOOP_ITEM(24, ##__VA_ARGS__) MACRO_EXP_E_12(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_0(LOOP_ITEM, ...) \ + LOOP_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXP_M_1(LOOP_ITEM, ...) \ + LOOP_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXP_M_2(LOOP_ITEM, ...) \ + LOOP_ITEM(2, ##__VA_ARGS__) MACRO_EXP_M_0(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_3(LOOP_ITEM, ...) \ + LOOP_ITEM(2, ##__VA_ARGS__) MACRO_EXP_M_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_4(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXP_M_0(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_5(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXP_M_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_6(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXP_M_2(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_7(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXP_M_3(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_8(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_M_0(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_9(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_M_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_10(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_M_2(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_11(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_M_3(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_12(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_M_4(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_13(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_M_5(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_14(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_M_6(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_15(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXP_M_7(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_16(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_0(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_17(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_18(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_2(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_19(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_3(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_20(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_4(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_21(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_5(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_22(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_6(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_23(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_7(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_24(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_8(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_25(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_9(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_26(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_10(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_27(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_11(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_28(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_12(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_29(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_13(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_30(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_14(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_31(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXP_M_15(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_0(LOOP_ITEM, ...) \ + LOOP_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_1(LOOP_ITEM, ...) \ + LOOP_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_2(LOOP_ITEM, ...) \ + LOOP_ITEM(2, ##__VA_ARGS__) MACRO_EXPANSION_M_0(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_3(LOOP_ITEM, ...) \ + LOOP_ITEM(2, ##__VA_ARGS__) MACRO_EXPANSION_M_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_4(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXPANSION_M_0(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_5(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXPANSION_M_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_6(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXPANSION_M_2(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_7(LOOP_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) MACRO_EXPANSION_M_3(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_8(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_M_0(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_9(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_M_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_10(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_M_2(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_11(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_M_3(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_12(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_M_4(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_13(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_M_5(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_14(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_M_6(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_15(LOOP_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) MACRO_EXPANSION_M_7(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_16(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_0(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_17(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_1(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_18(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_2(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_19(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_3(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_20(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_4(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_21(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_5(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_22(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_6(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_23(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_7(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_24(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_8(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_25(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_9(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_26(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_10(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_27(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_11(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_28(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_12(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_29(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_13(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_30(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_14(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_M_31(LOOP_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) MACRO_EXPANSION_M_15(LOOP_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_0(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_1(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(1, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_2(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(2, ##__VA_ARGS__) CROSS_ITEM(2, 1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_0(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_3(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(2, ##__VA_ARGS__) CROSS_ITEM(2, 1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_1(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_4(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) CROSS_ITEM(4, 1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_0(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_5(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) CROSS_ITEM(4, 1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_1(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_6(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) CROSS_ITEM(4, 2, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_2(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_7(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(4, ##__VA_ARGS__) CROSS_ITEM(4, 2, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_3(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_8(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) CROSS_ITEM(8, 1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_0(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_9(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) CROSS_ITEM(8, 1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_1(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_10(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) CROSS_ITEM(8, 2, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_2(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_11(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) CROSS_ITEM(8, 2, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_3(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_12(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) CROSS_ITEM(8, 4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_4(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_13(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) CROSS_ITEM(8, 4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_5(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_14(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) CROSS_ITEM(8, 4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_6(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_15(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(8, ##__VA_ARGS__) CROSS_ITEM(8, 4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_7(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_16(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_0(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_17(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_1(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_18(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 2, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_2(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_19(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 2, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_3(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_20(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_4(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_21(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_5(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_22(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_6(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_23(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_7(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_24(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_8(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_25(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_9(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_26(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_10(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_27(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_11(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_28(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_12(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_29(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_13(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_30(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_14(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_MX_31(LOOP_ITEM, CROSS_ITEM, ...) \ + LOOP_ITEM(16, ##__VA_ARGS__) CROSS_ITEM(16, 8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_15(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_0(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_0(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_1(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(1, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_1(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_2(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(2, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_2(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_3(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(2, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_3(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_4(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_4(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_5(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_5(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_6(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_6(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_7(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(4, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_7(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_8(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_8(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_9(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_9(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_10(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_10(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_11(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_11(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_12(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_12(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_13(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_13(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_14(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_14(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_15(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(8, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_15(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_16(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_16(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_17(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_17(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_18(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_18(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_19(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_19(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_20(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_20(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_21(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_21(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_22(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_22(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_23(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_23(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_24(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_24(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_25(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_25(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_26(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_26(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_27(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_27(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_28(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_28(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_29(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_29(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_30(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_30(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXPANSION_IMX_31(INIT_ITEM, LOOP_ITEM, CROSS_ITEM, ...) \ + INIT_ITEM(16, ##__VA_ARGS__)\ + MACRO_EXPANSION_MX_31(LOOP_ITEM, CROSS_ITEM, ##__VA_ARGS__) + +#define MACRO_EXP_M_FIRSTITEM_0 1 +#define MACRO_EXP_M_FIRSTITEM_1 1 +#define MACRO_EXP_M_FIRSTITEM_2 2 +#define MACRO_EXP_M_FIRSTITEM_3 2 +#define MACRO_EXP_M_FIRSTITEM_4 4 +#define MACRO_EXP_M_FIRSTITEM_5 4 +#define MACRO_EXP_M_FIRSTITEM_6 4 +#define MACRO_EXP_M_FIRSTITEM_7 4 +#define MACRO_EXP_M_FIRSTITEM_8 8 +#define MACRO_EXP_M_FIRSTITEM_9 8 +#define MACRO_EXP_M_FIRSTITEM_10 8 +#define MACRO_EXP_M_FIRSTITEM_11 8 +#define MACRO_EXP_M_FIRSTITEM_12 8 +#define MACRO_EXP_M_FIRSTITEM_13 8 +#define MACRO_EXP_M_FIRSTITEM_14 8 +#define MACRO_EXP_M_FIRSTITEM_15 8 +#define MACRO_EXP_M_FIRSTITEM_16 16 +#define MACRO_EXP_M_FIRSTITEM_17 16 +#define MACRO_EXP_M_FIRSTITEM_18 16 +#define MACRO_EXP_M_FIRSTITEM_19 16 +#define MACRO_EXP_M_FIRSTITEM_20 16 +#define MACRO_EXP_M_FIRSTITEM_21 16 +#define MACRO_EXP_M_FIRSTITEM_22 16 +#define MACRO_EXP_M_FIRSTITEM_23 16 +#define MACRO_EXP_M_FIRSTITEM_24 16 +#define MACRO_EXP_M_FIRSTITEM_25 16 +#define MACRO_EXP_M_FIRSTITEM_26 16 +#define MACRO_EXP_M_FIRSTITEM_27 16 +#define MACRO_EXP_M_FIRSTITEM_28 16 +#define MACRO_EXP_M_FIRSTITEM_29 16 +#define MACRO_EXP_M_FIRSTITEM_30 16 +#define MACRO_EXP_M_FIRSTITEM_31 16 + +#endif + diff --git a/include/neon_armv7a/Bias.h b/include/neon_armv7a/Bias.h new file mode 100644 index 0000000..6fa7b9a --- /dev/null +++ b/include/neon_armv7a/Bias.h @@ -0,0 +1,35 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void bias_float(float *dst, float bias_dim0, + const float *bias_dim1, float bias_dim1_scale, + const float *bias_dim2, float bias_dim2_scale, + uint32_t dim1, uint32_t dim2); + +void bias_int32_t(int32_t *dst, int32_t bias_dim0, + const int32_t *bias_dim1, int32_t bias_dim1_scale, + const int32_t *bias_dim2, int32_t bias_dim2_scale, + uint32_t dim1, uint32_t dim2); + +void u8u32_sum(const uint8_t *src, uint32_t *dst, + uint32_t dim1, uint32_t dim2, uint8_t direction); + +void s16_sumsquare(const int16_t *dat, int32_t *sum, + int64_t *sumsquare, uint32_t size); + diff --git a/include/neon_armv7a/I8I32MlaGemmKernel.h b/include/neon_armv7a/I8I32MlaGemmKernel.h new file mode 100644 index 0000000..9ff9a55 --- /dev/null +++ b/include/neon_armv7a/I8I32MlaGemmKernel.h @@ -0,0 +1,242 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/NeonI8I32MlaGemmKernel.h" + +#ifndef INCLUDE_ARMV7A_I8I32MLA_KERNEL +#define INCLUDE_ARMV7A_I8I32MLA_KERNEL + +#define KERNEL_M6N8_UNIT(a_head, b_head) \ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06;\ + I32X4 cq07, cq08, cq09, cq10, cq11, cq12;\ + COMMON_KERNEL_HEADER(a_head, b_head)\ + __asm__ __volatile__(\ + "vmov.i8 %q[cq01],#0; vmov.i8 %q[cq02],#0\n\t"\ + "vmov.i8 %q[cq03],#0; vmov.i8 %q[cq04],#0\n\t"\ + "vmov.i8 %q[cq05],#0; vmov.i8 %q[cq06],#0\n\t"\ + "vmov.i8 %q[cq07],#0; vmov.i8 %q[cq08],#0\n\t"\ + "vmov.i8 %q[cq09],#0; vmov.i8 %q[cq10],#0\n\t"\ + "vmov.i8 %q[cq11],#0; vmov.i8 %q[cq12],#0\n\t"\ + "cmp %[k_left],#2; blt 4f\n\t"\ + "vldr d0,[%[a_ptr]]; ldr r0,[%[a_ptr],#8]\n\t"\ + "ldr r1,[%[a_ptr],#12]; add %[a_ptr],%[a_ptr],#24\n\t"\ + "vldr d4,[%[b_ptr]]; vldr d5,[%[b_ptr],#8]; ldr r2,[%[b_ptr],#16]\n\t"\ + "ldr r3,[%[b_ptr],#20]; add %[b_ptr],%[b_ptr],#32\n\t"\ + "cmp %[k_left],#6; blt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + "vmov d6,r2,r3; vldr d7,[%[b_ptr],#-8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d4,d0[0]; ldr r2,[%[b_ptr]]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d5,d0[0]; ldr r3,[%[b_ptr],#4]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d4,d0[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d5,d0[1]\n\t"\ + "vmov d1,r0,r1; vldr d2,[%[a_ptr],#-8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d4,d0[2]; ldr r0,[%[a_ptr]]\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d5,d0[2]; ldr r1,[%[a_ptr],#4]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d4,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d5,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d4,d1[0]; pld [%[a_ptr],#128]\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d5,d1[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d4,d1[1]; pld [%[b_ptr],#128]\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d5,d1[1]\n\t"\ + "vmov d4,r2,r3; vldr d5,[%[b_ptr],#8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d6,d1[2]; ldr r2,[%[b_ptr],#16]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d7,d1[2]; ldr r3,[%[b_ptr],#20]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d6,d1[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d7,d1[3]\n\t"\ + "vmov d0,r0,r1; vldr d1,[%[a_ptr],#8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d6,d2[0]; ldr r0,[%[a_ptr],#16]\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d7,d2[0]; ldr r1,[%[a_ptr],#20]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d6,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d7,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d6,d2[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d7,d2[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d6,d2[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d7,d2[3]\n\t"\ + "vmov d6,r2,r3; vldr d7,[%[b_ptr],#24]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d4,d0[0]; ldr r2,[%[b_ptr],#32]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d5,d0[0]; ldr r3,[%[b_ptr],#36]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d4,d0[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d5,d0[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d4,d0[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d5,d0[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d4,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d5,d0[3]\n\t"\ + "vmov d2,r0,r1; vldr d0,[%[a_ptr],#24]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d4,d1[0]; ldr r0,[%[a_ptr],#32]\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d5,d1[0]; ldr r1,[%[a_ptr],#36]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d4,d1[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d5,d1[1]\n\t"\ + "vmov d4,r2,r3; vldr d5,[%[b_ptr],#40]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d6,d1[2]; ldr r2,[%[b_ptr],#48]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d7,d1[2]; ldr r3,[%[b_ptr],#52]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d6,d1[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d7,d1[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d6,d2[0]; add %[a_ptr],%[a_ptr],#48\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d7,d2[0]; add %[b_ptr],%[b_ptr],#64\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d6,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d7,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d6,d2[2]; sub %[k_left],%[k_left],#4\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d7,d2[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d6,d2[3]; cmp %[k_left],#6\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d7,d2[3]; bge 1b\n\t"\ + "2:\n\t"\ + "cmp %[k_left],#4; blt 3f\n\t"\ + "vmov d6,r2,r3; vldr d7,[%[b_ptr],#-8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d4,d0[0]; ldr r2,[%[b_ptr]]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d5,d0[0]; ldr r3,[%[b_ptr],#4]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d4,d0[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d5,d0[1]\n\t"\ + "vmov d1,r0,r1; vldr d2,[%[a_ptr],#-8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d4,d0[2]; ldr r0,[%[a_ptr]]\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d5,d0[2]; ldr r1,[%[a_ptr],#4]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d4,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d5,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d4,d1[0]; pld [%[a_ptr],#128]\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d5,d1[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d4,d1[1]; pld [%[b_ptr],#128]\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d5,d1[1]\n\t"\ + "vmov d4,r2,r3; vldr d5,[%[b_ptr],#8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d6,d1[2]; ldr r2,[%[b_ptr],#16]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d7,d1[2]; ldr r3,[%[b_ptr],#20]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d6,d1[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d7,d1[3]\n\t"\ + "vmov d0,r0,r1; vldr d1,[%[a_ptr],#8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d6,d2[0]; ldr r0,[%[a_ptr],#16]\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d7,d2[0]; ldr r1,[%[a_ptr],#20]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d6,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d7,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d6,d2[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d7,d2[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d6,d2[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d7,d2[3]\n\t"\ + "vmov d6,r2,r3; vldr d7,[%[b_ptr],#24]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d4,d0[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d5,d0[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d4,d0[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d5,d0[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d4,d0[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d5,d0[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d4,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d5,d0[3]\n\t"\ + "vmov d2,r0,r1\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d4,d1[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d5,d1[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d4,d1[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d5,d1[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d6,d1[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d7,d1[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d6,d1[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d7,d1[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d6,d2[0]; add %[a_ptr],%[a_ptr],#24\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d7,d2[0]; add %[b_ptr],%[b_ptr],#32\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d6,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d7,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d6,d2[2]; sub %[k_left],%[k_left],#4\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d7,d2[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d6,d2[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d7,d2[3]; b 4f\n\t"\ + "3:\n\t"\ + "vmov d6,r2,r3; vldr d7,[%[b_ptr],#-8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d4,d0[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d5,d0[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d4,d0[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d5,d0[1]\n\t"\ + "vmov d1,r0,r1; vldr d2,[%[a_ptr],#-8]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d4,d0[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d5,d0[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d4,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d5,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d4,d1[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d5,d1[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d4,d1[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d5,d1[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d6,d1[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq02],d7,d1[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d6,d1[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq04],d7,d1[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d6,d2[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq06],d7,d2[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d6,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq08],d7,d2[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d6,d2[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq10],d7,d2[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d6,d2[3]; sub %[k_left],%[k_left],#2\n\t"\ + ""ASM_VMLAL_I16" %q[cq12],d7,d2[3]\n\t"\ + "4:\n\t"\ + "cmp %[k_left],#1; blt 5f\n\t"\ + "vldr d4,[%[b_ptr]]; vldr d5,[%[b_ptr],#8]; add %[b_ptr],%[b_ptr],#16\n\t"\ + "vldr d0,[%[a_ptr]]; vldr s2,[%[a_ptr],#8]\n\t"\ + "add %[a_ptr],%[a_ptr],#12\n\t"\ + ""ASM_VMLAL_I16" %q[cq01],d4,d0[0]; "ASM_VMLAL_I16" %q[cq02],d5,d0[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq03],d4,d0[1]; "ASM_VMLAL_I16" %q[cq04],d5,d0[1]\n\t"\ + ""ASM_VMLAL_I16" %q[cq05],d4,d0[2]; "ASM_VMLAL_I16" %q[cq06],d5,d0[2]\n\t"\ + ""ASM_VMLAL_I16" %q[cq07],d4,d0[3]; "ASM_VMLAL_I16" %q[cq08],d5,d0[3]\n\t"\ + ""ASM_VMLAL_I16" %q[cq09],d4,d1[0]; "ASM_VMLAL_I16" %q[cq10],d5,d1[0]\n\t"\ + ""ASM_VMLAL_I16" %q[cq11],d4,d1[1]; "ASM_VMLAL_I16" %q[cq12],d5,d1[1]\n\t"\ + "sub %[k_left],%[k_left],#1\n\t"\ + "5:\n\t"\ + :[a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr), [k_left]"+r"(k_left),\ + [cq01]"=w"(cq01), [cq02]"=w"(cq02), [cq03]"=w"(cq03), [cq04]"=w"(cq04),\ + [cq05]"=w"(cq05), [cq06]"=w"(cq06), [cq07]"=w"(cq07), [cq08]"=w"(cq08),\ + [cq09]"=w"(cq09), [cq10]"=w"(cq10), [cq11]"=w"(cq11), [cq12]"=w"(cq12)\ + ::"r0","r1","r2","r3","cc","memory","q0","q1","q2","q3"); + +static inline void pldw_c_6(const I32 *c) { + __asm__("pld [%0]; pld [%0,#20]\n\t"::"r"(c):); +} + +static inline void pldw_c_8(const I32 *c) { + __asm__("pld [%0]; pld [%0,#28]\n\t"::"r"(c):); +} + +#define KERNEL_M6N8 \ + I32 *c_pref = c_ptr;\ + pldw_c_6(c_pref); c_pref += ldc;\ + pldw_c_6(c_pref); c_pref += ldc;\ + pldw_c_6(c_pref); c_pref += ldc;\ + pldw_c_6(c_pref); c_pref += ldc;\ + pldw_c_6(c_pref); c_pref += ldc;\ + pldw_c_6(c_pref); c_pref += ldc;\ + pldw_c_6(c_pref); c_pref += ldc;\ + pldw_c_6(c_pref);\ + KERNEL_M6N8_UNIT(a_head, b_head) + +#define KERNEL_M8N6 \ + I32 *c_pref = c_ptr;\ + pldw_c_8(c_pref); c_pref += ldc;\ + pldw_c_8(c_pref); c_pref += ldc;\ + pldw_c_8(c_pref); c_pref += ldc;\ + pldw_c_8(c_pref); c_pref += ldc;\ + pldw_c_8(c_pref); c_pref += ldc;\ + pldw_c_8(c_pref);\ + KERNEL_M6N8_UNIT(b_head, a_head) + +#define SAVE_M6N8 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M6N4(cq01, cq03, cq05, cq07, cq09, cq11)\ + UNIT_SAVE_M6N4(cq02, cq04, cq06, cq08, cq10, cq12) + +#define SAVE_M8N6 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M8N2(cq01, cq02, cq03, cq04)\ + UNIT_SAVE_M8N2(cq05, cq06, cq07, cq08)\ + UNIT_SAVE_M8N2(cq09, cq10, cq11, cq12) + +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(6, 8, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 6, I16, I32) + +#endif diff --git a/include/neon_armv7a/S8S32MlaGemmCopy.h b/include/neon_armv7a/S8S32MlaGemmCopy.h new file mode 100644 index 0000000..47c5052 --- /dev/null +++ b/include/neon_armv7a/S8S32MlaGemmCopy.h @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32mlagemm_int8_t_int16_t_ncopy_unroll6(const int8_t * __restrict__ src, + int16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32mlagemm_int8_t_int16_t_ncopy_unroll8(const int8_t * __restrict__ src, + int16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32mlagemm_int8_t_int16_t_tcopy_unroll6(const int8_t * __restrict__ src, + int16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32mlagemm_int8_t_int16_t_tcopy_unroll8(const int8_t * __restrict__ src, + int16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv7a/S8S32MlaGemmDriver.h b/include/neon_armv7a/S8S32MlaGemmDriver.h new file mode 100644 index 0000000..26121fa --- /dev/null +++ b/include/neon_armv7a/S8S32MlaGemmDriver.h @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +int s8s32mlagemm_serial(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, int32_t beta_inp); + +int s8s32mlagemm(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, + int32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv7a/S8S32MlaGemmKernel.h b/include/neon_armv7a/S8S32MlaGemmKernel.h new file mode 100644 index 0000000..4dd3469 --- /dev/null +++ b/include/neon_armv7a/S8S32MlaGemmKernel.h @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32mlagemm_kernel_lm_m6n8(uint32_t M, uint32_t N, uint32_t K, + int32_t beta, + const int16_t * __restrict__ sa, const int16_t * __restrict__ sb, + int32_t * __restrict__ C, uint32_t ldc); + +void s8s32mlagemm_kernel_ln_m8n6(uint32_t M, uint32_t N, uint32_t K, + int32_t beta, + const int16_t * __restrict__ sa, const int16_t * __restrict__ sb, + int32_t * __restrict__ C, uint32_t ldc); + diff --git a/include/neon_armv7a/S8S32MlaGemmSkinnyDot.h b/include/neon_armv7a/S8S32MlaGemmSkinnyDot.h new file mode 100644 index 0000000..1d4765e --- /dev/null +++ b/include/neon_armv7a/S8S32MlaGemmSkinnyDot.h @@ -0,0 +1,47 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n1(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n2(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n3(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n4(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n1_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n2_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n3_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n4_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv7a/S8S32MlaGemmSkinnyGer.h b/include/neon_armv7a/S8S32MlaGemmSkinnyGer.h new file mode 100644 index 0000000..79e73a9 --- /dev/null +++ b/include/neon_armv7a/S8S32MlaGemmSkinnyGer.h @@ -0,0 +1,47 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n1(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n2(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n3(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n4(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n1_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n2_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n3_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n4_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv7a/SgemmCopy.h b/include/neon_armv7a/SgemmCopy.h new file mode 100644 index 0000000..ec11f82 --- /dev/null +++ b/include/neon_armv7a/SgemmCopy.h @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_float_float_ncopy_unroll6(const float * __restrict__ src, + float * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void sgemm_float_float_ncopy_unroll8(const float * __restrict__ src, + float * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void sgemm_float_float_tcopy_unroll6(const float * __restrict__ src, + float * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void sgemm_float_float_tcopy_unroll8(const float * __restrict__ src, + float * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv7a/SgemmDriver.h b/include/neon_armv7a/SgemmDriver.h new file mode 100644 index 0000000..bfc4217 --- /dev/null +++ b/include/neon_armv7a/SgemmDriver.h @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +int sgemm_serial(int a_rowmajor, int b_rowmajor, + const float *A, const float *B, float *C, + uint32_t M, uint32_t N, uint32_t K, float beta_inp); + +int sgemm(int a_rowmajor, int b_rowmajor, + const float *A, const float *B, float *C, + uint32_t M, uint32_t N, uint32_t K, float beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv7a/SgemmKernel.h b/include/neon_armv7a/SgemmKernel.h new file mode 100644 index 0000000..4535041 --- /dev/null +++ b/include/neon_armv7a/SgemmKernel.h @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_kernel_lm_m6n8(uint32_t M, uint32_t N, uint32_t K, float beta, + const float * __restrict__ sa, const float * __restrict__ sb, + float * __restrict__ C, uint32_t ldc); + +void sgemm_kernel_ln_m8n6(uint32_t M, uint32_t N, uint32_t K, float beta, + const float * __restrict__ sa, const float * __restrict__ sb, + float * __restrict__ C, uint32_t ldc); + diff --git a/include/neon_armv7a/SgemmSkinnyDot.h b/include/neon_armv7a/SgemmSkinnyDot.h new file mode 100644 index 0000000..950d576 --- /dev/null +++ b/include/neon_armv7a/SgemmSkinnyDot.h @@ -0,0 +1,67 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_arowmajor_bskinny_afloat_bfloat_n1(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n2(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n3(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n4(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n5(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n6(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n7(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n8(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n1_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n2_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n3_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n4_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n5_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n6_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n7_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n8_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv7a/SgemmSkinnyGer.h b/include/neon_armv7a/SgemmSkinnyGer.h new file mode 100644 index 0000000..6466d79 --- /dev/null +++ b/include/neon_armv7a/SgemmSkinnyGer.h @@ -0,0 +1,67 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_acolmajor_bskinny_afloat_bfloat_n1(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n2(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n3(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n4(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n5(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n6(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n7(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n8(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n1_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n2_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n3_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n4_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n5_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n6_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n7_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n8_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv7a/U8U32MlaGemmCopy.h b/include/neon_armv7a/U8U32MlaGemmCopy.h new file mode 100644 index 0000000..cb78832 --- /dev/null +++ b/include/neon_armv7a/U8U32MlaGemmCopy.h @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32mlagemm_uint8_t_uint16_t_ncopy_unroll6(const uint8_t * __restrict__ src, + uint16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32mlagemm_uint8_t_uint16_t_ncopy_unroll8(const uint8_t * __restrict__ src, + uint16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32mlagemm_uint8_t_uint16_t_tcopy_unroll6(const uint8_t * __restrict__ src, + uint16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32mlagemm_uint8_t_uint16_t_tcopy_unroll8(const uint8_t * __restrict__ src, + uint16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv7a/U8U32MlaGemmDriver.h b/include/neon_armv7a/U8U32MlaGemmDriver.h new file mode 100644 index 0000000..9477c3d --- /dev/null +++ b/include/neon_armv7a/U8U32MlaGemmDriver.h @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +int u8u32mlagemm_serial(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, uint32_t beta_inp); + +int u8u32mlagemm(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, + uint32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv7a/U8U32MlaGemmKernel.h b/include/neon_armv7a/U8U32MlaGemmKernel.h new file mode 100644 index 0000000..c0b79b8 --- /dev/null +++ b/include/neon_armv7a/U8U32MlaGemmKernel.h @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32mlagemm_kernel_lm_m6n8(uint32_t M, uint32_t N, uint32_t K, + uint32_t beta, + const uint16_t * __restrict__ sa, const uint16_t * __restrict__ sb, + uint32_t * __restrict__ C, uint32_t ldc); + +void u8u32mlagemm_kernel_ln_m8n6(uint32_t M, uint32_t N, uint32_t K, + uint32_t beta, + const uint16_t * __restrict__ sa, const uint16_t * __restrict__ sb, + uint32_t * __restrict__ C, uint32_t ldc); + diff --git a/include/neon_armv7a/U8U32MlaGemmSkinnyDot.h b/include/neon_armv7a/U8U32MlaGemmSkinnyDot.h new file mode 100644 index 0000000..59df381 --- /dev/null +++ b/include/neon_armv7a/U8U32MlaGemmSkinnyDot.h @@ -0,0 +1,47 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n1(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n2(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n3(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n4(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n1_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n2_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n3_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n4_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv7a/U8U32MlaGemmSkinnyGer.h b/include/neon_armv7a/U8U32MlaGemmSkinnyGer.h new file mode 100644 index 0000000..5c121de --- /dev/null +++ b/include/neon_armv7a/U8U32MlaGemmSkinnyGer.h @@ -0,0 +1,47 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n1(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n2(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n3(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n4(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n1_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n2_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n3_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n4_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/Bias.h b/include/neon_armv8a/Bias.h new file mode 100644 index 0000000..4ee8eb5 --- /dev/null +++ b/include/neon_armv8a/Bias.h @@ -0,0 +1,36 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +void bias_float(float *dst, float bias_dim0, + const float *bias_dim1, float bias_dim1_scale, + const float *bias_dim2, float bias_dim2_scale, + uint32_t dim1, uint32_t dim2); + +void bias_int32_t(int32_t *dst, int32_t bias_dim0, + const int32_t *bias_dim1, int32_t bias_dim1_scale, + const int32_t *bias_dim2, int32_t bias_dim2_scale, + uint32_t dim1, uint32_t dim2); + +void u8u32_sum(const uint8_t *src, uint32_t *dst, + uint32_t dim1, uint32_t dim2, uint8_t direction); + +void s16_sumsquare(const int16_t *dat, int32_t *sum, + int64_t *sumsquare, uint32_t size); + diff --git a/include/neon_armv8a/HgemmCopy.h b/include/neon_armv8a/HgemmCopy.h new file mode 100644 index 0000000..7ce17cd --- /dev/null +++ b/include/neon_armv8a/HgemmCopy.h @@ -0,0 +1,32 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +void hgemm_float16_t_float16_t_ncopy_unroll8(const float16_t * __restrict__ src, + float16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void hgemm_float16_t_float16_t_ncopy_unroll16(const float16_t * __restrict__ src, + float16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void hgemm_float16_t_float16_t_tcopy_unroll8(const float16_t * __restrict__ src, + float16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void hgemm_float16_t_float16_t_tcopy_unroll16(const float16_t * __restrict__ src, + float16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv8a/HgemmDriver.h b/include/neon_armv8a/HgemmDriver.h new file mode 100644 index 0000000..931a1e9 --- /dev/null +++ b/include/neon_armv8a/HgemmDriver.h @@ -0,0 +1,25 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +int hgemm_serial(uint8_t transAB, const float16_t *A, const float16_t *B, float16_t *C, + uint32_t M, uint32_t N, uint32_t K, float16_t beta_inp); + +int hgemm(uint8_t transAB, const float16_t *A, const float16_t *B, float16_t *C, + uint32_t M, uint32_t N, uint32_t K, float16_t beta_inp, uint32_t num_threads); diff --git a/include/neon_armv8a/HgemmKernel.h b/include/neon_armv8a/HgemmKernel.h new file mode 100644 index 0000000..c778284 --- /dev/null +++ b/include/neon_armv8a/HgemmKernel.h @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +void hgemm_kernel_lm_m8n16(uint32_t M, uint32_t N, uint32_t K, float16_t beta, + const float16_t * __restrict__ sa, const float16_t * __restrict__ sb, + float16_t * __restrict__ C, uint32_t ldc); + +void hgemm_kernel_ln_m16n8(uint32_t M, uint32_t N, uint32_t K, float16_t beta, + const float16_t * __restrict__ sa, const float16_t * __restrict__ sb, + float16_t * __restrict__ C, uint32_t ldc); + diff --git a/include/neon_armv8a/HgemmSkinnyDot.h b/include/neon_armv8a/HgemmSkinnyDot.h new file mode 100644 index 0000000..601c56b --- /dev/null +++ b/include/neon_armv8a/HgemmSkinnyDot.h @@ -0,0 +1,116 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n1(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n2(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n3(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n4(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n5(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n6(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n7(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n8(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n9(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n10(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n11(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n12(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n1_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n2_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n3_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n4_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n5_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n6_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n7_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n8_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n9_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n10_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n11_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_arowmajor_bskinny_afloat16_t_bfloat16_t_n12_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/HgemmSkinnyGer.h b/include/neon_armv8a/HgemmSkinnyGer.h new file mode 100644 index 0000000..44a1d2b --- /dev/null +++ b/include/neon_armv8a/HgemmSkinnyGer.h @@ -0,0 +1,116 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n1(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n2(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n3(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n4(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n5(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n6(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n7(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n8(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n9(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n10(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n11(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n12(const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, float16_t beta_inp); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n1_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n2_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n3_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n4_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n5_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n6_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n7_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n8_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n9_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n10_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n11_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + +void hgemm_acolmajor_bskinny_afloat16_t_bfloat16_t_n12_omp( + const float16_t *A, const float16_t *B, + float16_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + float16_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/I8I32DotGemmCopy.h b/include/neon_armv8a/I8I32DotGemmCopy.h new file mode 100644 index 0000000..aa8addd --- /dev/null +++ b/include/neon_armv8a/I8I32DotGemmCopy.h @@ -0,0 +1,454 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/NeonIntOpSign.h" + +#ifndef INCLUDE_I8I32DOT_COPY +#define INCLUDE_I8I32DOT_COPY + +static inline void pref_ab(const I8 *dat) { + __asm__ ("prfm pldl1keep,[%0,#64]\n\t"::"r"(dat):); +} + +#define NCOPY_NEON_LOOP_K16_UNROLL4(inc, dst_ptr, src1, src2, src3, src4) \ + for (dim1_count = dim1; dim1_count > 15; dim1_count -= 16) {\ + I32X4X4 t1;\ + t1.val[0] = VREINTERPRETQ_I32_I8(VLD1Q_I8(src1));\ + src1 += 16; pref_ab(src1);\ + t1.val[1] = VREINTERPRETQ_I32_I8(VLD1Q_I8(src2));\ + src2 += 16; pref_ab(src2);\ + t1.val[2] = VREINTERPRETQ_I32_I8(VLD1Q_I8(src3));\ + src3 += 16; pref_ab(src3);\ + t1.val[3] = VREINTERPRETQ_I32_I8(VLD1Q_I8(src4));\ + src4 += 16; pref_ab(src4);\ + VST4Q_LANE_I32(dst_ptr, t1, 0);\ + VST4Q_LANE_I32(dst_ptr + inc, t1, 1);\ + VST4Q_LANE_I32(dst_ptr + inc * 2, t1, 2);\ + VST4Q_LANE_I32(dst_ptr + inc * 3, t1, 3);\ + dst_ptr += inc * 4;\ + }\ + if (dim1_count > 7) {\ + I32X2X4 t1;\ + t1.val[0] = VREINTERPRET_I32_I8(VLD1_I8(src1)); src1 += 8;\ + t1.val[1] = VREINTERPRET_I32_I8(VLD1_I8(src2)); src2 += 8;\ + t1.val[2] = VREINTERPRET_I32_I8(VLD1_I8(src3)); src3 += 8;\ + t1.val[3] = VREINTERPRET_I32_I8(VLD1_I8(src4)); src4 += 8;\ + VST4_LANE_I32(dst_ptr, t1, 0);\ + VST4_LANE_I32(dst_ptr + inc, t1, 1);\ + dst_ptr += inc * 2; dim1_count -= 8;\ + }\ + if (dim1_count > 3) {\ + __asm__(\ + "ldr w0,[%0],#4; ldr w1,[%1],#4; ldr w2,[%2],#4; ldr w3,[%3],#4\n\t"\ + "str w0,[%4]; str w1,[%4,#4]; str w2,[%4,#8]; str w3,[%4,#12]\n\t"\ + :"+r"(src1),"+r"(src2),"+r"(src3),"+r"(src4):"r"(dst_ptr)\ + :"cc","memory","x0","x1","x2","x3");\ + dst_ptr += inc; dim1_count -= 4;\ + }\ + if (dim1_count > 0) {\ + uint32_t *dst_cast = (uint32_t *)dst_ptr; dst_ptr += inc;\ + uint8_t *src1_cast = (uint8_t *)src1; src1 += dim1_count;\ + uint8_t *src2_cast = (uint8_t *)src2; src2 += dim1_count;\ + uint8_t *src3_cast = (uint8_t *)src3; src3 += dim1_count;\ + uint8_t *src4_cast = (uint8_t *)src4; src4 += dim1_count;\ + uint32_t d0, d1, d2, d3;\ + d0 = *src1_cast; d1 = *src2_cast;\ + d2 = *src3_cast; d3 = *src4_cast;\ + if (dim1_count >= 2) {\ + d0 |= ((uint32_t)src1_cast[1]) << 8;\ + d1 |= ((uint32_t)src2_cast[1]) << 8;\ + d2 |= ((uint32_t)src3_cast[1]) << 8;\ + d3 |= ((uint32_t)src4_cast[1]) << 8;\ + }\ + if (dim1_count >= 3) {\ + d0 |= ((uint32_t)src1_cast[2]) << 16;\ + d1 |= ((uint32_t)src2_cast[2]) << 16;\ + d2 |= ((uint32_t)src3_cast[2]) << 16;\ + d3 |= ((uint32_t)src4_cast[2]) << 16;\ + }\ + dst_cast[0] = d0; dst_cast[1] = d1;\ + dst_cast[2] = d2; dst_cast[3] = d3;\ + } + +#define NCOPY_UNROLL_12 {\ + I32 *dst_h1 = dst1;\ + NCOPY_NEON_LOOP_K16_UNROLL4(12, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_NEON_LOOP_K16_UNROLL4(12, dst_h1, src5, src6, src7, src8)\ + dst_h1 = dst1 + 8;\ + NCOPY_NEON_LOOP_K16_UNROLL4(12, dst_h1, src9, src10, src11, src12)\ + dst1 = dst_h1 - 8;\ +} + +#define NCOPY_UNROLL_8 {\ + I32 *dst_h1 = dst1;\ + NCOPY_NEON_LOOP_K16_UNROLL4(8, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_NEON_LOOP_K16_UNROLL4(8, dst_h1, src5, src6, src7, src8)\ + dst1 = dst_h1 - 4;\ +} + +#define NCOPY_UNROLL_4 {\ + NCOPY_NEON_LOOP_K16_UNROLL4(4, dst1, src1, src2, src3, src4)\ +} + +#define NCOPY_UNROLL_2 {\ + for (dim1_count = dim1; dim1_count > 15; dim1_count -= 16) {\ + I32X4X2 t1;\ + t1.val[0] = VREINTERPRETQ_I32_I8(VLD1Q_I8(src1));\ + src1 += 16; pref_ab(src1);\ + t1.val[1] = VREINTERPRETQ_I32_I8(VLD1Q_I8(src2));\ + src2 += 16; pref_ab(src2);\ + VST2Q_I32(dst1, t1);\ + dst1 += 8;\ + }\ + if (dim1_count > 7) {\ + I32X2X2 t1;\ + t1.val[0] = VREINTERPRET_I32_I8(VLD1_I8(src1)); src1 += 8;\ + t1.val[1] = VREINTERPRET_I32_I8(VLD1_I8(src2)); src2 += 8;\ + VST2_I32(dst1, t1);\ + dst1 += 4; dim1_count -= 8;\ + }\ + if (dim1_count > 3) {\ + __asm__(\ + "ldr w0,[%0],#4; ldr w1,[%1],#4\n\t"\ + "str w0,[%2]; str w1,[%2,#4]\n\t"\ + :"+r"(src1),"+r"(src2):"r"(dst1)\ + :"cc","memory","x0","x1");\ + dst1 += 2; dim1_count -= 4;\ + }\ + if (dim1_count > 0) {\ + uint32_t *dst_cast = (uint32_t *)dst1; dst1 += 2;\ + uint8_t *src1_cast = (uint8_t *)src1; src1 += dim1_count;\ + uint8_t *src2_cast = (uint8_t *)src2; src2 += dim1_count;\ + uint32_t d0, d1;\ + d0 = *src1_cast; d1 = *src2_cast;\ + if (dim1_count >= 2) {\ + d0 |= ((uint32_t)src1_cast[1]) << 8;\ + d1 |= ((uint32_t)src2_cast[1]) << 8;\ + }\ + if (dim1_count >= 3) {\ + d0 |= ((uint32_t)src1_cast[2]) << 16;\ + d1 |= ((uint32_t)src2_cast[2]) << 16;\ + }\ + dst_cast[0] = d0; dst_cast[1] = d1;\ + }\ +} + +#define NCOPY_UNROLL_1 {\ + for (dim1_count = dim1; dim1_count > 15; dim1_count -= 16) {\ + I32X4 t1 = VREINTERPRETQ_I32_I8(VLD1Q_I8(src1));\ + src1 += 16;\ + VST1Q_I32(dst1, t1);\ + dst1 += 4;\ + }\ + if (dim1_count > 7) {\ + I32X2 t1 = VREINTERPRET_I32_I8(VLD1_I8(src1)); src1 += 8;\ + VST1_I32(dst1, t1);\ + dst1 += 2; dim1_count -= 8;\ + }\ + if (dim1_count > 3) {\ + __asm__(\ + "ldr w0,[%0],#4; str w0,[%1]\n\t"\ + :"+r"(src1):"r"(dst1)\ + :"cc","memory","x0","x1");\ + dst1++; dim1_count -= 4;\ + }\ + if (dim1_count > 0) {\ + uint32_t *dst_cast = (uint32_t *)dst1; dst1++;\ + uint8_t *src1_cast = (uint8_t *)src1; src1 += dim1_count;\ + uint32_t d0 = *src1_cast;\ + if (dim1_count >= 2) {\ + d0 |= ((uint32_t)src1_cast[1]) << 8;\ + }\ + if (dim1_count >= 3) {\ + d0 |= ((uint32_t)src1_cast[2]) << 16;\ + }\ + dst_cast[0] = d0;\ + }\ +} + +#ifdef GEMM_UNSIGNED_INT +#define NCOPY_uint8_t_uint32_t(unroll) NCOPY_UNROLL_##unroll +#else +#define NCOPY_int8_t_int32_t(unroll) NCOPY_UNROLL_##unroll +#endif + +#define TCOPY_K4N8 {\ + uint8_t *src1_cast = (uint8_t *)src1; src1 += 8; pref_ab(src1);\ + uint8_t *src2_cast = (uint8_t *)src2; src2 += 8; pref_ab(src2);\ + uint8_t *src3_cast = (uint8_t *)src3; src3 += 8; pref_ab(src3);\ + uint8_t *src4_cast = (uint8_t *)src4; src4 += 8; pref_ab(src4);\ + uint8_t *dst1_cast = (uint8_t *)dst1; dst1 += 8;\ + uint8x8x4_t t1;\ + t1.val[0] = vld1_u8(src1_cast);\ + t1.val[1] = vld1_u8(src2_cast);\ + t1.val[2] = vld1_u8(src3_cast);\ + t1.val[3] = vld1_u8(src4_cast);\ + vst4_u8(dst1_cast, t1);\ +} + +#define TCOPY_K3N8 {\ + uint8_t *src1_cast = (uint8_t *)src1; src1 += 8; pref_ab(src1);\ + uint8_t *src2_cast = (uint8_t *)src2; src2 += 8; pref_ab(src2);\ + uint8_t *src3_cast = (uint8_t *)src3; src3 += 8; pref_ab(src3);\ + uint8_t *dst1_cast = (uint8_t *)dst1; dst1 += 8;\ + uint8x8x4_t t1;\ + t1.val[0] = vld1_u8(src1_cast);\ + t1.val[1] = vld1_u8(src2_cast);\ + t1.val[2] = vld1_u8(src3_cast);\ + t1.val[3] = vdup_n_u8(0);\ + vst4_u8(dst1_cast, t1);\ +} + +#define TCOPY_K2N8 {\ + uint8_t *src1_cast = (uint8_t *)src1; src1 += 8; pref_ab(src1);\ + uint8_t *src2_cast = (uint8_t *)src2; src2 += 8; pref_ab(src2);\ + uint8_t *dst1_cast = (uint8_t *)dst1; dst1 += 8;\ + uint8x8x4_t t1;\ + t1.val[0] = vld1_u8(src1_cast);\ + t1.val[1] = vld1_u8(src2_cast);\ + t1.val[2] = vdup_n_u8(0);\ + t1.val[3] = vdup_n_u8(0);\ + vst4_u8(dst1_cast, t1);\ +} + +#define TCOPY_K1N8 {\ + uint8_t *src1_cast = (uint8_t *)src1; src1 += 8;\ + uint8_t *dst1_cast = (uint8_t *)dst1; dst1 += 8;\ + uint8x8x4_t t1;\ + t1.val[0] = vld1_u8(src1_cast);\ + t1.val[1] = vdup_n_u8(0);\ + t1.val[2] = vdup_n_u8(0);\ + t1.val[3] = vdup_n_u8(0);\ + vst4_u8(dst1_cast, t1);\ +} + +#define LOAD_4_INCPTR_I8(ptr, v) \ + __asm__ __volatile__("ldr %s["#v"],[%["#ptr"]],#4\n\t"\ + :[v]"=w"(v),[ptr]"+r"(ptr)::"memory"); + +#define STORE_4X4_INTERLEAVE_I8(v1, v2, v3, v4, dst) \ + __asm__ __volatile__(\ + "zip1 %["#v1"].8b,%["#v1"].8b,%["#v2"].8b\n\t"\ + "zip1 %["#v3"].8b,%["#v3"].8b,%["#v4"].8b\n\t"\ + "zip1 %["#v1"].8h,%["#v1"].8h,%["#v3"].8h\n\t"\ + "str %q["#v1"],[%["#dst"]],#16\n\t"\ + :[v1]"+w"(v1), [v2]"+w"(v2), [v3]"+w"(v3), [v4]"+w"(v4), [dst]"+r"(dst)\ + ::"memory"); + +#define TCOPY_K4N4 {\ + I8X8 t1, t2, t3, t4;\ + LOAD_4_INCPTR_I8(src1, t1)\ + LOAD_4_INCPTR_I8(src2, t2)\ + LOAD_4_INCPTR_I8(src3, t3)\ + LOAD_4_INCPTR_I8(src4, t4)\ + STORE_4X4_INTERLEAVE_I8(t1, t2, t3, t4, dst1)\ +} + +#define TCOPY_K3N4 {\ + I8X8 t1, t2, t3, t4;\ + LOAD_4_INCPTR_I8(src1, t1)\ + LOAD_4_INCPTR_I8(src2, t2)\ + LOAD_4_INCPTR_I8(src3, t3)\ + t4 = VDUP_N_I8(0);\ + STORE_4X4_INTERLEAVE_I8(t1, t2, t3, t4, dst1)\ +} + +#define TCOPY_K2N4 {\ + I8X8 t1, t2, t3, t4;\ + LOAD_4_INCPTR_I8(src1, t1)\ + LOAD_4_INCPTR_I8(src2, t2)\ + t3 = VDUP_N_I8(0);\ + t4 = VDUP_N_I8(0);\ + STORE_4X4_INTERLEAVE_I8(t1, t2, t3, t4, dst1)\ +} + +#define TCOPY_K1N4 {\ + I8X8 t1, t2, t3, t4;\ + LOAD_4_INCPTR_I8(src1, t1)\ + t2 = VDUP_N_I8(0);\ + t3 = VDUP_N_I8(0);\ + t4 = VDUP_N_I8(0);\ + STORE_4X4_INTERLEAVE_I8(t1, t2, t3, t4, dst1)\ +} + +#define TCOPY_K4N2 \ + __asm__ __volatile__(\ + "ldr h0,[%0],#2; ldr h1,[%1],#2\n\t"\ + "ldr h2,[%2],#2; ldr h3,[%3],#2\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[0],[%4],#4\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[1],[%4],#4\n\t"\ + :"+r"(src1),"+r"(src2),"+r"(src3),"+r"(src4),"+r"(dst1)\ + ::"cc","memory","v0","v1","v2","v3"); + +#define TCOPY_K3N2 \ + __asm__ __volatile__(\ + "ldr h0,[%0],#2; ldr h1,[%1],#2\n\t"\ + "ldr h2,[%2],#2; movi v3.8b,#0\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[0],[%3],#4\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[1],[%3],#4\n\t"\ + :"+r"(src1),"+r"(src2),"+r"(src3),"+r"(dst1)\ + ::"cc","memory","v0","v1","v2","v3"); + +#define TCOPY_K2N2 \ + __asm__ __volatile__(\ + "ldr h0,[%0],#2; ldr h1,[%1],#2\n\t"\ + "movi v2.8b,#0; movi v3.8b,#0\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[0],[%2],#4\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[1],[%2],#4\n\t"\ + :"+r"(src1),"+r"(src2),"+r"(dst1)\ + ::"cc","memory","v0","v1","v2","v3"); + +#define TCOPY_K1N2 \ + __asm__ __volatile__(\ + "ldr h0,[%0],#2; movi v1.8b,#0\n\t"\ + "movi v2.8b,#0; movi v3.8b,#0\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[0],[%1],#4\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[1],[%1],#4\n\t"\ + :"+r"(src1),"+r"(dst1)\ + ::"cc","memory","v0","v1","v2","v3"); + +#define TCOPY_K4N1 \ + __asm__ __volatile__(\ + "ldr b0,[%0],#1; ldr b1,[%1],#1\n\t"\ + "ldr b2,[%2],#1; ldr b3,[%3],#1\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[0],[%4]\n\t"\ + :"+r"(src1),"+r"(src2),"+r"(src3),"+r"(src4):"r"(dst1)\ + :"cc","memory","v0","v1","v2","v3"); + +#define TCOPY_K3N1 \ + __asm__ __volatile__(\ + "ldr b0,[%0],#1; ldr b1,[%1],#1\n\t"\ + "ldr b2,[%2],#1; movi v3.8b,#0\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[0],[%3]\n\t"\ + :"+r"(src1),"+r"(src2),"+r"(src3):"r"(dst1)\ + :"cc","memory","v0","v1","v2","v3"); + +#define TCOPY_K2N1 \ + __asm__ __volatile__(\ + "ldr b0,[%0],#1; ldr b1,[%1],#1\n\t"\ + "movi v2.8b,#0; movi v3.8b,#0\n\t"\ + "st4 {v0.b,v1.b,v2.b,v3.b}[0],[%2]\n\t"\ + :"+r"(src1),"+r"(src2):"r"(dst1)\ + :"cc","memory","v0","v1","v2","v3"); + +#define TCOPY_K1N1 \ + __asm__ __volatile__(\ + "ldr b0,[%0],#1; str s0,[%1]\n\t"\ + :"+r"(src1):"r"(dst1)\ + :"cc","memory","v0"); + + +#define TCOPY_NMAX12_TEMPLATE(kdim) \ + dst1 = dst + chunk_k_pass * 12;\ + for (; dim1_count > 11; dim1_count -= 12) {\ + TCOPY_K##kdim##N4 TCOPY_K##kdim##N8\ + dst1 += chunk_k_num * 12 - 12;\ + }\ + dst1 -= chunk_k_pass * 4;\ + if (dim1_count > 7) {\ + TCOPY_K##kdim##N8\ + dst1 += chunk_k_num * 8 - 8;\ + dim1_count -= 8;\ + }\ + dst1 -= chunk_k_pass * 4;\ + if (dim1_count > 3) {\ + TCOPY_K##kdim##N4\ + dst1 += chunk_k_num * 4 - 4;\ + dim1_count -= 4;\ + }\ + dst1 -= chunk_k_pass * 2;\ + if (dim1_count > 1) {\ + TCOPY_K##kdim##N2\ + dst1 += chunk_k_num * 2 - 2;\ + dim1_count -= 2;\ + }\ + dst1 -= chunk_k_pass;\ + if (dim1_count > 0) {\ + TCOPY_K##kdim##N1\ + } + +#define TCOPY_NMAX8_TEMPLATE(kdim) \ + dst1 = dst + chunk_k_pass * 8;\ + for (; dim1_count > 7; dim1_count -= 8) {\ + TCOPY_K##kdim##N8\ + dst1 += chunk_k_num * 8 - 8;\ + }\ + dst1 -= chunk_k_pass * 4;\ + if (dim1_count > 3) {\ + TCOPY_K##kdim##N4\ + dst1 += chunk_k_num * 4 - 4;\ + dim1_count -= 4;\ + }\ + dst1 -= chunk_k_pass * 2;\ + if (dim1_count > 1) {\ + TCOPY_K##kdim##N2\ + dst1 += chunk_k_num * 2 - 2;\ + dim1_count -= 2;\ + }\ + dst1 -= chunk_k_pass;\ + if (dim1_count > 0) {\ + TCOPY_K##kdim##N1\ + } + + +#define TCOPY_FUNC_TEMPLATE(funcname, maxunroll) \ +void funcname##maxunroll(\ + const I8 * __restrict__ src,\ + I32 * __restrict__ dst, uint32_t ld_dim,\ + uint32_t dim1, uint32_t dim2) {\ + if (!dim2) return;\ + uint32_t dim2_count = dim2;\ + const uint32_t chunk_k_num = ((dim2 - 1) >> 2) + 1;\ + const I8 *src0 = src;\ + for (; dim2_count > 3; dim2_count -= 4) {\ + const I8 *src1 = src0;\ + const I8 *src2 = src0 + ld_dim;\ + const I8 *src3 = src0 + ld_dim * 2;\ + const I8 *src4 = src0 + ld_dim * 3;\ + src0 += ld_dim * 4;\ + I32 *dst1;\ + uint32_t dim1_count = dim1;\ + const uint32_t chunk_k_pass = (dim2 - dim2_count) / 4;\ + TCOPY_NMAX##maxunroll##_TEMPLATE(4)\ + }\ + if (dim2_count == 3) {\ + const I8 *src1 = src0;\ + const I8 *src2 = src0 + ld_dim;\ + const I8 *src3 = src0 + ld_dim * 2;\ + I32 *dst1;\ + uint32_t dim1_count = dim1;\ + const uint32_t chunk_k_pass = chunk_k_num - 1;\ + TCOPY_NMAX##maxunroll##_TEMPLATE(3)\ + } else if (dim2_count == 2) {\ + const I8 *src1 = src0;\ + const I8 *src2 = src0 + ld_dim;\ + I32 *dst1;\ + uint32_t dim1_count = dim1;\ + const uint32_t chunk_k_pass = chunk_k_num - 1;\ + TCOPY_NMAX##maxunroll##_TEMPLATE(2)\ + } else if (dim2_count == 1) {\ + const I8 *src1 = src0;\ + I32 *dst1;\ + uint32_t dim1_count = dim1;\ + const uint32_t chunk_k_pass = chunk_k_num - 1;\ + TCOPY_NMAX##maxunroll##_TEMPLATE(1)\ + }\ +} + +#endif diff --git a/include/neon_armv8a/I8I32DotGemmKernel.h b/include/neon_armv8a/I8I32DotGemmKernel.h new file mode 100644 index 0000000..104f8f7 --- /dev/null +++ b/include/neon_armv8a/I8I32DotGemmKernel.h @@ -0,0 +1,1030 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/NeonIntOpSign.h" + +#ifndef INCLUDE_I8I32DOT_GEMM_KERNEL +#define INCLUDE_I8I32DOT_GEMM_KERNEL + +static inline void pref_c(I32 *dat) { + __asm__ ("prfm pstl1keep,[%0]\n\t"::"r"(dat):); +} + +#define PREF_N1 pref_c(c_pref); c_pref += ldc; +#define PREF_N2 PREF_N1 PREF_N1 +#define PREF_N4 PREF_N2 PREF_N2 +#define PREF_N8 PREF_N4 PREF_N4 +#define PREF_N12 PREF_N8 PREF_N4 + +/* NOTE that the K actually means k/4 IN THIS FILE */ + +/* unaligned load of 4 8-bit int to a S register */ +#define UNALIGNED_LD4B_SREG(var, ptr) \ + __asm__("ldr %s0,[%1]\n\t":"=w"(var):"r"(ptr):"memory") + +#define VLD1(ptr) VREINTERPRET_I8_I32(VLD1_I32(ptr)) + +#define VLD1Q(ptr) VREINTERPRETQ_I8_I32(VLD1Q_I32(ptr)) + +#define NORMAL_KERNEL_SETUP(a_head, b_head) \ + uint32_t kdiv4_left = K;\ + const I32 *a_rd = a_head;\ + const I32 *b_rd = b_head; + +#define KERNEL_M1N1 \ + I32X4 cq1, cq2;\ + cq1 = cq2 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2, bq1, bq2;\ + if (kdiv4_left > 3) {\ + aq1 = VLD1Q(a_rd); a_rd += 4;\ + bq1 = VLD1Q(b_rd); b_rd += 4;\ + }\ + for (; kdiv4_left > 11; kdiv4_left -= 8) {\ + aq2 = VLD1Q(a_rd);\ + bq2 = VLD1Q(b_rd);\ + cq1 = VDOTQ_I32(cq1, aq1, bq1);\ + aq1 = VLD1Q(a_rd + 4); a_rd += 8;\ + bq1 = VLD1Q(b_rd + 4); b_rd += 8;\ + cq2 = VDOTQ_I32(cq2, aq2, bq2);\ + }\ + if (kdiv4_left > 7) {\ + aq2 = VLD1Q(a_rd); a_rd += 4;\ + bq2 = VLD1Q(b_rd); b_rd += 4;\ + cq1 = VDOTQ_I32(cq1, aq1, bq1);\ + cq2 = VDOTQ_I32(cq2, aq2, bq2);\ + kdiv4_left -= 8;\ + } else if (kdiv4_left > 3) {\ + cq1 = VDOTQ_I32(cq1, aq1, bq1);\ + kdiv4_left -= 4;\ + }\ + cq1 = VADDQ_I32(cq1, cq2);\ + I32X2 cd1 = VADD_I32(VGET_LOW_I32(cq1), VGET_HIGH_I32(cq1));\ + if (kdiv4_left > 1) {\ + I8X8 ad1 = VLD1(a_rd); a_rd += 2;\ + I8X8 bd1 = VLD1(b_rd); b_rd += 2;\ + cd1 = VDOT_I32(cd1, ad1, bd1);\ + kdiv4_left -= 2;\ + }\ + if (kdiv4_left > 0) {\ + I8X8 ad1, bd1;\ + UNALIGNED_LD4B_SREG(ad1, a_rd); a_rd++;\ + UNALIGNED_LD4B_SREG(bd1, b_rd); b_rd++;\ + cd1 = VDOT_I32(cd1, ad1, bd1);\ + }\ + I32 cs1 = VGET_LANE_I32(cd1, 0) + VGET_LANE_I32(cd1, 1); + +#define SAVE_M1N1 *c_ptr = c_ptr[0] * beta + cs1; + +#define KERNEL_M2N1_UNIT(a_head, b_head) \ + I32X2 cd1, cd2;\ + cd1 = cd2 = VDUP_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X8 ad1, ad2, bd1;\ + if (kdiv4_left > 1) {\ + ad1 = VLD1(a_rd); ad2 = VLD1(a_rd + 2); a_rd += 4;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + for (; kdiv4_left > 3; kdiv4_left -= 2) {\ + cd1 = VDOT_LANE_I32(cd1, ad1, bd1, 0); ad1 = VLD1(a_rd);\ + cd2 = VDOT_LANE_I32(cd2, ad2, bd1, 1); ad2 = VLD1(a_rd + 2);\ + a_rd += 4; bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + if (kdiv4_left > 1) {\ + cd1 = VDOT_LANE_I32(cd1, ad1, bd1, 0);\ + cd2 = VDOT_LANE_I32(cd2, ad2, bd1, 1);\ + kdiv4_left -= 2;\ + }\ + cd1 = VADD_I32(cd1, cd2);\ + if (kdiv4_left > 0) {\ + UNALIGNED_LD4B_SREG(bd1, b_rd); b_rd++;\ + ad1 = VLD1(a_rd); a_rd += 2;\ + cd1 = VDOT_LANE_I32(cd1, ad1, bd1, 0);\ + } + +#define KERNEL_M2N1 KERNEL_M2N1_UNIT(a_head, b_head) +#define KERNEL_M1N2 KERNEL_M2N1_UNIT(b_head, a_head) + +#define SAVE_M2N1 \ + cd1 = VMLA_N_I32(cd1, VLD1_I32(c_ptr), beta);\ + VST1_I32(c_ptr, cd1); + +#define SAVE_M1N2 \ + c_ptr[0] = c_ptr[0] * beta + VGET_LANE_I32(cd1, 0);\ + c_ptr[ldc] = c_ptr[ldc] * beta + VGET_LANE_I32(cd1, 1); + +#define KERNEL_M2N2 \ + I32X2 cd1, cd2;\ + cd1 = cd2 = VDUP_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X8 ad1, bd1;\ + if (kdiv4_left > 0) {\ + ad1 = VLD1(a_rd); a_rd += 2;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + for (; kdiv4_left > 1; kdiv4_left--) {\ + cd1 = VDOT_LANE_I32(cd1, ad1, bd1, 0);\ + cd2 = VDOT_LANE_I32(cd2, ad1, bd1, 1);\ + ad1 = VLD1(a_rd); a_rd += 2;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + if (kdiv4_left > 0) {\ + cd1 = VDOT_LANE_I32(cd1, ad1, bd1, 0);\ + cd2 = VDOT_LANE_I32(cd2, ad1, bd1, 1);\ + } + +#define SAVE_M2N2 \ + cd1 = VMLA_N_I32(cd1, VLD1_I32(c_ptr), beta);\ + cd2 = VMLA_N_I32(cd2, VLD1_I32(c_ptr + ldc), beta);\ + VST1_I32(c_ptr, cd1); VST1_I32(c_ptr + ldc, cd2); + +#define KERNEL_M4N1_UNIT(a_head, b_head) \ + I32X4 cq1, cq2;\ + cq1 = cq2 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2;\ + I8X8 bd1;\ + if (kdiv4_left > 1) {\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4); a_rd += 8;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + for (; kdiv4_left > 3; kdiv4_left -= 2) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0); aq1 = VLD1Q(a_rd);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 1); aq2 = VLD1Q(a_rd + 4);\ + a_rd += 8; bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + if (kdiv4_left > 1) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 1);\ + kdiv4_left -= 2;\ + }\ + cq1 = VADDQ_I32(cq1, cq2);\ + if (kdiv4_left > 0) {\ + UNALIGNED_LD4B_SREG(bd1, b_rd); b_rd++;\ + aq1 = VLD1Q(a_rd); a_rd += 4;\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + } + +#define KERNEL_M4N1 KERNEL_M4N1_UNIT(a_head, b_head) +#define KERNEL_M1N4 KERNEL_M4N1_UNIT(b_head, a_head) + +#define SAVE_M4N1 \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_ptr), beta);\ + VST1Q_I32(c_ptr, cq1); + +#define UNIT_SAVE_M1N4(cq1) \ + c_tmp[0] = c_tmp[0] * beta + VGETQ_LANE_I32(cq1, 0);\ + c_tmp[ldc] = c_tmp[ldc] * beta + VGETQ_LANE_I32(cq1, 1);\ + c_tmp += ldc * 2;\ + c_tmp[0] = c_tmp[0] * beta + VGETQ_LANE_I32(cq1, 2);\ + c_tmp[ldc] = c_tmp[ldc] * beta + VGETQ_LANE_I32(cq1, 3);\ + c_tmp += ldc * 2; + +#define SAVE_M1N4 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M1N4(cq1) + +#define KERNEL_M4N2_UNIT(a_head, b_head) \ + I32X4 cq1, cq2;\ + cq1 = cq2 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1; I8X8 bd1;\ + if (kdiv4_left > 0) {\ + aq1 = VLD1Q(a_rd); a_rd += 4;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + for (; kdiv4_left > 1; kdiv4_left--) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq2 = VDOTQ_LANE_I32(cq2, aq1, bd1, 1);\ + aq1 = VLD1Q(a_rd); a_rd += 4;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + if (kdiv4_left > 0) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq2 = VDOTQ_LANE_I32(cq2, aq1, bd1, 1);\ + } + +#define KERNEL_M4N2 KERNEL_M4N2_UNIT(a_head, b_head) +#define KERNEL_M2N4 KERNEL_M4N2_UNIT(b_head, a_head) + +#define UNIT_SAVE_M4N2(cq1, cq2) \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_tmp), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_tmp + ldc), beta);\ + VST1Q_I32(c_tmp, cq1); VST1Q_I32(c_tmp + ldc, cq2);\ + c_tmp += ldc * 2; + +#define SAVE_M4N2 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M4N2(cq1, cq2) + +#define UNIT_SAVE_M2N4(cq1, cq2) {\ + I32X4 t1 = VZIP1Q_I32(cq1, cq2);\ + I32X4 t2 = VZIP2Q_I32(cq1, cq2);\ + I32X2 l1 = VMLA_N_I32(VGET_LOW_I32(t1), VLD1_I32(c_tmp), beta);\ + I32X2 l2 = VMLA_N_I32(VGET_HIGH_I32(t1), VLD1_I32(c_tmp + ldc), beta);\ + VST1_I32(c_tmp, l1); VST1_I32(c_tmp + ldc, l2); c_tmp += ldc * 2;\ + I32X2 l3 = VMLA_N_I32(VGET_LOW_I32(t2), VLD1_I32(c_tmp), beta);\ + I32X2 l4 = VMLA_N_I32(VGET_HIGH_I32(t2), VLD1_I32(c_tmp + ldc), beta);\ + VST1_I32(c_tmp, l3); VST1_I32(c_tmp + ldc, l4); c_tmp += ldc * 2;\ +} + +#define SAVE_M2N4 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M2N4(cq1, cq2) + +#define KERNEL_M4N4 \ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, bq1;\ + if (kdiv4_left > 0) {\ + aq1 = VLD1Q(a_rd); a_rd += 4;\ + bq1 = VLD1Q(b_rd); b_rd += 4;\ + }\ + for (; kdiv4_left > 1; kdiv4_left--) {\ + cq1 = VDOTQ_LANEQ_I32(cq1, aq1, bq1, 0);\ + cq2 = VDOTQ_LANEQ_I32(cq2, aq1, bq1, 1);\ + cq3 = VDOTQ_LANEQ_I32(cq3, aq1, bq1, 2);\ + cq4 = VDOTQ_LANEQ_I32(cq4, aq1, bq1, 3);\ + aq1 = VLD1Q(a_rd); a_rd += 4;\ + bq1 = VLD1Q(b_rd); b_rd += 4;\ + }\ + if (kdiv4_left > 0) {\ + cq1 = VDOTQ_LANEQ_I32(cq1, aq1, bq1, 0);\ + cq2 = VDOTQ_LANEQ_I32(cq2, aq1, bq1, 1);\ + cq3 = VDOTQ_LANEQ_I32(cq3, aq1, bq1, 2);\ + cq4 = VDOTQ_LANEQ_I32(cq4, aq1, bq1, 3);\ + } + +#define SAVE_M4N4 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M4N2(cq1, cq2) UNIT_SAVE_M4N2(cq3, cq4) + +#define KERNEL_M8N1_UNIT(a_head, b_head) \ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2, aq3, aq4;\ + I8X8 bd1;\ + if (kdiv4_left > 1) {\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4);\ + aq3 = VLD1Q(a_rd + 8); aq4 = VLD1Q(a_rd + 12); a_rd += 16;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + for (; kdiv4_left > 3; kdiv4_left -= 2) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0); aq1 = VLD1Q(a_rd);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0); aq2 = VLD1Q(a_rd + 4);\ + cq3 = VDOTQ_LANE_I32(cq3, aq3, bd1, 1); aq3 = VLD1Q(a_rd + 8);\ + cq4 = VDOTQ_LANE_I32(cq4, aq4, bd1, 1); aq4 = VLD1Q(a_rd + 12);\ + a_rd += 16; bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + if (kdiv4_left > 1) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0);\ + cq3 = VDOTQ_LANE_I32(cq3, aq3, bd1, 1);\ + cq4 = VDOTQ_LANE_I32(cq4, aq4, bd1, 1);\ + kdiv4_left -= 2;\ + }\ + cq1 = VADDQ_I32(cq1, cq3); cq2 = VADDQ_I32(cq2, cq4);\ + if (kdiv4_left > 0) {\ + UNALIGNED_LD4B_SREG(bd1, b_rd); b_rd++;\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4); a_rd += 8;\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0);\ + } + +#define KERNEL_M8N1 KERNEL_M8N1_UNIT(a_head, b_head) +#define KERNEL_M1N8 KERNEL_M8N1_UNIT(b_head, a_head) + +#define SAVE_M8N1 \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_ptr), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_ptr + 4), beta);\ + VST1Q_I32(c_ptr, cq1); VST1Q_I32(c_ptr + 4, cq2); + +#define SAVE_M1N8 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M1N4(cq1) UNIT_SAVE_M1N4(cq2) + +#define KERNEL_M8N2_UNIT(a_head, b_head) \ + I32X4 cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2;\ + I8X8 bd1;\ + if (kdiv4_left > 0) {\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4); a_rd += 8;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + for (; kdiv4_left > 1; kdiv4_left--) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq3 = VDOTQ_LANE_I32(cq3, aq1, bd1, 1);\ + aq1 = VLD1Q(a_rd);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0);\ + cq4 = VDOTQ_LANE_I32(cq4, aq2, bd1, 1);\ + aq2 = VLD1Q(a_rd + 4); a_rd += 8;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + if (kdiv4_left > 0) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq3 = VDOTQ_LANE_I32(cq3, aq1, bd1, 1);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0);\ + cq4 = VDOTQ_LANE_I32(cq4, aq2, bd1, 1);\ + } + +#define KERNEL_M8N2 KERNEL_M8N2_UNIT(a_head, b_head) +#define KERNEL_M2N8 KERNEL_M8N2_UNIT(b_head, a_head) + +#define UNIT_SAVE_M8N2(cq1, cq2, cq3, cq4) \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_tmp), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_tmp + 4), beta);\ + cq3 = VMLAQ_N_I32(cq3, VLD1Q_I32(c_tmp + ldc), beta);\ + cq4 = VMLAQ_N_I32(cq4, VLD1Q_I32(c_tmp + ldc + 4), beta);\ + VST1Q_I32(c_tmp, cq1); VST1Q_I32(c_tmp + 4, cq2);\ + VST1Q_I32(c_tmp + ldc, cq3); VST1Q_I32(c_tmp + ldc + 4, cq4);\ + c_tmp += ldc * 2; + +#define SAVE_M8N2 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M8N2(cq1, cq2, cq3, cq4) + +#define SAVE_M2N8 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M2N4(cq1, cq3) UNIT_SAVE_M2N4(cq2, cq4) + +#define KERNEL_M8N4_UNIT(a_head, b_head) \ + I32X4 cq1, cq2, cq3, cq4, cq5, cq6, cq7, cq8;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = cq7 = cq8 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2, bq1;\ + if (kdiv4_left > 0) {\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4); a_rd += 8;\ + bq1 = VLD1Q(b_rd); b_rd += 4;\ + }\ + for (; kdiv4_left > 1; kdiv4_left--) {\ + cq1 = VDOTQ_LANEQ_I32(cq1, aq1, bq1, 0);\ + cq3 = VDOTQ_LANEQ_I32(cq3, aq1, bq1, 1);\ + cq5 = VDOTQ_LANEQ_I32(cq5, aq1, bq1, 2);\ + cq7 = VDOTQ_LANEQ_I32(cq7, aq1, bq1, 3);\ + aq1 = VLD1Q(a_rd);\ + cq2 = VDOTQ_LANEQ_I32(cq2, aq2, bq1, 0);\ + cq4 = VDOTQ_LANEQ_I32(cq4, aq2, bq1, 1);\ + cq6 = VDOTQ_LANEQ_I32(cq6, aq2, bq1, 2);\ + cq8 = VDOTQ_LANEQ_I32(cq8, aq2, bq1, 3);\ + aq2 = VLD1Q(a_rd + 4); a_rd += 8;\ + bq1 = VLD1Q(b_rd); b_rd += 4;\ + }\ + if (kdiv4_left > 0) {\ + cq1 = VDOTQ_LANEQ_I32(cq1, aq1, bq1, 0);\ + cq3 = VDOTQ_LANEQ_I32(cq3, aq1, bq1, 1);\ + cq5 = VDOTQ_LANEQ_I32(cq5, aq1, bq1, 2);\ + cq7 = VDOTQ_LANEQ_I32(cq7, aq1, bq1, 3);\ + cq2 = VDOTQ_LANEQ_I32(cq2, aq2, bq1, 0);\ + cq4 = VDOTQ_LANEQ_I32(cq4, aq2, bq1, 1);\ + cq6 = VDOTQ_LANEQ_I32(cq6, aq2, bq1, 2);\ + cq8 = VDOTQ_LANEQ_I32(cq8, aq2, bq1, 3);\ + } + +#define KERNEL_M8N4 KERNEL_M8N4_UNIT(a_head, b_head) +#define KERNEL_M4N8 KERNEL_M8N4_UNIT(b_head, a_head) + +#define SAVE_M8N4 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M8N2(cq1, cq2, cq3, cq4) UNIT_SAVE_M8N2(cq5, cq6, cq7, cq8) + +#define UNIT_SAVE_M4N4_TRANS(cq1, cq2, cq3, cq4) {\ + I32X4 l1 = VLD1Q_I32(c_tmp);\ + I32X4 l2 = VLD1Q_I32(c_tmp + ldc);\ + I32X4 l3 = VLD1Q_I32(c_tmp + ldc * 2);\ + I32X4 l4 = VLD1Q_I32(c_tmp + ldc * 3);\ + I64X2 t1 = VREINTERPRETQ_I64_I32(VZIP1Q_I32(cq1, cq2));\ + I64X2 t2 = VREINTERPRETQ_I64_I32(VZIP1Q_I32(cq3, cq4));\ + I64X2 t3 = VREINTERPRETQ_I64_I32(VZIP2Q_I32(cq1, cq2));\ + I64X2 t4 = VREINTERPRETQ_I64_I32(VZIP2Q_I32(cq3, cq4));\ + cq1 = VREINTERPRETQ_I32_I64(VZIP1Q_I64(t1, t2));\ + cq2 = VREINTERPRETQ_I32_I64(VZIP2Q_I64(t1, t2));\ + cq3 = VREINTERPRETQ_I32_I64(VZIP1Q_I64(t3, t4));\ + cq4 = VREINTERPRETQ_I32_I64(VZIP2Q_I64(t3, t4));\ + cq1 = VMLAQ_N_I32(cq1, l1, beta); cq2 = VMLAQ_N_I32(cq2, l2, beta);\ + cq3 = VMLAQ_N_I32(cq3, l3, beta); cq4 = VMLAQ_N_I32(cq4, l4, beta);\ + VST1Q_I32(c_tmp, cq1); VST1Q_I32(c_tmp + ldc, cq2);\ + VST1Q_I32(c_tmp + ldc * 2, cq3); VST1Q_I32(c_tmp + ldc * 3, cq4);\ + c_tmp += ldc * 4;\ +} + +#define SAVE_M4N8 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M4N4_TRANS(cq1, cq3, cq5, cq7)\ + UNIT_SAVE_M4N4_TRANS(cq2, cq4, cq6, cq8) + +#define KERNEL_M8N8 \ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + I32X4 cq09, cq10, cq11, cq12, cq13, cq14, cq15, cq16;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = cq07 = cq08 = VDUPQ_N_I32(0);\ + cq09 = cq10 = cq11 = cq12 = cq13 = cq14 = cq15 = cq16 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2, bq1, bq2;\ + if (kdiv4_left > 0) {\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4); a_rd += 8;\ + bq1 = VLD1Q(b_rd); bq2 = VLD1Q(b_rd + 4); b_rd += 8;\ + }\ + for (; kdiv4_left > 1; kdiv4_left--) {\ + cq01 = VDOTQ_LANEQ_I32(cq01, aq1, bq1, 0);\ + cq03 = VDOTQ_LANEQ_I32(cq03, aq1, bq1, 1);\ + cq05 = VDOTQ_LANEQ_I32(cq05, aq1, bq1, 2);\ + cq07 = VDOTQ_LANEQ_I32(cq07, aq1, bq1, 3);\ + cq09 = VDOTQ_LANEQ_I32(cq09, aq1, bq2, 0);\ + cq11 = VDOTQ_LANEQ_I32(cq11, aq1, bq2, 1);\ + cq13 = VDOTQ_LANEQ_I32(cq13, aq1, bq2, 2);\ + cq15 = VDOTQ_LANEQ_I32(cq15, aq1, bq2, 3);\ + aq1 = VLD1Q(a_rd);\ + cq02 = VDOTQ_LANEQ_I32(cq02, aq2, bq1, 0);\ + cq04 = VDOTQ_LANEQ_I32(cq04, aq2, bq1, 1);\ + cq06 = VDOTQ_LANEQ_I32(cq06, aq2, bq1, 2);\ + cq08 = VDOTQ_LANEQ_I32(cq08, aq2, bq1, 3);\ + bq1 = VLD1Q(b_rd);\ + cq10 = VDOTQ_LANEQ_I32(cq10, aq2, bq2, 0);\ + cq12 = VDOTQ_LANEQ_I32(cq12, aq2, bq2, 1);\ + cq14 = VDOTQ_LANEQ_I32(cq14, aq2, bq2, 2);\ + cq16 = VDOTQ_LANEQ_I32(cq16, aq2, bq2, 3);\ + aq2 = VLD1Q(a_rd + 4); a_rd += 8;\ + bq2 = VLD1Q(b_rd + 4); b_rd += 8;\ + }\ + if (kdiv4_left > 0) {\ + cq01 = VDOTQ_LANEQ_I32(cq01, aq1, bq1, 0);\ + cq03 = VDOTQ_LANEQ_I32(cq03, aq1, bq1, 1);\ + cq05 = VDOTQ_LANEQ_I32(cq05, aq1, bq1, 2);\ + cq07 = VDOTQ_LANEQ_I32(cq07, aq1, bq1, 3);\ + cq09 = VDOTQ_LANEQ_I32(cq09, aq1, bq2, 0);\ + cq11 = VDOTQ_LANEQ_I32(cq11, aq1, bq2, 1);\ + cq13 = VDOTQ_LANEQ_I32(cq13, aq1, bq2, 2);\ + cq15 = VDOTQ_LANEQ_I32(cq15, aq1, bq2, 3);\ + cq02 = VDOTQ_LANEQ_I32(cq02, aq2, bq1, 0);\ + cq04 = VDOTQ_LANEQ_I32(cq04, aq2, bq1, 1);\ + cq06 = VDOTQ_LANEQ_I32(cq06, aq2, bq1, 2);\ + cq08 = VDOTQ_LANEQ_I32(cq08, aq2, bq1, 3);\ + cq10 = VDOTQ_LANEQ_I32(cq10, aq2, bq2, 0);\ + cq12 = VDOTQ_LANEQ_I32(cq12, aq2, bq2, 1);\ + cq14 = VDOTQ_LANEQ_I32(cq14, aq2, bq2, 2);\ + cq16 = VDOTQ_LANEQ_I32(cq16, aq2, bq2, 3);\ + } + +#define SAVE_M8N8 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M8N2(cq01, cq02, cq03, cq04)\ + UNIT_SAVE_M8N2(cq05, cq06, cq07, cq08)\ + UNIT_SAVE_M8N2(cq09, cq10, cq11, cq12)\ + UNIT_SAVE_M8N2(cq13, cq14, cq15, cq16) + +#define KERNEL_M12N1_UNIT(a_head, b_head) \ + I32X4 cq1, cq2, cq3, cq4, cq5, cq6;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2, aq3, aq4, aq5, aq6;\ + I8X8 bd1;\ + if (kdiv4_left > 1) {\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4);\ + aq3 = VLD1Q(a_rd + 8); aq4 = VLD1Q(a_rd + 12);\ + aq5 = VLD1Q(a_rd + 16); aq6 = VLD1Q(a_rd + 20); a_rd += 24;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + for (; kdiv4_left > 3; kdiv4_left -= 2) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0); aq1 = VLD1Q(a_rd);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0); aq2 = VLD1Q(a_rd + 4);\ + cq3 = VDOTQ_LANE_I32(cq3, aq3, bd1, 0); aq3 = VLD1Q(a_rd + 8);\ + cq4 = VDOTQ_LANE_I32(cq4, aq4, bd1, 1); aq4 = VLD1Q(a_rd + 12);\ + cq5 = VDOTQ_LANE_I32(cq5, aq5, bd1, 1); aq5 = VLD1Q(a_rd + 16);\ + cq6 = VDOTQ_LANE_I32(cq6, aq6, bd1, 1); aq6 = VLD1Q(a_rd + 20);\ + a_rd += 24; bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + if (kdiv4_left > 1) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0);\ + cq3 = VDOTQ_LANE_I32(cq3, aq3, bd1, 0);\ + cq4 = VDOTQ_LANE_I32(cq4, aq4, bd1, 1);\ + cq5 = VDOTQ_LANE_I32(cq5, aq5, bd1, 1);\ + cq6 = VDOTQ_LANE_I32(cq6, aq6, bd1, 1);\ + kdiv4_left -= 2;\ + }\ + cq1 = VADDQ_I32(cq1, cq4);\ + cq2 = VADDQ_I32(cq2, cq5);\ + cq3 = VADDQ_I32(cq3, cq6);\ + if (kdiv4_left > 0) {\ + UNALIGNED_LD4B_SREG(bd1, b_rd); b_rd++;\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4);\ + aq3 = VLD1Q(a_rd + 8); a_rd += 12;\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0);\ + cq3 = VDOTQ_LANE_I32(cq3, aq3, bd1, 0);\ + } + +#define KERNEL_M12N1 KERNEL_M12N1_UNIT(a_head, b_head) +#define KERNEL_M1N12 KERNEL_M12N1_UNIT(b_head, a_head) + +#define SAVE_M12N1 \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_ptr), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_ptr + 4), beta);\ + cq3 = VMLAQ_N_I32(cq3, VLD1Q_I32(c_ptr + 8), beta);\ + VST1Q_I32(c_ptr, cq1); VST1Q_I32(c_ptr + 4, cq2); VST1Q_I32(c_ptr + 8, cq3); + +#define SAVE_M1N12 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M1N4(cq1)\ + UNIT_SAVE_M1N4(cq2) UNIT_SAVE_M1N4(cq3) + +#define KERNEL_M12N2_UNIT(a_head, b_head) \ + I32X4 cq1, cq2, cq3, cq4, cq5, cq6;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2, aq3;\ + I8X8 bd1;\ + if (kdiv4_left > 0) {\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4);\ + aq3 = VLD1Q(a_rd + 8); a_rd += 12;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + for (; kdiv4_left > 1; kdiv4_left--) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq4 = VDOTQ_LANE_I32(cq4, aq1, bd1, 1);\ + aq1 = VLD1Q(a_rd);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0);\ + cq5 = VDOTQ_LANE_I32(cq5, aq2, bd1, 1);\ + aq2 = VLD1Q(a_rd + 4);\ + cq3 = VDOTQ_LANE_I32(cq3, aq3, bd1, 0);\ + cq6 = VDOTQ_LANE_I32(cq6, aq3, bd1, 1);\ + aq3 = VLD1Q(a_rd + 8); a_rd += 12;\ + bd1 = VLD1(b_rd); b_rd += 2;\ + }\ + if (kdiv4_left > 0) {\ + cq1 = VDOTQ_LANE_I32(cq1, aq1, bd1, 0);\ + cq4 = VDOTQ_LANE_I32(cq4, aq1, bd1, 1);\ + cq2 = VDOTQ_LANE_I32(cq2, aq2, bd1, 0);\ + cq5 = VDOTQ_LANE_I32(cq5, aq2, bd1, 1);\ + cq3 = VDOTQ_LANE_I32(cq3, aq3, bd1, 0);\ + cq6 = VDOTQ_LANE_I32(cq6, aq3, bd1, 1);\ + } + +#define KERNEL_M12N2 KERNEL_M12N2_UNIT(a_head, b_head) +#define KERNEL_M2N12 KERNEL_M12N2_UNIT(b_head, a_head) + +#define UNIT_SAVE_M12N2(cq1, cq2, cq3, cq4, cq5, cq6) \ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_tmp), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_tmp + 4), beta);\ + cq3 = VMLAQ_N_I32(cq3, VLD1Q_I32(c_tmp + 8), beta);\ + cq4 = VMLAQ_N_I32(cq4, VLD1Q_I32(c_tmp + ldc), beta);\ + cq5 = VMLAQ_N_I32(cq5, VLD1Q_I32(c_tmp + ldc + 4), beta);\ + cq6 = VMLAQ_N_I32(cq6, VLD1Q_I32(c_tmp + ldc + 8), beta);\ + VST1Q_I32(c_tmp, cq1); VST1Q_I32(c_tmp + 4, cq2);\ + VST1Q_I32(c_tmp + 8, cq3); VST1Q_I32(c_tmp + ldc, cq4);\ + VST1Q_I32(c_tmp + ldc + 4, cq5); VST1Q_I32(c_tmp + ldc + 8, cq6);\ + c_tmp += ldc * 2; + +#define SAVE_M12N2 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M12N2(cq1, cq2, cq3, cq4, cq5, cq6) + +#define SAVE_M2N12 \ + I32 *c_tmp = c_ptr; UNIT_SAVE_M2N4(cq1, cq4) UNIT_SAVE_M2N4(cq2, cq5)\ + UNIT_SAVE_M2N4(cq3, cq6) + +#define KERNEL_M12N4_UNIT(a_head, b_head) \ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06;\ + I32X4 cq07, cq08, cq09, cq10, cq11, cq12;\ + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = VDUPQ_N_I32(0);\ + cq07 = cq08 = cq09 = cq10 = cq11 = cq12 = VDUPQ_N_I32(0);\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + I8X16 aq1, aq2, aq3, bq1;\ + if (kdiv4_left > 0) {\ + aq1 = VLD1Q(a_rd); aq2 = VLD1Q(a_rd + 4);\ + aq3 = VLD1Q(a_rd + 8); a_rd += 12;\ + bq1 = VLD1Q(b_rd); b_rd += 4;\ + }\ + for (; kdiv4_left > 1; kdiv4_left--) {\ + cq01 = VDOTQ_LANEQ_I32(cq01, aq1, bq1, 0);\ + cq04 = VDOTQ_LANEQ_I32(cq04, aq1, bq1, 1);\ + cq07 = VDOTQ_LANEQ_I32(cq07, aq1, bq1, 2);\ + cq10 = VDOTQ_LANEQ_I32(cq10, aq1, bq1, 3);\ + aq1 = VLD1Q(a_rd);\ + cq02 = VDOTQ_LANEQ_I32(cq02, aq2, bq1, 0);\ + cq05 = VDOTQ_LANEQ_I32(cq05, aq2, bq1, 1);\ + cq08 = VDOTQ_LANEQ_I32(cq08, aq2, bq1, 2);\ + cq11 = VDOTQ_LANEQ_I32(cq11, aq2, bq1, 3);\ + aq2 = VLD1Q(a_rd + 4);\ + cq03 = VDOTQ_LANEQ_I32(cq03, aq3, bq1, 0);\ + cq06 = VDOTQ_LANEQ_I32(cq06, aq3, bq1, 1);\ + cq09 = VDOTQ_LANEQ_I32(cq09, aq3, bq1, 2);\ + cq12 = VDOTQ_LANEQ_I32(cq12, aq3, bq1, 3);\ + aq3 = VLD1Q(a_rd + 8); a_rd += 12;\ + bq1 = VLD1Q(b_rd); b_rd += 4;\ + }\ + if (kdiv4_left > 0) {\ + cq01 = VDOTQ_LANEQ_I32(cq01, aq1, bq1, 0);\ + cq04 = VDOTQ_LANEQ_I32(cq04, aq1, bq1, 1);\ + cq07 = VDOTQ_LANEQ_I32(cq07, aq1, bq1, 2);\ + cq10 = VDOTQ_LANEQ_I32(cq10, aq1, bq1, 3);\ + cq02 = VDOTQ_LANEQ_I32(cq02, aq2, bq1, 0);\ + cq05 = VDOTQ_LANEQ_I32(cq05, aq2, bq1, 1);\ + cq08 = VDOTQ_LANEQ_I32(cq08, aq2, bq1, 2);\ + cq11 = VDOTQ_LANEQ_I32(cq11, aq2, bq1, 3);\ + cq03 = VDOTQ_LANEQ_I32(cq03, aq3, bq1, 0);\ + cq06 = VDOTQ_LANEQ_I32(cq06, aq3, bq1, 1);\ + cq09 = VDOTQ_LANEQ_I32(cq09, aq3, bq1, 2);\ + cq12 = VDOTQ_LANEQ_I32(cq12, aq3, bq1, 3);\ + } + +#define KERNEL_M12N4 KERNEL_M12N4_UNIT(a_head, b_head) +#define KERNEL_M4N12 KERNEL_M12N4_UNIT(b_head, a_head) + +#define SAVE_M12N4 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M12N2(cq01, cq02, cq03, cq04, cq05, cq06)\ + UNIT_SAVE_M12N2(cq07, cq08, cq09, cq10, cq11, cq12) + +#define SAVE_M4N12 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M4N4_TRANS(cq01, cq04, cq07, cq10)\ + UNIT_SAVE_M4N4_TRANS(cq02, cq05, cq08, cq11)\ + UNIT_SAVE_M4N4_TRANS(cq03, cq06, cq09, cq12) + +#define LDQ_STEP1_IDX_A55(v, ptr, idx) "ldr d"#v",["#ptr"],#"#idx"\n\t" +#define LDQ_STEP1_OFF_A55(v, ptr, off) "ldr d"#v",["#ptr",#"#off"]\n\t" +#define LDQ_STEP2_A55(r, ptr, off) "ldr x"#r",["#ptr",#"#off"]\n\t" +#define LDQ_STEP3_A55(r, v) "fmov v"#v".d[1],x"#r"\n\t" +#define LDQ_STEP1_IDX_A76(v, ptr, idx) "ldr q"#v",["#ptr"],#"#idx"\n\t" +#define LDQ_STEP1_OFF_A76(v, ptr, off) "ldr q"#v",["#ptr",#"#off"]\n\t" +#define LDQ_STEP2_A76(r, ptr, off) "" +#define LDQ_STEP3_A76(r, v) "" + +#define KERNEL_M8N12_TEMPLATE(cpu) \ + I32 *c_pref = c_ptr + 7; PREF_N12\ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06;\ + I32X4 cq07, cq08, cq09, cq10, cq11, cq12;\ + I32X4 cq13, cq14, cq15, cq16, cq17, cq18;\ + I32X4 cq19, cq20, cq21, cq22, cq23, cq24;\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + __asm__ __volatile__(\ + "movi %[cq01].16b,#0; movi %[cq02].16b,#0\n\t"\ + "movi %[cq03].16b,#0; movi %[cq04].16b,#0\n\t"\ + "movi %[cq05].16b,#0; movi %[cq06].16b,#0\n\t"\ + "movi %[cq07].16b,#0; movi %[cq08].16b,#0\n\t"\ + "movi %[cq09].16b,#0; movi %[cq10].16b,#0\n\t"\ + "movi %[cq11].16b,#0; movi %[cq12].16b,#0\n\t"\ + "movi %[cq13].16b,#0; movi %[cq14].16b,#0\n\t"\ + "movi %[cq15].16b,#0; movi %[cq16].16b,#0\n\t"\ + "movi %[cq17].16b,#0; movi %[cq18].16b,#0\n\t"\ + "movi %[cq19].16b,#0; movi %[cq20].16b,#0\n\t"\ + "movi %[cq21].16b,#0; movi %[cq22].16b,#0\n\t"\ + "movi %[cq23].16b,#0; movi %[cq24].16b,#0\n\t"\ + "cmp %w[kdiv4_left],#1; b.lt 4f\n\t"\ + "ldr q0,[%[a_rd]]; ldr q1,[%[a_rd],#16]; add %[a_rd],%[a_rd],#32\n\t"\ + "ldr q4,[%[b_rd]]; ldr q5,[%[b_rd],#16]; add %[b_rd],%[b_rd],#48\n\t"\ + "cmp %w[kdiv4_left],#3; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + ""IDOT" %[cq01].4s,v0.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(6, %[b_rd], -16)\ + ""IDOT" %[cq02].4s,v1.16b,v4.4b[0]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -8)\ + ""IDOT" %[cq03].4s,v0.16b,v4.4b[1]\n\t" LDQ_STEP1_IDX_##cpu(2, %[a_rd], 64)\ + ""IDOT" %[cq04].4s,v1.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq05].4s,v0.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq06].4s,v1.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq07].4s,v0.16b,v4.4b[3]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -56)\ + ""IDOT" %[cq08].4s,v1.16b,v4.4b[3]\n\t" LDQ_STEP3_##cpu(1, 6)\ + ""IDOT" %[cq09].4s,v0.16b,v5.4b[0]\n\t" LDQ_STEP1_IDX_##cpu(4, %[b_rd], 96)\ + ""IDOT" %[cq10].4s,v1.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq11].4s,v0.16b,v5.4b[1]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -88)\ + ""IDOT" %[cq12].4s,v1.16b,v5.4b[1]\n\t" LDQ_STEP3_##cpu(0, 2)\ + ""IDOT" %[cq13].4s,v0.16b,v5.4b[2]\n\t" LDQ_STEP1_OFF_##cpu(3, %[a_rd], -48)\ + ""IDOT" %[cq14].4s,v1.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq15].4s,v0.16b,v5.4b[3]\n\t"\ + ""IDOT" %[cq16].4s,v1.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(1, 4)\ + ""IDOT" %[cq17].4s,v0.16b,v6.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(5, %[b_rd], -80)\ + ""IDOT" %[cq18].4s,v1.16b,v6.4b[0]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -72)\ + ""IDOT" %[cq19].4s,v0.16b,v6.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -40)\ + ""IDOT" %[cq20].4s,v1.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq21].4s,v0.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq22].4s,v1.16b,v6.4b[2]\n\t" LDQ_STEP3_##cpu(0, 3)\ + ""IDOT" %[cq23].4s,v0.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq24].4s,v1.16b,v6.4b[3]\n\t" LDQ_STEP3_##cpu(1, 5)\ + ""IDOT" %[cq01].4s,v2.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(6, %[b_rd], -64)\ + ""IDOT" %[cq02].4s,v3.16b,v4.4b[0]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -56)\ + ""IDOT" %[cq03].4s,v2.16b,v4.4b[1]\n\t" LDQ_STEP1_OFF_##cpu(0, %[a_rd], -32)\ + ""IDOT" %[cq04].4s,v3.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq05].4s,v2.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq06].4s,v3.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq07].4s,v2.16b,v4.4b[3]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -24)\ + ""IDOT" %[cq08].4s,v3.16b,v4.4b[3]\n\t" LDQ_STEP3_##cpu(1, 6)\ + ""IDOT" %[cq09].4s,v2.16b,v5.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(4, %[b_rd], -48)\ + ""IDOT" %[cq10].4s,v3.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq11].4s,v2.16b,v5.4b[1]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -40)\ + ""IDOT" %[cq12].4s,v3.16b,v5.4b[1]\n\t" LDQ_STEP3_##cpu(0, 0)\ + ""IDOT" %[cq13].4s,v2.16b,v5.4b[2]\n\t" LDQ_STEP1_OFF_##cpu(1, %[a_rd], -16)\ + ""IDOT" %[cq14].4s,v3.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq15].4s,v2.16b,v5.4b[3]\n\t"\ + ""IDOT" %[cq16].4s,v3.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(1, 4)\ + ""IDOT" %[cq17].4s,v2.16b,v6.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(5, %[b_rd], -32)\ + ""IDOT" %[cq18].4s,v3.16b,v6.4b[0]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -24)\ + ""IDOT" %[cq19].4s,v2.16b,v6.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -8)\ + ""IDOT" %[cq20].4s,v3.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq21].4s,v2.16b,v6.4b[2]; sub %w[kdiv4_left],%w[kdiv4_left],#2\n\t"\ + ""IDOT" %[cq22].4s,v3.16b,v6.4b[2]\n\t" LDQ_STEP3_##cpu(0, 1)\ + ""IDOT" %[cq23].4s,v2.16b,v6.4b[3]; cmp %w[kdiv4_left],#3\n\t"\ + ""IDOT" %[cq24].4s,v3.16b,v6.4b[3]\n\t" LDQ_STEP3_##cpu(1, 5)\ + "b.ge 1b; 2:\n\t"\ + "cmp %w[kdiv4_left],#2; b.lt 3f\n\t"\ + ""IDOT" %[cq01].4s,v0.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(6, %[b_rd], -16)\ + ""IDOT" %[cq02].4s,v1.16b,v4.4b[0]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -8)\ + ""IDOT" %[cq03].4s,v0.16b,v4.4b[1]\n\t" LDQ_STEP1_IDX_##cpu(2, %[a_rd], 32)\ + ""IDOT" %[cq04].4s,v1.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq05].4s,v0.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq06].4s,v1.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq07].4s,v0.16b,v4.4b[3]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -24)\ + ""IDOT" %[cq08].4s,v1.16b,v4.4b[3]\n\t" LDQ_STEP3_##cpu(1, 6)\ + ""IDOT" %[cq09].4s,v0.16b,v5.4b[0]\n\t" LDQ_STEP1_IDX_##cpu(4, %[b_rd], 48)\ + ""IDOT" %[cq10].4s,v1.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq11].4s,v0.16b,v5.4b[1]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -40)\ + ""IDOT" %[cq12].4s,v1.16b,v5.4b[1]\n\t" LDQ_STEP3_##cpu(0, 2)\ + ""IDOT" %[cq13].4s,v0.16b,v5.4b[2]\n\t" LDQ_STEP1_OFF_##cpu(3, %[a_rd], -16)\ + ""IDOT" %[cq14].4s,v1.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq15].4s,v0.16b,v5.4b[3]\n\t"\ + ""IDOT" %[cq16].4s,v1.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(1, 4)\ + ""IDOT" %[cq17].4s,v0.16b,v6.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(5, %[b_rd], -32)\ + ""IDOT" %[cq18].4s,v1.16b,v6.4b[0]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -24)\ + ""IDOT" %[cq19].4s,v0.16b,v6.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -8)\ + ""IDOT" %[cq20].4s,v1.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq21].4s,v0.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq22].4s,v1.16b,v6.4b[2]\n\t" LDQ_STEP3_##cpu(0, 3)\ + ""IDOT" %[cq23].4s,v0.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq24].4s,v1.16b,v6.4b[3]\n\t" LDQ_STEP3_##cpu(1, 5)\ + ""IDOT" %[cq01].4s,v2.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(6, %[b_rd], -16)\ + ""IDOT" %[cq02].4s,v3.16b,v4.4b[0]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -8)\ + ""IDOT" %[cq03].4s,v2.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq04].4s,v3.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq05].4s,v2.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq06].4s,v3.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq07].4s,v2.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq08].4s,v3.16b,v4.4b[3]\n\t" LDQ_STEP3_##cpu(1, 6)\ + ""IDOT" %[cq09].4s,v2.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq10].4s,v3.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq11].4s,v2.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq12].4s,v3.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq13].4s,v2.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq14].4s,v3.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq15].4s,v2.16b,v5.4b[3]\n\t"\ + ""IDOT" %[cq16].4s,v3.16b,v5.4b[3]\n\t"\ + ""IDOT" %[cq17].4s,v2.16b,v6.4b[0]\n\t"\ + ""IDOT" %[cq18].4s,v3.16b,v6.4b[0]\n\t"\ + ""IDOT" %[cq19].4s,v2.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq20].4s,v3.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq21].4s,v2.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq22].4s,v3.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq23].4s,v2.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq24].4s,v3.16b,v6.4b[3]\n\t"\ + "b 4f; 3:\n\t"\ + ""IDOT" %[cq01].4s,v0.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(6, %[b_rd], -16)\ + ""IDOT" %[cq02].4s,v1.16b,v4.4b[0]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -8)\ + ""IDOT" %[cq03].4s,v0.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq04].4s,v1.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq05].4s,v0.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq06].4s,v1.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq07].4s,v0.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq08].4s,v1.16b,v4.4b[3]\n\t" LDQ_STEP3_##cpu(1, 6)\ + ""IDOT" %[cq09].4s,v0.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq10].4s,v1.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq11].4s,v0.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq12].4s,v1.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq13].4s,v0.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq14].4s,v1.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq15].4s,v0.16b,v5.4b[3]\n\t"\ + ""IDOT" %[cq16].4s,v1.16b,v5.4b[3]\n\t"\ + ""IDOT" %[cq17].4s,v0.16b,v6.4b[0]\n\t"\ + ""IDOT" %[cq18].4s,v1.16b,v6.4b[0]\n\t"\ + ""IDOT" %[cq19].4s,v0.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq20].4s,v1.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq21].4s,v0.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq22].4s,v1.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq23].4s,v0.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq24].4s,v1.16b,v6.4b[3]\n\t"\ + "4:\n\t"\ + :[cq01]"=w"(cq01),[cq02]"=w"(cq02),[cq03]"=w"(cq03),[cq04]"=w"(cq04),\ + [cq05]"=w"(cq05),[cq06]"=w"(cq06),[cq07]"=w"(cq07),[cq08]"=w"(cq08),\ + [cq09]"=w"(cq09),[cq10]"=w"(cq10),[cq11]"=w"(cq11),[cq12]"=w"(cq12),\ + [cq13]"=w"(cq13),[cq14]"=w"(cq14),[cq15]"=w"(cq15),[cq16]"=w"(cq16),\ + [cq17]"=w"(cq17),[cq18]"=w"(cq18),[cq19]"=w"(cq19),[cq20]"=w"(cq20),\ + [cq21]"=w"(cq21),[cq22]"=w"(cq22),[cq23]"=w"(cq23),[cq24]"=w"(cq24),\ + [kdiv4_left]"+r"(kdiv4_left),[a_rd]"+r"(a_rd),[b_rd]"+r"(b_rd)\ + ::"cc","memory","x0","x1","v0","v1","v2","v3","v4","v5","v6"); + +#define SAVE_M8N12 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M8N2(cq01, cq02, cq03, cq04)\ + UNIT_SAVE_M8N2(cq05, cq06, cq07, cq08)\ + UNIT_SAVE_M8N2(cq09, cq10, cq11, cq12)\ + UNIT_SAVE_M8N2(cq13, cq14, cq15, cq16)\ + UNIT_SAVE_M8N2(cq17, cq18, cq19, cq20)\ + UNIT_SAVE_M8N2(cq21, cq22, cq23, cq24) + +#define KERNEL_M12N8_TEMPLATE(cpu) \ + I32 *c_pref = c_ptr + 11; PREF_N8\ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06;\ + I32X4 cq07, cq08, cq09, cq10, cq11, cq12;\ + I32X4 cq13, cq14, cq15, cq16, cq17, cq18;\ + I32X4 cq19, cq20, cq21, cq22, cq23, cq24;\ + NORMAL_KERNEL_SETUP(a_head, b_head)\ + __asm__ __volatile__(\ + "movi %[cq01].16b,#0; movi %[cq02].16b,#0\n\t"\ + "movi %[cq03].16b,#0; movi %[cq04].16b,#0\n\t"\ + "movi %[cq05].16b,#0; movi %[cq06].16b,#0\n\t"\ + "movi %[cq07].16b,#0; movi %[cq08].16b,#0\n\t"\ + "movi %[cq09].16b,#0; movi %[cq10].16b,#0\n\t"\ + "movi %[cq11].16b,#0; movi %[cq12].16b,#0\n\t"\ + "movi %[cq13].16b,#0; movi %[cq14].16b,#0\n\t"\ + "movi %[cq15].16b,#0; movi %[cq16].16b,#0\n\t"\ + "movi %[cq17].16b,#0; movi %[cq18].16b,#0\n\t"\ + "movi %[cq19].16b,#0; movi %[cq20].16b,#0\n\t"\ + "movi %[cq21].16b,#0; movi %[cq22].16b,#0\n\t"\ + "movi %[cq23].16b,#0; movi %[cq24].16b,#0\n\t"\ + "cmp %w[kdiv4_left],#1; b.lt 4f\n\t"\ + "ldr q0,[%[a_rd]]; ldr q1,[%[a_rd],#16]; add %[a_rd],%[a_rd],#48\n\t"\ + "ldr q4,[%[b_rd]]; ldr q5,[%[b_rd],#16]; add %[b_rd],%[b_rd],#32\n\t"\ + "cmp %w[kdiv4_left],#3; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + ""IDOT" %[cq01].4s,v0.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(2, %[a_rd], -16)\ + ""IDOT" %[cq04].4s,v0.16b,v4.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -8)\ + ""IDOT" %[cq07].4s,v0.16b,v4.4b[2]\n\t" LDQ_STEP1_IDX_##cpu(6, %[b_rd], 64)\ + ""IDOT" %[cq10].4s,v0.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq13].4s,v0.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq16].4s,v0.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq19].4s,v0.16b,v5.4b[2]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -56)\ + ""IDOT" %[cq22].4s,v0.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(0, 2)\ + ""IDOT" %[cq02].4s,v1.16b,v4.4b[0]\n\t" LDQ_STEP1_IDX_##cpu(0, %[a_rd], 96)\ + ""IDOT" %[cq05].4s,v1.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq08].4s,v1.16b,v4.4b[2]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -88)\ + ""IDOT" %[cq11].4s,v1.16b,v4.4b[3]\n\t" LDQ_STEP3_##cpu(1, 6)\ + ""IDOT" %[cq14].4s,v1.16b,v5.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(7, %[b_rd], -48)\ + ""IDOT" %[cq17].4s,v1.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq20].4s,v1.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq23].4s,v1.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(0, 0)\ + ""IDOT" %[cq03].4s,v2.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(1, %[a_rd], -80)\ + ""IDOT" %[cq06].4s,v2.16b,v4.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -72)\ + ""IDOT" %[cq09].4s,v2.16b,v4.4b[2]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -40)\ + ""IDOT" %[cq12].4s,v2.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq15].4s,v2.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq18].4s,v2.16b,v5.4b[1]\n\t" LDQ_STEP3_##cpu(1, 7)\ + ""IDOT" %[cq21].4s,v2.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq24].4s,v2.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(0, 1)\ + ""IDOT" %[cq01].4s,v0.16b,v6.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(2, %[a_rd], -64)\ + ""IDOT" %[cq04].4s,v0.16b,v6.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -56)\ + ""IDOT" %[cq07].4s,v0.16b,v6.4b[2]\n\t" LDQ_STEP1_OFF_##cpu(4, %[b_rd], -32)\ + ""IDOT" %[cq10].4s,v0.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq13].4s,v0.16b,v7.4b[0]\n\t"\ + ""IDOT" %[cq16].4s,v0.16b,v7.4b[1]\n\t"\ + ""IDOT" %[cq19].4s,v0.16b,v7.4b[2]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -24)\ + ""IDOT" %[cq22].4s,v0.16b,v7.4b[3]\n\t" LDQ_STEP3_##cpu(0, 2)\ + ""IDOT" %[cq02].4s,v1.16b,v6.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(0, %[a_rd], -48)\ + ""IDOT" %[cq05].4s,v1.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq08].4s,v1.16b,v6.4b[2]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -40)\ + ""IDOT" %[cq11].4s,v1.16b,v6.4b[3]\n\t" LDQ_STEP3_##cpu(1, 4)\ + ""IDOT" %[cq14].4s,v1.16b,v7.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(5, %[b_rd], -16)\ + ""IDOT" %[cq17].4s,v1.16b,v7.4b[1]\n\t"\ + ""IDOT" %[cq20].4s,v1.16b,v7.4b[2]\n\t"\ + ""IDOT" %[cq23].4s,v1.16b,v7.4b[3]\n\t" LDQ_STEP3_##cpu(0, 0)\ + ""IDOT" %[cq03].4s,v2.16b,v6.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(1, %[a_rd], -32)\ + ""IDOT" %[cq06].4s,v2.16b,v6.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -24)\ + ""IDOT" %[cq09].4s,v2.16b,v6.4b[2]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -8)\ + ""IDOT" %[cq12].4s,v2.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq15].4s,v2.16b,v7.4b[0]; sub %w[kdiv4_left],%w[kdiv4_left],#2\n\t"\ + ""IDOT" %[cq18].4s,v2.16b,v7.4b[1]\n\t" LDQ_STEP3_##cpu(1, 5)\ + ""IDOT" %[cq21].4s,v2.16b,v7.4b[2]; cmp %w[kdiv4_left],#3\n\t"\ + ""IDOT" %[cq24].4s,v2.16b,v7.4b[3]\n\t" LDQ_STEP3_##cpu(0, 1)\ + "b.ge 1b; 2:\n\t"\ + "cmp %w[kdiv4_left],#2; b.lt 3f\n\t"\ + ""IDOT" %[cq01].4s,v0.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(2, %[a_rd], -16)\ + ""IDOT" %[cq04].4s,v0.16b,v4.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -8)\ + ""IDOT" %[cq07].4s,v0.16b,v4.4b[2]\n\t" LDQ_STEP1_IDX_##cpu(6, %[b_rd], 32)\ + ""IDOT" %[cq10].4s,v0.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq13].4s,v0.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq16].4s,v0.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq19].4s,v0.16b,v5.4b[2]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -24)\ + ""IDOT" %[cq22].4s,v0.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(0, 2)\ + ""IDOT" %[cq02].4s,v1.16b,v4.4b[0]\n\t" LDQ_STEP1_IDX_##cpu(0, %[a_rd], 48)\ + ""IDOT" %[cq05].4s,v1.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq08].4s,v1.16b,v4.4b[2]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -40)\ + ""IDOT" %[cq11].4s,v1.16b,v4.4b[3]\n\t" LDQ_STEP3_##cpu(1, 6)\ + ""IDOT" %[cq14].4s,v1.16b,v5.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(7, %[b_rd], -16)\ + ""IDOT" %[cq17].4s,v1.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq20].4s,v1.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq23].4s,v1.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(0, 0)\ + ""IDOT" %[cq03].4s,v2.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(1, %[a_rd], -32)\ + ""IDOT" %[cq06].4s,v2.16b,v4.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -24)\ + ""IDOT" %[cq09].4s,v2.16b,v4.4b[2]\n\t" LDQ_STEP2_##cpu(1, %[b_rd], -8)\ + ""IDOT" %[cq12].4s,v2.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq15].4s,v2.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq18].4s,v2.16b,v5.4b[1]\n\t" LDQ_STEP3_##cpu(1, 7)\ + ""IDOT" %[cq21].4s,v2.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq24].4s,v2.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(0, 1)\ + ""IDOT" %[cq01].4s,v0.16b,v6.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(2, %[a_rd], -16)\ + ""IDOT" %[cq04].4s,v0.16b,v6.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -8)\ + ""IDOT" %[cq07].4s,v0.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq10].4s,v0.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq13].4s,v0.16b,v7.4b[0]\n\t"\ + ""IDOT" %[cq16].4s,v0.16b,v7.4b[1]\n\t"\ + ""IDOT" %[cq19].4s,v0.16b,v7.4b[2]\n\t"\ + ""IDOT" %[cq22].4s,v0.16b,v7.4b[3]\n\t" LDQ_STEP3_##cpu(0, 2)\ + ""IDOT" %[cq02].4s,v1.16b,v6.4b[0]\n\t"\ + ""IDOT" %[cq05].4s,v1.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq08].4s,v1.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq11].4s,v1.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq14].4s,v1.16b,v7.4b[0]\n\t"\ + ""IDOT" %[cq17].4s,v1.16b,v7.4b[1]\n\t"\ + ""IDOT" %[cq20].4s,v1.16b,v7.4b[2]\n\t"\ + ""IDOT" %[cq23].4s,v1.16b,v7.4b[3]\n\t"\ + ""IDOT" %[cq03].4s,v2.16b,v6.4b[0]\n\t"\ + ""IDOT" %[cq06].4s,v2.16b,v6.4b[1]\n\t"\ + ""IDOT" %[cq09].4s,v2.16b,v6.4b[2]\n\t"\ + ""IDOT" %[cq12].4s,v2.16b,v6.4b[3]\n\t"\ + ""IDOT" %[cq15].4s,v2.16b,v7.4b[0]\n\t"\ + ""IDOT" %[cq18].4s,v2.16b,v7.4b[1]\n\t"\ + ""IDOT" %[cq21].4s,v2.16b,v7.4b[2]\n\t"\ + ""IDOT" %[cq24].4s,v2.16b,v7.4b[3]\n\t"\ + "b 4f; 3:\n\t"\ + ""IDOT" %[cq01].4s,v0.16b,v4.4b[0]\n\t" LDQ_STEP1_OFF_##cpu(2, %[a_rd], -16)\ + ""IDOT" %[cq04].4s,v0.16b,v4.4b[1]\n\t" LDQ_STEP2_##cpu(0, %[a_rd], -8)\ + ""IDOT" %[cq07].4s,v0.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq10].4s,v0.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq13].4s,v0.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq16].4s,v0.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq19].4s,v0.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq22].4s,v0.16b,v5.4b[3]\n\t" LDQ_STEP3_##cpu(0, 2)\ + ""IDOT" %[cq02].4s,v1.16b,v4.4b[0]\n\t"\ + ""IDOT" %[cq05].4s,v1.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq08].4s,v1.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq11].4s,v1.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq14].4s,v1.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq17].4s,v1.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq20].4s,v1.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq23].4s,v1.16b,v5.4b[3]\n\t"\ + ""IDOT" %[cq03].4s,v2.16b,v4.4b[0]\n\t"\ + ""IDOT" %[cq06].4s,v2.16b,v4.4b[1]\n\t"\ + ""IDOT" %[cq09].4s,v2.16b,v4.4b[2]\n\t"\ + ""IDOT" %[cq12].4s,v2.16b,v4.4b[3]\n\t"\ + ""IDOT" %[cq15].4s,v2.16b,v5.4b[0]\n\t"\ + ""IDOT" %[cq18].4s,v2.16b,v5.4b[1]\n\t"\ + ""IDOT" %[cq21].4s,v2.16b,v5.4b[2]\n\t"\ + ""IDOT" %[cq24].4s,v2.16b,v5.4b[3]\n\t"\ + "4:\n\t"\ + :[cq01]"=w"(cq01),[cq02]"=w"(cq02),[cq03]"=w"(cq03),[cq04]"=w"(cq04),\ + [cq05]"=w"(cq05),[cq06]"=w"(cq06),[cq07]"=w"(cq07),[cq08]"=w"(cq08),\ + [cq09]"=w"(cq09),[cq10]"=w"(cq10),[cq11]"=w"(cq11),[cq12]"=w"(cq12),\ + [cq13]"=w"(cq13),[cq14]"=w"(cq14),[cq15]"=w"(cq15),[cq16]"=w"(cq16),\ + [cq17]"=w"(cq17),[cq18]"=w"(cq18),[cq19]"=w"(cq19),[cq20]"=w"(cq20),\ + [cq21]"=w"(cq21),[cq22]"=w"(cq22),[cq23]"=w"(cq23),[cq24]"=w"(cq24),\ + [kdiv4_left]"+r"(kdiv4_left),[a_rd]"+r"(a_rd),[b_rd]"+r"(b_rd)\ + ::"cc","memory","x0","x1","v0","v1","v2","v4","v5","v6","v7"); + +#define SAVE_M12N8 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M12N2(cq01, cq02, cq03, cq04, cq05, cq06)\ + UNIT_SAVE_M12N2(cq07, cq08, cq09, cq10, cq11, cq12)\ + UNIT_SAVE_M12N2(cq13, cq14, cq15, cq16, cq17, cq18)\ + UNIT_SAVE_M12N2(cq19, cq20, cq21, cq22, cq23, cq24) + +#define NEON_IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(mdim, ndim, srcint, dstint) \ +static inline void\ + inline_dualpack_gemm_a##srcint##_b##srcint##_c##dstint##_m##mdim##_n##ndim(\ + const srcint *a_head, const srcint *b_head, dstint *c_ptr,\ + uint32_t K, dstint beta, uint32_t ldc) {\ + KERNEL_M##mdim##N##ndim\ + SAVE_M##mdim##N##ndim\ +} + +#define IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(mdim, ndim, srcint, dstint)\ + NEON_IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(mdim, ndim, srcint, dstint) + + +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 1, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 2, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 1, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 2, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 4, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 4, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 1, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 2, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 4, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 8, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 8, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 8, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 1, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 2, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 4, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 8, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 12, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 12, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 12, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 1, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 2, I32, I32) +IDOTGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 4, I32, I32) + +#endif diff --git a/include/neon_armv8a/I8I32MlaGemmKernel.h b/include/neon_armv8a/I8I32MlaGemmKernel.h new file mode 100644 index 0000000..5ed9fe6 --- /dev/null +++ b/include/neon_armv8a/I8I32MlaGemmKernel.h @@ -0,0 +1,378 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/NeonI8I32MlaGemmKernel.h" + +#ifndef INCLUDE_ARMV8A_I8I32MLA_ASM_KERNEL +#define INCLUDE_ARMV8A_I8I32MLA_ASM_KERNEL + +static inline void pref_c_8(const I32 *c) { + __asm__("prfm pstl1keep,[%0]; prfm pstl1keep,[%0,#32]\n\t"::"r"(c):); +} + +static inline void pref_c_12(const I32 *c) { + __asm__("prfm pstl1keep,[%0]; prfm pstl1keep,[%0,#48]\n\t"::"r"(c):); +} + +#define KERNEL_M8N12 \ + const I32 *c_pref = c_ptr;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref); c_pref += ldc;\ + pref_c_8(c_pref);\ + COMMON_KERNEL_HEADER(a_head, b_head)\ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + I32X4 cq09, cq10, cq11, cq12, cq13, cq14, cq15, cq16;\ + I32X4 cq17, cq18, cq19, cq20, cq21, cq22, cq23, cq24;\ + __asm__ __volatile__(\ + "movi %[cq01].16b,#0; movi %[cq02].16b,#0\n\t"\ + "movi %[cq03].16b,#0; movi %[cq04].16b,#0\n\t"\ + "movi %[cq05].16b,#0; movi %[cq06].16b,#0\n\t"\ + "movi %[cq07].16b,#0; movi %[cq08].16b,#0\n\t"\ + "movi %[cq09].16b,#0; movi %[cq10].16b,#0\n\t"\ + "movi %[cq11].16b,#0; movi %[cq12].16b,#0\n\t"\ + "movi %[cq13].16b,#0; movi %[cq14].16b,#0\n\t"\ + "movi %[cq15].16b,#0; movi %[cq16].16b,#0\n\t"\ + "movi %[cq17].16b,#0; movi %[cq18].16b,#0\n\t"\ + "movi %[cq19].16b,#0; movi %[cq20].16b,#0\n\t"\ + "movi %[cq21].16b,#0; movi %[cq22].16b,#0\n\t"\ + "movi %[cq23].16b,#0; movi %[cq24].16b,#0\n\t"\ + "cmp %w[k_left],#1; b.lt 4f\n\t"\ + "ldr q0,[%[a_ptr]],#16\n\t"\ + "ldr q2,[%[b_ptr]]; ldr d3,[%[b_ptr],#16]; add %[b_ptr],%[b_ptr],#24\n\t"\ + "cmp %w[k_left],#3; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + ""IMLAL" %[cq01].4s,v0.4h,v2.h[0]; ldr x0,[%[b_ptr]],#48\n\t"\ + ""IMLAL2" %[cq02].4s,v0.8h,v2.h[0]\n\t"\ + ""IMLAL" %[cq03].4s,v0.4h,v2.h[1]\n\t"\ + ""IMLAL2" %[cq04].4s,v0.8h,v2.h[1]\n\t"\ + ""IMLAL" %[cq05].4s,v0.4h,v2.h[2]\n\t"\ + ""IMLAL2" %[cq06].4s,v0.8h,v2.h[2]\n\t"\ + ""IMLAL" %[cq07].4s,v0.4h,v2.h[3]\n\t"\ + ""IMLAL2" %[cq08].4s,v0.8h,v2.h[3]\n\t"\ + "fmov v3.d[1],x0; ldr d1,[%[a_ptr]],#32\n\t"\ + ""IMLAL" %[cq09].4s,v0.4h,v2.h[4]\n\t"\ + ""IMLAL2" %[cq10].4s,v0.8h,v2.h[4]; ldr x0,[%[a_ptr],#-24]\n\t"\ + ""IMLAL" %[cq11].4s,v0.4h,v2.h[5]\n\t"\ + ""IMLAL2" %[cq12].4s,v0.8h,v2.h[5]\n\t"\ + ""IMLAL" %[cq13].4s,v0.4h,v2.h[6]\n\t"\ + ""IMLAL2" %[cq14].4s,v0.8h,v2.h[6]\n\t"\ + ""IMLAL" %[cq15].4s,v0.4h,v2.h[7]\n\t"\ + ""IMLAL2" %[cq16].4s,v0.8h,v2.h[7]\n\t"\ + "fmov v1.d[1],x0; ldr d4,[%[b_ptr],#-40]\n\t"\ + ""IMLAL" %[cq17].4s,v0.4h,v3.h[0]; ldr x0,[%[b_ptr],#-32]\n\t"\ + ""IMLAL2" %[cq18].4s,v0.8h,v3.h[0]\n\t"\ + ""IMLAL" %[cq19].4s,v0.4h,v3.h[1]\n\t"\ + ""IMLAL2" %[cq20].4s,v0.8h,v3.h[1]\n\t"\ + ""IMLAL" %[cq21].4s,v0.4h,v3.h[2]\n\t"\ + ""IMLAL2" %[cq22].4s,v0.8h,v3.h[2]\n\t"\ + ""IMLAL" %[cq23].4s,v0.4h,v3.h[3]\n\t"\ + ""IMLAL2" %[cq24].4s,v0.8h,v3.h[3]\n\t"\ + "fmov v4.d[1],x0; ldr d0,[%[a_ptr],#-16]\n\t"\ + ""IMLAL" %[cq01].4s,v1.4h,v3.h[4]; ldr x0,[%[a_ptr],#-8]\n\t"\ + ""IMLAL2" %[cq02].4s,v1.8h,v3.h[4]\n\t"\ + ""IMLAL" %[cq03].4s,v1.4h,v3.h[5]\n\t"\ + ""IMLAL2" %[cq04].4s,v1.8h,v3.h[5]\n\t"\ + ""IMLAL" %[cq05].4s,v1.4h,v3.h[6]\n\t"\ + ""IMLAL2" %[cq06].4s,v1.8h,v3.h[6]\n\t"\ + ""IMLAL" %[cq07].4s,v1.4h,v3.h[7]\n\t"\ + ""IMLAL2" %[cq08].4s,v1.8h,v3.h[7]\n\t"\ + "fmov v0.d[1],x0; ldr d2,[%[b_ptr],#-24]\n\t"\ + ""IMLAL" %[cq09].4s,v1.4h,v4.h[0]; ldr x0,[%[b_ptr],#-16]\n\t"\ + ""IMLAL2" %[cq10].4s,v1.8h,v4.h[0]\n\t"\ + ""IMLAL" %[cq11].4s,v1.4h,v4.h[1]\n\t"\ + ""IMLAL2" %[cq12].4s,v1.8h,v4.h[1]\n\t"\ + ""IMLAL" %[cq13].4s,v1.4h,v4.h[2]\n\t"\ + ""IMLAL2" %[cq14].4s,v1.8h,v4.h[2]\n\t"\ + ""IMLAL" %[cq15].4s,v1.4h,v4.h[3]\n\t"\ + ""IMLAL2" %[cq16].4s,v1.8h,v4.h[3]\n\t"\ + "fmov v2.d[1],x0; ldr d3,[%[b_ptr],#-8]\n\t"\ + ""IMLAL" %[cq17].4s,v1.4h,v4.h[4]\n\t"\ + ""IMLAL2" %[cq18].4s,v1.8h,v4.h[4]\n\t"\ + ""IMLAL" %[cq19].4s,v1.4h,v4.h[5]\n\t"\ + ""IMLAL2" %[cq20].4s,v1.8h,v4.h[5]; sub %w[k_left],%w[k_left],#2\n\t"\ + ""IMLAL" %[cq21].4s,v1.4h,v4.h[6]\n\t"\ + ""IMLAL2" %[cq22].4s,v1.8h,v4.h[6]; cmp %w[k_left],#3\n\t"\ + ""IMLAL" %[cq23].4s,v1.4h,v4.h[7]\n\t"\ + ""IMLAL2" %[cq24].4s,v1.8h,v4.h[7]; b.ge 1b\n\t"\ + "2:\n\t"\ + "cmp %w[k_left],#2; b.lt 3f\n\t"\ + ""IMLAL" %[cq01].4s,v0.4h,v2.h[0]; ldr x0,[%[b_ptr]],#24\n\t"\ + ""IMLAL2" %[cq02].4s,v0.8h,v2.h[0]\n\t"\ + ""IMLAL" %[cq03].4s,v0.4h,v2.h[1]\n\t"\ + ""IMLAL2" %[cq04].4s,v0.8h,v2.h[1]\n\t"\ + ""IMLAL" %[cq05].4s,v0.4h,v2.h[2]\n\t"\ + ""IMLAL2" %[cq06].4s,v0.8h,v2.h[2]\n\t"\ + ""IMLAL" %[cq07].4s,v0.4h,v2.h[3]\n\t"\ + ""IMLAL2" %[cq08].4s,v0.8h,v2.h[3]\n\t"\ + "fmov v3.d[1],x0; ldr d1,[%[a_ptr]],#16\n\t"\ + ""IMLAL" %[cq09].4s,v0.4h,v2.h[4]\n\t"\ + ""IMLAL2" %[cq10].4s,v0.8h,v2.h[4]; ldr x0,[%[a_ptr],#-8]\n\t"\ + ""IMLAL" %[cq11].4s,v0.4h,v2.h[5]\n\t"\ + ""IMLAL2" %[cq12].4s,v0.8h,v2.h[5]\n\t"\ + ""IMLAL" %[cq13].4s,v0.4h,v2.h[6]\n\t"\ + ""IMLAL2" %[cq14].4s,v0.8h,v2.h[6]\n\t"\ + ""IMLAL" %[cq15].4s,v0.4h,v2.h[7]\n\t"\ + ""IMLAL2" %[cq16].4s,v0.8h,v2.h[7]\n\t"\ + "fmov v1.d[1],x0; ldr d4,[%[b_ptr],#-16]\n\t"\ + ""IMLAL" %[cq17].4s,v0.4h,v3.h[0]; ldr x0,[%[b_ptr],#-8]\n\t"\ + ""IMLAL2" %[cq18].4s,v0.8h,v3.h[0]\n\t"\ + ""IMLAL" %[cq19].4s,v0.4h,v3.h[1]\n\t"\ + ""IMLAL2" %[cq20].4s,v0.8h,v3.h[1]\n\t"\ + ""IMLAL" %[cq21].4s,v0.4h,v3.h[2]\n\t"\ + ""IMLAL2" %[cq22].4s,v0.8h,v3.h[2]\n\t"\ + ""IMLAL" %[cq23].4s,v0.4h,v3.h[3]\n\t"\ + ""IMLAL2" %[cq24].4s,v0.8h,v3.h[3]\n\t"\ + "fmov v4.d[1],x0\n\t"\ + ""IMLAL" %[cq01].4s,v1.4h,v3.h[4]\n\t"\ + ""IMLAL2" %[cq02].4s,v1.8h,v3.h[4]\n\t"\ + ""IMLAL" %[cq03].4s,v1.4h,v3.h[5]\n\t"\ + ""IMLAL2" %[cq04].4s,v1.8h,v3.h[5]\n\t"\ + ""IMLAL" %[cq05].4s,v1.4h,v3.h[6]\n\t"\ + ""IMLAL2" %[cq06].4s,v1.8h,v3.h[6]\n\t"\ + ""IMLAL" %[cq07].4s,v1.4h,v3.h[7]\n\t"\ + ""IMLAL2" %[cq08].4s,v1.8h,v3.h[7]\n\t"\ + ""IMLAL" %[cq09].4s,v1.4h,v4.h[0]\n\t"\ + ""IMLAL2" %[cq10].4s,v1.8h,v4.h[0]\n\t"\ + ""IMLAL" %[cq11].4s,v1.4h,v4.h[1]\n\t"\ + ""IMLAL2" %[cq12].4s,v1.8h,v4.h[1]\n\t"\ + ""IMLAL" %[cq13].4s,v1.4h,v4.h[2]\n\t"\ + ""IMLAL2" %[cq14].4s,v1.8h,v4.h[2]\n\t"\ + ""IMLAL" %[cq15].4s,v1.4h,v4.h[3]\n\t"\ + ""IMLAL2" %[cq16].4s,v1.8h,v4.h[3]\n\t"\ + ""IMLAL" %[cq17].4s,v1.4h,v4.h[4]\n\t"\ + ""IMLAL2" %[cq18].4s,v1.8h,v4.h[4]\n\t"\ + ""IMLAL" %[cq19].4s,v1.4h,v4.h[5]\n\t"\ + ""IMLAL2" %[cq20].4s,v1.8h,v4.h[5]; sub %w[k_left],%w[k_left],#2\n\t"\ + ""IMLAL" %[cq21].4s,v1.4h,v4.h[6]\n\t"\ + ""IMLAL2" %[cq22].4s,v1.8h,v4.h[6]\n\t"\ + ""IMLAL" %[cq23].4s,v1.4h,v4.h[7]\n\t"\ + ""IMLAL2" %[cq24].4s,v1.8h,v4.h[7]; b 4f\n\t"\ + "3:\n\t"\ + ""IMLAL" %[cq01].4s,v0.4h,v2.h[0]; "IMLAL2" %[cq02].4s,v0.8h,v2.h[0]\n\t"\ + ""IMLAL" %[cq03].4s,v0.4h,v2.h[1]; "IMLAL2" %[cq04].4s,v0.8h,v2.h[1]\n\t"\ + ""IMLAL" %[cq05].4s,v0.4h,v2.h[2]; "IMLAL2" %[cq06].4s,v0.8h,v2.h[2]\n\t"\ + ""IMLAL" %[cq07].4s,v0.4h,v2.h[3]; "IMLAL2" %[cq08].4s,v0.8h,v2.h[3]\n\t"\ + ""IMLAL" %[cq09].4s,v0.4h,v2.h[4]; "IMLAL2" %[cq10].4s,v0.8h,v2.h[4]\n\t"\ + ""IMLAL" %[cq11].4s,v0.4h,v2.h[5]; "IMLAL2" %[cq12].4s,v0.8h,v2.h[5]\n\t"\ + ""IMLAL" %[cq13].4s,v0.4h,v2.h[6]; "IMLAL2" %[cq14].4s,v0.8h,v2.h[6]\n\t"\ + ""IMLAL" %[cq15].4s,v0.4h,v2.h[7]; "IMLAL2" %[cq16].4s,v0.8h,v2.h[7]\n\t"\ + ""IMLAL" %[cq17].4s,v0.4h,v3.h[0]; "IMLAL2" %[cq18].4s,v0.8h,v3.h[0]\n\t"\ + ""IMLAL" %[cq19].4s,v0.4h,v3.h[1]; "IMLAL2" %[cq20].4s,v0.8h,v3.h[1]\n\t"\ + ""IMLAL" %[cq21].4s,v0.4h,v3.h[2]; "IMLAL2" %[cq22].4s,v0.8h,v3.h[2]\n\t"\ + ""IMLAL" %[cq23].4s,v0.4h,v3.h[3]; "IMLAL2" %[cq24].4s,v0.8h,v3.h[3]\n\t"\ + "sub %w[k_left],%w[k_left],#1\n\t"\ + "4:\n\t"\ + :[a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr), [k_left]"+r"(k_left),\ + [cq01]"=w"(cq01), [cq02]"=w"(cq02), [cq03]"=w"(cq03), [cq04]"=w"(cq04),\ + [cq05]"=w"(cq05), [cq06]"=w"(cq06), [cq07]"=w"(cq07), [cq08]"=w"(cq08),\ + [cq09]"=w"(cq09), [cq10]"=w"(cq10), [cq11]"=w"(cq11), [cq12]"=w"(cq12),\ + [cq13]"=w"(cq13), [cq14]"=w"(cq14), [cq15]"=w"(cq15), [cq16]"=w"(cq16),\ + [cq17]"=w"(cq17), [cq18]"=w"(cq18), [cq19]"=w"(cq19), [cq20]"=w"(cq20),\ + [cq21]"=w"(cq21), [cq22]"=w"(cq22), [cq23]"=w"(cq23), [cq24]"=w"(cq24)\ + ::"cc","memory","x0","v0","v1","v2","v3","v4"); + +#define SAVE_M8N12 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M8N2(cq01, cq02, cq03, cq04)\ + UNIT_SAVE_M8N2(cq05, cq06, cq07, cq08)\ + UNIT_SAVE_M8N2(cq09, cq10, cq11, cq12)\ + UNIT_SAVE_M8N2(cq13, cq14, cq15, cq16)\ + UNIT_SAVE_M8N2(cq17, cq18, cq19, cq20)\ + UNIT_SAVE_M8N2(cq21, cq22, cq23, cq24) + +#define KERNEL_M12N8 \ + const I32 *c_pref = c_ptr;\ + pref_c_12(c_pref); c_pref += ldc;\ + pref_c_12(c_pref); c_pref += ldc;\ + pref_c_12(c_pref); c_pref += ldc;\ + pref_c_12(c_pref); c_pref += ldc;\ + pref_c_12(c_pref); c_pref += ldc;\ + pref_c_12(c_pref); c_pref += ldc;\ + pref_c_12(c_pref); c_pref += ldc;\ + pref_c_12(c_pref);\ + COMMON_KERNEL_HEADER(a_head, b_head)\ + I32X4 cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + I32X4 cq09, cq10, cq11, cq12, cq13, cq14, cq15, cq16;\ + I32X4 cq17, cq18, cq19, cq20, cq21, cq22, cq23, cq24;\ + __asm__ __volatile__(\ + "movi %[cq01].16b,#0; movi %[cq02].16b,#0\n\t"\ + "movi %[cq03].16b,#0; movi %[cq04].16b,#0\n\t"\ + "movi %[cq05].16b,#0; movi %[cq06].16b,#0\n\t"\ + "movi %[cq07].16b,#0; movi %[cq08].16b,#0\n\t"\ + "movi %[cq09].16b,#0; movi %[cq10].16b,#0\n\t"\ + "movi %[cq11].16b,#0; movi %[cq12].16b,#0\n\t"\ + "movi %[cq13].16b,#0; movi %[cq14].16b,#0\n\t"\ + "movi %[cq15].16b,#0; movi %[cq16].16b,#0\n\t"\ + "movi %[cq17].16b,#0; movi %[cq18].16b,#0\n\t"\ + "movi %[cq19].16b,#0; movi %[cq20].16b,#0\n\t"\ + "movi %[cq21].16b,#0; movi %[cq22].16b,#0\n\t"\ + "movi %[cq23].16b,#0; movi %[cq24].16b,#0\n\t"\ + "cmp %w[k_left],#1; b.lt 4f\n\t"\ + "ldr q0,[%[a_ptr]]; ldr d1,[%[a_ptr],#16]; add %[a_ptr],%[a_ptr],#24\n\t"\ + "ldr q3,[%[b_ptr]],#16\n\t"\ + "cmp %w[k_left],#3; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + ""IMLAL" %[cq01].4s,v0.4h,v3.h[0]; ldr x0,[%[a_ptr]],#48\n\t"\ + ""IMLAL" %[cq04].4s,v0.4h,v3.h[1]\n\t"\ + ""IMLAL" %[cq07].4s,v0.4h,v3.h[2]\n\t"\ + ""IMLAL" %[cq10].4s,v0.4h,v3.h[3]\n\t"\ + ""IMLAL" %[cq13].4s,v0.4h,v3.h[4]\n\t"\ + ""IMLAL" %[cq16].4s,v0.4h,v3.h[5]\n\t"\ + ""IMLAL" %[cq19].4s,v0.4h,v3.h[6]\n\t"\ + ""IMLAL" %[cq22].4s,v0.4h,v3.h[7]\n\t"\ + "fmov v1.d[1],x0; ldr d4,[%[b_ptr]],#32\n\t"\ + ""IMLAL2" %[cq02].4s,v0.8h,v3.h[0]\n\t"\ + ""IMLAL2" %[cq05].4s,v0.8h,v3.h[1]; ldr x0,[%[b_ptr],#-24]\n\t"\ + ""IMLAL2" %[cq08].4s,v0.8h,v3.h[2]\n\t"\ + ""IMLAL2" %[cq11].4s,v0.8h,v3.h[3]\n\t"\ + ""IMLAL2" %[cq14].4s,v0.8h,v3.h[4]\n\t"\ + ""IMLAL2" %[cq17].4s,v0.8h,v3.h[5]\n\t"\ + ""IMLAL2" %[cq20].4s,v0.8h,v3.h[6]\n\t"\ + ""IMLAL2" %[cq23].4s,v0.8h,v3.h[7]\n\t"\ + "fmov v4.d[1],x0; ldr d2,[%[a_ptr],#-40]\n\t"\ + ""IMLAL" %[cq03].4s,v1.4h,v3.h[0]; ldr x0,[%[a_ptr],#-32]\n\t"\ + ""IMLAL" %[cq06].4s,v1.4h,v3.h[1]\n\t"\ + ""IMLAL" %[cq09].4s,v1.4h,v3.h[2]\n\t"\ + ""IMLAL" %[cq12].4s,v1.4h,v3.h[3]\n\t"\ + ""IMLAL" %[cq15].4s,v1.4h,v3.h[4]\n\t"\ + ""IMLAL" %[cq18].4s,v1.4h,v3.h[5]\n\t"\ + ""IMLAL" %[cq21].4s,v1.4h,v3.h[6]\n\t"\ + ""IMLAL" %[cq24].4s,v1.4h,v3.h[7]\n\t"\ + "fmov v2.d[1],x0; ldr d3,[%[b_ptr],#-16]\n\t"\ + ""IMLAL2" %[cq01].4s,v1.8h,v4.h[0]; ldr x0,[%[b_ptr],#-8]\n\t"\ + ""IMLAL2" %[cq04].4s,v1.8h,v4.h[1]\n\t"\ + ""IMLAL2" %[cq07].4s,v1.8h,v4.h[2]\n\t"\ + ""IMLAL2" %[cq10].4s,v1.8h,v4.h[3]\n\t"\ + ""IMLAL2" %[cq13].4s,v1.8h,v4.h[4]\n\t"\ + ""IMLAL2" %[cq16].4s,v1.8h,v4.h[5]\n\t"\ + ""IMLAL2" %[cq19].4s,v1.8h,v4.h[6]\n\t"\ + ""IMLAL2" %[cq22].4s,v1.8h,v4.h[7]\n\t"\ + "fmov v3.d[1],x0; ldr d0,[%[a_ptr],#-24]\n\t"\ + ""IMLAL" %[cq02].4s,v2.4h,v4.h[0]; ldr x0,[%[a_ptr],#-16]\n\t"\ + ""IMLAL" %[cq05].4s,v2.4h,v4.h[1]\n\t"\ + ""IMLAL" %[cq08].4s,v2.4h,v4.h[2]\n\t"\ + ""IMLAL" %[cq11].4s,v2.4h,v4.h[3]\n\t"\ + ""IMLAL" %[cq14].4s,v2.4h,v4.h[4]\n\t"\ + ""IMLAL" %[cq17].4s,v2.4h,v4.h[5]\n\t"\ + ""IMLAL" %[cq20].4s,v2.4h,v4.h[6]\n\t"\ + ""IMLAL" %[cq23].4s,v2.4h,v4.h[7]\n\t"\ + "fmov v0.d[1],x0; ldr d1,[%[a_ptr],#-8]\n\t"\ + ""IMLAL2" %[cq03].4s,v2.8h,v4.h[0]\n\t"\ + ""IMLAL2" %[cq06].4s,v2.8h,v4.h[1]\n\t"\ + ""IMLAL2" %[cq09].4s,v2.8h,v4.h[2]\n\t"\ + ""IMLAL2" %[cq12].4s,v2.8h,v4.h[3]; sub %w[k_left],%w[k_left],#2\n\t"\ + ""IMLAL2" %[cq15].4s,v2.8h,v4.h[4]\n\t"\ + ""IMLAL2" %[cq18].4s,v2.8h,v4.h[5]; cmp %w[k_left],#3\n\t"\ + ""IMLAL2" %[cq21].4s,v2.8h,v4.h[6]\n\t"\ + ""IMLAL2" %[cq24].4s,v2.8h,v4.h[7]; b.ge 1b\n\t"\ + "2:\n\t"\ + "cmp %w[k_left],#2; b.lt 3f\n\t"\ + ""IMLAL" %[cq01].4s,v0.4h,v3.h[0]; ldr x0,[%[a_ptr]],#24\n\t"\ + ""IMLAL" %[cq04].4s,v0.4h,v3.h[1]\n\t"\ + ""IMLAL" %[cq07].4s,v0.4h,v3.h[2]\n\t"\ + ""IMLAL" %[cq10].4s,v0.4h,v3.h[3]\n\t"\ + ""IMLAL" %[cq13].4s,v0.4h,v3.h[4]\n\t"\ + ""IMLAL" %[cq16].4s,v0.4h,v3.h[5]\n\t"\ + ""IMLAL" %[cq19].4s,v0.4h,v3.h[6]\n\t"\ + ""IMLAL" %[cq22].4s,v0.4h,v3.h[7]\n\t"\ + "fmov v1.d[1],x0; ldr d4,[%[b_ptr]],#16\n\t"\ + ""IMLAL2" %[cq02].4s,v0.8h,v3.h[0]\n\t"\ + ""IMLAL2" %[cq05].4s,v0.8h,v3.h[1]; ldr x0,[%[b_ptr],#-8]\n\t"\ + ""IMLAL2" %[cq08].4s,v0.8h,v3.h[2]\n\t"\ + ""IMLAL2" %[cq11].4s,v0.8h,v3.h[3]\n\t"\ + ""IMLAL2" %[cq14].4s,v0.8h,v3.h[4]\n\t"\ + ""IMLAL2" %[cq17].4s,v0.8h,v3.h[5]\n\t"\ + ""IMLAL2" %[cq20].4s,v0.8h,v3.h[6]\n\t"\ + ""IMLAL2" %[cq23].4s,v0.8h,v3.h[7]\n\t"\ + "fmov v4.d[1],x0; ldr d2,[%[a_ptr],#-16]\n\t"\ + ""IMLAL" %[cq03].4s,v1.4h,v3.h[0]; ldr x0,[%[a_ptr],#-8]\n\t"\ + ""IMLAL" %[cq06].4s,v1.4h,v3.h[1]\n\t"\ + ""IMLAL" %[cq09].4s,v1.4h,v3.h[2]\n\t"\ + ""IMLAL" %[cq12].4s,v1.4h,v3.h[3]\n\t"\ + ""IMLAL" %[cq15].4s,v1.4h,v3.h[4]\n\t"\ + ""IMLAL" %[cq18].4s,v1.4h,v3.h[5]\n\t"\ + ""IMLAL" %[cq21].4s,v1.4h,v3.h[6]\n\t"\ + ""IMLAL" %[cq24].4s,v1.4h,v3.h[7]\n\t"\ + "fmov v2.d[1],x0\n\t"\ + ""IMLAL2" %[cq01].4s,v1.8h,v4.h[0]\n\t"\ + ""IMLAL2" %[cq04].4s,v1.8h,v4.h[1]\n\t"\ + ""IMLAL2" %[cq07].4s,v1.8h,v4.h[2]\n\t"\ + ""IMLAL2" %[cq10].4s,v1.8h,v4.h[3]\n\t"\ + ""IMLAL2" %[cq13].4s,v1.8h,v4.h[4]\n\t"\ + ""IMLAL2" %[cq16].4s,v1.8h,v4.h[5]\n\t"\ + ""IMLAL2" %[cq19].4s,v1.8h,v4.h[6]\n\t"\ + ""IMLAL2" %[cq22].4s,v1.8h,v4.h[7]\n\t"\ + ""IMLAL" %[cq02].4s,v2.4h,v4.h[0]\n\t"\ + ""IMLAL" %[cq05].4s,v2.4h,v4.h[1]\n\t"\ + ""IMLAL" %[cq08].4s,v2.4h,v4.h[2]\n\t"\ + ""IMLAL" %[cq11].4s,v2.4h,v4.h[3]\n\t"\ + ""IMLAL" %[cq14].4s,v2.4h,v4.h[4]\n\t"\ + ""IMLAL" %[cq17].4s,v2.4h,v4.h[5]\n\t"\ + ""IMLAL" %[cq20].4s,v2.4h,v4.h[6]\n\t"\ + ""IMLAL" %[cq23].4s,v2.4h,v4.h[7]\n\t"\ + ""IMLAL2" %[cq03].4s,v2.8h,v4.h[0]\n\t"\ + ""IMLAL2" %[cq06].4s,v2.8h,v4.h[1]\n\t"\ + ""IMLAL2" %[cq09].4s,v2.8h,v4.h[2]\n\t"\ + ""IMLAL2" %[cq12].4s,v2.8h,v4.h[3]; sub %w[k_left],%w[k_left],#2\n\t"\ + ""IMLAL2" %[cq15].4s,v2.8h,v4.h[4]\n\t"\ + ""IMLAL2" %[cq18].4s,v2.8h,v4.h[5]\n\t"\ + ""IMLAL2" %[cq21].4s,v2.8h,v4.h[6]\n\t"\ + ""IMLAL2" %[cq24].4s,v2.8h,v4.h[7]; b 4f\n\t"\ + "3:\n\t"\ + ""IMLAL" %[cq01].4s,v0.4h,v3.h[0]; "IMLAL" %[cq04].4s,v0.4h,v3.h[1]\n\t"\ + ""IMLAL" %[cq07].4s,v0.4h,v3.h[2]; "IMLAL" %[cq10].4s,v0.4h,v3.h[3]\n\t"\ + ""IMLAL" %[cq13].4s,v0.4h,v3.h[4]; "IMLAL" %[cq16].4s,v0.4h,v3.h[5]\n\t"\ + ""IMLAL" %[cq19].4s,v0.4h,v3.h[6]; "IMLAL" %[cq22].4s,v0.4h,v3.h[7]\n\t"\ + ""IMLAL2" %[cq02].4s,v0.8h,v3.h[0]; "IMLAL2" %[cq05].4s,v0.8h,v3.h[1]\n\t"\ + ""IMLAL2" %[cq08].4s,v0.8h,v3.h[2]; "IMLAL2" %[cq11].4s,v0.8h,v3.h[3]\n\t"\ + ""IMLAL2" %[cq14].4s,v0.8h,v3.h[4]; "IMLAL2" %[cq17].4s,v0.8h,v3.h[5]\n\t"\ + ""IMLAL2" %[cq20].4s,v0.8h,v3.h[6]; "IMLAL2" %[cq23].4s,v0.8h,v3.h[7]\n\t"\ + ""IMLAL" %[cq03].4s,v1.4h,v3.h[0]; "IMLAL" %[cq06].4s,v1.4h,v3.h[1]\n\t"\ + ""IMLAL" %[cq09].4s,v1.4h,v3.h[2]; "IMLAL" %[cq12].4s,v1.4h,v3.h[3]\n\t"\ + ""IMLAL" %[cq15].4s,v1.4h,v3.h[4]; "IMLAL" %[cq18].4s,v1.4h,v3.h[5]\n\t"\ + ""IMLAL" %[cq21].4s,v1.4h,v3.h[6]; "IMLAL" %[cq24].4s,v1.4h,v3.h[7]\n\t"\ + "sub %w[k_left],%w[k_left],#1\n\t"\ + "4:\n\t"\ + :[a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr), [k_left]"+r"(k_left),\ + [cq01]"=w"(cq01), [cq02]"=w"(cq02), [cq03]"=w"(cq03), [cq04]"=w"(cq04),\ + [cq05]"=w"(cq05), [cq06]"=w"(cq06), [cq07]"=w"(cq07), [cq08]"=w"(cq08),\ + [cq09]"=w"(cq09), [cq10]"=w"(cq10), [cq11]"=w"(cq11), [cq12]"=w"(cq12),\ + [cq13]"=w"(cq13), [cq14]"=w"(cq14), [cq15]"=w"(cq15), [cq16]"=w"(cq16),\ + [cq17]"=w"(cq17), [cq18]"=w"(cq18), [cq19]"=w"(cq19), [cq20]"=w"(cq20),\ + [cq21]"=w"(cq21), [cq22]"=w"(cq22), [cq23]"=w"(cq23), [cq24]"=w"(cq24)\ + ::"cc","memory","x0","v0","v1","v2","v3","v4"); + +#define SAVE_M12N8 \ + I32 *c_tmp = c_ptr;\ + UNIT_SAVE_M12N2(cq01, cq02, cq03, cq04, cq05, cq06)\ + UNIT_SAVE_M12N2(cq07, cq08, cq09, cq10, cq11, cq12)\ + UNIT_SAVE_M12N2(cq13, cq14, cq15, cq16, cq17, cq18)\ + UNIT_SAVE_M12N2(cq19, cq20, cq21, cq22, cq23, cq24) + +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(12, 8, I16, I32) +IMLAGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 12, I16, I32) + +#endif diff --git a/include/neon_armv8a/I8I32MlaGemmSkinnyDot.h b/include/neon_armv8a/I8I32MlaGemmSkinnyDot.h new file mode 100644 index 0000000..049a1ca --- /dev/null +++ b/include/neon_armv8a/I8I32MlaGemmSkinnyDot.h @@ -0,0 +1,501 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/NeonI8I32MlaGemmSkinnyDot.h" + +#ifndef INCLUDE_ARMV8_I8I32_SKINNYDOT_ASM +#define INCLUDE_ARMV8_I8I32_SKINNYDOT_ASM + +#define I8I32MLA_SKINNYDOT_INLINE_M4N1(gemm) \ +static inline void inline_##gemm##_arowmajor_bskinny_m4n1(\ + const I8 *a_ptr1, const I8 *b_ptr, I32 *c_ptr,\ + uint32_t k_left, uint32_t LDK, uint32_t LDM,\ + I32 beta, bool c_rowmajor) {\ +\ + const I8 *a_ptr2 = a_ptr1 + LDK;\ + const I8 *a_ptr3 = a_ptr1 + LDK * 2;\ + const I8 *a_ptr4 = a_ptr2 + LDK * 2;\ + I32X2 cd1, cd2;\ + const uint32_t next_pref = (LDK * 4 - k_left) + 16;\ + __asm__ __volatile__ (\ + "movi v8.16b,#0; movi v9.16b,#0; movi v10.16b,#0; movi v11.16b,#0\n\t"\ + "cmp %w[k_left],#16; b.lt 3f\n\t"\ + "ldr q0,[%[a_ptr1]],#16; ldr q1,[%[a_ptr2]],#16\n\t"\ + "ldr q2,[%[a_ptr3]],#16; ldr q3,[%[a_ptr4]],#16\n\t"\ + "ldr q4,[%[b_ptr]],#16\n\t"\ + "cmp %w[k_left],#32; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + ""IMULL" v12.8h,v0.8b,v4.8b; prfm pldl1keep,[%[a_ptr1],#64]\n\t"\ + ""IMULL" v13.8h,v1.8b,v4.8b; prfm pldl1keep,[%[a_ptr2],#64]\n\t"\ + ""IMULL" v14.8h,v2.8b,v4.8b; prfm pldl1keep,[%[a_ptr3],#64]\n\t"\ + ""IMULL" v15.8h,v3.8b,v4.8b; prfm pldl1keep,[%[a_ptr4],#64]\n\t"\ + ""IADALP" v8.4s,v12.8h; "IMULL"2 v12.8h,v0.16b,v4.16b\n\t"\ + "ldr q0,[%[a_ptr1]],#16\n\t"\ + ""IADALP" v9.4s,v13.8h; "IMULL"2 v13.8h,v1.16b,v4.16b\n\t"\ + "ldr q1,[%[a_ptr2]],#16\n\t"\ + ""IADALP" v10.4s,v14.8h; "IMULL"2 v14.8h,v2.16b,v4.16b\n\t"\ + "ldr q2,[%[a_ptr3]],#16\n\t"\ + ""IADALP" v11.4s,v15.8h; "IMULL"2 v15.8h,v3.16b,v4.16b\n\t"\ + "ldr q3,[%[a_ptr4]],#16\n\t"\ + "ldr q4,[%[b_ptr]],#16\n\t"\ + ""IADALP" v8.4s,v12.8h; sub %w[k_left],%w[k_left],#16\n\t"\ + ""IADALP" v9.4s,v13.8h\n\t"\ + ""IADALP" v10.4s,v14.8h; cmp %w[k_left],#32\n\t"\ + ""IADALP" v11.4s,v15.8h; b.ge 1b\n\t"\ + "2:\n\t"\ + ""IMULL" v12.8h,v0.8b,v4.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr1],%w[next_pref],SXTW #0]\n\t"\ + ""IMULL" v13.8h,v1.8b,v4.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr2],%w[next_pref],SXTW #0]\n\t"\ + ""IMULL" v14.8h,v2.8b,v4.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr3],%w[next_pref],SXTW #0]\n\t"\ + ""IMULL" v15.8h,v3.8b,v4.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr4],%w[next_pref],SXTW #0]\n\t"\ + ""IADALP" v8.4s,v12.8h; "IMULL"2 v12.8h,v0.16b,v4.16b\n\t"\ + ""IADALP" v9.4s,v13.8h; "IMULL"2 v13.8h,v1.16b,v4.16b\n\t"\ + ""IADALP" v10.4s,v14.8h; "IMULL"2 v14.8h,v2.16b,v4.16b\n\t"\ + ""IADALP" v11.4s,v15.8h; "IMULL"2 v15.8h,v3.16b,v4.16b\n\t"\ + ""IADALP" v8.4s,v12.8h; "IADALP" v9.4s,v13.8h\n\t"\ + "sub %w[k_left],%w[k_left],#16\n\t"\ + ""IADALP" v10.4s,v14.8h; "IADALP" v11.4s,v15.8h\n\t"\ + "3:\n\t"\ + "cmp %w[k_left],#8; b.lt 4f\n\t"\ + "ldr d0,[%[a_ptr1]],#8; ldr d1,[%[a_ptr2]],#8\n\t"\ + "ldr d2,[%[a_ptr3]],#8; ldr d3,[%[a_ptr4]],#8\n\t"\ + "ldr d4,[%[b_ptr]],#8; sub %w[k_left],%w[k_left],#8\n\t"\ + ""IMULL" v12.8h,v0.8b,v4.8b; "IMULL" v13.8h,v1.8b,v4.8b\n\t"\ + ""IMULL" v14.8h,v2.8b,v4.8b; "IMULL" v15.8h,v3.8b,v4.8b\n\t"\ + ""IADALP" v8.4s,v12.8h; "IADALP" v9.4s,v13.8h\n\t"\ + ""IADALP" v10.4s,v14.8h; "IADALP" v11.4s,v15.8h\n\t"\ + "4:\n\t"\ + "movi v12.16b,#0\n\t"\ + "addp v8.4s,v8.4s,v12.4s; addp v9.4s,v9.4s,v12.4s\n\t"\ + "addp v10.4s,v10.4s,v12.4s; addp v11.4s,v11.4s,v12.4s\n\t"\ + "cmp %w[k_left],#4; b.lt 5f\n\t"\ + "ldr s0,[%[a_ptr1]],#4; ldr s1,[%[a_ptr2]],#4\n\t"\ + "ldr s2,[%[a_ptr3]],#4; ldr s3,[%[a_ptr4]],#4\n\t"\ + "ldr s4,[%[b_ptr]],#4; sub %w[k_left],%w[k_left],#4\n\t"\ + ""IMULL" v12.8h,v0.8b,v4.8b; "IMULL" v13.8h,v1.8b,v4.8b\n\t"\ + ""IMULL" v14.8h,v2.8b,v4.8b; "IMULL" v15.8h,v3.8b,v4.8b\n\t"\ + ""IADALP" v8.2s,v12.4h; "IADALP" v9.2s,v13.4h\n\t"\ + ""IADALP" v10.2s,v14.4h; "IADALP" v11.2s,v15.4h\n\t"\ + "5:\n\t"\ + "cmp %w[k_left],#2; b.lt 6f\n\t"\ + "ldr h0,[%[a_ptr1]],#2; ldr h1,[%[a_ptr2]],#2\n\t"\ + "ldr h2,[%[a_ptr3]],#2; ldr h3,[%[a_ptr4]],#2\n\t"\ + "ldr h4,[%[b_ptr]],#2; sub %w[k_left],%w[k_left],#2\n\t"\ + ""IXTL" v0.8h,v0.8b; "IXTL" v1.8h,v1.8b\n\t"\ + ""IXTL" v2.8h,v2.8b; "IXTL" v3.8h,v3.8b; "IXTL" v4.8h,v4.8b\n\t"\ + ""IMLAL" v8.4s,v0.4h,v4.4h; "IMLAL" v9.4s,v1.4h,v4.4h\n\t"\ + ""IMLAL" v10.4s,v2.4h,v4.4h; "IMLAL" v11.4s,v3.4h,v4.4h\n\t"\ + "6:\n\t"\ + "addp %[cd1].2s,v8.2s,v9.2s; addp %[cd2].2s,v10.2s,v11.2s\n\t"\ + "cmp %w[k_left],#1; b.lt 7f\n\t"\ + "ldr b0,[%[a_ptr1]],#1; ldr b1,[%[a_ptr2]],#1\n\t"\ + "ldr b2,[%[a_ptr3]],#1; ldr b3,[%[a_ptr4]],#1\n\t"\ + "ldr b4,[%[b_ptr]],#1; sub %w[k_left],%w[k_left],#1\n\t"\ + "ins v0.b[1],v1.b[0]; ins v2.b[1],v3.b[0]\n\t"\ + ""IXTL" v0.8h,v0.8b; "IXTL" v2.8h,v2.8b; "IXTL" v4.8h,v4.8b\n\t"\ + ""IMLAL" %[cd1].4s,v0.4h,v4.h[0]; "IMLAL" %[cd2].4s,v2.4h,v4.h[0]\n\t"\ + "7:\n\t"\ + :[cd1]"=w"(cd1), [cd2]"=w"(cd2), [k_left]"+r"(k_left), [b_ptr]"+r"(b_ptr),\ + [a_ptr1]"+r"(a_ptr1), [a_ptr2]"+r"(a_ptr2),\ + [a_ptr3]"+r"(a_ptr3), [a_ptr4]"+r"(a_ptr4)\ + :[next_pref]"r"(next_pref)\ + :"cc","memory","v0","v1","v2","v3","v4",\ + "v8","v9","v10","v11","v12","v13","v14","v15");\ +\ + cd1 = VMLA_N_I32(cd1, VLD1_I32(c_ptr), beta);\ + cd2 = VMLA_N_I32(cd2, VLD1_I32(c_ptr + 2), beta);\ + VST1_I32(c_ptr, cd1);\ + VST1_I32(c_ptr + 2, cd2);\ +} + +/* k_mask = 31 */ +#define I8I32MLA_SKINNYDOT_INLINE_M4N2(gemm) \ +static inline void inline_##gemm##_arowmajor_bskinny_m4n2(\ + const I8 *a_ptr1, const I8 *b_ptr, I32 *c_ptr,\ + uint32_t k_left, uint32_t LDK, uint32_t LDM,\ + I32 beta, bool c_rowmajor) {\ +\ + const I8 *a_ptr2 = a_ptr1 + LDK;\ + const I8 *a_ptr3 = a_ptr1 + LDK * 2;\ + const I8 *a_ptr4 = a_ptr2 + LDK * 2;\ + I32X4 cq1, cq2, cq3, cq4; /* higher 2 elements not used */\ + const uint32_t next_pref = (LDK * 4 - k_left) + 16;\ + __asm__ __volatile__ (\ + "movi v6.16b,#0; movi v7.16b,#0\n\t"\ + "movi v8.16b,#0; movi v9.16b,#0\n\t"\ + "movi v10.16b,#0; movi v11.16b,#0\n\t"\ + "movi v12.16b,#0; movi v13.16b,#0\n\t"\ + "cmp %w[k_left],#16; b.lt 3f\n\t"\ + "ldr q0,[%[a_ptr1]],#16; ldr q1,[%[a_ptr2]],#16\n\t"\ + "ldr q2,[%[a_ptr3]],#16; ldr q3,[%[a_ptr4]],#16\n\t"\ + "ldr q4,[%[b_ptr]]; ldr q5,[%[b_ptr],#16]; add %[b_ptr],%[b_ptr],#32\n\t"\ + "cmp %w[k_left],#32; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + ""IMULL" v14.8h,v0.8b,v4.8b; "IMULL" v18.8h,v0.8b,v5.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr1],#64]\n\t"\ + ""IMULL" v15.8h,v1.8b,v4.8b; "IMULL" v19.8h,v1.8b,v5.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr2],#64]\n\t"\ + ""IMULL" v16.8h,v2.8b,v4.8b; "IMULL" v20.8h,v2.8b,v5.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr3],#64]\n\t"\ + ""IMULL" v17.8h,v3.8b,v4.8b; "IMULL" v21.8h,v3.8b,v5.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr4],#64]\n\t"\ + ""IADALP" v6.4s,v14.8h; "IMULL"2 v14.8h,v0.16b,v4.16b\n\t"\ + ""IADALP" v10.4s,v18.8h; "IMULL"2 v18.8h,v0.16b,v5.16b\n\t"\ + "ldr q0,[%[a_ptr1]],#16\n\t"\ + ""IADALP" v7.4s,v15.8h; "IMULL"2 v15.8h,v1.16b,v4.16b\n\t"\ + ""IADALP" v11.4s,v19.8h; "IMULL"2 v19.8h,v1.16b,v5.16b\n\t"\ + "ldr q1,[%[a_ptr2]],#16\n\t"\ + ""IADALP" v8.4s,v16.8h; "IMULL"2 v16.8h,v2.16b,v4.16b\n\t"\ + ""IADALP" v12.4s,v20.8h; "IMULL"2 v20.8h,v2.16b,v5.16b\n\t"\ + "ldr q2,[%[a_ptr3]],#16\n\t"\ + ""IADALP" v9.4s,v17.8h; "IMULL"2 v17.8h,v3.16b,v4.16b\n\t"\ + ""IADALP" v13.4s,v21.8h; "IMULL"2 v21.8h,v3.16b,v5.16b\n\t"\ + "ldr q3,[%[a_ptr4]],#16\n\t"\ + ""IADALP" v6.4s,v14.8h; "IADALP" v10.4s,v18.8h\n\t"\ + "ldr q4,[%[b_ptr]],#32\n\t"\ + ""IADALP" v7.4s,v15.8h; "IADALP" v11.4s,v19.8h\n\t"\ + "ldr q5,[%[b_ptr],#-16]\n\t"\ + ""IADALP" v8.4s,v16.8h; sub %w[k_left],%w[k_left],#16\n\t"\ + ""IADALP" v12.4s,v20.8h; cmp %w[k_left],#32\n\t"\ + ""IADALP" v9.4s,v17.8h; "IADALP" v13.4s,v21.8h\n\t"\ + "b.ge 1b\n\t"\ + "2:\n\t"\ + ""IMULL" v14.8h,v0.8b,v4.8b; "IMULL" v18.8h,v0.8b,v5.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr1],%w[next_pref],SXTW #0]\n\t"\ + ""IMULL" v15.8h,v1.8b,v4.8b; "IMULL" v19.8h,v1.8b,v5.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr2],%w[next_pref],SXTW #0]\n\t"\ + ""IMULL" v16.8h,v2.8b,v4.8b; "IMULL" v20.8h,v2.8b,v5.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr3],%w[next_pref],SXTW #0]\n\t"\ + ""IMULL" v17.8h,v3.8b,v4.8b; "IMULL" v21.8h,v3.8b,v5.8b\n\t"\ + "prfm pldl1keep,[%[a_ptr4],%w[next_pref],SXTW #0]\n\t"\ + ""IADALP" v6.4s,v14.8h; "IMULL"2 v14.8h,v0.16b,v4.16b\n\t"\ + ""IADALP" v10.4s,v18.8h; "IMULL"2 v18.8h,v0.16b,v5.16b\n\t"\ + ""IADALP" v7.4s,v15.8h; "IMULL"2 v15.8h,v1.16b,v4.16b\n\t"\ + ""IADALP" v11.4s,v19.8h; "IMULL"2 v19.8h,v1.16b,v5.16b\n\t"\ + ""IADALP" v8.4s,v16.8h; "IMULL"2 v16.8h,v2.16b,v4.16b\n\t"\ + ""IADALP" v12.4s,v20.8h; "IMULL"2 v20.8h,v2.16b,v5.16b\n\t"\ + ""IADALP" v9.4s,v17.8h; "IMULL"2 v17.8h,v3.16b,v4.16b\n\t"\ + ""IADALP" v13.4s,v21.8h; "IMULL"2 v21.8h,v3.16b,v5.16b\n\t"\ + ""IADALP" v6.4s,v14.8h; "IADALP" v10.4s,v18.8h\n\t"\ + ""IADALP" v7.4s,v15.8h; "IADALP" v11.4s,v19.8h\n\t"\ + ""IADALP" v8.4s,v16.8h; sub %w[k_left],%w[k_left],#16\n\t"\ + ""IADALP" v12.4s,v20.8h\n\t"\ + ""IADALP" v9.4s,v17.8h; "IADALP" v13.4s,v21.8h\n\t"\ + "3:\n\t"\ + "cmp %w[k_left],#8; b.lt 4f\n\t"\ + "ldr d0,[%[a_ptr1]],#8; ldr d1,[%[a_ptr2]],#8\n\t"\ + "ldr d2,[%[a_ptr3]],#8; ldr d3,[%[a_ptr4]],#8\n\t"\ + "ldr d4,[%[b_ptr]],#16; ldr d5,[%[b_ptr],#-8]\n\t"\ + ""IMULL" v14.8h,v0.8b,v4.8b; "IMULL" v18.8h,v0.8b,v5.8b\n\t"\ + ""IMULL" v15.8h,v1.8b,v4.8b; "IMULL" v19.8h,v1.8b,v5.8b\n\t"\ + ""IMULL" v16.8h,v2.8b,v4.8b; "IMULL" v20.8h,v2.8b,v5.8b\n\t"\ + ""IMULL" v17.8h,v3.8b,v4.8b; "IMULL" v21.8h,v3.8b,v5.8b\n\t"\ + "sub %w[k_left],%w[k_left],#8\n\t"\ + ""IADALP" v6.4s,v14.8h; "IADALP" v10.4s,v18.8h\n\t"\ + ""IADALP" v7.4s,v15.8h; "IADALP" v11.4s,v19.8h\n\t"\ + ""IADALP" v8.4s,v16.8h; "IADALP" v12.4s,v20.8h\n\t"\ + ""IADALP" v9.4s,v17.8h; "IADALP" v13.4s,v21.8h\n\t"\ + "4:\n\t"\ + "addp v6.4s,v6.4s,v10.4s; addp v7.4s,v7.4s,v11.4s\n\t"\ + "addp v8.4s,v8.4s,v12.4s; addp v9.4s,v9.4s,v13.4s\n\t"\ + "cmp %w[k_left],#4; b.lt 5f\n\t"\ + "ldr s4,[%[b_ptr]],#8; ldr s5,[%[b_ptr],#-4]\n\t"\ + "ld1r {v0.2s},[%[a_ptr1]],#4; ld1r {v1.2s},[%[a_ptr2]],#4\n\t"\ + "ins v4.s[1],v5.s[0]\n\t"\ + "ld1r {v2.2s},[%[a_ptr3]],#4; ld1r {v3.2s},[%[a_ptr4]],#4\n\t"\ + "sub %w[k_left],%w[k_left],#4\n\t"\ + ""IMULL" v14.8h,v0.8b,v4.8b; "IMULL" v15.8h,v1.8b,v4.8b\n\t"\ + ""IMULL" v16.8h,v2.8b,v4.8b; "IMULL" v17.8h,v3.8b,v4.8b\n\t"\ + ""IADALP" v6.4s,v14.8h; "IADALP" v7.4s,v15.8h\n\t"\ + ""IADALP" v8.4s,v16.8h; "IADALP" v9.4s,v17.8h\n\t"\ + "5:\n\t"\ + "movi v14.16b,#0\n\t"\ + "addp %[cq1].4s,v6.4s,v14.4s; addp %[cq2].4s,v7.4s,v14.4s\n\t"\ + "addp %[cq3].4s,v8.4s,v14.4s; addp %[cq4].4s,v9.4s,v14.4s\n\t"\ + "cmp %w[k_left],#2; b.lt 6f\n\t"\ + "ldr h4,[%[b_ptr]],#4; ldr h5,[%[b_ptr],#-2]\n\t"\ + "ld1r {v0.4h},[%[a_ptr1]],#2; ld1r {v1.4h},[%[a_ptr2]],#2\n\t"\ + "sub %w[k_left],%w[k_left],#2\n\t"\ + "ins v4.h[1],v5.h[0]\n\t"\ + "ld1r {v2.4h},[%[a_ptr3]],#2; ld1r {v3.4h},[%[a_ptr4]],#2\n\t"\ + ""IMULL" v14.8h,v0.8b,v4.8b; "IMULL" v15.8h,v1.8b,v4.8b\n\t"\ + ""IMULL" v16.8h,v2.8b,v4.8b; "IMULL" v17.8h,v3.8b,v4.8b\n\t"\ + ""IADALP" %[cq1].2s,v14.4h; "IADALP" %[cq2].2s,v15.4h\n\t"\ + ""IADALP" %[cq3].2s,v16.4h; "IADALP" %[cq4].2s,v17.4h\n\t"\ + "6:\n\t"\ + "cmp %w[k_left],#1; b.lt 7f\n\t"\ + "ldr b0,[%[a_ptr1]],#1; ldr b1,[%[a_ptr2]],#1\n\t"\ + "ldr b2,[%[a_ptr3]],#1; ldr b3,[%[a_ptr4]],#1\n\t"\ + "ldr b4,[%[b_ptr]],#2; ldr b5,[%[b_ptr],#-1]\n\t"\ + "ins v0.b[1],v1.b[0]; ins v2.b[1],v3.b[0]; ins v4.b[1],v5.b[0]\n\t"\ + ""IXTL" v0.8h,v0.8b; "IXTL" v2.8h,v2.8b; "IXTL" v4.8h,v4.8b\n\t"\ + "sub %w[k_left],%w[k_left],#1\n\t"\ + ""IMLAL" %[cq1].4s,v4.4h,v0.h[0]; "IMLAL" %[cq2].4s,v4.4h,v0.h[1]\n\t"\ + ""IMLAL" %[cq3].4s,v4.4h,v2.h[0]; "IMLAL" %[cq4].4s,v4.4h,v2.h[1]\n\t"\ + "7:\n\t"\ + :[cq1]"=w"(cq1), [cq2]"=w"(cq2), [cq3]"=w"(cq3), [cq4]"=w"(cq4),\ + [a_ptr1]"+r"(a_ptr1), [a_ptr2]"+r"(a_ptr2), [a_ptr3]"+r"(a_ptr3),\ + [a_ptr4]"+r"(a_ptr4), [b_ptr]"+r"(b_ptr),\ + [k_left]"+r"(k_left)\ + :[next_pref]"r"(next_pref)\ + :"cc","memory","v0","v1","v2","v3","v4","v5",\ + "v6","v7","v8","v9","v10","v11","v12","v13",\ + "v14","v15","v16","v17","v18","v19","v20","v21");\ +\ + I32X2 cd1 = VGET_LOW_I32(cq1);\ + I32X2 cd2 = VGET_LOW_I32(cq2);\ + I32X2 cd3 = VGET_LOW_I32(cq3);\ + I32X2 cd4 = VGET_LOW_I32(cq4);\ + if (c_rowmajor) {\ + cd1 = VMLA_N_I32(cd1, VLD1_I32(c_ptr), beta);\ + cd2 = VMLA_N_I32(cd2, VLD1_I32(c_ptr + 2), beta);\ + cd3 = VMLA_N_I32(cd3, VLD1_I32(c_ptr + 4), beta);\ + cd4 = VMLA_N_I32(cd4, VLD1_I32(c_ptr + 6), beta);\ + VST1_I32(c_ptr, cd1);\ + VST1_I32(c_ptr + 2, cd2);\ + VST1_I32(c_ptr + 4, cd3);\ + VST1_I32(c_ptr + 6, cd4);\ + } else {\ + I32 *c_ptr2 = c_ptr + LDM;\ + I32X2 cdl1 = VZIP1_I32(cd1, cd2);\ + I32X2 cdl2 = VZIP1_I32(cd3, cd4);\ + I32X2 cdl3 = VZIP2_I32(cd1, cd2);\ + I32X2 cdl4 = VZIP2_I32(cd3, cd4);\ + cdl1 = VMLA_N_I32(cdl1, VLD1_I32(c_ptr), beta);\ + cdl2 = VMLA_N_I32(cdl2, VLD1_I32(c_ptr + 2), beta);\ + cdl3 = VMLA_N_I32(cdl3, VLD1_I32(c_ptr2), beta);\ + cdl4 = VMLA_N_I32(cdl4, VLD1_I32(c_ptr2 + 2), beta);\ + VST1_I32(c_ptr, cdl1); VST1_I32(c_ptr + 2, cdl2);\ + VST1_I32(c_ptr2, cdl3); VST1_I32(c_ptr2 + 2, cdl4);\ + }\ +} + +/* k_mask = 31 */ +#define I8I32MLA_SKINNYDOT_INLINE_M4N3(gemm) \ +static inline void inline_##gemm##_arowmajor_bskinny_m4n3(\ + const I8 *a_ptr1, const I8 *b_ptr, I32 *c_ptr,\ + uint32_t k_left, uint32_t LDK, uint32_t LDM,\ + I32 beta, bool c_rowmajor) {\ +\ + const I8 *a_ptr2 = a_ptr1 + LDK;\ + const I8 *a_ptr3 = a_ptr1 + LDK * 2;\ + const I8 *a_ptr4 = a_ptr2 + LDK * 2;\ + I32X4 cq1, cq2, cq3;\ + const uint32_t next_pref = (LDK * 4 - k_left) + 16;\ + __asm__ __volatile__ (\ + "movi %[q1].16b,#0; movi %[q2].16b,#0; movi %[q3].16b,#0\n\t"\ + "movi v10.16b,#0; movi v11.16b,#0; movi v12.16b,#0\n\t"\ + "movi v13.16b,#0; movi v14.16b,#0; movi v15.16b,#0\n\t"\ + "movi v16.16b,#0; movi v17.16b,#0; movi v18.16b,#0\n\t"\ + "cmp %w[k_left],#16; b.lt 3f\n\t"\ + "ldr q0,[%[a_ptr1]],#16; ldr q1,[%[a_ptr2]],#16\n\t"\ + "ldr q2,[%[a_ptr3]],#16; ldr q3,[%[a_ptr4]],#16\n\t"\ + "ldr q4,[%[b_ptr]]; ldr q5,[%[b_ptr],#16]\n\t"\ + "ldr q6,[%[b_ptr],#32]; add %[b_ptr],%[b_ptr],#48\n\t"\ + "cmp %w[k_left],#32; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + ""IMULL" v19.8h,v0.8b,v4.8b; "IMULL" v20.8h,v0.8b,v5.8b\n\t"\ + ""IMULL" v21.8h,v0.8b,v6.8b; prfm pldl1keep,[%[a_ptr1],#64]\n\t"\ + ""IMULL" v22.8h,v1.8b,v4.8b; "IMULL" v23.8h,v1.8b,v5.8b\n\t"\ + ""IMULL" v24.8h,v1.8b,v6.8b; prfm pldl1keep,[%[a_ptr2],#64]\n\t"\ + ""IMULL" v25.8h,v2.8b,v4.8b; "IMULL" v26.8h,v2.8b,v5.8b\n\t"\ + ""IMULL" v27.8h,v2.8b,v6.8b; prfm pldl1keep,[%[a_ptr3],#64]\n\t"\ + ""IMULL" v28.8h,v3.8b,v4.8b; "IMULL" v29.8h,v3.8b,v5.8b\n\t"\ + ""IMULL" v30.8h,v3.8b,v6.8b; prfm pldl1keep,[%[a_ptr4],#64]\n\t"\ + ""IADALP" %[q1].4s,v19.8h; "IMULL"2 v19.8h,v0.16b,v4.16b\n\t"\ + ""IADALP" %[q2].4s,v20.8h; "IMULL"2 v20.8h,v0.16b,v5.16b\n\t"\ + ""IADALP" %[q3].4s,v21.8h; "IMULL"2 v21.8h,v0.16b,v6.16b\n\t"\ + "ldr q0,[%[a_ptr1]],#16\n\t"\ + ""IADALP" v10.4s,v22.8h; "IMULL"2 v22.8h,v1.16b,v4.16b\n\t"\ + ""IADALP" v11.4s,v23.8h; "IMULL"2 v23.8h,v1.16b,v5.16b\n\t"\ + ""IADALP" v12.4s,v24.8h; "IMULL"2 v24.8h,v1.16b,v6.16b\n\t"\ + "ldr q1,[%[a_ptr2]],#16\n\t"\ + ""IADALP" v13.4s,v25.8h; "IMULL"2 v25.8h,v2.16b,v4.16b\n\t"\ + ""IADALP" v14.4s,v26.8h; "IMULL"2 v26.8h,v2.16b,v5.16b\n\t"\ + ""IADALP" v15.4s,v27.8h; "IMULL"2 v27.8h,v2.16b,v6.16b\n\t"\ + "ldr q2,[%[a_ptr3]],#16\n\t"\ + ""IADALP" v16.4s,v28.8h; "IMULL"2 v28.8h,v3.16b,v4.16b\n\t"\ + ""IADALP" v17.4s,v29.8h; "IMULL"2 v29.8h,v3.16b,v5.16b\n\t"\ + ""IADALP" v18.4s,v30.8h; "IMULL"2 v30.8h,v3.16b,v6.16b\n\t"\ + "ldr q3,[%[a_ptr4]],#16\n\t"\ + ""IADALP" %[q1].4s,v19.8h; "IADALP" %[q2].4s,v20.8h; "IADALP" %[q3].4s,v21.8h\n\t"\ + "ldr q4,[%[b_ptr]]; sub %w[k_left],%w[k_left],#16\n\t"\ + ""IADALP" v10.4s,v22.8h; "IADALP" v11.4s,v23.8h; "IADALP" v12.4s,v24.8h\n\t"\ + "ldr q5,[%[b_ptr],#16]\n\t"\ + ""IADALP" v13.4s,v25.8h; "IADALP" v14.4s,v26.8h; "IADALP" v15.4s,v27.8h\n\t"\ + "ldr q6,[%[b_ptr],#32]; add %[b_ptr],%[b_ptr],#48; cmp %w[k_left],#32\n\t"\ + ""IADALP" v16.4s,v28.8h; "IADALP" v17.4s,v29.8h; "IADALP" v18.4s,v30.8h\n\t"\ + "b.ge 1b\n\t"\ + "2:\n\t"\ + ""IMULL" v19.8h,v0.8b,v4.8b; "IMULL" v20.8h,v0.8b,v5.8b\n\t"\ + ""IMULL" v21.8h,v0.8b,v6.8b; prfm pldl1keep,[%[a_ptr1],%w[pref],SXTW #0]\n\t"\ + ""IMULL" v22.8h,v1.8b,v4.8b; "IMULL" v23.8h,v1.8b,v5.8b\n\t"\ + ""IMULL" v24.8h,v1.8b,v6.8b; prfm pldl1keep,[%[a_ptr2],%w[pref],SXTW #0]\n\t"\ + ""IMULL" v25.8h,v2.8b,v4.8b; "IMULL" v26.8h,v2.8b,v5.8b\n\t"\ + ""IMULL" v27.8h,v2.8b,v6.8b; prfm pldl1keep,[%[a_ptr3],%w[pref],SXTW #0]\n\t"\ + ""IMULL" v28.8h,v3.8b,v4.8b; "IMULL" v29.8h,v3.8b,v5.8b\n\t"\ + ""IMULL" v30.8h,v3.8b,v6.8b; prfm pldl1keep,[%[a_ptr4],%w[pref],SXTW #0]\n\t"\ + ""IADALP" %[q1].4s,v19.8h; "IMULL"2 v19.8h,v0.16b,v4.16b\n\t"\ + ""IADALP" %[q2].4s,v20.8h; "IMULL"2 v20.8h,v0.16b,v5.16b\n\t"\ + ""IADALP" %[q3].4s,v21.8h; "IMULL"2 v21.8h,v0.16b,v6.16b\n\t"\ + ""IADALP" v10.4s,v22.8h; "IMULL"2 v22.8h,v1.16b,v4.16b\n\t"\ + ""IADALP" v11.4s,v23.8h; "IMULL"2 v23.8h,v1.16b,v5.16b\n\t"\ + ""IADALP" v12.4s,v24.8h; "IMULL"2 v24.8h,v1.16b,v6.16b\n\t"\ + ""IADALP" v13.4s,v25.8h; "IMULL"2 v25.8h,v2.16b,v4.16b\n\t"\ + ""IADALP" v14.4s,v26.8h; "IMULL"2 v26.8h,v2.16b,v5.16b\n\t"\ + ""IADALP" v15.4s,v27.8h; "IMULL"2 v27.8h,v2.16b,v6.16b\n\t"\ + ""IADALP" v16.4s,v28.8h; "IMULL"2 v28.8h,v3.16b,v4.16b\n\t"\ + ""IADALP" v17.4s,v29.8h; "IMULL"2 v29.8h,v3.16b,v5.16b\n\t"\ + ""IADALP" v18.4s,v30.8h; "IMULL"2 v30.8h,v3.16b,v6.16b\n\t"\ + ""IADALP" %[q1].4s,v19.8h; "IADALP" %[q2].4s,v20.8h; "IADALP" %[q3].4s,v21.8h\n\t"\ + "sub %w[k_left],%w[k_left],#16\n\t"\ + ""IADALP" v10.4s,v22.8h; "IADALP" v11.4s,v23.8h; "IADALP" v12.4s,v24.8h\n\t"\ + ""IADALP" v13.4s,v25.8h; "IADALP" v14.4s,v26.8h; "IADALP" v15.4s,v27.8h\n\t"\ + ""IADALP" v16.4s,v28.8h; "IADALP" v17.4s,v29.8h; "IADALP" v18.4s,v30.8h\n\t"\ + "3:\n\t"\ + "cmp %w[k_left],#8; b.lt 4f\n\t"\ + "ldr d0,[%[a_ptr1]],#8; ldr d1,[%[a_ptr2]],#8\n\t"\ + "ldr d2,[%[a_ptr3]],#8; ldr d3,[%[a_ptr4]],#8\n\t"\ + "ldr d4,[%[b_ptr]]; ldr d5,[%[b_ptr],#8]\n\t"\ + "ldr d6,[%[b_ptr],#16]; add %[b_ptr],%[b_ptr],#24\n\t"\ + "sub %w[k_left],%w[k_left],#8\n\t"\ + ""IMULL" v19.8h,v0.8b,v4.8b; "IMULL" v20.8h,v0.8b,v5.8b\n\t"\ + ""IMULL" v21.8h,v0.8b,v6.8b; "IMULL" v22.8h,v1.8b,v4.8b\n\t"\ + ""IMULL" v23.8h,v1.8b,v5.8b; "IMULL" v24.8h,v1.8b,v6.8b\n\t"\ + ""IMULL" v25.8h,v2.8b,v4.8b; "IMULL" v26.8h,v2.8b,v5.8b\n\t"\ + ""IMULL" v27.8h,v2.8b,v6.8b; "IMULL" v28.8h,v3.8b,v4.8b\n\t"\ + ""IMULL" v29.8h,v3.8b,v5.8b; "IMULL" v30.8h,v3.8b,v6.8b\n\t"\ + ""IADALP" %[q1].4s,v19.8h; "IADALP" %[q2].4s,v20.8h; "IADALP" %[q3].4s,v21.8h\n\t"\ + ""IADALP" v10.4s,v22.8h; "IADALP" v11.4s,v23.8h; "IADALP" v12.4s,v24.8h\n\t"\ + ""IADALP" v13.4s,v25.8h; "IADALP" v14.4s,v26.8h; "IADALP" v15.4s,v27.8h\n\t"\ + ""IADALP" v16.4s,v28.8h; "IADALP" v17.4s,v29.8h; "IADALP" v18.4s,v30.8h\n\t"\ + "4:\n\t"\ + "addp %[q1].4s,%[q1].4s,v10.4s; addp v13.4s,v13.4s,v16.4s\n\t"\ + "addp %[q2].4s,%[q2].4s,v11.4s; addp v14.4s,v14.4s,v17.4s\n\t"\ + "addp %[q3].4s,%[q3].4s,v12.4s; addp v15.4s,v15.4s,v18.4s\n\t"\ + "cmp %w[k_left],#4; b.lt 5f\n\t"\ + "ldr s0,[%[a_ptr1]],#4; ldr s1,[%[a_ptr2]],#4\n\t"\ + "ldr s2,[%[a_ptr3]],#4; ldr s3,[%[a_ptr4]],#4\n\t"\ + "ld1r {v4.2s},[%[b_ptr]],#4; ins v0.s[1],v1.s[0]\n\t"\ + "ld1r {v5.2s},[%[b_ptr]],#4; ins v2.s[1],v3.s[0]\n\t"\ + "ld1r {v6.2s},[%[b_ptr]],#4\n\t"\ + "sub %w[k_left],%w[k_left],#4\n\t"\ + ""IMULL" v19.8h,v0.8b,v4.8b; "IMULL" v20.8h,v0.8b,v5.8b\n\t"\ + ""IMULL" v21.8h,v0.8b,v6.8b; "IMULL" v25.8h,v2.8b,v4.8b\n\t"\ + ""IMULL" v26.8h,v2.8b,v5.8b; "IMULL" v27.8h,v2.8b,v6.8b\n\t"\ + ""IADALP" %[q1].4s,v19.8h; "IADALP" %[q2].4s,v20.8h; "IADALP" %[q3].4s,v21.8h\n\t"\ + ""IADALP" v13.4s,v25.8h; "IADALP" v14.4s,v26.8h; "IADALP" v15.4s,v27.8h\n\t"\ + "5:\n\t"\ + "addp %[q1].4s,%[q1].4s,v13.4s\n\t"\ + "addp %[q2].4s,%[q2].4s,v14.4s\n\t"\ + "addp %[q3].4s,%[q3].4s,v15.4s\n\t"\ + "cmp %w[k_left],#2; b.lt 6f\n\t"\ + "ldr h0,[%[a_ptr1]],#2; ldr h1,[%[a_ptr2]],#2\n\t"\ + "ldr h2,[%[a_ptr3]],#2; ldr h3,[%[a_ptr4]],#2\n\t"\ + "ld1r {v4.4h},[%[b_ptr]],#2; ins v0.h[1],v1.h[0]\n\t"\ + "ld1r {v5.4h},[%[b_ptr]],#2; ins v2.h[1],v3.h[0]\n\t"\ + "ld1r {v6.4h},[%[b_ptr]],#2\n\t"\ + "sub %w[k_left],%w[k_left],#2\n\t"\ + "ins v0.s[1],v2.s[0]\n\t"\ + ""IMULL" v19.8h,v0.8b,v4.8b\n\t"\ + ""IMULL" v20.8h,v0.8b,v5.8b\n\t"\ + ""IMULL" v21.8h,v0.8b,v6.8b\n\t"\ + ""IADALP" %[q1].4s,v19.8h; "IADALP" %[q2].4s,v20.8h; "IADALP" %[q3].4s,v21.8h\n\t"\ + "6:\n\t"\ + "cmp %w[k_left],#1; b.lt 7f\n\t"\ + "ldr b0,[%[a_ptr1]],#1; ldr b1,[%[a_ptr2]],#1\n\t"\ + "ldr b2,[%[a_ptr3]],#1; ldr b3,[%[a_ptr4]],#1\n\t"\ + "ldr b4,[%[b_ptr]]; ins v0.b[1],v1.b[0]\n\t"\ + "ldr b5,[%[b_ptr],#1]; ins v2.b[1],v3.b[0]\n\t"\ + "ldr b6,[%[b_ptr],#2]; add %[b_ptr],%[b_ptr],#3\n\t"\ + "ins v4.b[1],v5.b[0]\n\t"\ + "ins v0.h[1],v2.h[0]; ins v4.b[2],v6.b[0]\n\t"\ + "sub %w[k_left],%w[k_left],#1\n\t"\ + ""IXTL" v0.8h,v0.8b; "IXTL" v4.8h,v4.8b\n\t"\ + ""IMLAL" %[q1].4s,v0.4h,v4.h[0]\n\t"\ + ""IMLAL" %[q2].4s,v0.4h,v4.h[1]\n\t"\ + ""IMLAL" %[q3].4s,v0.4h,v4.h[2]\n\t"\ + "7:\n\t"\ + :[q1]"=w"(cq1), [q2]"=w"(cq2), [q3]"=w"(cq3), [k_left]"+r"(k_left),\ + [a_ptr1]"+r"(a_ptr1), [a_ptr2]"+r"(a_ptr2), [a_ptr3]"+r"(a_ptr3),\ + [a_ptr4]"+r"(a_ptr4), [b_ptr]"+r"(b_ptr)\ + :[pref]"r"(next_pref)\ + :"cc","memory","v0","v1","v2","v3","v4","v5","v6",\ + "v10","v11","v12","v13","v14","v15","v16","v17","v18",\ + "v19","v20","v21","v22","v23","v24","v25","v26","v27",\ + "v28","v29","v30");\ +\ + if (c_rowmajor) {\ + I32X4X3 cqt1 = VLD3Q_I32(c_ptr);\ + cqt1.val[0] = VMLAQ_N_I32(cq1, cqt1.val[0], beta);\ + cqt1.val[1] = VMLAQ_N_I32(cq2, cqt1.val[1], beta);\ + cqt1.val[2] = VMLAQ_N_I32(cq3, cqt1.val[2], beta);\ + VST3Q_I32(c_ptr, cqt1);\ + } else {\ + cq1 = VMLAQ_N_I32(cq1, VLD1Q_I32(c_ptr), beta);\ + cq2 = VMLAQ_N_I32(cq2, VLD1Q_I32(c_ptr + LDM), beta);\ + cq3 = VMLAQ_N_I32(cq3, VLD1Q_I32(c_ptr + LDM * 2), beta);\ + VST1Q_I32(c_ptr, cq1); c_ptr += LDM;\ + VST1Q_I32(c_ptr, cq2); c_ptr += LDM;\ + VST1Q_I32(c_ptr, cq3);\ + }\ +} + + + +#define I8I32MLA_SKINNY_DOT_INLINE_FUNCS_M4(gemm) \ + I8I32MLA_SKINNYDOT_INLINE_M4N1(gemm)\ + I8I32MLA_SKINNYDOT_INLINE_M4N2(gemm)\ + I8I32MLA_SKINNYDOT_INLINE_M4N3(gemm) + +I8I32MLA_SKINNY_DOT_INLINE_FUNCS_M4(I8I32MLAGEMM) + +#define GEMM_SKINNY_DOT_INLINE_FUNC_DEDUCE(a, b, c, d)\ + GEMM_SKINNY_DOT_INLINE_PACK_FUNC(a, b, c, d) + +GEMM_SKINNY_DOT_INLINE_FUNC_DEDUCE(I8I32MLAGEMM, 1, 1, 31) +GEMM_SKINNY_DOT_INLINE_FUNC_DEDUCE(I8I32MLAGEMM, 1, 2, 31) +GEMM_SKINNY_DOT_INLINE_FUNC_DEDUCE(I8I32MLAGEMM, 1, 3, 31) + +static inline bool unroll_test_m4n1(uint32_t M, uint32_t K) { + return K <= 16384; +} + +static inline bool unroll_test_m1n1(uint32_t M, uint32_t K) { + return true; +} + +static inline bool unroll_test_m4n2(uint32_t M, uint32_t K) { + return K <= 16384; +} + +static inline bool unroll_test_m1n2(uint32_t M, uint32_t K) { + return true; +} + +static inline bool unroll_test_m4n3(uint32_t M, uint32_t K) { + return K <= 16384; +} + +static inline bool unroll_test_m1n3(uint32_t M, uint32_t K) { + return true; +} + +#endif \ No newline at end of file diff --git a/include/neon_armv8a/S8S32DotGemmCopy.h b/include/neon_armv8a/S8S32DotGemmCopy.h new file mode 100644 index 0000000..d3e2d8c --- /dev/null +++ b/include/neon_armv8a/S8S32DotGemmCopy.h @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32dotgemm_int8_t_int32_t_tcopy_unroll8(const int8_t * __restrict__ src, + int32_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32dotgemm_int8_t_int32_t_tcopy_unroll12(const int8_t * __restrict__ src, + int32_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32dotgemm_int8_t_int32_t_ncopy_unroll8(const int8_t * __restrict__ src, + int32_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32dotgemm_int8_t_int32_t_ncopy_unroll12(const int8_t * __restrict__ src, + int32_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv8a/S8S32DotGemmDriver.h b/include/neon_armv8a/S8S32DotGemmDriver.h new file mode 100644 index 0000000..06e3d93 --- /dev/null +++ b/include/neon_armv8a/S8S32DotGemmDriver.h @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +int s8s32dotgemm_serial(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, int32_t beta_inp); + +int s8s32dotgemm(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, + int32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/S8S32DotGemmKernel.h b/include/neon_armv8a/S8S32DotGemmKernel.h new file mode 100644 index 0000000..8b55075 --- /dev/null +++ b/include/neon_armv8a/S8S32DotGemmKernel.h @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32dotgemm_kernel_lm_m8n12(uint32_t M, uint32_t N, uint32_t kdiv4, + int32_t beta, + const int32_t * __restrict__ sa, const int32_t * __restrict__ sb, + int32_t * __restrict__ C, uint32_t ldc); + +void s8s32dotgemm_kernel_ln_m12n8(uint32_t M, uint32_t N, uint32_t Kdiv4, + int32_t beta, + const int32_t * __restrict__ sa, const int32_t * __restrict__ sb, + int32_t * __restrict__ C, uint32_t ldc); + diff --git a/include/neon_armv8a/S8S32DotGemmSkinnyDot.h b/include/neon_armv8a/S8S32DotGemmSkinnyDot.h new file mode 100644 index 0000000..58c0639 --- /dev/null +++ b/include/neon_armv8a/S8S32DotGemmSkinnyDot.h @@ -0,0 +1,103 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n1(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n2(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n3(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n4(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n5(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n6(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n7(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n8(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n9(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n10(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n11(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n12(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n1_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n2_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n3_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n4_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n5_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n6_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n7_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n8_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n9_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n10_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n11_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32dotgemm_arowmajor_bskinny_aint8_t_bint8_t_n12_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/S8S32MlaGemmCopy.h b/include/neon_armv8a/S8S32MlaGemmCopy.h new file mode 100644 index 0000000..2041972 --- /dev/null +++ b/include/neon_armv8a/S8S32MlaGemmCopy.h @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32mlagemm_int8_t_int16_t_ncopy_unroll8(const int8_t * __restrict__ src, + int16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32mlagemm_int8_t_int16_t_ncopy_unroll12(const int8_t * __restrict__ src, + int16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32mlagemm_int8_t_int16_t_tcopy_unroll8(const int8_t * __restrict__ src, + int16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void s8s32mlagemm_int8_t_int16_t_tcopy_unroll12(const int8_t * __restrict__ src, + int16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv8a/S8S32MlaGemmDriver.h b/include/neon_armv8a/S8S32MlaGemmDriver.h new file mode 100644 index 0000000..26121fa --- /dev/null +++ b/include/neon_armv8a/S8S32MlaGemmDriver.h @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +int s8s32mlagemm_serial(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, int32_t beta_inp); + +int s8s32mlagemm(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, + int32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/S8S32MlaGemmKernel.h b/include/neon_armv8a/S8S32MlaGemmKernel.h new file mode 100644 index 0000000..330a8a2 --- /dev/null +++ b/include/neon_armv8a/S8S32MlaGemmKernel.h @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32mlagemm_kernel_lm_m8n12(uint32_t M, uint32_t N, uint32_t K, + int32_t beta, + const int16_t * __restrict__ sa, const int16_t * __restrict__ sb, + int32_t * __restrict__ C, uint32_t ldc); + +void s8s32mlagemm_kernel_ln_m12n8(uint32_t M, uint32_t N, uint32_t K, + int32_t beta, + const int16_t * __restrict__ sa, const int16_t * __restrict__ sb, + int32_t * __restrict__ C, uint32_t ldc); + diff --git a/include/neon_armv8a/S8S32MlaGemmSkinnyDot.h b/include/neon_armv8a/S8S32MlaGemmSkinnyDot.h new file mode 100644 index 0000000..ffe895b --- /dev/null +++ b/include/neon_armv8a/S8S32MlaGemmSkinnyDot.h @@ -0,0 +1,75 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n1(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n2(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n3(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n4(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n5(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n6(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n7(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n8(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n1_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n2_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n3_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n4_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n5_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n6_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n7_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_arowmajor_bskinny_aint8_t_bint8_t_n8_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/S8S32MlaGemmSkinnyGer.h b/include/neon_armv8a/S8S32MlaGemmSkinnyGer.h new file mode 100644 index 0000000..8600d13 --- /dev/null +++ b/include/neon_armv8a/S8S32MlaGemmSkinnyGer.h @@ -0,0 +1,74 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n1(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n2(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n3(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n4(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n5(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n6(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n7(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n8(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, int32_t beta_inp); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n1_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n2_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n3_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n4_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n5_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n6_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n7_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); + +void s8s32mlagemm_acolmajor_bskinny_aint8_t_bint8_t_n8_omp(const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + int32_t beta_inp, uint32_t num_threads); diff --git a/include/neon_armv8a/SgemmCopy.h b/include/neon_armv8a/SgemmCopy.h new file mode 100644 index 0000000..7a74074 --- /dev/null +++ b/include/neon_armv8a/SgemmCopy.h @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_float_float_ncopy_unroll8(const float * __restrict__ src, + float * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void sgemm_float_float_ncopy_unroll12(const float * __restrict__ src, + float * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void sgemm_float_float_tcopy_unroll8(const float * __restrict__ src, + float * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void sgemm_float_float_tcopy_unroll12(const float * __restrict__ src, + float * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv8a/SgemmDriver.h b/include/neon_armv8a/SgemmDriver.h new file mode 100644 index 0000000..bfc4217 --- /dev/null +++ b/include/neon_armv8a/SgemmDriver.h @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +int sgemm_serial(int a_rowmajor, int b_rowmajor, + const float *A, const float *B, float *C, + uint32_t M, uint32_t N, uint32_t K, float beta_inp); + +int sgemm(int a_rowmajor, int b_rowmajor, + const float *A, const float *B, float *C, + uint32_t M, uint32_t N, uint32_t K, float beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/SgemmKernel.h b/include/neon_armv8a/SgemmKernel.h new file mode 100644 index 0000000..936160c --- /dev/null +++ b/include/neon_armv8a/SgemmKernel.h @@ -0,0 +1,26 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_kernel_lm_m8n12(uint32_t M, uint32_t N, uint32_t K, float beta, + const float * __restrict__ sa, const float * __restrict__ sb, + float * __restrict__ C, uint32_t ldc); + +void sgemm_kernel_ln_m12n8(uint32_t M, uint32_t N, uint32_t K, float beta, + const float * __restrict__ sa, const float * __restrict__ sb, + float * __restrict__ C, uint32_t ldc); diff --git a/include/neon_armv8a/SgemmSkinnyDot.h b/include/neon_armv8a/SgemmSkinnyDot.h new file mode 100644 index 0000000..d40593e --- /dev/null +++ b/include/neon_armv8a/SgemmSkinnyDot.h @@ -0,0 +1,319 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_arowmajor_bskinny_afloat_bfloat_n1(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n2(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n3(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n4(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n5(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n6(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n7(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n8(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n9(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n10(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n11(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n12(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n13(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n14(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n15(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n16(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n17(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n18(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n19(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n20(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n21(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n22(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n23(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n24(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n25(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n26(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n27(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n28(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n29(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n30(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n31(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n32(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n33(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n34(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n35(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n36(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n37(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n38(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n39(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n40(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n41(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n42(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n43(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n44(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n45(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n46(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n47(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n48(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n49(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n50(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n1_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n2_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n3_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n4_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n5_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n6_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n7_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n8_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n9_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n10_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n11_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n12_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n13_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n14_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n15_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n16_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n17_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n18_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n19_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n20_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n21_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n22_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n23_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n24_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n25_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n26_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n27_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n28_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n29_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n30_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n31_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n32_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n33_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n34_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n35_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n36_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n37_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n38_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n39_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n40_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n41_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n42_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n43_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n44_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n45_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n46_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n47_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n48_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n49_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_arowmajor_bskinny_afloat_bfloat_n50_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/SgemmSkinnyGer.h b/include/neon_armv8a/SgemmSkinnyGer.h new file mode 100644 index 0000000..fdc16b3 --- /dev/null +++ b/include/neon_armv8a/SgemmSkinnyGer.h @@ -0,0 +1,91 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_acolmajor_bskinny_afloat_bfloat_n1(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n2(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n3(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n4(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n5(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n6(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n7(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n8(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n9(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n10(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n11(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n12(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n1_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n2_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n3_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n4_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n5_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n6_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n7_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n8_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n9_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n10_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n11_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_acolmajor_bskinny_afloat_bfloat_n12_omp(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint8_t b_c_order, float beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/U8U32DotGemmCopy.h b/include/neon_armv8a/U8U32DotGemmCopy.h new file mode 100644 index 0000000..b33c964 --- /dev/null +++ b/include/neon_armv8a/U8U32DotGemmCopy.h @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32dotgemm_uint8_t_uint32_t_tcopy_unroll8(const uint8_t * __restrict__ src, + uint32_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32dotgemm_uint8_t_uint32_t_tcopy_unroll12(const uint8_t * __restrict__ src, + uint32_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32dotgemm_uint8_t_uint32_t_ncopy_unroll8(const uint8_t * __restrict__ src, + uint32_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32dotgemm_uint8_t_uint32_t_ncopy_unroll12(const uint8_t * __restrict__ src, + uint32_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv8a/U8U32DotGemmDriver.h b/include/neon_armv8a/U8U32DotGemmDriver.h new file mode 100644 index 0000000..170723a --- /dev/null +++ b/include/neon_armv8a/U8U32DotGemmDriver.h @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +int u8u32dotgemm_serial(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, uint32_t beta_inp); + +int u8u32dotgemm(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, + uint32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/U8U32DotGemmKernel.h b/include/neon_armv8a/U8U32DotGemmKernel.h new file mode 100644 index 0000000..910ef46 --- /dev/null +++ b/include/neon_armv8a/U8U32DotGemmKernel.h @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32dotgemm_kernel_lm_m8n12(uint32_t M, uint32_t N, uint32_t kdiv4, + uint32_t beta, + const uint32_t * __restrict__ sa, const uint32_t * __restrict__ sb, + uint32_t * __restrict__ C, uint32_t ldc); + +void u8u32dotgemm_kernel_ln_m12n8(uint32_t M, uint32_t N, uint32_t Kdiv4, + uint32_t beta, + const uint32_t * __restrict__ sa, const uint32_t * __restrict__ sb, + uint32_t * __restrict__ C, uint32_t ldc); + diff --git a/include/neon_armv8a/U8U32DotGemmSkinnyDot.h b/include/neon_armv8a/U8U32DotGemmSkinnyDot.h new file mode 100644 index 0000000..5c8a646 --- /dev/null +++ b/include/neon_armv8a/U8U32DotGemmSkinnyDot.h @@ -0,0 +1,115 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n1(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n2(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n3(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n4(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n5(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n6(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n7(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n8(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n9(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n10(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n11(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n12(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n1_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n2_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n3_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n4_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n5_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n6_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n7_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n8_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n9_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n10_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n11_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32dotgemm_arowmajor_bskinny_auint8_t_buint8_t_n12_omp( + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/U8U32MlaGemmCopy.h b/include/neon_armv8a/U8U32MlaGemmCopy.h new file mode 100644 index 0000000..18a6eef --- /dev/null +++ b/include/neon_armv8a/U8U32MlaGemmCopy.h @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32mlagemm_uint8_t_uint16_t_ncopy_unroll8(const uint8_t * __restrict__ src, + uint16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32mlagemm_uint8_t_uint16_t_ncopy_unroll12(const uint8_t * __restrict__ src, + uint16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32mlagemm_uint8_t_uint16_t_tcopy_unroll8(const uint8_t * __restrict__ src, + uint16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + +void u8u32mlagemm_uint8_t_uint16_t_tcopy_unroll12(const uint8_t * __restrict__ src, + uint16_t * __restrict__ dst, uint32_t ld_dim, uint32_t dim1, uint32_t dim2); + diff --git a/include/neon_armv8a/U8U32MlaGemmDriver.h b/include/neon_armv8a/U8U32MlaGemmDriver.h new file mode 100644 index 0000000..9477c3d --- /dev/null +++ b/include/neon_armv8a/U8U32MlaGemmDriver.h @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +int u8u32mlagemm_serial(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, uint32_t beta_inp); + +int u8u32mlagemm(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, + uint32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/U8U32MlaGemmKernel.h b/include/neon_armv8a/U8U32MlaGemmKernel.h new file mode 100644 index 0000000..34f1285 --- /dev/null +++ b/include/neon_armv8a/U8U32MlaGemmKernel.h @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32mlagemm_kernel_lm_m8n12(uint32_t M, uint32_t N, uint32_t K, + uint32_t beta, + const uint16_t * __restrict__ sa, const uint16_t * __restrict__ sb, + uint32_t * __restrict__ C, uint32_t ldc); + +void u8u32mlagemm_kernel_ln_m12n8(uint32_t M, uint32_t N, uint32_t K, + uint32_t beta, + const uint16_t * __restrict__ sa, const uint16_t * __restrict__ sb, + uint32_t * __restrict__ C, uint32_t ldc); + diff --git a/include/neon_armv8a/U8U32MlaGemmSkinnyDot.h b/include/neon_armv8a/U8U32MlaGemmSkinnyDot.h new file mode 100644 index 0000000..e1ea3e4 --- /dev/null +++ b/include/neon_armv8a/U8U32MlaGemmSkinnyDot.h @@ -0,0 +1,75 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n1(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n2(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n3(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n4(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n5(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n6(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n7(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n8(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n1_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n2_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n3_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n4_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n5_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n6_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n7_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_arowmajor_bskinny_auint8_t_buint8_t_n8_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + diff --git a/include/neon_armv8a/U8U32MlaGemmSkinnyGer.h b/include/neon_armv8a/U8U32MlaGemmSkinnyGer.h new file mode 100644 index 0000000..d72b3bf --- /dev/null +++ b/include/neon_armv8a/U8U32MlaGemmSkinnyGer.h @@ -0,0 +1,74 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n1(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n2(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n3(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n4(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n5(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n6(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n7(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n8(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, uint32_t beta_inp); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n1_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n2_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n3_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n4_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n5_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n6_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n7_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); + +void u8u32mlagemm_acolmajor_bskinny_auint8_t_buint8_t_n8_omp(const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t K, uint8_t b_c_order, + uint32_t beta_inp, uint32_t num_threads); diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/ReadME.md b/include/neon_armv8a/sgemm_skinny_dot_kernel/ReadME.md new file mode 100644 index 0000000..2ae6f90 --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/ReadME.md @@ -0,0 +1,23 @@ +# Tuned ARMv8a SGEMM functions for skinny matrices + +### Supported shapes and orders +``` +C(MxN) = A(MxK) B(KxN) +(1). 4 < M < 51, N >> 50, K >> 50, matrix B is column-major; +(2). 4 < N < 51, M >> 50, K >> 50, matrix A is row-major. +``` + +### Interface +``` +sgemm_skinny1_arowmajor_nXXX_YYY(const float *A, const float *B, float *C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order); + +XXX: a number representing the length of dimension N +YYY: letters indicating tuned arm CPU, e.g. a35/a53/a7x +b_c_order: the order of skinny matrices B & C + 0: B & C column-major; + 1: B row-major, C column-major + 2: B column-major, C row-major + 3: B & C row-major +``` diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA35.h b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA35.h new file mode 100644 index 0000000..9e9b4b7 --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA35.h @@ -0,0 +1,488 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_skinny1_arowmajor_n4_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n5_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n6_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n7_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n8_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n9_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n10_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n11_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n12_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n13_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n14_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n15_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n16_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n17_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n18_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n19_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n20_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n21_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n22_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n23_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n24_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n25_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n26_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n27_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n28_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n29_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n30_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n31_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n32_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n33_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n34_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n35_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n36_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n37_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n38_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n39_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n40_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n41_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n42_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n43_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n44_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n45_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n46_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n47_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n48_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n49_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n50_a35(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n4_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n5_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n6_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n7_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n8_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n9_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n10_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n11_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n12_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n13_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n14_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n15_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n16_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n17_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n18_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n19_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n20_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n21_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n22_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n23_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n24_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n25_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n26_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n27_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n28_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n29_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n30_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n31_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n32_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n33_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n34_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n35_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n36_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n37_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n38_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n39_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n40_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n41_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n42_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n43_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n44_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n45_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n46_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n47_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n48_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n49_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n50_a35_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA53.h b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA53.h new file mode 100644 index 0000000..1b2575f --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA53.h @@ -0,0 +1,488 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_skinny1_arowmajor_n4_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n5_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n6_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n7_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n8_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n9_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n10_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n11_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n12_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n13_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n14_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n15_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n16_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n17_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n18_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n19_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n20_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n21_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n22_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n23_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n24_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n25_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n26_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n27_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n28_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n29_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n30_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n31_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n32_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n33_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n34_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n35_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n36_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n37_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n38_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n39_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n40_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n41_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n42_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n43_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n44_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n45_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n46_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n47_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n48_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n49_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n50_a53(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n4_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n5_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n6_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n7_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n8_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n9_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n10_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n11_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n12_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n13_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n14_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n15_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n16_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n17_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n18_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n19_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n20_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n21_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n22_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n23_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n24_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n25_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n26_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n27_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n28_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n29_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n30_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n31_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n32_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n33_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n34_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n35_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n36_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n37_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n38_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n39_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n40_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n41_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n42_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n43_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n44_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n45_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n46_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n47_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n48_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n49_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n50_a53_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA7x.h b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA7x.h new file mode 100644 index 0000000..b9b61f3 --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA7x.h @@ -0,0 +1,488 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +void sgemm_skinny1_arowmajor_n4_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n5_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n6_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n7_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n8_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n9_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n10_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n11_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n12_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n13_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n14_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n15_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n16_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n17_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n18_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n19_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n20_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n21_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n22_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n23_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n24_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n25_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n26_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n27_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n28_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n29_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n30_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n31_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n32_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n33_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n34_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n35_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n36_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n37_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n38_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n39_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n40_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n41_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n42_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n43_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n44_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n45_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n46_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n47_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n48_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n49_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n50_a7x(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp); + +void sgemm_skinny1_arowmajor_n4_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n5_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n6_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n7_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n8_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n9_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n10_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n11_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n12_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n13_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n14_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n15_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n16_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n17_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n18_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n19_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n20_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n21_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n22_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n23_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n24_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n25_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n26_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n27_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n28_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n29_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n30_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n31_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n32_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n33_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n34_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n35_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n36_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n37_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n38_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n39_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n40_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n41_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n42_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n43_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n44_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n45_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n46_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n47_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n48_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n49_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); + +void sgemm_skinny1_arowmajor_n50_a7x_omp(const float * __restrict__ A, + const float * __restrict__ B, float * __restrict__ C, + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC, + uint8_t b_c_order, float beta_inp, uint32_t num_threads); diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.h b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.h new file mode 100644 index 0000000..55778b1 --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.h @@ -0,0 +1,108 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include + +/* In the calculation, the content of skinny matrix B will be read for multiple + * times. We rearrange its elements to make the reading sequential and + * contiguous. The process of such rearrangement is called packing. */ + +/* There are 5 packing types used for skinny matrix B */ +/* type_0: row-major contiguous pattern */ +/* type_1: partitioned in 4-row chunks, row-major bulk + col-major edge */ +/* type_2: partitioned in 2-row chunks, row-major bulk + col-major edge */ +/* type_3: partitioned in 2-row chunks, col-major in each chunk */ +/* type_4: partitioned in 2-row chunks, x-type interleave like shoelaces */ + +/* The selection of paking type depends on CPU architecture and problem size */ +/* cortex-a35: type_3 when N < 10, type_0 for even N, type_4 for odd N */ +/* cortex-a53: type_1 when N < 15, type_2 when 14 < N < 23, type_0 for big N */ +/* cortex-a55: the same to cortex-a53 */ +/* cortex-a76 & cortex-a72: always type_1 */ + +/* Example 1 */ +/* source matrix B (4x5): + * a b c d e + * f g h i j + * k l m n o + * p q r s t */ +/* pack results to b_scr[] */ +/* type_0 pack: abcdefghijklmnopqrst */ +/* type_1 pack: abcdfghiklmnpqrsejot */ +/* type_2 pack: abcdfghiejklmnpqrsot */ +/* type_3 pack: afbgchdiejkplqmrnsot */ +/* type_4 pack: agciejfbhdkqmsotplrn */ + +/* Example 2 */ +/* source matrix B (6x6): + * 11-12-13-14-15-16 + * 21-22-23-24-25-26 + * 31-32-33-34-35-36 + * 41-42-43-44-45-46 + * 51-52-53-54-55-56 + * 61-62-63-64-65-66 */ +/* type_0 pack: 11-12-13-14-15-16-21-22-23-24-25-26-31-32-33-34- + * 35-36-41-42-43-44-45-46-51-52-53-54-55-56-61-62-63-64-65-66 */ +/* type_1 pack: 11-12-13-14-21-22-23-24-31-32-33-34-41-42-43-44- + * 15-25-35-45-16-26-36-46-51-52-53-54-55-56-61-62-63-64-65-66 */ +/* type_2 pack: 11-12-13-14-21-22-23-24-15-25-16-26-31-32-33-34- + * 41-42-43-44-35-45-36-46-51-52-53-54-61-62-63-64-55-65-56-66 */ +/* type_3 pack: 11-21-12-22-13-23-14-24-15-25-16-26-31-41-32-42- + * 33-43-34-44-35-45-36-46-51-61-52-62-53-63-54-64-55-65-56-66 */ +/* type-4 pack: 11-22-13-24-15-26-21-12-23-14-25-16-31-42-33-44- + * 35-46-41-32-43-34-45-36-51-62-53-64-55-66-61-52-63-54-65-56 */ + +/* type_0 pack from col-major B */ +void pack_0_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_1 pack from col-major B */ +void pack_1_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_2 pack from col-major B */ +void pack_2_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_3 pack from col-major B */ +void pack_3_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_4 pack from col-major B */ +void pack_4_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_0 pack from row-major B */ +void pack_0_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_1 pack from row-major B */ +void pack_1_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_2 pack from row-major B */ +void pack_2_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_3 pack from row-major B */ +void pack_3_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + +/* type_4 pack from row-major B */ +void pack_4_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N); + diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotDriver.h b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotDriver.h new file mode 100644 index 0000000..183ee01 --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotDriver.h @@ -0,0 +1,485 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonSched.h" +#ifndef EMLL_SERIAL_ONLY +#include +#endif + +#ifndef INCLUDE_SKINNY1_DRIVER +#define INCLUDE_SKINNY1_DRIVER + +#define DRIVER_PURE_PACK_SERIAL(cpu, ndim, K_BATCH, pack_type, unroll_m) \ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu(const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp) {\ +\ + const uint8_t b_rowmajor = b_c_order & 1;\ + const uint8_t c_rowmajor = b_c_order & 2;\ +\ + __attribute__((aligned(4096))) float b_scr[ndim * K_BATCH];\ +\ + uint32_t k_pos, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= K_BATCH * 2) k_inc = K_BATCH;\ + else if (k_inc > K_BATCH) k_inc >>= 1;\ + if (b_rowmajor == 0) {\ + pack_##pack_type##_from_cm(b_scr, B + k_pos, LDB, k_inc, ndim);\ + } else {\ + pack_##pack_type##_from_rm(b_scr, B + k_pos * LDB, LDB, k_inc, ndim);\ + }\ + uint32_t m_pos = M;\ + const float *a_ptr = A + k_pos;\ + float *c_ptr = C;\ + const uint32_t c_incr = (c_rowmajor == 0) ? 1 : LDC;\ + const float beta = (k_pos == 0) ? beta_inp : 1.0f;\ + for (; m_pos >= unroll_m; m_pos -= unroll_m) {\ + sgemm_skinny1_##cpu##_m##unroll_m##n##ndim(a_ptr, b_scr, c_ptr,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + a_ptr += LDA * unroll_m;\ + c_ptr += c_incr * unroll_m;\ + }\ + for (; m_pos > 0; m_pos--) {\ + sgemm_skinny1_##cpu##_m1n##ndim(a_ptr, b_scr, c_ptr, k_inc, LDC,\ + c_rowmajor, beta);\ + a_ptr += LDA;\ + c_ptr += c_incr;\ + }\ + }\ +} + +#define DRIVER_PURE_PACK_OMP(cpu, ndim, K_BATCH, pack_type, unroll_m) \ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu##_omp(\ + const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp, uint32_t num_threads) {\ +\ + if (num_threads <= 1) {\ + sgemm_skinny1_arowmajor_n##ndim##_##cpu(A, B, C, M, K,\ + LDA, LDB, LDC, b_c_order, beta_inp);\ + return;\ + }\ + omp_set_num_threads(num_threads);\ +\ + const uint8_t b_rowmajor = b_c_order & 1;\ + const uint8_t c_rowmajor = b_c_order & 2;\ + const uint32_t c_m_inc = (c_rowmajor == 0) ? 1 : LDC;\ +\ + __attribute__((aligned(4096))) float b_scr[ndim * K_BATCH];\ +\ + uint32_t k_pos, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= K_BATCH * 2) k_inc = K_BATCH;\ + else if (k_inc > K_BATCH) k_inc >>= 1;\ + const float beta = (k_pos == 0) ? beta_inp : 1.0f;\ +\ + uint32_t k_copy_left = k_inc;\ + uint32_t m_calc_done = 0;\ + _Pragma("omp parallel")\ + {\ + uint32_t k_copy_start, k_copy_end;\ + while(get_copy_task(&k_copy_left, 64, &k_copy_start, &k_copy_end)) {\ + if (b_rowmajor == 0) {\ + pack_##pack_type##_from_cm(b_scr + k_copy_start * ndim,\ + B + k_pos + k_copy_start, LDB,\ + k_copy_end - k_copy_start, ndim);\ + } else {\ + pack_##pack_type##_from_rm(b_scr + k_copy_start * ndim,\ + B + (k_pos + k_copy_start) * LDB, LDB,\ + k_copy_end - k_copy_start, ndim);\ + }\ + }\ + _Pragma("omp barrier")\ + uint32_t m_calc_start, m_calc_end;\ + while(get_irreg_task(&m_calc_done, &m_calc_start, &m_calc_end,\ + unroll_m << 2, M)) {\ + const float *a_ptr = A + m_calc_start * LDA + k_pos;\ + float *c_ptr = C + m_calc_start * c_m_inc;\ + uint32_t sub_m_left = m_calc_end - m_calc_start;\ + for (; sub_m_left >= unroll_m; sub_m_left -= unroll_m) {\ + sgemm_skinny1_##cpu##_m##unroll_m##n##ndim(a_ptr, b_scr, c_ptr,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + a_ptr += LDA * unroll_m;\ + c_ptr += c_m_inc * unroll_m;\ + }\ + for (; sub_m_left > 0; sub_m_left--) {\ + sgemm_skinny1_##cpu##_m1n##ndim(a_ptr, b_scr, c_ptr, k_inc, LDC,\ + c_rowmajor, beta);\ + a_ptr += LDA;\ + c_ptr += c_m_inc;\ + }\ + }\ + }\ + }\ +} + +#define DRIVER_MIX2_PACK_SERIAL(cpu, ndim, K_BATCH, pack1, pack2, n_pack1, n_pack2, unroll_m) \ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu(const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp) {\ +\ + const uint8_t b_rowmajor = b_c_order & 1;\ + const uint8_t c_rowmajor = b_c_order & 2;\ +\ + __attribute__((aligned(4096))) float b_scr[ndim * K_BATCH];\ + float * const b_scr2 = b_scr + n_pack1 * K_BATCH;\ +\ + uint32_t k_pos, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= K_BATCH * 2) k_inc = K_BATCH;\ + else if (k_inc > K_BATCH) k_inc >>= 1;\ + if (b_rowmajor == 0) {\ + pack_##pack1##_from_cm(b_scr, B + k_pos, LDB, k_inc, n_pack1);\ + pack_##pack2##_from_cm(b_scr2, B + k_pos + n_pack1 * LDB,\ + LDB, k_inc, n_pack2);\ + } else {\ + pack_##pack1##_from_rm(b_scr, B + k_pos * LDB, LDB, k_inc, n_pack1);\ + pack_##pack2##_from_rm(b_scr2, B + k_pos * LDB + n_pack1,\ + LDB, k_inc, n_pack2);\ + }\ + uint32_t m_pos = M;\ + const float *a_ptr = A + k_pos;\ + float *c_ptr1 = C;\ + float *c_ptr2 = (c_rowmajor == 0) ? C + n_pack1 * LDC : C + n_pack1;\ + const uint32_t c_incr = (c_rowmajor == 0) ? 1 : LDC;\ + const float beta = (k_pos == 0) ? beta_inp : 1.0f;\ + for (; m_pos >= unroll_m; m_pos -= unroll_m) {\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack1(a_ptr, b_scr, c_ptr1,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack2(a_ptr, b_scr2, c_ptr2,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + a_ptr += LDA * unroll_m;\ + c_ptr1 += c_incr * unroll_m;\ + c_ptr2 += c_incr * unroll_m;\ + }\ + for (; m_pos > 0; m_pos--) {\ + sgemm_skinny1_##cpu##_m1n##n_pack1(a_ptr, b_scr, c_ptr1, k_inc, LDC,\ + c_rowmajor, beta);\ + sgemm_skinny1_##cpu##_m1n##n_pack2(a_ptr, b_scr2, c_ptr2, k_inc, LDC,\ + c_rowmajor, beta);\ + a_ptr += LDA;\ + c_ptr1 += c_incr;\ + c_ptr2 += c_incr;\ + }\ + }\ +} + +#define DRIVER_MIX2_PACK_OMP(cpu, ndim, K_BATCH, pack1, pack2, n_pack1, n_pack2, unroll_m) \ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu##_omp(\ + const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp, uint32_t num_threads) {\ +\ + if (num_threads <= 1) {\ + sgemm_skinny1_arowmajor_n##ndim##_##cpu(A, B, C, M, K,\ + LDA, LDB, LDC, b_c_order, beta_inp);\ + return;\ + }\ +\ + const uint8_t b_rowmajor = b_c_order & 1;\ + const uint8_t c_rowmajor = b_c_order & 2;\ + const uint32_t c_m_inc = (c_rowmajor == 0) ? 1 : LDC;\ +\ + __attribute__((aligned(4096))) float b_scr[ndim * K_BATCH];\ + float * const b_scr2 = b_scr + n_pack1 * K_BATCH;\ +\ + uint32_t k_pos, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= K_BATCH * 2) k_inc = K_BATCH;\ + else if (k_inc > K_BATCH) k_inc >>= 1;\ + const float beta = (k_pos == 0) ? beta_inp : 1.0f;\ +\ + uint32_t k_copy_left = k_inc;\ + uint32_t m_calc_done = 0;\ + _Pragma("omp parallel")\ + {\ + uint32_t k_copy_start, k_copy_end;\ + while(get_copy_task(&k_copy_left, 64, &k_copy_start, &k_copy_end)) {\ + if (b_rowmajor == 0) {\ + pack_##pack1##_from_cm(b_scr + k_copy_start * n_pack1,\ + B + (k_pos + k_copy_start), LDB,\ + k_copy_end - k_copy_start, n_pack1);\ + pack_##pack2##_from_cm(b_scr2 + k_copy_start * n_pack2,\ + B + (k_pos + k_copy_start) + n_pack1 * LDB, LDB,\ + k_copy_end - k_copy_start, n_pack2);\ + } else {\ + pack_##pack1##_from_rm(b_scr + k_copy_start * n_pack1,\ + B + (k_pos + k_copy_start) * LDB, LDB,\ + k_copy_end - k_copy_start, n_pack1);\ + pack_##pack2##_from_rm(b_scr2 + k_copy_start * n_pack2,\ + B + (k_pos + k_copy_start) * LDB + n_pack1, LDB,\ + k_copy_end - k_copy_start, n_pack2);\ + }\ + }\ + _Pragma("omp barrier")\ + uint32_t m_calc_start, m_calc_end;\ + while(get_irreg_task(&m_calc_done, &m_calc_start, &m_calc_end,\ + unroll_m << 2, M)) {\ + const float *a_ptr = A + m_calc_start * LDA + k_pos;\ + float *c_ptr1 = C + m_calc_start * c_m_inc;\ + float *c_ptr2 = (c_rowmajor == 0) ?\ + c_ptr1 + n_pack1 * LDC : c_ptr1 + n_pack1;\ + uint32_t sub_m_left = m_calc_end - m_calc_start;\ + for (; sub_m_left >= unroll_m; sub_m_left -= unroll_m) {\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack1(a_ptr, b_scr, c_ptr1,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack2(a_ptr, b_scr2, c_ptr2,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + a_ptr += LDA * unroll_m;\ + c_ptr1 += c_m_inc * unroll_m;\ + c_ptr2 += c_m_inc * unroll_m;\ + }\ + for (; sub_m_left > 0; sub_m_left--) {\ + sgemm_skinny1_##cpu##_m1n##n_pack1(a_ptr, b_scr, c_ptr1, k_inc, LDC,\ + c_rowmajor, beta);\ + sgemm_skinny1_##cpu##_m1n##n_pack2(a_ptr, b_scr2, c_ptr2, k_inc, LDC,\ + c_rowmajor, beta);\ + a_ptr += LDA;\ + c_ptr1 += c_m_inc;\ + c_ptr2 += c_m_inc;\ + }\ + }\ + }\ + }\ +} + +#define DRIVER_MIX3_PACK_SERIAL(cpu, ndim, K_BATCH, pack1, pack2, pack3, n_pack1, n_pack2, n_pack3, unroll_m) \ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu(const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp) {\ +\ + const uint8_t b_rowmajor = b_c_order & 1;\ + const uint8_t c_rowmajor = b_c_order & 2;\ +\ + __attribute__((aligned(4096))) float b_scr[ndim * K_BATCH];\ + float * const b_scr2 = b_scr + n_pack1 * K_BATCH;\ + float * const b_scr3 = b_scr2 + n_pack2 * K_BATCH;\ +\ + uint32_t k_pos, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= K_BATCH * 2) k_inc = K_BATCH;\ + else if (k_inc > K_BATCH) k_inc >>= 1;\ + if (b_rowmajor == 0) {\ + pack_##pack1##_from_cm(b_scr, B + k_pos, LDB, k_inc, n_pack1);\ + pack_##pack2##_from_cm(b_scr2, B + k_pos + n_pack1 * LDB,\ + LDB, k_inc, n_pack2);\ + pack_##pack3##_from_cm(b_scr3, B + k_pos + (n_pack1 + n_pack2) * LDB,\ + LDB, k_inc, n_pack3);\ + } else {\ + pack_##pack1##_from_rm(b_scr, B + k_pos * LDB, LDB, k_inc, n_pack1);\ + pack_##pack2##_from_rm(b_scr2, B + k_pos * LDB + n_pack1,\ + LDB, k_inc, n_pack2);\ + pack_##pack3##_from_rm(b_scr3, B + k_pos * LDB + n_pack1 + n_pack2,\ + LDB, k_inc, n_pack3);\ + }\ + uint32_t m_pos = M;\ + const float *a_ptr = A + k_pos;\ + float *c_ptr1 = C;\ + float *c_ptr2 = (c_rowmajor == 0) ? C + n_pack1 * LDC : C + n_pack1;\ + float *c_ptr3 = (c_rowmajor == 0) ? C + (n_pack1 + n_pack2) * LDC :\ + C + n_pack1 + n_pack2;\ + const uint32_t c_incr = (c_rowmajor == 0) ? 1 : LDC;\ + const float beta = (k_pos == 0) ? beta_inp : 1.0f;\ + for (; m_pos >= unroll_m; m_pos -= unroll_m) {\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack1(a_ptr, b_scr, c_ptr1,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack2(a_ptr, b_scr2, c_ptr2,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack3(a_ptr, b_scr3, c_ptr3,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + a_ptr += LDA * unroll_m;\ + c_ptr1 += c_incr * unroll_m;\ + c_ptr2 += c_incr * unroll_m;\ + c_ptr3 += c_incr * unroll_m;\ + }\ + for (; m_pos > 0; m_pos--) {\ + sgemm_skinny1_##cpu##_m1n##n_pack1(a_ptr, b_scr, c_ptr1, k_inc, LDC,\ + c_rowmajor, beta);\ + sgemm_skinny1_##cpu##_m1n##n_pack2(a_ptr, b_scr2, c_ptr2, k_inc, LDC,\ + c_rowmajor, beta);\ + sgemm_skinny1_##cpu##_m1n##n_pack3(a_ptr, b_scr3, c_ptr3, k_inc, LDC,\ + c_rowmajor, beta);\ + a_ptr += LDA;\ + c_ptr1 += c_incr;\ + c_ptr2 += c_incr;\ + c_ptr3 += c_incr;\ + }\ + }\ +} + +#define DRIVER_MIX3_PACK_OMP(cpu, ndim, K_BATCH, pack1, pack2, pack3, n_pack1, n_pack2, n_pack3, unroll_m) \ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu##_omp(\ + const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp, uint32_t num_threads) {\ +\ + if (num_threads <= 1) {\ + sgemm_skinny1_arowmajor_n##ndim##_##cpu(A, B, C, M, K,\ + LDA, LDB, LDC, b_c_order, beta_inp);\ + return;\ + }\ +\ + const uint8_t b_rowmajor = b_c_order & 1;\ + const uint8_t c_rowmajor = b_c_order & 2;\ + const uint32_t c_m_inc = (c_rowmajor == 0) ? 1 : LDC;\ +\ + __attribute__((aligned(4096))) float b_scr[ndim * K_BATCH];\ + float * const b_scr2 = b_scr + n_pack1 * K_BATCH;\ + float * const b_scr3 = b_scr2 + n_pack2 * K_BATCH;\ +\ + uint32_t k_pos, k_inc;\ + for (k_pos = 0; k_pos < K; k_pos += k_inc) {\ + k_inc = K - k_pos;\ + if (k_inc >= K_BATCH * 2) k_inc = K_BATCH;\ + else if (k_inc > K_BATCH) k_inc >>= 1;\ + const float beta = (k_pos == 0) ? beta_inp : 1.0f;\ +\ + uint32_t k_copy_left = k_inc;\ + uint32_t m_calc_done = 0;\ + _Pragma("omp parallel")\ + {\ + uint32_t k_copy_start, k_copy_end;\ + while(get_copy_task(&k_copy_left, 64, &k_copy_start, &k_copy_end)) {\ + if (b_rowmajor == 0) {\ + pack_##pack1##_from_cm(b_scr + k_copy_start * n_pack1,\ + B + (k_pos + k_copy_start), LDB,\ + k_copy_end - k_copy_start, n_pack1);\ + pack_##pack2##_from_cm(b_scr2 + k_copy_start * n_pack2,\ + B + (k_pos + k_copy_start) + n_pack1 * LDB, LDB,\ + k_copy_end - k_copy_start, n_pack2);\ + pack_##pack3##_from_cm(b_scr3 + k_copy_start * n_pack3,\ + B + (k_pos + k_copy_start) + (n_pack1 + n_pack2) * LDB, LDB,\ + k_copy_end - k_copy_start, n_pack3);\ + } else {\ + pack_##pack1##_from_rm(b_scr + k_copy_start * n_pack1,\ + B + (k_pos + k_copy_start) * LDB, LDB,\ + k_copy_end - k_copy_start, n_pack1);\ + pack_##pack2##_from_rm(b_scr2 + k_copy_start * n_pack2,\ + B + (k_pos + k_copy_start) * LDB + n_pack1, LDB,\ + k_copy_end - k_copy_start, n_pack2);\ + pack_##pack3##_from_rm(b_scr3 + k_copy_start * n_pack3,\ + B + (k_pos + k_copy_start) * LDB + n_pack1 + n_pack2, LDB,\ + k_copy_end - k_copy_start, n_pack3);\ + }\ + }\ + _Pragma("omp barrier")\ + uint32_t m_calc_start, m_calc_end;\ + while(get_irreg_task(&m_calc_done, &m_calc_start, &m_calc_end,\ + unroll_m << 2, M)) {\ + const float *a_ptr = A + m_calc_start * LDA + k_pos;\ + float *c_ptr1 = C + m_calc_start * c_m_inc;\ + float *c_ptr2 = (c_rowmajor == 0) ?\ + c_ptr1 + n_pack1 * LDC : c_ptr1 + n_pack1;\ + float *c_ptr3 = (c_rowmajor == 0) ?\ + c_ptr1 + (n_pack1 + n_pack2) * LDC : c_ptr1 + n_pack1 + n_pack2;\ + uint32_t sub_m_left = m_calc_end - m_calc_start;\ + for (; sub_m_left >= unroll_m; sub_m_left -= unroll_m) {\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack1(a_ptr, b_scr, c_ptr1,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack2(a_ptr, b_scr2, c_ptr2,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + sgemm_skinny1_##cpu##_m##unroll_m##n##n_pack3(a_ptr, b_scr3, c_ptr3,\ + k_inc, LDA, LDC, c_rowmajor, &beta);\ + a_ptr += LDA * unroll_m;\ + c_ptr1 += c_m_inc * unroll_m;\ + c_ptr2 += c_m_inc * unroll_m;\ + c_ptr3 += c_m_inc * unroll_m;\ + }\ + for (; sub_m_left > 0; sub_m_left--) {\ + sgemm_skinny1_##cpu##_m1n##n_pack1(a_ptr, b_scr, c_ptr1, k_inc, LDC,\ + c_rowmajor, beta);\ + sgemm_skinny1_##cpu##_m1n##n_pack2(a_ptr, b_scr2, c_ptr2, k_inc, LDC,\ + c_rowmajor, beta);\ + sgemm_skinny1_##cpu##_m1n##n_pack3(a_ptr, b_scr3, c_ptr3, k_inc, LDC,\ + c_rowmajor, beta);\ + a_ptr += LDA;\ + c_ptr1 += c_m_inc;\ + c_ptr2 += c_m_inc;\ + c_ptr3 += c_m_inc;\ + }\ + }\ + }\ + }\ +} + +#ifdef EMLL_SERIAL_ONLY + +#define DRIVER_PURE_PACK(cpu, ndim, K_BATCH, pack_type, unroll_m) \ + DRIVER_PURE_PACK_SERIAL(cpu, ndim, K_BATCH, pack_type, unroll_m)\ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu##_omp(\ + const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp, uint32_t num_threads) {\ +\ + sgemm_skinny1_arowmajor_n##ndim##_##cpu(A, B, C, M, K,\ + LDA, LDB, LDC, b_c_order, beta_inp);\ +} + +#define DRIVER_MIX2_PACK(cpu, ndim, K_BATCH, pack1, pack2, n_pack1, n_pack2, unroll_m) \ + DRIVER_MIX2_PACK_SERIAL(cpu, ndim, K_BATCH, pack1, pack2, n_pack1, n_pack2, unroll_m)\ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu##_omp(\ + const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp, uint32_t num_threads) {\ +\ + sgemm_skinny1_arowmajor_n##ndim##_##cpu(A, B, C, M, K,\ + LDA, LDB, LDC, b_c_order, beta_inp);\ +} + +#define DRIVER_MIX3_PACK(cpu, ndim, K_BATCH, pack1, pack2, pack3, n_pack1, n_pack2, n_pack3, unroll_m) \ + DRIVER_MIX3_PACK_SERIAL(cpu, ndim, K_BATCH, pack1, pack2, pack3, n_pack1, n_pack2, n_pack3, unroll_m)\ +void sgemm_skinny1_arowmajor_n##ndim##_##cpu##_omp(\ + const float * __restrict__ A,\ + const float * __restrict__ B, float * __restrict__ C,\ + uint32_t M, uint32_t K, uint32_t LDA, uint32_t LDB, uint32_t LDC,\ + uint8_t b_c_order, float beta_inp, uint32_t num_threads) {\ +\ + sgemm_skinny1_arowmajor_n##ndim##_##cpu(A, B, C, M, K,\ + LDA, LDB, LDC, b_c_order, beta_inp);\ +} + +#else + +#define DRIVER_PURE_PACK(cpu, ndim, K_BATCH, pack_type, unroll_m) \ + DRIVER_PURE_PACK_SERIAL(cpu, ndim, K_BATCH, pack_type, unroll_m)\ + DRIVER_PURE_PACK_OMP(cpu, ndim, K_BATCH, pack_type, unroll_m) + +#define DRIVER_MIX2_PACK(cpu, ndim, K_BATCH, pack1, pack2, n_pack1, n_pack2, unroll_m) \ + DRIVER_MIX2_PACK_SERIAL(cpu, ndim, K_BATCH, pack1, pack2, n_pack1, n_pack2, unroll_m)\ + DRIVER_MIX2_PACK_OMP(cpu, ndim, K_BATCH, pack1, pack2, n_pack1, n_pack2, unroll_m) + +#define DRIVER_MIX3_PACK(cpu, ndim, K_BATCH, pack1, pack2, pack3, n_pack1, n_pack2, n_pack3, unroll_m) \ + DRIVER_MIX3_PACK_SERIAL(cpu, ndim, K_BATCH, pack1, pack2, pack3, n_pack1, n_pack2, n_pack3, unroll_m)\ + DRIVER_MIX3_PACK_OMP(cpu, ndim, K_BATCH, pack1, pack2, pack3, n_pack1, n_pack2, n_pack3, unroll_m) + +#endif +#endif + diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA35.h b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA35.h new file mode 100644 index 0000000..5cb2de5 --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA35.h @@ -0,0 +1,2439 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +#ifndef INCLUDE_A35_KERNEL +#define INCLUDE_A35_KERNEL + +/* for cortex-a35, fp32 NEON operation on q regs are not recommended, + * using d regs without broadcast is better */ + +/* for cortex-a35 fp32 fma instruction sequence, + * it's recommended to put 3 nearest fma inst together */ +#define FMA_3V(c1, c2, c3, a1, a2, a3, b1, b2, b3) \ + "fmla v"#c1".2s,v"#a1".2s,v"#b1".2s\n\t"\ + "fmla v"#c2".2s,v"#a2".2s,v"#b2".2s\n\t"\ + "fmla v"#c3".2s,v"#a3".2s,v"#b3".2s\n\t" + +#define INIT_3V(c1, c2, c3) \ + "movi v"#c1".8b,#0; movi v"#c2".8b,#0; movi v"#c3".8b,#0\n\t" + +#define INIT_4V(c1, c2, c3, c4) INIT_3V(c1, c2, c3)\ + "movi v"#c4".8b,#0\n\t" + +/* x12 - x15 for c_tmp pointers */ +/* v0 always for beta at storage status */ + +#define INIT_SAVE_M3_CR \ + "ld1r {v0.2s},[%[beta_addr]]\n\t"\ + "mov x12,%[c_ptr]; add x13,%[c_ptr],%w[LDC],UXTW #2\n\t"\ + "add x14,%[c_ptr],%w[LDC],UXTW #3\n\t" + +#define INIT_SAVE_M4_CR INIT_SAVE_M3_CR \ + "add x15,x13,%w[LDC],UXTW #3\n\t" + +#define INIT_SAVE_CC \ + "ld1r {v0.2s},[%[beta_addr]]; mov x12,%[c_ptr]\n\t"\ + "add x13,%[c_ptr],%w[LDC],UXTW #2\n\t" + +/* c1[0], c1[1] */ +/* c2[0], c2[1] */ +/* c3[0], c3[1] */ +/* c4[0], c4[1] */ +/* clobber: x12 - x13, v0 - v4 */ +#define UNIT_SAVE_M4N2_CC(c1, c2, c3, c4) \ + "ldr d1,[x12]; ldr d2,[x12,#8]\n\t"\ + "trn1 v3.2s,v"#c1".2s,v"#c2".2s; trn1 v4.2s,v"#c3".2s,v"#c4".2s\n\t"\ + "trn2 v"#c2".2s,v"#c1".2s,v"#c2".2s; trn2 v"#c4".2s,v"#c3".2s,v"#c4".2s\n\t"\ + "ldr d"#c1",[x13]; ldr d"#c3",[x13,#8]\n\t"\ + "fmla v3.2s,v1.2s,v0.2s\n\t"\ + "fmla v4.2s,v2.2s,v0.2s\n\t"\ + "fmla v"#c2".2s,v"#c1".2s,v0.2s\n\t"\ + "fmla v"#c4".2s,v"#c3".2s,v0.2s\n\t"\ + "str d3,[x12]; str d4,[x12,#8]\n\t"\ + "prfm pstl2keep,[x12,#32]; add x12,x12,%w[LDC],UXTW #3\n\t"\ + "str d"#c2",[x13]; str d"#c4",[x13,#8]\n\t"\ + "prfm pstl2keep,[x13,#32]; add x13,x13,%w[LDC],UXTW #3\n\t" + +/* clobber: x12 - x15, v0 - v4 */ +#define UNIT_SAVE_M4N2_CR(c1, c2, c3, c4) \ + "ldr d1,[x12]; ldr d2,[x13]\n\t"\ + "ldr d3,[x14]; ldr d4,[x15]\n\t"\ + "fmla v"#c1".2s,v1.2s,v0.2s\n\t"\ + "fmla v"#c2".2s,v2.2s,v0.2s\n\t"\ + "fmla v"#c3".2s,v3.2s,v0.2s\n\t"\ + "fmla v"#c4".2s,v4.2s,v0.2s\n\t"\ + "str d"#c1",[x12],#8; str d"#c2",[x13],#8\n\t"\ + "str d"#c3",[x14],#8; str d"#c4",[x15],#8\n\t" + +/* c1[0], c1[1] */ +/* c2[0], c2[1] */ +/* c3[0], c3[1] */ +/* clobber: x12 - x13, v0 - v3 */ +#define UNIT_SAVE_M3N2_CC(c1, c2, c3) \ + "ldr d1,[x12]\n\t"\ + "trn1 v2.2s,v"#c1".2s,v"#c2".2s\n\t"\ + "trn2 v"#c2".2s,v"#c1".2s,v"#c2".2s\n\t"\ + "ldr d"#c1",[x13]; ldr s3,[x12,#8]\n\t"\ + "fmla v2.2s,v1.2s,v0.2s\n\t"\ + "fmla v"#c2".2s,v"#c1".2s,v0.2s\n\t"\ + "ldr s1,[x13,#8]\n\t"\ + "str d2,[x12]; ins v2.s[0],v"#c3".s[1]\n\t"\ + "str d"#c2",[x13]\n\t"\ + "fmla s"#c3",s3,v0.s[0]; fmla s2,s1,v0.s[0]\n\t"\ + "str s"#c3",[x12,#8]; prfm pstl2keep,[x12,#24]\n\t"\ + "add x12,x12,%w[LDC],UXTW #3\n\t"\ + "str s2,[x13,#8]; prfm pstl2keep,[x13,#24]\n\t"\ + "add x13,x13,%w[LDC],UXTW #3\n\t" + +/* clobber: x12 - x14, v0 - v3 */ +#define UNIT_SAVE_M3N2_CR(c1, c2, c3) \ + "ldr d1,[x12]; ldr d2,[x13]\n\t"\ + "ldr d3,[x14]\n\t"\ + "fmla v"#c1".2s,v1.2s,v0.2s\n\t"\ + "fmla v"#c2".2s,v2.2s,v0.2s\n\t"\ + "fmla v"#c3".2s,v3.2s,v0.2s\n\t"\ + "str d"#c1",[x12],#8; str d"#c2",[x13],#8\n\t"\ + "str d"#c3",[x14],#8\n\t" + +/* c1[0] + c1[1] */ +/* c2[0] + c2[1] */ +/* c3[0] + c3[1] */ +/* c4[0] + c4[1] */ +/* clobber: x12, v0 - v4 */ +#define UNIT_SAVE_M4N1_CC(c1, c2, c3, c4) \ + "ldr d3,[x12]; ldr d4,[x12,#8]\n\t"\ + "faddp v1.2s,v"#c1".2s,v"#c2".2s\n\t"\ + "faddp v2.2s,v"#c3".2s,v"#c4".2s\n\t"\ + "fmla v1.2s,v3.2s,v0.2s\n\t"\ + "fmla v2.2s,v4.2s,v0.2s\n\t"\ + "str d1,[x12]; str d2,[x12,#8]\n\t"\ + "prfm pstl2keep,[x12,#32]\n\t"\ + "add x12,x12,%w[LDC],UXTW #2\n\t" + +/* clobber: x12 - x15, v0 - v4 */ +#define UNIT_SAVE_M4N1_CR(c1, c2, c3, c4) \ + "ldr s1,[x12]; ldr s2,[x13]\n\t"\ + "faddp v"#c1".2s,v"#c1".2s,v"#c3".2s\n\t"\ + "ld1 {v1.s}[1],[x14]; ld1 {v2.s}[1],[x15]\n\t"\ + "faddp v"#c2".2s,v"#c2".2s,v"#c4".2s\n\t"\ + "fmla v"#c1".2s,v1.2s,v0.2s\n\t"\ + "fmla v"#c2".2s,v2.2s,v0.2s\n\t"\ + "str s"#c1",[x12],#4; str s"#c2",[x13],#4\n\t"\ + "st1 {v"#c1".s}[1],[x14],#4; st1 {v"#c2".s}[1],[x15],#4\n\t" + +/* c1[0] + c1[1] */ +/* c2[0] + c2[1] */ +/* c3[0] + c3[1] */ +/* clobber: x12, v0 - v3 */ +#define UNIT_SAVE_M3N1_CC(c1, c2, c3) \ + "ldr d1,[x12]; ldr s2,[x12,#8]\n\t"\ + "faddp v"#c1".2s,v"#c1".2s,v"#c2".2s\n\t"\ + "faddp s"#c3",v"#c3".2s\n\t"\ + "fmla v"#c1".2s,v1.2s,v0.2s\n\t"\ + "fmla s"#c3",s2,v0.s[0]\n\t"\ + "str d"#c1",[x12]; str s"#c3",[x12,#8]\n\t"\ + "prfm pstl2keep,[x12,#24]\n\t"\ + "add x12,x12,%w[LDC],UXTW #2\n\t" + +/* clobber: x12 - x14, v0 - v3 */ +#define UNIT_SAVE_M3N1_CR(c1, c2, c3) \ + "ldr s1,[x12]; ldr s2,[x13]; ldr s3,[x14]\n\t"\ + "faddp s"#c1",v"#c1".2s; faddp s"#c2",v"#c2".2s; faddp s"#c3",v"#c3".2s\n\t"\ + "fmla s"#c1",s1,v0.s[0]\n\t"\ + "fmla s"#c2",s2,v0.s[0]\n\t"\ + "fmla s"#c3",s3,v0.s[0]\n\t"\ + "str s"#c1",[x12],#4\n\t"\ + "str s"#c2",[x13],#4\n\t"\ + "str s"#c3",[x14],#4\n\t" + +/* x0 = a_ptr1 (top) */ +/* x1 = a_ptr2 */ +/* x2 = a_ptr3 */ +/* x3 = a_ptr4 (or pref_head when M == 3) */ +/* x4 = b_ptr */ +/* w5 = k_left */ +/* x8 - x11 for pref head */ +/* x12 - x15 for c_tmp1 - c_tmp4 */ + +/* macro for GEMM with packing pattern NO.#3 */ +/* mdim = 3, 4; ndim = 5 - 8 */ +#define FUNC_PACK3(mdim, ndim) \ +static inline void sgemm_skinny1_a35_m##mdim##n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_scr,\ + float * __restrict__ c_ptr, uint32_t K, uint32_t LDA, uint32_t LDC,\ + uint8_t c_rowmajor, const float * __restrict__ beta_addr) {\ + __asm__ __volatile__(\ + "mov x4,%[b_scr]\n\t"\ + "mov x0,%[a_ptr]; add x1,%[a_ptr],%w[LDA],UXTW #2\n\t"\ + "add x2,%[a_ptr],%w[LDA],UXTW #3; add x3,x1,%w[LDA],UXTW #3\n\t"\ + "add x8,x0,%w[LDA],UXTW #4; add x9,x1,%w[LDA],UXTW #4\n\t"\ + "add x10,x2,%w[LDA],UXTW #4; add x11,x3,%w[LDA],UXTW #4\n\t"\ + "mov w5,%w[K]\n\t"\ + INIT_M##mdim##N##ndim\ + "cmp w5,#2; b.lt 4f\n\t"\ + KERNEL_M##mdim##N##ndim##_PRELOAD2\ + "cmp w5,#10; b.lt 7f\n\t"\ + ".balign 16; 8:\n\t"\ + KERNEL_M##mdim##N##ndim##_MAIN8 "b.ge 8b\n\t"\ + "7:\n\t"\ + "cmp w5,#6; b.lt 1f\n\t"\ + KERNEL_M##mdim##N##ndim##_MAIN4\ + "1:\n\t"\ + "cmp w5,#4; b.lt 2f\n\t"\ + KERNEL_M##mdim##N##ndim##_TAIL4 "b 4f\n\t"\ + "2:\n\t"\ + KERNEL_M##mdim##N##ndim##_TAIL2\ + "4:\n\t"\ + "cmp w5,#1; b.lt 5f\n\t"\ + KERNEL_M##mdim##N##ndim##_FIN1\ + "5:\n\t"\ + "cmp %w[c_rowmajor],#0; b.eq 6f\n\t"\ + INIT_SAVE_M##mdim##_CR SAVE_M##mdim##N##ndim(CR) "b 7f\n\t"\ + "6:\n\t"\ + INIT_SAVE_CC SAVE_M##mdim##N##ndim(CC)\ + "7:\n\t"\ + ::[a_ptr]"r"(a_ptr), [b_scr]"r"(b_scr), [c_ptr]"r"(c_ptr),\ + [LDA]"r"(LDA), [LDC]"r"(LDC), [K]"r"(K),\ + [beta_addr]"r"(beta_addr), [c_rowmajor]"r"(c_rowmajor)\ + :"cc","memory","x0","x1","x2","x3","x4","x5",\ + "x8","x9","x10","x11","x12","x13","x14","x15",\ + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11",\ + "v12","v13","v14","v15","v16","v17","v18","v19","v20","v21",\ + "v22","v23","v24","v25","v26","v27","v28","v29","v30","v31");\ +} + +/* acc layout for m4n4 kernel */ +/* m0n0 v16 v17 v18 v19 m0n4 */ +/* m1n0 v20 v21 v22 v23 m1n4 */ +/* m2n0 v24 v25 v26 v27 m2n4 */ +/* m3n0 v28 v29 v30 v31 m3n4 */ +/* b-holder layout for m4n4 kernel */ +/* n0 v4 v5 v6 v7 */ +/* a-holder layout for m4n4 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2, a_ptr4->v3 */ + +#define INIT_M4N4 \ + INIT_4V(16, 20, 24, 28) INIT_4V(17, 21, 25, 29)\ + INIT_4V(18, 22, 26, 30) INIT_4V(19, 23, 27, 31) + +#define SAVE_M4N4(mode) \ + UNIT_SAVE_M4N1_##mode(16, 20, 24, 28) UNIT_SAVE_M4N1_##mode(17, 21, 25, 29)\ + UNIT_SAVE_M4N1_##mode(18, 22, 26, 30) UNIT_SAVE_M4N1_##mode(19, 23, 27, 31) + +#define KERNEL_M4N4_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]; add x4,x4,#32\n\t" + +#define KERNEL_M4N4_MAIN8 \ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "ldr d4,[x4]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d6,[x4,#16]\n\t"\ + "prfm pldl1keep,[x0,#64]; fmla v31.2s,v3.2s,v7.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "ldr d4,[x4,#32]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#40]; ldr d6,[x4,#48]\n\t"\ + "prfm pldl1keep,[x1,#64]; fmla v31.2s,v3.2s,v7.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#56]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "ldr d4,[x4,#64]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#72]; ldr d6,[x4,#80]\n\t"\ + "prfm pldl1keep,[x2,#64]; fmla v31.2s,v3.2s,v7.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#88]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "ldr d4,[x4,#96]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "ldr d0,[x0],#8; sub w5,w5,#8\n\t"\ + "ldr d5,[x4,#104]; ldr d6,[x4,#112]; cmp w5,#10\n\t"\ + "prfm pldl1keep,[x3,#64]; fmla v31.2s,v3.2s,v7.2s; add x4,x4,#128\n\t" + +#define KERNEL_M4N4_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "ldr d4,[x4]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d6,[x4,#16]\n\t"\ + "sub w5,w5,#4; fmla v31.2s,v3.2s,v7.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "ldr d4,[x4,#32]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#40]; ldr d6,[x4,#48]\n\t"\ + "add x4,x4,#64; fmla v31.2s,v3.2s,v7.2s\n\t" + +#define KERNEL_M4N4_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "ldr d4,[x4]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d6,[x4,#16]\n\t"\ + "prfm pldl1keep,[x8]; sub w5,w5,#4\n\t"\ + "prfm pldl1keep,[x9]; fmla v31.2s,v3.2s,v7.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "prfm pldl1keep,[x10]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "add x4,x4,#32\n\t"\ + "prfm pldl1keep,[x11]; fmla v31.2s,v3.2s,v7.2s\n\t" + +#define KERNEL_M4N4_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "prfm pldl1keep,[x9]\n\t"\ + "prfm pldl1keep,[x10]; sub w5,w5,#2\n\t"\ + "prfm pldl1keep,[x11]; fmla v31.2s,v3.2s,v7.2s\n\t" + +#define KERNEL_M4N4_FIN1 \ + "ldr s0,[x0],#4; ldr s4,[x4]; ldr s5,[x4,#4]; ldr s6,[x4,#8]\n\t"\ + "ldr s1,[x1],#4\n\t" FMA_3V(16, 17, 18, 0, 0, 0, 4, 5, 6)\ + "ldr s2,[x2],#4\n\t" FMA_3V(20, 21, 22, 1, 1, 1, 4, 5, 6)\ + "ldr s3,[x3],#4\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 4, 5, 6)\ + "ldr s7,[x4,#12]\n\t" FMA_3V(28, 29, 30, 3, 3, 3, 4, 5, 6)\ + "add x4,x4,#16\n\t" FMA_3V(19, 23, 27, 0, 1, 2, 7, 7, 7)\ + "fmla v31.2s,v3.2s,v7.2s\n\t" + + +/* acc layout for m4n5 kernel */ +/* m0n0 v12 v13 v14 v15 v16 m0n5 */ +/* m1n0 v17 v18 v19 v20 v21 m1n5 */ +/* m2n0 v22 v23 v24 v25 v26 m2n5 */ +/* m3n0 v27 v28 v29 v30 v31 m3n5 */ +/* b-holder layout for m4n5 kernel */ +/* n0 v4 v5 v6 v7 v8 */ +/* a-holder layout for m4n5 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2, a_ptr4->v3 */ + +#define INIT_M4N5 \ + INIT_4V(12, 17, 22, 27) INIT_4V(13, 18, 23, 28) INIT_4V(14, 19, 24, 29)\ + INIT_4V(15, 20, 25, 30) INIT_4V(16, 21, 26, 31) + +#define SAVE_M4N5(mode) \ + UNIT_SAVE_M4N1_##mode(12, 17, 22, 27) UNIT_SAVE_M4N1_##mode(13, 18, 23, 28)\ + UNIT_SAVE_M4N1_##mode(14, 19, 24, 29) UNIT_SAVE_M4N1_##mode(15, 20, 25, 30)\ + UNIT_SAVE_M4N1_##mode(16, 21, 26, 31) + +#define KERNEL_M4N5_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]; add x4,x4,#40\n\t" + +#define KERNEL_M4N5_MAIN8 \ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-16]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#-8]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d6,[x4,#16]\n\t"\ + "prfm pldl1keep,[x0,#64]; fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#32]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#40]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "ldr d0,[x0],#8; ldr d5,[x4,#48]; ldr d6,[x4,#56]; sub w5,w5,#8\n\t"\ + "prfm pldl1keep,[x1,#64]; fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#64]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#72]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#80]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "ldr d0,[x0],#8; ldr d5,[x4,#88]; ldr d6,[x4,#96]; cmp w5,#10\n\t"\ + "prfm pldl1keep,[x2,#64]; fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#104]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#112]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#120]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "ldr d0,[x0],#8; ldr d5,[x4,#128]; ldr d6,[x4,#136]; add x4,x4,#160\n\t"\ + "prfm pldl1keep,[x3,#64]; fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t" + +#define KERNEL_M4N5_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-16]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#-8]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d6,[x4,#16]; sub w5,w5,#4\n\t"\ + "fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#32]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#40]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "ldr d0,[x0],#8; ldr d5,[x4,#48]; ldr d6,[x4,#56]; add x4,x4,#80\n\t"\ + "fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t" + +#define KERNEL_M4N5_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-16]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#-8]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d6,[x4,#16]\n\t"\ + "prfm pldl1keep,[x8]\n\t"\ + "prfm pldl1keep,[x9]; fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#32]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "prfm pldl1keep,[x10]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "add x4,x4,#40; sub w5,w5,#4\n\t"\ + "prfm pldl1keep,[x11]; fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t" + +#define KERNEL_M4N5_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-16]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#-8]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "prfm pldl1keep,[x9]; prfm pldl1keep,[x10]; prfm pldl1keep,[x11]\n\t"\ + "sub w5,w5,#2; fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t" + +#define KERNEL_M4N5_FIN1 \ + "ldr s0,[x0],#4; ldr s4,[x4]; ldr s5,[x4,#4]; ldr s6,[x4,#8]\n\t"\ + "ldr s1,[x1],#4\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr s2,[x2],#4\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr s3,[x3],#4\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr s7,[x4,#12]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr s8,[x4,#16]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "add x4,x4,#20\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t" + + +/* acc layout for m4n6 kernel */ +/* m0n0 v8 v9 v10 v11 v12 v13 m0n6 */ +/* m1n0 v14 v15 v16 v17 v18 v19 m1n6 */ +/* m2n0 v20 v21 v22 v23 v24 v25 m2n6 */ +/* m3n0 v26 v27 v28 v29 v30 v31 m3n6 */ +/* b-holder layout for m4n6 kernel */ +/* n0 v4 v5 v6 v7 */ +/* a-holder layout for m4n5 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2, a_ptr4->v3 */ + +#define INIT_M4N6 \ + INIT_4V(8, 14, 20, 26) INIT_4V(9, 15, 21, 27) INIT_4V(10, 16, 22, 28)\ + INIT_4V(11, 17, 23, 29) INIT_4V(12, 18, 24, 30) INIT_4V(13, 19, 25, 31) + +#define SAVE_M4N6(mode) \ + UNIT_SAVE_M4N1_##mode(8, 14, 20, 26) UNIT_SAVE_M4N1_##mode(9, 15, 21, 27)\ + UNIT_SAVE_M4N1_##mode(10, 16, 22, 28) UNIT_SAVE_M4N1_##mode(11, 17, 23, 29)\ + UNIT_SAVE_M4N1_##mode(12, 18, 24, 30) UNIT_SAVE_M4N1_##mode(13, 19, 25, 31) + +#define KERNEL_M4N6_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]; add x4,x4,#48\n\t" + +#define KERNEL_M4N6_MAIN8 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-8]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d7,[x4,#16]; sub w5,w5,#8\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d5,[x4,#32]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d7,[x4,#40]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#48]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 6, 5, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#56]; ldr d6,[x4,#64]; cmp w5,#10\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 7, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#72]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#80]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#88]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#96]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d0,[x0],#8; ldr d5,[x4,#104]; ldr d7,[x4,#112]\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#120]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d5,[x4,#128]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d7,[x4,#136]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#144]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 6, 5, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#152]; ldr d6,[x4,#160]; add x4,x4,#192\n\t"\ + "prfm pldl1keep,[x3,#64]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N6_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-8]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d7,[x4,#16]; sub w5,w5,#4\n\t"\ + FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d5,[x4,#32]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d7,[x4,#40]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#48]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 6, 5, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#56]; ldr d6,[x4,#64]; add x4,x4,#96\n\t"\ + FMA_3V(19, 25, 31, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N6_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-8]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; ldr d7,[x4,#16]; sub w5,w5,#4\n\t"\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d5,[x4,#32]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d7,[x4,#40]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 6, 5, 7)\ + "prfm pldl1keep,[x10]; add x4,x4,#48\n\t"\ + "prfm pldl1keep,[x11]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N6_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-8]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "prfm pldl1keep,[x9]; prfm pldl1keep,[x10]; sub w5,w5,#2\n\t"\ + "prfm pldl1keep,[x11]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6) + +#define KERNEL_M4N6_FIN1 \ + "ldr s0,[x0],#4; ldr s4,[x4]; ldr s5,[x4,#4]; ldr s6,[x4,#8]\n\t"\ + "ldr s1,[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr s2,[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr s3,[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr s7,[x4,#12]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr s5,[x4,#16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr s6,[x4,#20]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "add x4,x4,#24\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6) + + +/* acc layout for m3n7 kernel */ +/* m1n0 v11 v12 v13 v14 v15 v16 v17 m0n7 */ +/* m2n0 v18 v19 v20 v21 v22 v23 v24 m1n7 */ +/* m3n0 v25 v26 v27 v28 v29 v30 v31 m2n7 */ +/* b-holder layout for m3n7 kernel */ +/* n0 v3 v4 v5 v6 v7 v8 v9 */ +/* a-holder layout for m3n7 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2 */ + +#define INIT_M3N7 \ + INIT_3V(11, 18, 25) INIT_3V(12, 19, 26) INIT_3V(13, 20, 27)\ + INIT_3V(14, 21, 28) INIT_3V(15, 22, 29) INIT_3V(16, 23, 30)\ + INIT_3V(17, 24, 31) + +#define SAVE_M3N7(mode) \ + UNIT_SAVE_M3N1_##mode(11, 18, 25) UNIT_SAVE_M3N1_##mode(12, 19, 26)\ + UNIT_SAVE_M3N1_##mode(13, 20, 27) UNIT_SAVE_M3N1_##mode(14, 21, 28)\ + UNIT_SAVE_M3N1_##mode(15, 22, 29) UNIT_SAVE_M3N1_##mode(16, 23, 30)\ + UNIT_SAVE_M3N1_##mode(17, 24, 31) + +#define KERNEL_M3N7_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d3,[x4]; ldr d4,[x4,#8]; ldr d5,[x4,#16]; add x4,x4,#56\n\t" + +#define KERNEL_M3N7_MAIN8 \ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-16]; ldr d9,[x4,#-8]; ldr d3,[x4]\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#40]; ldr d9,[x4,#48]; ldr d3,[x4,#56]; sub w5,w5,#8\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t"\ + "ldr d4,[x4,#64]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#72]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#80]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#88]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#96]; ldr d9,[x4,#104]; ldr d3,[x4,#112]; cmp w5,#10\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t"\ + "ldr d4,[x4,#120]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#128]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#136]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#144]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#152]; ldr d9,[x4,#160]; ldr d3,[x4,#168]; add x4,x4,#224\n\t"\ + "prfm pldl1keep,[x3,#64]\n\t"\ + "ldr d4,[x4,#-48]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#-40]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9) + +#define KERNEL_M3N7_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-16]; ldr d9,[x4,#-8]; ldr d3,[x4]; sub w5,w5,#4\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#40]; ldr d9,[x4,#48]; ldr d3,[x4,#56]; add x4,x4,#112\n\t"\ + "ldr d4,[x4,#-48]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#-40]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9) + +#define KERNEL_M3N7_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-16]; ldr d9,[x4,#-8]; ldr d3,[x4]; sub w5,w5,#4\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#40]; ldr d9,[x4,#48]; add x4,x4,#56\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9) + +#define KERNEL_M3N7_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-16]; ldr d9,[x4,#-8]; sub w5,w5,#2\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9) + +#define KERNEL_M3N7_FIN1 \ + "ldr s0,[x0],#4; ldr s3,[x4]; ldr s4,[x4,#4]; ldr s5,[x4,#8]\n\t"\ + "ldr s1,[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr s2,[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr s6,[x4,#12]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr s7,[x4,#16]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr s8,[x4,#20]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr s9,[x4,#24]\n\t" FMA_3V(16, 23, 30, 0, 1, 2, 8, 8, 8)\ + "add x4,x4,#28\n\t" FMA_3V(17, 24, 31, 0, 1, 2, 9, 9, 9) + + +/* acc layout for m3n8 kernel */ +/* m1n0 v8 v9 v10 v11 v12 v13 v14 v15 m0n8 */ +/* m2n0 v16 v17 v18 v19 v20 v21 v22 v23 m1n8 */ +/* m3n0 v24 v25 v26 v27 v28 v29 v30 v31 m2n8 */ +/* b-holder layout for m3n8 kernel */ +/* n0 v3 v4 v5 v6 v7 */ +/* a-holder layout for m3n8 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2 */ + +#define INIT_M3N8 \ + INIT_3V(8, 16, 24) INIT_3V(9, 17, 25) INIT_3V(10, 18, 26)\ + INIT_3V(11, 19, 27) INIT_3V(12, 20, 28) INIT_3V(13, 21, 29)\ + INIT_3V(14, 22, 30) INIT_3V(15, 23, 31) + +#define SAVE_M3N8(mode) \ + UNIT_SAVE_M3N1_##mode(8, 16, 24) UNIT_SAVE_M3N1_##mode(9, 17, 25)\ + UNIT_SAVE_M3N1_##mode(10, 18, 26) UNIT_SAVE_M3N1_##mode(11, 19, 27)\ + UNIT_SAVE_M3N1_##mode(12, 20, 28) UNIT_SAVE_M3N1_##mode(13, 21, 29)\ + UNIT_SAVE_M3N1_##mode(14, 22, 30) UNIT_SAVE_M3N1_##mode(15, 23, 31) + +#define KERNEL_M3N8_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d3,[x4]; ldr d4,[x4,#8]; ldr d5,[x4,#16]; add x4,x4,#64\n\t" + +#define KERNEL_M3N8_MAIN8 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-40]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; ldr d3,[x4]\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#48]; ldr d7,[x4,#56]; ldr d3,[x4,#64]; sub w5,w5,#8\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t"\ + "ldr d4,[x4,#72]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#80]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#88]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#96]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#104]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#112]; ldr d7,[x4,#120]; ldr d3,[x4,#128]; cmp w5,#10\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t"\ + "ldr d4,[x4,#136]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#144]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#152]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#160]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#168]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#176]; ldr d7,[x4,#184]; ldr d3,[x4,#192]; add x4,x4,#256\n\t"\ + "prfm pldl1keep,[x3,#64]\n\t"\ + "ldr d4,[x4,#-56]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#-48]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7) + +#define KERNEL_M3N8_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-40]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; ldr d3,[x4]; sub w5,w5,#4\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#48]; ldr d7,[x4,#56]; ldr d3,[x4,#64]; add x4,x4,#128\n\t"\ + "ldr d4,[x4,#-56]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#-48]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7) + +#define KERNEL_M3N8_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-40]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; ldr d3,[x4]; sub w5,w5,#4\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#48]; ldr d7,[x4,#56]; add x4,x4,#64\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7) + +#define KERNEL_M3N8_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-40]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; sub w5,w5,#2\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7) + +#define KERNEL_M3N8_FIN1 \ + "ldr s0,[x0],#4; ldr s3,[x4]; ldr s4,[x4,#4]; ldr s5,[x4,#8]\n\t"\ + "ldr s1,[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr s2,[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr s6,[x4,#12]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr s7,[x4,#16]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr s5,[x4,#20]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr s6,[x4,#24]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr s7,[x4,#28]\n\t" FMA_3V(14, 22, 30, 0, 1, 2, 6, 6, 6)\ + "add x4,x4,#32\n\t" FMA_3V(15, 23, 31, 0, 1, 2, 7, 7, 7) + +FUNC_PACK3(4, 4) + +FUNC_PACK3(4, 5) + +FUNC_PACK3(4, 6) + +FUNC_PACK3(3, 7) + +FUNC_PACK3(3, 8) + +/* macro for GEMM with packing pattern NO.#0 */ +/* mdim = 3, 4; ndim = 10, 12, 14, 16 */ +#define FUNC_PACK0(mdim, ndim) \ +static inline void sgemm_skinny1_a35_m##mdim##n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_scr,\ + float * __restrict__ c_ptr, uint32_t K, uint32_t LDA, uint32_t LDC,\ + uint8_t c_rowmajor, const float * __restrict__ beta_addr) {\ + __asm__ __volatile__(\ + "mov x4,%[b_scr]\n\t"\ + "mov x0,%[a_ptr]; add x1,%[a_ptr],%w[LDA],UXTW #2\n\t"\ + "add x2,%[a_ptr],%w[LDA],UXTW #3; add x3,x1,%w[LDA],UXTW #3\n\t"\ + "add x8,x0,%w[LDA],UXTW #4; add x9,x1,%w[LDA],UXTW #4\n\t"\ + "add x10,x2,%w[LDA],UXTW #4; add x11,x3,%w[LDA],UXTW #4\n\t"\ + "mov w5,%w[K]\n\t"\ + INIT_M##mdim##N##ndim\ + "cmp w5,#1; b.lt 4f\n\t"\ + KERNEL_M##mdim##N##ndim##_PRELOAD1\ + "cmp w5,#5; b.lt 1f\n\t"\ + ".balign 16; 8:\n\t"\ + KERNEL_M##mdim##N##ndim##_MAIN4 "b.ge 8b\n\t"\ + "1:\n\t"\ + "cmp w5,#3; b.lt 2f\n\t"\ + KERNEL_M##mdim##N##ndim##_MAIN2\ + "2:\n\t"\ + "cmp w5,#2; b.ne 3f\n\t"\ + KERNEL_M##mdim##N##ndim##_TAIL2 "b 4f\n\t"\ + "3:\n\t"\ + KERNEL_M##mdim##N##ndim##_TAIL1\ + "4:\n\t"\ + "cmp %w[c_rowmajor],#0; b.eq 6f\n\t"\ + INIT_SAVE_M##mdim##_CR SAVE_M##mdim##N##ndim(CR) "b 7f\n\t"\ + "6:\n\t"\ + INIT_SAVE_CC SAVE_M##mdim##N##ndim(CC)\ + "7:\n\t"\ + ::[a_ptr]"r"(a_ptr), [b_scr]"r"(b_scr), [c_ptr]"r"(c_ptr),\ + [LDA]"r"(LDA), [LDC]"r"(LDC), [K]"r"(K),\ + [beta_addr]"r"(beta_addr), [c_rowmajor]"r"(c_rowmajor)\ + :"cc","memory","x0","x1","x2","x3","x4","x5",\ + "x8","x9","x10","x11","x12","x13","x14","x15",\ + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11",\ + "v12","v13","v14","v15","v16","v17","v18","v19","v20","v21",\ + "v22","v23","v24","v25","v26","v27","v28","v29","v30","v31");\ +} + +/* acc layout for m4n10 kernel */ +/* m0n0 v10 v11 v12 v13 v14 m0n10 */ +/* m1n0 v15 v16 v17 v18 v19 m1n10 */ +/* m2n0 v20 v21 v22 v23 v24 m2n10 */ +/* m3n0 v25 v26 v27 v28 v29 m3n10 */ +/* b-holder layout for m4n10 kernel */ +/* n0 v5 v6 v7 v8 v9 n10 */ +/* a-holder layout for m4n10 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2, a_ptr4->v3 */ + +#define INIT_M4N10 \ + INIT_4V(10, 15, 20, 25) INIT_4V(11, 16, 21, 26)\ + INIT_4V(12, 17, 22, 27) INIT_4V(13, 18, 23, 28)\ + INIT_4V(14, 19, 24, 29) + +#define SAVE_M4N10(mode) \ + UNIT_SAVE_M4N2_##mode(10, 15, 20, 25) UNIT_SAVE_M4N2_##mode(11, 16, 21, 26)\ + UNIT_SAVE_M4N2_##mode(12, 17, 22, 27) UNIT_SAVE_M4N2_##mode(13, 18, 23, 28)\ + UNIT_SAVE_M4N2_##mode(14, 19, 24, 29) + +#define KERNEL_M4N10_PRELOAD1 \ + "ld1r {v0.2s},[x0],#4\n\t"\ + "ldr d5,[x4]; ldr d6,[x4,#8]; ldr d7,[x4,#16]; add x4,x4,#40\n\t" + +#define KERNEL_M4N10_MAIN4 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#-16]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#-8]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "ldr d5,[x4]\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "ld1r {v0.2s},[x0],#4; ldr d6,[x4,#8]; ldr d7,[x4,#16]\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; fmla v29.2s,v3.2s,v9.2s\n\t"\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#24]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#32]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "ld1r {v0.2s},[x0],#4; ldr d6,[x4,#48]; ldr d7,[x4,#56]\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; fmla v29.2s,v3.2s,v9.2s\n\t"\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#64]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#72]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "ldr d5,[x4,#80]\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "ld1r {v0.2s},[x0],#4; ldr d6,[x4,#88]; ldr d7,[x4,#96]\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; fmla v29.2s,v3.2s,v9.2s\n\t"\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#104]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#112]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "ldr d5,[x4,#120]\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "ld1r {v0.2s},[x0],#4; ldr d6,[x4,#128]; ldr d7,[x4,#136]; add x4,x4,#160\n\t"\ + "prfm pldl1keep,[x3,#64]; sub w5,w5,#4\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; cmp w5,#5; fmla v29.2s,v3.2s,v9.2s\n\t" + +#define KERNEL_M4N10_MAIN2 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#-16]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#-8]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "ldr d5,[x4]\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "ld1r {v0.2s},[x0],#4; ldr d6,[x4,#8]; ldr d7,[x4,#16]\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; fmla v29.2s,v3.2s,v9.2s\n\t"\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#24]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#32]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "ld1r {v0.2s},[x0],#4; ldr d6,[x4,#48]; ldr d7,[x4,#56]; add x4,x4,#80\n\t"\ + "sub w5,w5,#2\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; fmla v29.2s,v3.2s,v9.2s\n\t" + +#define KERNEL_M4N10_TAIL2 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#-16]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#-8]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "ldr d5,[x4]\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "ld1r {v0.2s},[x0],#4; ldr d6,[x4,#8]; ldr d7,[x4,#16]\n\t"\ + "prfm pldl1keep,[x8]; prfm pldl1keep,[x9]\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; fmla v29.2s,v3.2s,v9.2s\n\t"\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#24]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#32]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "prfm pldl1keep,[x10]; sub w5,w5,#2\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "prfm pldl1keep,[x11]; add x4,x4,#40\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; fmla v29.2s,v3.2s,v9.2s\n\t" + +#define KERNEL_M4N10_TAIL1 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(10, 11, 12, 0, 0, 0, 5, 6, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(15, 16, 17, 1, 1, 1, 5, 6, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 5, 6, 7)\ + "ldr d8,[x4,#-16]\n\t" FMA_3V(25, 26, 27, 3, 3, 3, 5, 6, 7)\ + "ldr d9,[x4,#-8]\n\t" FMA_3V(13, 18, 23, 0, 1, 2, 8, 8, 8)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(14, 19, 24, 0, 1, 2, 9, 9, 9)\ + "prfm pldl1keep,[x9]; prfm pldl1keep,[x10]; prfm pldl1keep,[x11]\n\t"\ + "sub w5,w5,#1\n\t"\ + "fmla v28.2s,v3.2s,v8.2s; fmla v29.2s,v3.2s,v9.2s\n\t" + + +/* acc layout for m4n12 kernel */ +/* m0n0 v8 v9 v10 v11 v12 v13 m0n12 */ +/* m1n0 v14 v15 v16 v17 v18 v19 m1n12 */ +/* m2n0 v20 v21 v22 v23 v24 v25 m2n12 */ +/* m3n0 v26 v27 v28 v29 v30 v31 m3n12 */ +/* b-holder layout for m4n12 kernel */ +/* n0 v4 v5 v6/v7 v7/v6 v5 v6/v7 n12 */ +/* a-holder layout for m4n12 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2, a_ptr4->v3 */ + +#define INIT_M4N12 \ + INIT_4V(8, 14, 20, 26) INIT_4V(9, 15, 21, 27)\ + INIT_4V(10, 16, 22, 28) INIT_4V(11, 17, 23, 29)\ + INIT_4V(12, 18, 24, 30) INIT_4V(13, 19, 25, 31) + +#define SAVE_M4N12(mode) \ + UNIT_SAVE_M4N2_##mode(8, 14, 20, 26) UNIT_SAVE_M4N2_##mode(9, 15, 21, 27)\ + UNIT_SAVE_M4N2_##mode(10, 16, 22, 28) UNIT_SAVE_M4N2_##mode(11, 17, 23, 29)\ + UNIT_SAVE_M4N2_##mode(12, 18, 24, 30) UNIT_SAVE_M4N2_##mode(13, 19, 25, 31) + +#define KERNEL_M4N12_PRELOAD1 \ + "ld1r {v0.2s},[x0],#4\n\t"\ + "ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]; add x4,x4,#48\n\t" + +#define KERNEL_M4N12_MAIN4 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-8]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ld1r {v0.2s},[x0],#4; ldr d5,[x4,#8]; ldr d7,[x4,#16]\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d5,[x4,#32]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d7,[x4,#40]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#48]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 6, 5, 7)\ + "ld1r {v0.2s},[x0],#4; ldr d5,[x4,#56]; ldr d6,[x4,#64]\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 7, 7, 7)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#72]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#80]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#88]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#96]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ld1r {v0.2s},[x0],#4; ldr d5,[x4,#104]; ldr d7,[x4,#112]\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#120]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d5,[x4,#128]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d7,[x4,#136]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#144]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 6, 5, 7)\ + "ld1r {v0.2s},[x0],#4; sub w5,w5,#4; ldr d5,[x4,#152]; cmp w5,#5\n\t"\ + "ldr d6,[x4,#160]; add x4,x4,#192\n\t"\ + "prfm pldl1keep,[x3,#64]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N12_MAIN2 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-8]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ld1r {v0.2s},[x0],#4; ldr d5,[x4,#8]; ldr d7,[x4,#16]\n\t"\ + FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d5,[x4,#32]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d7,[x4,#40]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#48]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 6, 5, 7)\ + "ld1r {v0.2s},[x0],#4; ldr d5,[x4,#56]; ldr d6,[x4,#64]; add x4,x4,#96\n\t"\ + "sub w5,w5,#2\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N12_TAIL2 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-8]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ld1r {v0.2s},[x0],#4; ldr d5,[x4,#8]; ldr d7,[x4,#16]\n\t"\ + "prfm pldl1keep,[x8]; prfm pldl1keep,[x9]\n\t"\ + FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d5,[x4,#32]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d7,[x4,#40]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "prfm pldl1keep,[x10]; add x4,x4,#48\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 6, 5, 7)\ + "prfm pldl1keep,[x11]; sub w5,w5,#2\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N12_TAIL1 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-16]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-8]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "prfm pldl1keep,[x8]; prfm pldl1keep,[x9]; prfm pldl1keep,[x10]\n\t"\ + "sub w5,w5,#1\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "prfm pldl1keep,[x11]\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6) + + +/* acc layout for m3n14 kernel */ +/* m0n0 v11 v12 v13 v14 v15 v16 v17 m0n14 */ +/* m1n0 v18 v19 v20 v21 v22 v23 v24 m1n14 */ +/* m2n0 v25 v26 v27 v28 v29 v30 v31 m2n14 */ +/* b-holder layout for m3n14 kernel */ +/* n0 v3 v4 v5 v6 v7 v8 v9 n14 */ +/* a-holder layout for m3n14 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2 */ + +#define INIT_M3N14 \ + INIT_3V(11, 18, 25) INIT_3V(12, 19, 26) INIT_3V(13, 20, 27)\ + INIT_3V(14, 21, 28) INIT_3V(15, 22, 29) INIT_3V(16, 23, 30)\ + INIT_3V(17, 24, 31) + +#define SAVE_M3N14(mode) \ + UNIT_SAVE_M3N2_##mode(11, 18, 25) UNIT_SAVE_M3N2_##mode(12, 19, 26)\ + UNIT_SAVE_M3N2_##mode(13, 20, 27) UNIT_SAVE_M3N2_##mode(14, 21, 28)\ + UNIT_SAVE_M3N2_##mode(15, 22, 29) UNIT_SAVE_M3N2_##mode(16, 23, 30)\ + UNIT_SAVE_M3N2_##mode(17, 24, 31) + +#define KERNEL_M3N14_PRELOAD1 \ + "ld1r {v0.2s},[x0],#4\n\t"\ + "ldr d3,[x4]; ldr d4,[x4,#8]; ldr d5,[x4,#16]; add x4,x4,#56\n\t" + +#define KERNEL_M3N14_MAIN4 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-16]; ldr d9,[x4,#-8]; ldr d3,[x4]; ldr d4,[x4,#8]\n\t"\ + FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#40]; ldr d9,[x4,#48]; ldr d3,[x4,#56]; ldr d4,[x4,#64]\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#72]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#80]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#88]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#96]; ldr d9,[x4,#104]; ldr d3,[x4,#112]; ldr d4,[x4,#120]\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#128]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#136]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#144]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#152]; ldr d9,[x4,#160]; ldr d3,[x4,#168]; ldr d4,[x4,#176]\n\t"\ + "add x4,x4,#224; prfm pldl1keep,[x2,#64]; sub w5,w5,#4\n\t"\ + FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#-40]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "ld1r {v0.2s},[x0],#4; cmp w5,#5\n\t"\ + FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9) + +#define KERNEL_M3N14_MAIN2 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-16]; ldr d9,[x4,#-8]; ldr d3,[x4]; ldr d4,[x4,#8]\n\t"\ + FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#40]; ldr d9,[x4,#48]; ldr d3,[x4,#56]\n\t"\ + "ldr d4,[x4,#64]; add x4,x4,#112\n\t"\ + "sub w5,w5,#2\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#-40]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9) + +#define KERNEL_M3N14_TAIL2 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-16]; ldr d9,[x4,#-8]; ldr d3,[x4]; ldr d4,[x4,#8]\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#40]; ldr d9,[x4,#48]; add x4,x4,#56\n\t"\ + "sub w5,w5,#2\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9) + +#define KERNEL_M3N14_TAIL1 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-24]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-16]; ldr d9,[x4,#-8]; sub w5,w5,#1\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(16, 17, 23, 0, 0, 1, 8, 9, 8)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(24, 30, 31, 1, 2, 2, 9, 8, 9) + + +/* acc layout for m3n16 kernel */ +/* m0n0 v8 v9 v10 v11 v12 v13 v14 v15 m0n16 */ +/* m1n0 v16 v17 v18 v19 v20 v21 v22 v23 m1n16 */ +/* m2n0 v24 v25 v26 v27 v28 v29 v30 v31 m2n16 */ +/* b-holder layout for m3n16 kernel */ +/* n0 v3 v4 v5 v6 v7 v5 v6 v7 n16 */ +/* a-holder layout for m3n16 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2 */ + +#define INIT_M3N16 \ + INIT_3V(8, 16, 24) INIT_3V(9, 17, 25) INIT_3V(10, 18, 26)\ + INIT_3V(11, 19, 27) INIT_3V(12, 20, 28) INIT_3V(13, 21, 29)\ + INIT_3V(14, 22, 30) INIT_3V(15, 23, 31) + +#define SAVE_M3N16(mode) \ + UNIT_SAVE_M3N2_##mode(8, 16, 24) UNIT_SAVE_M3N2_##mode(9, 17, 25)\ + UNIT_SAVE_M3N2_##mode(10, 18, 26) UNIT_SAVE_M3N2_##mode(11, 19, 27)\ + UNIT_SAVE_M3N2_##mode(12, 20, 28) UNIT_SAVE_M3N2_##mode(13, 21, 29)\ + UNIT_SAVE_M3N2_##mode(14, 22, 30) UNIT_SAVE_M3N2_##mode(15, 23, 31) + +#define KERNEL_M3N16_PRELOAD1 \ + "ld1r {v0.2s},[x0],#4\n\t"\ + "ldr d3,[x4]; ldr d4,[x4,#8]; ldr d5,[x4,#16]; add x4,x4,#64\n\t" + +#define KERNEL_M3N16_MAIN4 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-40]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; ldr d3,[x4]; ldr d4,[x4,#8]\n\t"\ + FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#48]; ldr d7,[x4,#56]; ldr d3,[x4,#64]; ldr d4,[x4,#72]\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#80]\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#88]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#96]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#104]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#112]; ldr d7,[x4,#120]; ldr d3,[x4,#128]; ldr d4,[x4,#136]\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#144]\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#152]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#160]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#168]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#176]; ldr d7,[x4,#184]; ldr d3,[x4,#192]; ldr d4,[x4,#200]\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#208]; add x4,x4,#256\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "sub w5,w5,#4; ld1r {v0.2s},[x0],#4; cmp w5,#5\n\t"\ + FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7) + +#define KERNEL_M3N16_MAIN2 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-40]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; ldr d3,[x4]; ldr d4,[x4,#8]\n\t"\ + FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#48]; ldr d7,[x4,#56]; ldr d3,[x4,#64]; ldr d4,[x4,#72]\n\t"\ + FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#80]; add x4,x4,#128\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "ld1r {v0.2s},[x0],#4; sub w5,w5,#2\n\t" FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7) + +#define KERNEL_M3N16_TAIL2 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-40]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; ldr d3,[x4]; ldr d4,[x4,#8]\n\t"\ + FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "ld1r {v0.2s},[x0],#4\n\t" FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7)\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#48]; ldr d7,[x4,#56]\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "add x4,x4,#64; prfm pldl1keep,[x8]\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "sub w5,w5,#2; prfm pldl1keep,[x9]\n\t" FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7) + +#define KERNEL_M3N16_TAIL1 \ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-40]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; sub w5,w5,#1\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(14, 15, 22, 0, 0, 1, 6, 7, 6)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(23, 30, 31, 1, 2, 2, 7, 6, 7) + +FUNC_PACK0(4, 10) + +FUNC_PACK0(4, 12) + +FUNC_PACK0(3, 14) + +FUNC_PACK0(3, 16) + +/* macro for GEMM with packing pattern NO.#4 */ +/* mdim = 3, 4; ndim = 9, 11, 13, 15, 17, 18 */ +#define FUNC_PACK4(mdim, ndim) \ +static inline void sgemm_skinny1_a35_m##mdim##n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_scr,\ + float * __restrict__ c_ptr, uint32_t K, uint32_t LDA, uint32_t LDC,\ + uint8_t c_rowmajor, const float * __restrict__ beta_addr) {\ + __asm__ __volatile__(\ + "mov x4,%[b_scr]\n\t"\ + "mov x0,%[a_ptr]; add x1,%[a_ptr],%w[LDA],UXTW #2\n\t"\ + "add x2,%[a_ptr],%w[LDA],UXTW #3; add x3,x1,%w[LDA],UXTW #3\n\t"\ + "add x8,x0,%w[LDA],UXTW #4; add x9,x1,%w[LDA],UXTW #4\n\t"\ + "add x10,x2,%w[LDA],UXTW #4; add x11,x3,%w[LDA],UXTW #4\n\t"\ + "mov w5,%w[K]\n\t"\ + INIT_M##mdim##N##ndim\ + "cmp w5,#2; b.lt 4f\n\t"\ + KERNEL_M##mdim##N##ndim##_PRELOAD2\ + "cmp w5,#6; b.lt 1f\n\t"\ + ".balign 16; 8:\n\t"\ + KERNEL_M##mdim##N##ndim##_MAIN4 "b.ge 8b\n\t"\ + "1:\n\t"\ + "cmp w5,#4; b.lt 2f\n\t"\ + KERNEL_M##mdim##N##ndim##_TAIL4 "b 4f\n\t"\ + "2:\n\t"\ + KERNEL_M##mdim##N##ndim##_TAIL2\ + "4:\n\t"\ + "cmp w5,#1; b.lt 5f\n\t"\ + KERNEL_M##mdim##N##ndim##_FIN1\ + "5:\n\t"\ + "cmp %w[c_rowmajor],#0; b.eq 6f\n\t"\ + INIT_SAVE_M##mdim##_CR SAVE_M##mdim##N##ndim(CR) "b 7f\n\t"\ + "6:\n\t"\ + INIT_SAVE_CC SAVE_M##mdim##N##ndim(CC)\ + "7:\n\t"\ + ::[a_ptr]"r"(a_ptr), [b_scr]"r"(b_scr), [c_ptr]"r"(c_ptr),\ + [LDA]"r"(LDA), [LDC]"r"(LDC), [K]"r"(K),\ + [beta_addr]"r"(beta_addr), [c_rowmajor]"r"(c_rowmajor)\ + :"cc","memory","x0","x1","x2","x3","x4","x5",\ + "x8","x9","x10","x11","x12","x13","x14","x15",\ + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11",\ + "v12","v13","v14","v15","v16","v17","v18","v19","v20","v21",\ + "v22","v23","v24","v25","v26","v27","v28","v29","v30","v31");\ +} + +/* acc layout for m4n9 kernel */ +/* m0n0 v12 v13 v14 v15 v16_h m0n9 */ +/* m1n0 v17 v18 v19 v20 v21_h m1n9 */ +/* m2n0 v22 v23 v24 v25 v26_h m2n9 */ +/* m3n0 v27 v28 v29 v30 v31_h m3n9 */ +/* b-holder layout for m4n9 kernel */ +/* n0 v4 v5 v6 v7 v8(s) n9 */ +/* a-holder layout for m4n9 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2, a_ptr4->v3 */ + +#define INIT_M4N9 \ + INIT_4V(12, 17, 22, 27) INIT_4V(13, 18, 23, 28)\ + INIT_4V(14, 19, 24, 29) INIT_4V(15, 20, 25, 30)\ + INIT_4V(16, 21, 26, 31) + +#define KERNEL_M4N9_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]; add x4,x4,#72\n\t" + +#define KERNEL_M4N9_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-48]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#-40]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#-32]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "rev64 v0.2s,v0.2s; ldr d5,[x4,#-24]; ldr d6,[x4,#-16]; sub w5,w5,#4\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t" FMA_3V(30, 31, 12, 3, 3, 0, 7, 8, 4)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(13, 14, 17, 0, 0, 1, 5, 6, 4)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 22, 1, 1, 2, 5, 6, 4)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(23, 24, 27, 2, 2, 3, 5, 6, 4)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(28, 29, 15, 3, 3, 0, 5, 6, 7)\ + "ldr d0,[x0],#8; ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]; cmp w5,#6\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t" FMA_3V(20, 25, 30, 1, 2, 3, 7, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#32]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#40]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "rev64 v0.2s,v0.2s; ldr d5,[x4,#48]; ldr d6,[x4,#56]; add x4,x4,#144\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t" FMA_3V(30, 31, 12, 3, 3, 0, 7, 8, 4)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(13, 14, 17, 0, 0, 1, 5, 6, 4)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 22, 1, 1, 2, 5, 6, 4)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(23, 24, 27, 2, 2, 3, 5, 6, 4)\ + "ldr d7,[x4,#-80]\n\t" FMA_3V(28, 29, 15, 3, 3, 0, 5, 6, 7)\ + "ldr d0,[x0],#8; ldr d4,[x4,#-72]; ldr d5,[x4,#-64]; ldr d6,[x4,#-56]\n\t"\ + "prfm pldl1keep,[x3,#64]\n\t" FMA_3V(20, 25, 30, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N9_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-48]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#-40]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#-32]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "rev64 v0.2s,v0.2s; ldr d5,[x4,#-24]; ldr d6,[x4,#-16]; sub w5,w5,#4\n\t"\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(30, 31, 12, 3, 3, 0, 7, 8, 4)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(13, 14, 17, 0, 0, 1, 5, 6, 4)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 22, 1, 1, 2, 5, 6, 4)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(23, 24, 27, 2, 2, 3, 5, 6, 4)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(28, 29, 15, 3, 3, 0, 5, 6, 7)\ + "ldr d0,[x0],#8; ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]\n\t"\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(20, 25, 30, 1, 2, 3, 7, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#32]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#40]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "rev64 v0.2s,v0.2s; ldr d5,[x4,#48]; ldr d6,[x4,#56]; add x4,x4,#72\n\t"\ + "prfm pldl1keep,[x10]\n\t" FMA_3V(30, 31, 12, 3, 3, 0, 7, 8, 4)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(13, 14, 17, 0, 0, 1, 5, 6, 4)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 22, 1, 1, 2, 5, 6, 4)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(23, 24, 27, 2, 2, 3, 5, 6, 4)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(28, 29, 15, 3, 3, 0, 5, 6, 7)\ + "prfm pldl1keep,[x11]\n\t" FMA_3V(20, 25, 30, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N9_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-48]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr d8,[x4,#-40]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "ldr d4,[x4,#-32]\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "rev64 v0.2s,v0.2s; ldr d5,[x4,#-24]; ldr d6,[x4,#-16]; sub w5,w5,#2\n\t"\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(30, 31, 12, 3, 3, 0, 7, 8, 4)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(13, 14, 17, 0, 0, 1, 5, 6, 4)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 22, 1, 1, 2, 5, 6, 4)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(23, 24, 27, 2, 2, 3, 5, 6, 4)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(28, 29, 15, 3, 3, 0, 5, 6, 7)\ + "prfm pldl1keep,[x9]; prfm pldl1keep,[x10]\n\t"\ + "prfm pldl1keep,[x11]\n\t" FMA_3V(20, 25, 30, 1, 2, 3, 7, 7, 7) + +#define KERNEL_M4N9_FIN1 \ + "ld1r {v0.2s},[x0],#4; ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]\n\t"\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(12, 13, 14, 0, 0, 0, 4, 5, 6)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(17, 18, 19, 1, 1, 1, 4, 5, 6)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(22, 23, 24, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(27, 28, 29, 3, 3, 3, 4, 5, 6)\ + "ldr s8,[x4,#32]\n\t" FMA_3V(15, 20, 25, 0, 1, 2, 7, 7, 7)\ + "add x4,x4,#36\n\t" FMA_3V(16, 21, 26, 0, 1, 2, 8, 8, 8)\ + "fmla v30.2s,v3.2s,v7.2s; fmla v31.2s,v3.2s,v8.2s\n\t" + +#define SAVE_M4N9(mode) \ + UNIT_SAVE_M4N2_##mode(12, 17, 22, 27) UNIT_SAVE_M4N2_##mode(13, 18, 23, 28)\ + UNIT_SAVE_M4N2_##mode(14, 19, 24, 29) UNIT_SAVE_M4N2_##mode(15, 20, 25, 30)\ + UNIT_SAVE_M4N1_##mode(16, 21, 26, 31) + + +/* acc layout for m4n11 kernel */ +/* m0n0 v8 v9 v10 v11 v12 v13_h m0n11 */ +/* m1n0 v14 v15 v16 v17 v18 v19_h m1n11 */ +/* m2n0 v20 v21 v22 v23 v24 v25_h m2n11 */ +/* m3n0 v26 v27 v28 v29 v30 v31_h m3n11 */ +/* b-holder layout for m4n11 kernel */ +/* n0 v4 v5 v6/v7 v7/v6 v5/v7 v6(s) n11 */ +/* a-holder layout for m4n11 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2, a_ptr4->v3 */ + +#define INIT_M4N11 \ + INIT_4V(8, 14, 20, 26) INIT_4V(9, 15, 21, 27)\ + INIT_4V(10, 16, 22, 28) INIT_4V(11, 17, 23, 29)\ + INIT_4V(12, 18, 24, 30) INIT_4V(13, 19, 25, 31) + +#define KERNEL_M4N11_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d4,[x4]; ldr d5,[x4,#8]; ldr d6,[x4,#16]; add x4,x4,#88\n\t" + +#define KERNEL_M4N11_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-64]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-56]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-48]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d5,[x4,#-32]; ldr d7,[x4,#-24]; prfm pldl1keep,[x0,#64]\n\t"\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#-16]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d4,[x4]\n\t" FMA_3V(11, 12, 18, 0, 0, 1, 6, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; prfm pldl1keep,[x1,#64]; sub w5,w5,#4\n\t"\ + "ldr d6,[x4,#16]; fmla v24.2s,v2.2s,v7.2s; fmla v30.2s,v3.2s,v7.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#32]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#40]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#48]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d5,[x4,#56]; ldr d7,[x4,#64]; prfm pldl1keep,[x2,#64]\n\t"\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#72]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d7,[x4,#80]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d4,[x4,#88]\n\t" FMA_3V(11, 12, 18, 0, 0, 1, 6, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#96]; prfm pldl1keep,[x3,#64]; cmp w5,#6\n\t"\ + "ldr d6,[x4,#104]; add x4,x4,#176\n\t"\ + "fmla v24.2s,v2.2s,v7.2s; fmla v30.2s,v3.2s,v7.2s\n\t" + +#define KERNEL_M4N11_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-64]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-56]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-48]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d5,[x4,#-32]; ldr d7,[x4,#-24]; prfm pldl1keep,[x8]\n\t"\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#-16]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "ldr d4,[x4]\n\t" FMA_3V(11, 12, 18, 0, 0, 1, 6, 7, 7)\ + "ldr d0,[x0],#8; ldr d5,[x4,#8]; prfm pldl1keep,[x9]; sub w5,w5,#4\n\t"\ + "ldr d6,[x4,#16]; fmla v24.2s,v2.2s,v7.2s; fmla v30.2s,v3.2s,v7.2s\n\t"\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#24]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#32]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#40]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#48]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d5,[x4,#56]; ldr d7,[x4,#64]; prfm pldl1keep,[x10]\n\t"\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#72]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d7,[x4,#80]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "prfm pldl1keep,[x11]\n\t" FMA_3V(11, 12, 18, 0, 0, 1, 6, 7, 7)\ + "add x4,x4,#88; fmla v24.2s,v2.2s,v7.2s; fmla v30.2s,v3.2s,v7.2s\n\t" + +#define KERNEL_M4N11_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ldr d2,[x2],#8\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ldr d3,[x3],#8\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-64]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-56]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr d6,[x4,#-48]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + "ldr d5,[x4,#-32]; ldr d7,[x4,#-24]; prfm pldl1keep,[x8]\n\t"\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 7)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 7)\ + "rev64 v3.2s,v3.2s\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 7)\ + "ldr d6,[x4,#-16]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 7)\ + "ldr d7,[x4,#-8]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 6, 6, 6)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(11, 12, 18, 0, 0, 1, 6, 7, 7)\ + "prfm pldl1keep,[x10]; sub w5,w5,#2\n\t"\ + "fmla v24.2s,v2.2s,v7.2s; fmla v30.2s,v3.2s,v7.2s\n\t"\ + "prfm pldl1keep,[x11]\n\t" + +#define KERNEL_M4N11_FIN1 \ + "ld1r {v0.2s},[x0],#4; ldr d4,[x4]; ldr d5,[x4,#8]\n\t"\ + "ldr d6,[x4,#16]; add x4,x4,#44\n\t"\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 4, 5, 6)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(14, 15, 16, 1, 1, 1, 4, 5, 6)\ + "ld1r {v3.2s},[x3],#4\n\t" FMA_3V(20, 21, 22, 2, 2, 2, 4, 5, 6)\ + "ldr d7,[x4,#-20]\n\t" FMA_3V(26, 27, 28, 3, 3, 3, 4, 5, 6)\ + "ldr d5,[x4,#-12]\n\t" FMA_3V(17, 23, 29, 1, 2, 3, 7, 7, 7)\ + "ldr s6,[x4,#-4]\n\t" FMA_3V(18, 24, 30, 1, 2, 3, 5, 5, 5)\ + FMA_3V(11, 12, 13, 0, 0, 0, 7, 5, 6)\ + FMA_3V(19, 25, 31, 1, 2, 3, 6, 6, 6) + +#define SAVE_M4N11(mode) \ + UNIT_SAVE_M4N2_##mode(8, 14, 20, 26) UNIT_SAVE_M4N2_##mode(9, 15, 21, 27)\ + UNIT_SAVE_M4N2_##mode(10, 16, 22, 28) UNIT_SAVE_M4N2_##mode(11, 17, 23, 29)\ + UNIT_SAVE_M4N2_##mode(12, 18, 24, 30) UNIT_SAVE_M4N1_##mode(13, 19, 25, 31) + + +/* acc layout for m3n13 kernel */ +/* m0n0 v11 v12 v13 v14 v15 v16 v17_h m0n13 */ +/* m1n0 v18 v19 v20 v21 v22 v23 v24_h m1n13 */ +/* m2n0 v25 v26 v27 v28 v29 v30 v31_h m2n13 */ +/* b-holder layout for m3n13 kernel */ +/* n0 v3 v4 v5 v6 v7 v8 v9(s) n13 */ +/* a-holder layout for m3n13 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2 */ + +#define INIT_M3N13 \ + INIT_3V(11, 18, 25) INIT_3V(12, 19, 26) INIT_3V(13, 20, 27)\ + INIT_3V(14, 21, 28) INIT_3V(15, 22, 29) INIT_3V(16, 23, 30)\ + INIT_3V(17, 24, 31) + +#define KERNEL_M3N13_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d3,[x4]; ldr d4,[x4,#8]; ldr d5,[x4,#16]; add x4,x4,#104\n\t" + +#define KERNEL_M3N13_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-80]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-72]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-64]; ldr d9,[x4,#-56]; prfm pldl1keep,[x0,#64]\n\t"\ + "ldr d3,[x4,#-48]; sub w5,w5,#4\n\t"\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#-32]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-16]; ldr d8,[x4,#-8]; prfm pldl1keep,[x1,#64]\n\t"\ + "ldr d3,[x4]; cmp w5,#6\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(15, 22, 16, 0, 1, 0, 7, 7, 8)\ + "ldr d0,[x0],#8\n\t" FMA_3V(29, 23, 30, 2, 1, 2, 7, 8, 8)\ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#40]; ldr d9,[x4,#48]; ldr d3,[x4,#56]; ldr d4,[x4,#64]\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#72]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#80]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#88]; ldr d8,[x4,#96]; ldr d3,[x4,#104]; ldr d4,[x4,#112]\n\t"\ + "add x4,x4,#208\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-88]\n\t" FMA_3V(15, 22, 16, 0, 1, 0, 7, 7, 8)\ + "ldr d0,[x0],#8\n\t" FMA_3V(29, 23, 30, 2, 1, 2, 7, 8, 8) + +#define KERNEL_M3N13_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-80]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-72]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-64]; ldr d9,[x4,#-56]; prfm pldl1keep,[x3]\n\t"\ + "ldr d3,[x4,#-48]; sub w5,w5,#4\n\t"\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#-32]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-16]; ldr d8,[x4,#-8]; prfm pldl1keep,[x8]\n\t"\ + "ldr d3,[x4]\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(15, 22, 16, 0, 1, 0, 7, 7, 8)\ + "ldr d0,[x0],#8\n\t" FMA_3V(29, 23, 30, 2, 1, 2, 7, 8, 8)\ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#40]; ldr d9,[x4,#48]; ldr d3,[x4,#56]; ldr d4,[x4,#64]\n\t"\ + FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#72]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#80]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#88]; ldr d8,[x4,#96]\n\t"\ + "add x4,x4,#104\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(15, 22, 16, 0, 1, 0, 7, 7, 8)\ + FMA_3V(29, 23, 30, 2, 1, 2, 7, 8, 8) + +#define KERNEL_M3N13_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-80]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-72]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-64]; ldr d9,[x4,#-56]\n\t"\ + "ldr d3,[x4,#-48]; sub w5,w5,#2\n\t"\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr d5,[x4,#-32]\n\t" FMA_3V(16, 23, 17, 0, 1, 0, 8, 8, 9)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 24, 31, 2, 1, 2, 8, 9, 9)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-24]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-16]; ldr d8,[x4,#-8]\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(15, 22, 16, 0, 1, 0, 7, 7, 8)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(29, 23, 30, 2, 1, 2, 7, 8, 8) + +#define KERNEL_M3N13_FIN1 \ + "ld1r {v0.2s},[x0],#4; ldr d3,[x4]\n\t"\ + "ldr d4,[x4,#8]; ldr d5,[x4,#16]; add x4,x4,#52\n\t"\ + "ld1r {v1.2s},[x1],#4\n\t" FMA_3V(11, 12, 13, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(18, 19, 20, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-28]\n\t" FMA_3V(25, 26, 27, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-20]\n\t" FMA_3V(14, 21, 28, 0, 1, 2, 6, 6, 6)\ + "ldr d8,[x4,#-12]\n\t" FMA_3V(15, 22, 29, 0, 1, 2, 7, 7, 7)\ + "ldr s9,[x4,#-4]\n\t" FMA_3V(16, 23, 30, 0, 1, 2, 8, 8, 8)\ + FMA_3V(17, 24, 31, 0, 1, 2, 9, 9, 9) + +#define SAVE_M3N13(mode) \ + UNIT_SAVE_M3N2_##mode(11, 18, 25) UNIT_SAVE_M3N2_##mode(12, 19, 26)\ + UNIT_SAVE_M3N2_##mode(13, 20, 27) UNIT_SAVE_M3N2_##mode(14, 21, 28)\ + UNIT_SAVE_M3N2_##mode(15, 22, 29) UNIT_SAVE_M3N2_##mode(16, 23, 30)\ + UNIT_SAVE_M3N1_##mode(17, 24, 31) + + +/* acc layout for m3n15 kernel */ +/* m0n0 v8 v9 v10 v11 v12 v13 v14 v15_h m0n15 */ +/* m1n0 v16 v17 v18 v19 v20 v21 v22 v23_h m1n15 */ +/* m2n0 v24 v25 v26 v27 v28 v29 v30 v31_h m2n15 */ +/* b-holder layout for m3n15 kernel */ +/* n0 v3 v4 v5 v6 v7/v5 v5/v6 v6/v7 v7(s) n15 */ +/* a-holder layout for m3n15 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2 */ + +#define INIT_M3N15 \ + INIT_3V(8, 16, 24) INIT_3V(9, 17, 25) INIT_3V(10, 18, 26)\ + INIT_3V(11, 19, 27) INIT_3V(12, 20, 28) INIT_3V(13, 21, 29)\ + INIT_3V(14, 22, 30) INIT_3V(15, 23, 31) + +#define KERNEL_M3N15_PRELOAD2 \ + "ldr d0,[x0],#8\n\t"\ + "ldr d3,[x4]; ldr d4,[x4,#8]; ldr d5,[x4,#16]; add x4,x4,#120\n\t" + +#define KERNEL_M3N15_MAIN4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-96]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-88]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-80]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-72]; ldr d7,[x4,#-64]; ldr d3,[x4,#-56]; ldr d4,[x4,#-48]\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#-40]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; ldr d3,[x4]; ldr d4,[x4,#8]\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(13, 21, 14, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(29, 22, 30, 2, 1, 2, 6, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#48]; ldr d7,[x4,#56]; ldr d3,[x4,#64]; ldr d4,[x4,#72]\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#80]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#88]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d5,[x4,#96]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d6,[x4,#104]; ldr d7,[x4,#112]; ldr d3,[x4,#120]\n\t"\ + "ldr d4,[x4,#128]; sub w5,w5,#4\n\t"\ + FMA_3V(12, 20, 28, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#136]; add x4,x4,#240\n\t"\ + FMA_3V(13, 21, 14, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8; cmp w5,#6\n\t"\ + FMA_3V(29, 22, 30, 2, 1, 2, 6, 7, 7) + +#define KERNEL_M3N15_TAIL4 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-96]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-88]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-80]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-72]; ldr d7,[x4,#-64]; ldr d3,[x4,#-56]; ldr d4,[x4,#-48]\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#-40]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]; ldr d3,[x4]; ldr d4,[x4,#8]\n\t"\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#16]\n\t" FMA_3V(13, 21, 14, 0, 1, 0, 6, 6, 7)\ + "ldr d0,[x0],#8\n\t" FMA_3V(29, 22, 30, 2, 1, 2, 6, 7, 7)\ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#24]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#32]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#40]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#48]; ldr d7,[x4,#56]; ldr d3,[x4,#64]; ldr d4,[x4,#72]\n\t"\ + FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#80]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#88]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d5,[x4,#96]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d6,[x4,#104]; ldr d7,[x4,#112]\n\t"\ + "sub w5,w5,#4\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 5, 5, 5)\ + "prfm pldl1keep,[x9]; add x4,x4,#120\n\t"\ + FMA_3V(13, 21, 14, 0, 1, 0, 6, 6, 7)\ + FMA_3V(29, 22, 30, 2, 1, 2, 6, 7, 7) + +#define KERNEL_M3N15_TAIL2 \ + "ldr d1,[x1],#8\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ldr d2,[x2],#8\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-96]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-88]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-80]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-72]; ldr d7,[x4,#-64]; ldr d3,[x4,#-56]; ldr d4,[x4,#-48]\n\t"\ + "prfm pldl1keep,[x3]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr d5,[x4,#-40]\n\t" FMA_3V(14, 22, 15, 0, 1, 0, 6, 6, 7)\ + "rev64 v0.2s,v0.2s\n\t" FMA_3V(30, 23, 31, 2, 1, 2, 6, 7, 7)\ + "rev64 v1.2s,v1.2s\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "rev64 v2.2s,v2.2s\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-32]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d5,[x4,#-24]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d6,[x4,#-16]; ldr d7,[x4,#-8]\n\t"\ + "prfm pldl1keep,[x8]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 5, 5, 5)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(13, 21, 14, 0, 1, 0, 6, 6, 7)\ + "sub w5,w5,#2\n\t" FMA_3V(29, 22, 30, 2, 1, 2, 6, 7, 7) + +#define KERNEL_M3N15_FIN1 \ + "ld1r {v0.2s},[x0],#4; ldr d3,[x4]; ldr d4,[x4,#8]; ldr d5,[x4,#16]\n\t"\ + "ld1r {v1.2s},[x1],#4; add x4,x4,#60\n\t" FMA_3V(8, 9, 10, 0, 0, 0, 3, 4, 5)\ + "ld1r {v2.2s},[x2],#4\n\t" FMA_3V(16, 17, 18, 1, 1, 1, 3, 4, 5)\ + "ldr d6,[x4,#-36]\n\t" FMA_3V(24, 25, 26, 2, 2, 2, 3, 4, 5)\ + "ldr d7,[x4,#-28]\n\t" FMA_3V(11, 19, 27, 0, 1, 2, 6, 6, 6)\ + "ldr d5,[x4,#-20]\n\t" FMA_3V(12, 20, 28, 0, 1, 2, 7, 7, 7)\ + "ldr d6,[x4,#-12]\n\t" FMA_3V(13, 21, 29, 0, 1, 2, 5, 5, 5)\ + "ldr s7,[x4,#-4]\n\t" FMA_3V(14, 22, 30, 0, 1, 2, 6, 6, 6)\ + FMA_3V(15, 23, 31, 0, 1, 2, 7, 7, 7) + +#define SAVE_M3N15(mode) \ + UNIT_SAVE_M3N2_##mode(8, 16, 24) UNIT_SAVE_M3N2_##mode(9, 17, 25)\ + UNIT_SAVE_M3N2_##mode(10, 18, 26) UNIT_SAVE_M3N2_##mode(11, 19, 27)\ + UNIT_SAVE_M3N2_##mode(12, 20, 28) UNIT_SAVE_M3N2_##mode(13, 21, 29)\ + UNIT_SAVE_M3N2_##mode(14, 22, 30) UNIT_SAVE_M3N1_##mode(15, 23, 31) + + +/* acc layout for m3n17 kernel */ +/* m0n0 v5 v6 v7 v8 v9 v10 v11 v12 v13_h m0n17 */ +/* m1n0 v14 v15 v16 v17 v18 v19 v20 v21 v22_h m1n17 */ +/* m2n0 v23 v24 v25 v26 v27 v28 v29 v30 v31_h m2n17 */ +/* b-holder layout for m3n17 kernel */ +/* n0 v3-4 alt n17 */ +/* a-holder layout for m3n17 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2 */ + +#define INIT_M3N17 \ + INIT_3V(5, 14, 23) INIT_3V(6, 15, 24) INIT_3V(7, 16, 25)\ + INIT_3V(8, 17, 26) INIT_3V(9, 18, 27) INIT_3V(10, 19, 28)\ + INIT_3V(11, 20, 29) INIT_3V(12, 21, 30) INIT_3V(13, 22, 31) + +#define KERNEL_M3N17_PRELOAD2 \ + "ldr d3,[x4],#136\n\t" + +#define KERNEL_M3N17_MAIN4 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "prfm pldl1keep,[x0,#64]\n\t"\ + "ldr d4,[x4,#-128]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-120]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-112]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-104]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-96]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-88]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-80]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-72]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-64]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3)\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "prfm pldl1keep,[x1,#64]\n\t"\ + "ldr d3,[x4,#-56]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-48]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-40]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-32]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-24]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-16]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-8]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "ldr d3,[x4,#8]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#16]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#24]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#32]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#40]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#48]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#56]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#64]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#72]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 4, 4, 4)\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t"\ + "ldr d4,[x4,#80]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#88]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#96]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#104]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#112]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#120]; sub w5,w5,#4\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#128]; cmp w5,#6\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#136]; add x4,x4,#272\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4) + +#define KERNEL_M3N17_TAIL4 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "ldr d4,[x4,#-128]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-120]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-112]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-104]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-96]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-88]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-80]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-72]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-64]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3)\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "prfm pldl1keep,[x3]\n\t"\ + "ldr d3,[x4,#-56]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-48]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-40]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-32]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-24]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-16]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-8]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "ldr d3,[x4,#8]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#16]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#24]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#32]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#40]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#48]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#56]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#64]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#72]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 4, 4, 4)\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "prfm pldl1keep,[x8]\n\t"\ + "ldr d4,[x4,#80]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#88]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#96]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#104]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#112]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#120]; sub w5,w5,#4\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#128]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "prfm pldl1keep,[x9]; add x4,x4,#136\n\t"\ + FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4) + +#define KERNEL_M3N17_TAIL2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "prfm pldl1keep,[x3]\n\t"\ + "ldr d4,[x4,#-128]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-120]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-112]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-104]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-96]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-88]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-80]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-72]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-64]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3)\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "sub w5,w5,#2; prfm pldl1keep,[x8]\n\t"\ + "ldr d3,[x4,#-56]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-48]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-40]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-32]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-24]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-16]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-8]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3) + +#define KERNEL_M3N17_FIN1 \ + "ldr d3,[x4],#68\n\t"\ + "ld1r {v0.2s},[x0],#4; ld1r {v1.2s},[x1],#4; ld1r {v2.2s},[x2],#4\n\t"\ + "ldr d4,[x4,#-60]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-52]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-44]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-36]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-28]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-20]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-12]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr s3,[x4,#-4]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3) + +#define SAVE_M3N17(mode) \ + UNIT_SAVE_M3N2_##mode(5, 14, 23) UNIT_SAVE_M3N2_##mode(6, 15, 24)\ + UNIT_SAVE_M3N2_##mode(7, 16, 25) UNIT_SAVE_M3N2_##mode(8, 17, 26)\ + UNIT_SAVE_M3N2_##mode(9, 18, 27) UNIT_SAVE_M3N2_##mode(10, 19, 28)\ + UNIT_SAVE_M3N2_##mode(11, 20, 29) UNIT_SAVE_M3N2_##mode(12, 21, 30)\ + UNIT_SAVE_M3N1_##mode(13, 22, 31) + + +/* acc layout for m3n18 kernel */ +/* m0n0 v5 v6 v7 v8 v9 v10 v11 v12 v13 m0n18 */ +/* m1n0 v14 v15 v16 v17 v18 v19 v20 v21 v22 m1n18 */ +/* m2n0 v23 v24 v25 v26 v27 v28 v29 v30 v31 m2n18 */ +/* b-holder layout for m3n18 kernel */ +/* n0 v3-4 alt n18 */ +/* a-holder layout for m3n18 kernel */ +/* a_ptr1->v0, a_ptr2->v1, a_ptr3->v2 */ + +#define INIT_M3N18 \ + INIT_3V(5, 14, 23) INIT_3V(6, 15, 24) INIT_3V(7, 16, 25)\ + INIT_3V(8, 17, 26) INIT_3V(9, 18, 27) INIT_3V(10, 19, 28)\ + INIT_3V(11, 20, 29) INIT_3V(12, 21, 30) INIT_3V(13, 22, 31) + +#define KERNEL_M3N18_PRELOAD2 \ + "ldr d3,[x4],#144\n\t" + +#define KERNEL_M3N18_MAIN4 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "prfm pldl1keep,[x0,#64]; sub w5,w5,#4\n\t"\ + "ldr d4,[x4,#-136]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-128]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-120]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-112]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-104]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-96]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-88]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-80]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-72]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3)\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "prfm pldl1keep,[x1,#64]; cmp w5,#6\n\t"\ + "ldr d3,[x4,#-64]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-56]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-48]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-32]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-24]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-16]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-8]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 4, 4, 4)\ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "prfm pldl1keep,[x2,#64]\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#16]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#24]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#32]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#40]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#48]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#56]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#64]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#72]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3)\ + "add x4,x4,#288\n\t"\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "ldr d3,[x4,#-208]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-200]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-192]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-184]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-176]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-168]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-160]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-152]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-144]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 4, 4, 4) + +#define KERNEL_M3N18_TAIL4 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "prfm pldl1keep,[x3]; sub w5,w5,#4\n\t"\ + "ldr d4,[x4,#-136]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-128]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-120]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-112]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-104]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-96]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-88]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-80]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-72]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3)\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "prfm pldl1keep,[x8]\n\t"\ + "ldr d3,[x4,#-64]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-56]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-48]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-32]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-24]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-16]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-8]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 4, 4, 4)\ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "ldr d4,[x4,#8]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#16]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#24]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#32]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#40]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#48]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#56]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#64]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#72]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3)\ + "add x4,x4,#144\n\t"\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "ldr d3,[x4,#-64]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-56]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-48]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-32]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-24]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-16]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-8]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 4, 4, 4) + +#define KERNEL_M3N18_TAIL2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr d2,[x2],#8\n\t"\ + "prfm pldl1keep,[x3]; sub w5,w5,#2\n\t"\ + "ldr d4,[x4,#-136]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-128]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-120]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-112]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-104]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-96]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-88]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-80]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-72]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3)\ + "rev64 v0.2s,v0.2s; rev64 v1.2s,v1.2s; rev64 v2.2s,v2.2s\n\t"\ + "prfm pldl1keep,[x8]\n\t"\ + "ldr d3,[x4,#-64]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-56]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-48]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-40]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-32]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-24]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-16]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-8]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 3, 3, 3)\ + "prfm pldl1keep,[x9]\n\t" FMA_3V(13, 22, 31, 0, 1, 2, 4, 4, 4) + +#define KERNEL_M3N18_FIN1 \ + "ldr d3,[x4],#72\n\t"\ + "ld1r {v0.2s},[x0],#4; ld1r {v1.2s},[x1],#4; ld1r {v2.2s},[x2],#4\n\t"\ + "ldr d4,[x4,#-64]\n\t" FMA_3V(5, 14, 23, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-56]\n\t" FMA_3V(6, 15, 24, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-48]\n\t" FMA_3V(7, 16, 25, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-40]\n\t" FMA_3V(8, 17, 26, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-32]\n\t" FMA_3V(9, 18, 27, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-24]\n\t" FMA_3V(10, 19, 28, 0, 1, 2, 4, 4, 4)\ + "ldr d4,[x4,#-16]\n\t" FMA_3V(11, 20, 29, 0, 1, 2, 3, 3, 3)\ + "ldr d3,[x4,#-8]\n\t" FMA_3V(12, 21, 30, 0, 1, 2, 4, 4, 4)\ + FMA_3V(13, 22, 31, 0, 1, 2, 3, 3, 3) + +#define SAVE_M3N18(mode) \ + UNIT_SAVE_M3N2_##mode(5, 14, 23) UNIT_SAVE_M3N2_##mode(6, 15, 24)\ + UNIT_SAVE_M3N2_##mode(7, 16, 25) UNIT_SAVE_M3N2_##mode(8, 17, 26)\ + UNIT_SAVE_M3N2_##mode(9, 18, 27) UNIT_SAVE_M3N2_##mode(10, 19, 28)\ + UNIT_SAVE_M3N2_##mode(11, 20, 29) UNIT_SAVE_M3N2_##mode(12, 21, 30)\ + UNIT_SAVE_M3N2_##mode(13, 22, 31) + + +FUNC_PACK4(4, 9) + +FUNC_PACK4(4, 11) + +FUNC_PACK4(3, 13) + +FUNC_PACK4(3, 15) + +FUNC_PACK4(3, 17) + +FUNC_PACK4(3, 18) + + +#define INIT_M1N4 \ + float32x2_t cd1, cd2, cd3, cd4;\ + cd1 = cd2 = cd3 = cd4 = vdup_n_f32(0.0f); + +#define INIT_M1N5 INIT_M1N4 float32x2_t cd5 = vdup_n_f32(0.0f); + +#define INIT_M1N6 INIT_M1N5 float32x2_t cd6 = vdup_n_f32(0.0f); + +#define INIT_M1N7 INIT_M1N6 float32x2_t cd7 = vdup_n_f32(0.0f); + +#define INIT_M1N8 INIT_M1N7 float32x2_t cd8 = vdup_n_f32(0.0f); + +#define INIT_M1N10 INIT_M1N5 + +#define INIT_M1N12 INIT_M1N6 + +#define INIT_M1N14 INIT_M1N7 + +#define INIT_M1N16 INIT_M1N8 + +#define INIT_M1N9 \ + float32x2_t cd1, cd2, cd3, cd4;\ + cd1 = cd2 = cd3 = cd4 = vdup_n_f32(0.0f);\ + float32x2_t cd0 = vdup_n_f32(0.0f); + +#define INIT_M1N11 INIT_M1N10 float32x2_t cd0 = vdup_n_f32(0.0f); + +#define INIT_M1N13 INIT_M1N12 float32x2_t cd0 = vdup_n_f32(0.0f); + +#define INIT_M1N15 INIT_M1N14 float32x2_t cd0 = vdup_n_f32(0.0f); + +#define INIT_M1N17 INIT_M1N16 float32x2_t cd0 = vdup_n_f32(0.0f); + +#define INIT_M1N18 INIT_M1N16 float32x2_t cd9 = vdup_n_f32(0.0f); + +#define LOAD_4D_B \ + float32x2_t bd1 = vld1_f32(b_ptr);\ + float32x2_t bd2 = vld1_f32(b_ptr + 2);\ + float32x2_t bd3 = vld1_f32(b_ptr + 4);\ + float32x2_t bd4 = vld1_f32(b_ptr + 6); + +#define LOAD_5D_B LOAD_4D_B float32x2_t bd5 = vld1_f32(b_ptr + 8); + +#define LOAD_6D_B LOAD_5D_B float32x2_t bd6 = vld1_f32(b_ptr + 10); + +#define LOAD_7D_B LOAD_6D_B float32x2_t bd7 = vld1_f32(b_ptr + 12); + +#define LOAD_8D_B LOAD_7D_B float32x2_t bd8 = vld1_f32(b_ptr + 14); + +#define LOAD_9D_B LOAD_8D_B float32x2_t bd9 = vld1_f32(b_ptr + 16); + +#define ACC_4D \ + cd1 = vfma_f32(cd1, ad1, bd1);\ + cd2 = vfma_f32(cd2, ad1, bd2);\ + cd3 = vfma_f32(cd3, ad1, bd3);\ + cd4 = vfma_f32(cd4, ad1, bd4); + +#define ACC_5D ACC_4D cd5 = vfma_f32(cd5, ad1, bd5); + +#define ACC_6D ACC_5D cd6 = vfma_f32(cd6, ad1, bd6); + +#define ACC_7D ACC_6D cd7 = vfma_f32(cd7, ad1, bd7); + +#define ACC_8D ACC_7D cd8 = vfma_f32(cd8, ad1, bd8); + +#define ACC_9D ACC_8D cd9 = vfma_f32(cd9, ad1, bd9); + +#define REDUC_4D \ + float cs1 = vpadds_f32(cd1); float cs2 = vpadds_f32(cd2);\ + float cs3 = vpadds_f32(cd3); float cs4 = vpadds_f32(cd4);\ + +#define REDUC_5D REDUC_4D float cs5 = vpadds_f32(cd5); + +#define REDUC_6D REDUC_5D float cs6 = vpadds_f32(cd6); + +#define REDUC_7D REDUC_6D float cs7 = vpadds_f32(cd7); + +#define REDUC_8D REDUC_7D float cs8 = vpadds_f32(cd8); + +#define ACC_4S \ + cs1 += as1 * b_ptr[0]; cs2 += as1 * b_ptr[1];\ + cs3 += as1 * b_ptr[2]; cs4 += as1 * b_ptr[3];\ + +#define ACC_5S ACC_4S cs5 += as1 * b_ptr[4]; + +#define ACC_6S ACC_5S cs6 += as1 * b_ptr[5]; + +#define ACC_7S ACC_6S cs7 += as1 * b_ptr[6]; + +#define ACC_8S ACC_7S cs8 += as1 * b_ptr[7]; + +#define UNIT_SAVE_M1N1_CC(cs1) \ + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr += LDC; + +#define UNIT_SAVE_M1N1_CR(cs1) \ + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr++; + +#define UNIT_SAVE_M1N2_CC(cd1) \ + c_ptr[0] = c_ptr[0] * beta + vget_lane_f32(cd1, 0);\ + c_ptr[LDC] = c_ptr[LDC] * beta + vget_lane_f32(cd1, 1);\ + c_ptr += LDC * 2; + +#define UNIT_SAVE_M1N2_CR(cd1) \ + cd1 = vfma_n_f32(cd1, vld1_f32(c_ptr), beta);\ + vst1_f32(c_ptr, cd1); c_ptr += 2; + +#define SAVE_M1N4(mode) \ + UNIT_SAVE_M1N1_##mode(cs1) UNIT_SAVE_M1N1_##mode(cs2)\ + UNIT_SAVE_M1N1_##mode(cs3) UNIT_SAVE_M1N1_##mode(cs4)\ + +#define SAVE_M1N5(mode) SAVE_M1N4(mode) UNIT_SAVE_M1N1_##mode(cs5) + +#define SAVE_M1N6(mode) SAVE_M1N5(mode) UNIT_SAVE_M1N1_##mode(cs6) + +#define SAVE_M1N7(mode) SAVE_M1N6(mode) UNIT_SAVE_M1N1_##mode(cs7) + +#define SAVE_M1N8(mode) SAVE_M1N7(mode) UNIT_SAVE_M1N1_##mode(cs8) + +#define SAVE_M1N10(mode) \ + UNIT_SAVE_M1N2_##mode(cd1) UNIT_SAVE_M1N2_##mode(cd2)\ + UNIT_SAVE_M1N2_##mode(cd3) UNIT_SAVE_M1N2_##mode(cd4)\ + UNIT_SAVE_M1N2_##mode(cd5) + +#define SAVE_M1N12(mode) SAVE_M1N10(mode) UNIT_SAVE_M1N2_##mode(cd6) + +#define SAVE_M1N14(mode) SAVE_M1N12(mode) UNIT_SAVE_M1N2_##mode(cd7) + +#define SAVE_M1N16(mode) SAVE_M1N14(mode) UNIT_SAVE_M1N2_##mode(cd8) + +#define SAVE_M1N18(mode) SAVE_M1N16(mode) UNIT_SAVE_M1N2_##mode(cd9) + +#define SAVE_M1N9(mode) \ + UNIT_SAVE_M1N2_##mode(cd1) UNIT_SAVE_M1N2_##mode(cd2)\ + UNIT_SAVE_M1N2_##mode(cd3) UNIT_SAVE_M1N2_##mode(cd4)\ + UNIT_SAVE_M1N1_##mode(cs0) + +#define SAVE_M1N11(mode) SAVE_M1N10(mode) UNIT_SAVE_M1N1_##mode(cs0) + +#define SAVE_M1N13(mode) SAVE_M1N12(mode) UNIT_SAVE_M1N1_##mode(cs0) + +#define SAVE_M1N15(mode) SAVE_M1N14(mode) UNIT_SAVE_M1N1_##mode(cs0) + +#define SAVE_M1N17(mode) SAVE_M1N16(mode) UNIT_SAVE_M1N1_##mode(cs0) + +#define COMPUTE_M1_PACK3(ndim) \ + for (; k_left > 1; k_left -= 2) {\ + float32x2_t ad1 = vld1_f32(a_ptr); a_ptr += 2;\ + LOAD_##ndim##D_B\ + ACC_##ndim##D\ + b_ptr += 2 * ndim;\ + }\ + REDUC_##ndim##D\ + if (k_left > 0) {\ + float as1 = *a_ptr;\ + ACC_##ndim##S\ + } + +#define COMPUTE_M1_PACK0_BASE(ndiv2) \ + for (; k_left > 0; k_left--) {\ + float32x2_t ad1 = vld1_dup_f32(a_ptr); a_ptr++;\ + LOAD_##ndiv2##D_B\ + ACC_##ndiv2##D\ + b_ptr += ndiv2 * 2;\ + } + +#define COMPUTE_M1_PACK0_N10 COMPUTE_M1_PACK0_BASE(5) +#define COMPUTE_M1_PACK0_N12 COMPUTE_M1_PACK0_BASE(6) +#define COMPUTE_M1_PACK0_N14 COMPUTE_M1_PACK0_BASE(7) +#define COMPUTE_M1_PACK0_N16 COMPUTE_M1_PACK0_BASE(8) +#define COMPUTE_M1_PACK0(ndim) COMPUTE_M1_PACK0_N##ndim + +#define COMPUTE_M1_PACK4_EVEN(ndiv2) \ + for (; k_left > 1; k_left -= 2) {\ + float32x2_t ad1 = vld1_f32(a_ptr); a_ptr += 2;\ + {\ + LOAD_##ndiv2##D_B b_ptr += ndiv2 * 2;\ + ACC_##ndiv2##D\ + }\ + ad1 = vrev64_f32(ad1);\ + LOAD_##ndiv2##D_B b_ptr += ndiv2 * 2;\ + ACC_##ndiv2##D\ + }\ + if (k_left > 0) {\ + float32x2_t ad1 = vld1_dup_f32(a_ptr);\ + LOAD_##ndiv2##D_B\ + ACC_##ndiv2##D\ + } + +#define COMPUTE_M1_PACK4_N18 COMPUTE_M1_PACK4_EVEN(9) + +#define COMPUTE_M1_PACK4_ODD(ndiv2) \ + for (; k_left > 1; k_left -= 2) {\ + float32x2_t ad1 = vld1_f32(a_ptr); a_ptr += 2;\ + {\ + LOAD_##ndiv2##D_B\ + float32x2_t bd0 = vld1_f32(b_ptr + ndiv2 * 2);\ + b_ptr += ndiv2 * 2 + 2;\ + ACC_##ndiv2##D\ + cd0 = vfma_f32(cd0, ad1, bd0);\ + }\ + ad1 = vrev64_f32(ad1);\ + LOAD_##ndiv2##D_B b_ptr += ndiv2 * 2;\ + ACC_##ndiv2##D\ + }\ + float cs0 = vpadds_f32(cd0);\ + if (k_left > 0) {\ + float32x2_t ad1 = vld1_dup_f32(a_ptr);\ + LOAD_##ndiv2##D_B\ + float bs0 = b_ptr[ndiv2 * 2];\ + ACC_##ndiv2##D\ + cs0 += bs0 * vget_lane_f32(ad1, 0);\ + } + +#define COMPUTE_M1_PACK4_N9 COMPUTE_M1_PACK4_ODD(4) +#define COMPUTE_M1_PACK4_N11 COMPUTE_M1_PACK4_ODD(5) +#define COMPUTE_M1_PACK4_N13 COMPUTE_M1_PACK4_ODD(6) +#define COMPUTE_M1_PACK4_N15 COMPUTE_M1_PACK4_ODD(7) +#define COMPUTE_M1_PACK4_N17 COMPUTE_M1_PACK4_ODD(8) + +#define COMPUTE_M1_PACK4(ndim) COMPUTE_M1_PACK4_N##ndim + +#define FUNC_EDGE(ndim, pack) \ +static inline void sgemm_skinny1_a35_m1n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_ptr,\ + float * __restrict__ c_ptr, uint32_t k_left, uint32_t LDC,\ + uint8_t c_rowmajor, float beta) {\ + INIT_M1N##ndim\ + COMPUTE_M1_PACK##pack(ndim)\ + if (c_rowmajor == 0) {\ + SAVE_M1N##ndim(CC)\ + } else {\ + SAVE_M1N##ndim(CR)\ + }\ +} + +FUNC_EDGE(4, 3) + +FUNC_EDGE(5, 3) + +FUNC_EDGE(6, 3) + +FUNC_EDGE(7, 3) + +FUNC_EDGE(8, 3) + +FUNC_EDGE(10, 0) + +FUNC_EDGE(12, 0) + +FUNC_EDGE(14, 0) + +FUNC_EDGE(16, 0) + +FUNC_EDGE(9, 4) + +FUNC_EDGE(11, 4) + +FUNC_EDGE(13, 4) + +FUNC_EDGE(15, 4) + +FUNC_EDGE(17, 4) + +FUNC_EDGE(18, 4) + +#endif diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA53.h b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA53.h new file mode 100644 index 0000000..f0f9370 --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA53.h @@ -0,0 +1,4306 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +#ifndef INCLUDE_A53_KERNEL +#define INCLUDE_A53_KERNEL + +/* x0 - x3 for a_ptrs */ +/* x4 for b_ptr, x5 for k_left */ +/* x6 - x9 for a_pref */ +/* x10 - x11, x16 - x17 and x19 - x20 for vec_fill */ +/* x12 - x15 for c_tmp */ + +#define INIT_SAVE \ + "ldr s0,[%[beta_addr]]; mov x12,%[c_ptr]\n\t"\ + "add x13,%[c_ptr],%w[LDC],UXTW #2; add x14,%[c_ptr],%w[LDC],UXTW #3\n\t"\ + "add x15,x13,%w[LDC],UXTW #3\n\t" + +#define UNIT_SAVE_M4N4_VR_CC(c1, c2, c3, c4) \ + "ldr q1,[x12]; ldr q2,[x13]; ldr q3,[x14]; ldr q4,[x15]\n\t"\ + "zip1 v5.4s,v"#c1".4s,v"#c2".4s; zip1 v6.4s,v"#c3".4s,v"#c4".4s\n\t"\ + "zip2 v7.4s,v"#c1".4s,v"#c2".4s; zip2 v"#c4".4s,v"#c3".4s,v"#c4".4s\n\t"\ + "zip1 v"#c1".2d,v5.2d,v6.2d; zip1 v"#c3".2d,v7.2d,v"#c4".2d\n\t"\ + "zip2 v"#c2".2d,v5.2d,v6.2d; zip2 v"#c4".2d,v7.2d,v"#c4".2d\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]; fmla v"#c2".4s,v2.4s,v0.s[0]\n\t"\ + "fmla v"#c3".4s,v3.4s,v0.s[0]; fmla v"#c4".4s,v4.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12]; prfm pldl2keep,[x12,#32]\n\t"\ + "add x12,x12,%w[LDC],UXTW #4; prfm pstl1keep,[x12,#8]\n\t"\ + "str q"#c2",[x13]; prfm pldl2keep,[x13,#32]\n\t"\ + "add x13,x13,%w[LDC],UXTW #4; prfm pstl1keep,[x13,#8]\n\t"\ + "str q"#c3",[x14]; prfm pldl2keep,[x14,#32]\n\t"\ + "add x14,x14,%w[LDC],UXTW #4; prfm pstl1keep,[x14,#8]\n\t"\ + "str q"#c4",[x15]; prfm pldl2keep,[x15,#32]\n\t"\ + "add x15,x15,%w[LDC],UXTW #4; prfm pstl1keep,[x15,#8]\n\t" + +#define UNIT_SAVE_M4N4_VR_CR(c1, c2, c3, c4) \ + "ldr q1,[x12]; ldr q2,[x13]; ldr q3,[x14]; ldr q4,[x15]\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]; fmla v"#c2".4s,v2.4s,v0.s[0]\n\t"\ + "fmla v"#c3".4s,v3.4s,v0.s[0]; fmla v"#c4".4s,v4.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12],#16; str q"#c2",[x13],#16\n\t"\ + "str q"#c3",[x14],#16; str q"#c4",[x15],#16\n\t" + +#define UNIT_SAVE_M4N4_VC_CC(c1, c2, c3, c4) \ + "ldr q1,[x12]; ldr q2,[x13]; ldr q3,[x14]; ldr q4,[x15]\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]; fmla v"#c2".4s,v2.4s,v0.s[0]\n\t"\ + "fmla v"#c3".4s,v3.4s,v0.s[0]; fmla v"#c4".4s,v4.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12]; prfm pldl2keep,[x12,#32]\n\t"\ + "add x12,x12,%w[LDC],UXTW #4; prfm pstl1keep,[x12,#8]\n\t"\ + "str q"#c2",[x13]; prfm pldl2keep,[x13,#32]\n\t"\ + "add x13,x13,%w[LDC],UXTW #4; prfm pstl1keep,[x13,#8]\n\t"\ + "str q"#c3",[x14]; prfm pldl2keep,[x14,#32]\n\t"\ + "add x14,x14,%w[LDC],UXTW #4; prfm pstl1keep,[x14,#8]\n\t"\ + "str q"#c4",[x15]; prfm pldl2keep,[x15,#32]\n\t"\ + "add x15,x15,%w[LDC],UXTW #4; prfm pstl1keep,[x15,#8]\n\t" + +#define UNIT_SAVE_M4N4_VC_CR(c1, c2, c3, c4) \ + "zip1 v1.4s,v"#c1".4s,v"#c2".4s; zip1 v2.4s,v"#c3".4s,v"#c4".4s\n\t"\ + "zip2 v3.4s,v"#c1".4s,v"#c2".4s; zip2 v4.4s,v"#c3".4s,v"#c4".4s\n\t"\ + "zip1 v"#c1".2d,v1.2d,v2.2d; zip2 v"#c2".2d,v1.2d,v2.2d\n\t"\ + "ldr q1,[x12]; ldr q2,[x13]\n\t"\ + "zip1 v"#c3".2d,v3.2d,v4.2d; zip2 v"#c4".2d,v3.2d,v4.2d\n\t"\ + "ldr q3,[x14]; ldr q4,[x15]\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]; fmla v"#c2".4s,v2.4s,v0.s[0]\n\t"\ + "fmla v"#c3".4s,v3.4s,v0.s[0]; fmla v"#c4".4s,v4.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12],#16; str q"#c2",[x13],#16\n\t"\ + "str q"#c3",[x14],#16; str q"#c4",[x15],#16\n\t" + +#define EDGE_SAVE_M4N1K4_CC(c1, c2, c3, c4) \ + "ldr q1,[x12]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c2".4s\n\t"\ + "faddp v"#c3".4s,v"#c3".4s,v"#c4".4s\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c3".4s\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]; str q"#c1",[x12]\n\t"\ + "prfm pldl1keep,[x12,#32]; add x12,x12,%w[LDC],UXTW #2\n\t" + +#define EDGE_SAVE_M4N1K4_CR(c1, c2, c3, c4) \ + "ldr s1,[x12]; ldr s2,[x13]; ldr s3,[x14]; ldr s4,[x15]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c2".4s\n\t"\ + "ins v1.s[1],v2.s[0]; ins v3.s[1],v4.s[0]\n\t"\ + "faddp v"#c3".4s,v"#c3".4s,v"#c4".4s\n\t"\ + "ins v1.d[1],v3.d[0]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c3".4s\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]\n\t"\ + "st1 {v"#c1".s}[0],[x12],#4; st1 {v"#c1".s}[1],[x13],#4\n\t"\ + "st1 {v"#c1".s}[2],[x14],#4; st1 {v"#c1".s}[3],[x15],#4\n\t" + +#define EDGE_SAVE_M4N1K2_CC(c1, c2) \ + "ldr q1,[x12]\n\t"\ + "trn1 v2.4s,v"#c1".4s,v"#c2".4s; trn2 v3.4s,v"#c1".4s,v"#c2".4s\n\t"\ + "fadd v2.4s,v2.4s,v3.4s; fmla v2.4s,v1.4s,v0.s[0]\n\t"\ + "str q2,[x12]; prfm pstl2keep,[x12,#32]; add x12,x12,%w[LDC],UXTW #2\n\t" + +#define EDGE_SAVE_M4N1K2_CR(c1, c2) \ + "ldr s1,[x12]; ldr s2,[x13]; ldr s3,[x14]; ldr s4,[x15]\n\t"\ + "dup d5,v"#c1".d[1]; ins v1.s[1],v2.s[0]\n\t"\ + "dup d6,v"#c2".d[1]; ins v3.s[1],v4.s[0]\n\t"\ + "faddp v"#c1".2s,v"#c1".2s,v"#c2".2s; faddp v"#c2".2s,v5.2s,v6.2s\n\t"\ + "fmla v"#c1".2s,v1.2s,v0.s[0]; fmla v"#c2".2s,v3.2s,v0.s[0]\n\t"\ + "st1 {v"#c1".s}[0],[x12],#4; st1 {v"#c1".s}[1],[x13],#4\n\t"\ + "st1 {v"#c2".s}[0],[x14],#4; st1 {v"#c2".s}[1],[x15],#4\n\t" + +#define EDGE_SAVE_M4N1K1_CC(c1) \ + "ldr q1,[x12]; fmla v"#c1".4s,v1.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12]; prfm pstl2keep,[x12,#32]\n\t"\ + "add x12,x12,%w[LDC],UXTW #2\n\t" + +#define EDGE_SAVE_M4N1K1_CR(c1) \ + "ldr s1,[x12]; ldr s2,[x13]; ldr s3,[x14]; ldr s4,[x15]\n\t"\ + "ins v1.s[1],v2.s[0]; ins v3.s[1],v4.s[0]; ins v1.d[1],v3.d[0]\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]\n\t"\ + "st1 {v"#c1".s}[0],[x12],#4; st1 {v"#c1".s}[1],[x13],#4\n\t"\ + "st1 {v"#c1".s}[2],[x14],#4; st1 {v"#c1".s}[3],[x15],#4\n\t" + +#define INIT_1V(c1) "movi v"#c1".16b,#0\n\t" + +#define INIT_2V(c1, c2) \ + "movi v"#c1".16b,#0; movi v"#c2".16b,#0\n\t"\ + +#define INIT_4V(c1, c2, c3, c4) INIT_2V(c1, c2) INIT_2V(c3, c4) + +/* m4n4 c_vec */ +/* v28(v24) */ +/* v29(v25) */ +/* v30(v26) */ +/* v31(v27) */ +#define INIT_M4N4 \ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N4(mode) \ + "fadd v28.4s,v28.4s,v24.4s; fadd v29.4s,v29.4s,v25.4s\n\t"\ + "fadd v30.4s,v30.4s,v26.4s; fadd v31.4s,v31.4s,v27.4s\n\t"\ + UNIT_SAVE_M4N4_VR_##mode(28, 29, 30, 31) + +#define KERNEL_M4N4_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#64; ldr d9,[x4,#-48]; ldr x10,[x4,#-40]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N4_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".s[0]; prfm pldl1keep,[x1,#80]\n\t"\ + "fmla v25.4s,v8.4s,v"#ac2".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-32]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[0]; ldr x10,[x4,#-24]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmov v10.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v28.4s,v9.4s,v"#ac1".s[1]; prfm pldl1keep,[x4,#96]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac2".s[1]; ldr x11,[x1,#-8]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d11,[x4,#-16]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-8]\n\t"\ + "fmla v31.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x2,#80]\n\t"\ + "fmov v11.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v24.4s,v10.4s,v"#ac1".s[2]; prfm pldl1keep,[x3,#80]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[2]; ldr x16,[x2,#-8]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d8,[x4]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[2]; ldr x10,[x4,#8]\n\t"\ + "fmla v27.4s,v10.4s,v"#ac4".s[2]; sub w5,w5,#4\n\t"\ + "fmov v8.d[1],x10; ldr d"#an4",[x3],#16\n\t"\ + "fmla v28.4s,v11.4s,v"#ac1".s[3]; cmp w5,#12\n\t"\ + "fmla v29.4s,v11.4s,v"#ac2".s[3]; ldr x11,[x3,#-8]\n\t"\ + "fmov v"#an3".d[1],x16; ldr d9,[x4,#16]\n\t"\ + "fmla v30.4s,v11.4s,v"#ac3".s[3]; ldr x10,[x4,#24]\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".s[3]; add x4,x4,#64\n\t" + +#define KERNEL_M4N4_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v25.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-32]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[0]; ldr x10,[x4,#-24]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmov v10.d[1],x10\n\t"\ + "fmla v28.4s,v9.4s,v"#ac1".s[1]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "ldr d11,[x4,#-16]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-8]\n\t"\ + "fmla v31.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v24.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[2]\n\t"\ + "fmla v27.4s,v10.4s,v"#ac4".s[2]; sub w5,w5,#4\n\t"\ + "fmla v28.4s,v11.4s,v"#ac1".s[3]; prfm pldl1keep,[x9]\n\t"\ + "fmla v29.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v30.4s,v11.4s,v"#ac3".s[3]\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".s[3]\n\t" + +#define KERNEL_M4N4_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4\n\t"\ + "ldr q8,[x4],#16\n\t"\ + "ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "fmla v28.4s,v8.4s,v0.s[0]; sub w5,w5,#1\n\t"\ + "fmla v29.4s,v8.4s,v1.s[0]; cmp w5,#1\n\t"\ + "fmla v30.4s,v8.4s,v2.s[0]\n\t"\ + "fmla v31.4s,v8.4s,v3.s[0]\n\t" + +/* m4n5 c_vec */ +/* v21(v20) v22_comp */ +/* v24(v23) v25_comp */ +/* v27(v26) v28_comp */ +/* v30(v29) v31_comp */ + +#define INIT_M4N5 \ + INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N5(mode) \ + "fadd v21.4s,v21.4s,v20.4s; fadd v24.4s,v24.4s,v23.4s\n\t"\ + "fadd v27.4s,v27.4s,v26.4s; fadd v30.4s,v30.4s,v29.4s\n\t"\ + UNIT_SAVE_M4N4_VR_##mode(21, 24, 27, 30) EDGE_SAVE_M4N1K4_##mode(22, 25, 28, 31) + +#define KERNEL_M4N5_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#80; ldr d9,[x4,#-64]; ldr x10,[x4,#-56]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N5_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[0]; prfm pldl1keep,[x1,#80]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-48]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[0]; ldr x10,[x4,#-40]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmov v10.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[1]; prfm pldl1keep,[x4,#96]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[1]; ldr x11,[x1,#-8]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d11,[x4,#-32]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-24]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x2,#80]\n\t"\ + "fmov v11.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmla v23.4s,v10.4s,v"#ac2".s[2]; ldr x16,[x2,#-8]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d12,[x4,#-16]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[2]; ldr x10,[x4,#-8]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[2]; sub w5,w5,#4\n\t"\ + "fmov v12.d[1],x10; ldr d"#an4",[x3],#16\n\t"\ + "fmla v21.4s,v11.4s,v"#ac1".s[3]; cmp w5,#12\n\t"\ + "fmla v24.4s,v11.4s,v"#ac2".s[3]; ldr x11,[x3,#-8]\n\t"\ + "fmov v"#an3".d[1],x16; ldr d8,[x4]\n\t"\ + "fmla v27.4s,v11.4s,v"#ac3".s[3]; ldr x10,[x4,#8]\n\t"\ + "fmla v30.4s,v11.4s,v"#ac4".s[3]; add x4,x4,#80\n\t"\ + "fmla v22.4s,v12.4s,v"#ac1".4s; prfm pldl1keep,[x3,#64]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-64]\n\t"\ + "fmla v25.4s,v12.4s,v"#ac2".4s; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v12.4s,v"#ac3".4s; prfm pldl1keep,[x4,#48]\n\t"\ + "fmla v31.4s,v12.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N5_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-48]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[0]; ldr x10,[x4,#-40]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[1]; prfm pldl1keep,[x6]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "ldr d11,[x4,#-32]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-24]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmla v23.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "ldr d12,[x4,#-16]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[2]; ldr x10,[x4,#-8]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[2]; sub w5,w5,#4\n\t"\ + "fmov v12.d[1],x10\n\t"\ + "fmla v21.4s,v11.4s,v"#ac1".s[3]; prfm pldl1keep,[x8]\n\t"\ + "fmla v24.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v27.4s,v11.4s,v"#ac3".s[3]\n\t"\ + "fmla v30.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "fmla v22.4s,v12.4s,v"#ac1".4s\n\t"\ + "fmla v25.4s,v12.4s,v"#ac2".4s\n\t"\ + "fmla v28.4s,v12.4s,v"#ac3".4s; prfm pldl1keep,[x9]\n\t"\ + "fmla v31.4s,v12.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N5_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4\n\t"\ + "ldr q8,[x4],#16\n\t"\ + "ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "fmla v21.4s,v8.4s,v0.s[0]; sub w5,w5,#1\n\t"\ + "fmla v24.4s,v8.4s,v1.s[0]; cmp w5,#1\n\t"\ + "fmla v27.4s,v8.4s,v2.s[0]\n\t"\ + "ldr s9,[x4],#4\n\t"\ + "fmla v30.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v22.4s,v0.4s,v9.s[0]\n\t"\ + "fmla v25.4s,v1.4s,v9.s[0]\n\t"\ + "fmla v28.4s,v2.4s,v9.s[0]\n\t"\ + "fmla v31.4s,v3.4s,v9.s[0]\n\t" + +/* m4n6 c_vec */ +/* v17(v16) v18_comp v19_comp */ +/* v21(v20) v22_comp v23_comp */ +/* v25(v24) v26_comp v27_comp */ +/* v29(v28) v30_comp v31_comp */ + +#define INIT_M4N6 \ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N6(mode) \ + "fadd v17.4s,v17.4s,v16.4s; fadd v21.4s,v21.4s,v20.4s\n\t"\ + "fadd v25.4s,v25.4s,v24.4s; fadd v29.4s,v29.4s,v28.4s\n\t"\ + UNIT_SAVE_M4N4_VR_##mode(17, 21, 25, 29) EDGE_SAVE_M4N1K4_##mode(18, 22, 26, 30)\ + EDGE_SAVE_M4N1K4_##mode(19, 23, 27, 31) + +#define KERNEL_M4N6_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#96; ldr d9,[x4,#-80]; ldr x10,[x4,#-72]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N6_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]; prfm pldl1keep,[x1,#80]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-64]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[0]; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmov v10.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[1]; prfm pldl1keep,[x4,#96]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[1]; ldr x11,[x1,#-8]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d8,[x4,#-48]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x2,#80]\n\t"\ + "fmov v8.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".s[2]; ldr x16,[x2,#-8]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d9,[x4,#-32]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[2]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac4".s[2]; sub w5,w5,#4\n\t"\ + "fmov v9.d[1],x10; ldr d10,[x4,#-16]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac1".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v25.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d"#an4",[x3],#16\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[3]; cmp w5,#12\n\t"\ + "fmla v18.4s,v9.4s,v"#ac1".4s; ldr x11,[x3,#-8]\n\t"\ + "fmla v22.4s,v9.4s,v"#ac2".4s; prfm pldl1keep,[x3,#64]\n\t"\ + "fmov v"#an3".d[1],x16; ldr d8,[x4]\n\t"\ + "fmla v26.4s,v9.4s,v"#ac3".4s; ldr x10,[x4,#8]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".4s; prfm pldl1keep,[x4,#144]\n\t"\ + "fmla v19.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#16]\n\t"\ + "fmla v23.4s,v10.4s,v"#ac2".4s; ldr x10,[x4,#24]\n\t"\ + "fmla v27.4s,v10.4s,v"#ac3".4s; add x4,x4,#96\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N6_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-64]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[0]; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[1]; prfm pldl1keep,[x6]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "ldr d8,[x4,#-48]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmov v8.d[1],x10\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "ldr d9,[x4,#-32]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[2]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac4".s[2]; sub w5,w5,#4\n\t"\ + "fmov v9.d[1],x10; ldr d10,[x4,#-16]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac1".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v25.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "fmov v10.d[1],x10\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[3]; prfm pldl1keep,[x8]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac1".4s\n\t"\ + "fmla v22.4s,v9.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v9.4s,v"#ac3".4s\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".4s; prfm pldl1keep,[x9]\n\t"\ + "fmla v19.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmla v23.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v27.4s,v10.4s,v"#ac3".4s\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N6_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4\n\t"\ + "ldr q8,[x4],#16\n\t"\ + "ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "fmla v17.4s,v8.4s,v0.s[0]; sub w5,w5,#1\n\t"\ + "fmla v21.4s,v8.4s,v1.s[0]; cmp w5,#1\n\t"\ + "fmla v25.4s,v8.4s,v2.s[0]\n\t"\ + "ldr d9,[x4],#8\n\t"\ + "fmla v29.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v18.4s,v0.4s,v9.s[0]\n\t"\ + "fmla v22.4s,v1.4s,v9.s[0]\n\t"\ + "fmla v26.4s,v2.4s,v9.s[0]\n\t"\ + "fmla v30.4s,v3.4s,v9.s[0]\n\t"\ + "fmla v19.4s,v0.4s,v9.s[1]\n\t"\ + "fmla v23.4s,v1.4s,v9.s[1]\n\t"\ + "fmla v27.4s,v2.4s,v9.s[1]\n\t"\ + "fmla v31.4s,v3.4s,v9.s[1]\n\t" + + +/* m4n7 c_vec */ +/* v13(v12) v14_comp v15_comp v16_comp */ +/* v18(v17) v19_comp v20_comp v21_comp */ +/* v23(v22) v24_comp v25_comp v26_comp */ +/* v28(v27) v29_comp v30_comp v31_comp */ + +#define INIT_M4N7 \ + INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N7(mode) \ + "fadd v13.4s,v13.4s,v12.4s; fadd v18.4s,v18.4s,v17.4s\n\t"\ + "fadd v23.4s,v23.4s,v22.4s; fadd v28.4s,v28.4s,v27.4s\n\t"\ + UNIT_SAVE_M4N4_VR_##mode(13, 18, 23, 28) EDGE_SAVE_M4N1K4_##mode(14, 19, 24, 29)\ + EDGE_SAVE_M4N1K4_##mode(15, 20, 25, 30) EDGE_SAVE_M4N1K4_##mode(16, 21, 26, 31) + +#define KERNEL_M4N7_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#112; ldr d9,[x4,#-96]; ldr x10,[x4,#-88]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N7_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-80]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmov v10.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[1]; prfm pldl1keep,[x4,#56]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[1]; ldr x11,[x1,#-8]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d8,[x4,#-64]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x1,#64]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-48]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v17.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[2]\n\t"\ + "fmla v27.4s,v10.4s,v"#ac4".s[2]; sub w5,w5,#4\n\t"\ + "fmla v13.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d10,[x4,#-32]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac2".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac3".s[3]; prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v14.4s,v9.4s,v"#ac1".4s; ldr x10,[x4,#-8]\n\t"\ + "fmla v19.4s,v9.4s,v"#ac2".4s; ldr x16,[x2,#-8]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac3".4s\n\t"\ + "fmov v11.d[1],x10; ldr d"#an4",[x3],#16\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".4s; cmp w5,#12\n\t"\ + "fmla v15.4s,v10.4s,v"#ac1".4s; prfm pldl1keep,[x3,#64]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmov v"#an3".d[1],x16; ldr d8,[x4]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac3".4s; ldr x10,[x4,#8]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".4s; prfm pldl1keep,[x4,#120]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".4s; ldr x11,[x3,#-8]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#16]\n\t"\ + "fmla v21.4s,v11.4s,v"#ac2".4s; ldr x10,[x4,#24]\n\t"\ + "fmla v26.4s,v11.4s,v"#ac3".4s; add x4,x4,#112\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N7_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-80]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "ldr d8,[x4,#-64]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-48]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v17.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[2]\n\t"\ + "fmla v27.4s,v10.4s,v"#ac4".s[2]; sub w5,w5,#4\n\t"\ + "fmla v13.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "ldr d10,[x4,#-32]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac2".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac3".s[3]; prfm pldl1keep,[x8]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v14.4s,v9.4s,v"#ac1".4s; ldr x10,[x4,#-8]\n\t"\ + "fmla v19.4s,v9.4s,v"#ac2".4s\n\t"\ + "fmla v24.4s,v9.4s,v"#ac3".4s\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".4s\n\t"\ + "fmla v15.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v25.4s,v10.4s,v"#ac3".4s\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".4s; prfm pldl1keep,[x9]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".4s\n\t"\ + "fmla v21.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v11.4s,v"#ac3".4s\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N7_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4\n\t"\ + "ldr q8,[x4],#16\n\t"\ + "ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "fmla v13.4s,v8.4s,v0.s[0]; sub w5,w5,#1\n\t"\ + "fmla v18.4s,v8.4s,v1.s[0]; cmp w5,#1\n\t"\ + "fmla v23.4s,v8.4s,v2.s[0]\n\t"\ + "ldr d9,[x4],#8\n\t"\ + "fmla v28.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v14.4s,v0.4s,v9.s[0]\n\t"\ + "fmla v19.4s,v1.4s,v9.s[0]\n\t"\ + "ldr s10,[x4],#4\n\t"\ + "fmla v24.4s,v2.4s,v9.s[0]\n\t"\ + "fmla v29.4s,v3.4s,v9.s[0]\n\t"\ + "fmla v15.4s,v0.4s,v9.s[1]\n\t"\ + "fmla v20.4s,v1.4s,v9.s[1]\n\t"\ + "fmla v25.4s,v2.4s,v9.s[1]\n\t"\ + "fmla v30.4s,v3.4s,v9.s[1]\n\t"\ + "fmla v16.4s,v0.4s,v10.s[0]\n\t"\ + "fmla v21.4s,v1.4s,v10.s[0]\n\t"\ + "fmla v26.4s,v2.4s,v10.s[0]\n\t"\ + "fmla v31.4s,v3.4s,v10.s[0]\n\t" + + +/* m4n8 c_vec */ +/* v24 - v25 */ +/* v26 - v27 */ +/* v28 - v29 */ +/* v30 - v31 */ + +#define INIT_M4N8 \ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N8(mode) \ + UNIT_SAVE_M4N4_VR_##mode(24, 26, 28, 30) UNIT_SAVE_M4N4_VR_##mode(25, 27, 29, 31) + +#define KERNEL_M4N8_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#128; ldr d9,[x4,#-112]; ldr x10,[x4,#-104]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N8_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac2".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-96]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac1".s[0]; prfm pldl1keep,[x4,#40]\n\t"\ + "fmla v30.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-80]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v31.4s,v9.4s,v"#ac4".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v11.d[1],x10; ldr d8,[x4,#-64]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[1]; sub w5,w5,#4\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v25.4s,v11.4s,v"#ac1".s[1]\n\t"\ + "fmla v27.4s,v11.4s,v"#ac2".s[1]; prfm pldl1keep,[x1,#64]\n\t"\ + "fmla v29.4s,v11.4s,v"#ac3".s[1]; cmp w5,#12\n\t"\ + "fmov v"#an1".d[1],x16; ldr d9,[x4,#-48]\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".s[2]; ldr x11,[x1,#-8]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "fmov v9.d[1],x10; ldr d10,[x4,#-32]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac3".s[2]; ldr x10,[x4,#-24]\n\t"\ + "fmla v30.4s,v8.4s,v"#ac4".s[2]; prfm pldl1keep,[x4,#104]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac1".s[2]; prfm pldl1keep,[x3,#80]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac2".s[2]; ldr x10,[x4,#-8]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac3".s[2]; add x4,x4,#128\n\t"\ + "fmla v31.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "fmov v11.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v24.4s,v10.4s,v"#ac1".s[3]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac2".s[3]; prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[3]; ldr x16,[x2,#-8]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d8,[x4,#-128]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[3]; ldr x10,[x4,#-120]\n\t"\ + "fmla v25.4s,v11.4s,v"#ac1".s[3]; ldr x11,[x3],#16\n\t"\ + "fmla v27.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-112]\n\t"\ + "fmov v"#an3".d[1],x16; fmov d"#an4",x11\n\t"\ + "fmla v29.4s,v11.4s,v"#ac3".s[3]; ldr x10,[x4,#-104]\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".s[3]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N8_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-96]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac1".s[0]\n\t"\ + "fmla v30.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-80]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v31.4s,v9.4s,v"#ac4".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v11.d[1],x10; ldr d8,[x4,#-64]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[1]; sub w5,w5,#4\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v8.d[1],x10\n\t"\ + "fmla v25.4s,v11.4s,v"#ac1".s[1]\n\t"\ + "fmla v27.4s,v11.4s,v"#ac2".s[1]\n\t"\ + "fmla v29.4s,v11.4s,v"#ac3".s[1]\n\t"\ + "ldr d9,[x4,#-48]\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".s[2]; prfm pldl1keep,[x7]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "fmov v9.d[1],x10; ldr d10,[x4,#-32]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac3".s[2]; ldr x10,[x4,#-24]\n\t"\ + "fmla v30.4s,v8.4s,v"#ac4".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac1".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac2".s[2]; ldr x10,[x4,#-8]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac3".s[2]\n\t"\ + "fmla v31.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v24.4s,v10.4s,v"#ac1".s[3]; prfm pldl1keep,[x9]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac2".s[3]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[3]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[3]\n\t"\ + "fmla v25.4s,v11.4s,v"#ac1".s[3]\n\t"\ + "fmla v27.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v29.4s,v11.4s,v"#ac3".s[3]\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".s[3]\n\t" + +#define KERNEL_M4N8_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; add x4,x4,#32\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v24.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v25.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v26.4s,v8.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4\n\t"\ + "fmla v27.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v28.4s,v8.4s,v2.s[0]\n\t"\ + "fmla v29.4s,v9.4s,v2.s[0]; sub w5,w5,#1\n\t"\ + "fmla v30.4s,v8.4s,v3.s[0]; cmp w5,#1\n\t"\ + "fmla v31.4s,v9.4s,v3.s[0]\n\t" + + +/* m4n9 c_vec */ +/* v20 - v21 v22_comp */ +/* v23 - v24 v25_comp */ +/* v26 - v27 v28_comp */ +/* v29 - v30 v31_comp */ + +#define INIT_M4N9 \ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27)\ + INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N9(mode) \ + UNIT_SAVE_M4N4_VR_##mode(20, 23, 26, 29) UNIT_SAVE_M4N4_VR_##mode(21, 24, 27, 30)\ + EDGE_SAVE_M4N1K4_##mode(22, 25, 28, 31) + +#define KERNEL_M4N9_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#144; ldr d9,[x4,#-128]; ldr x10,[x4,#-120]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N9_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-112]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[0]; prfm pldl1keep,[x4,#24]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-96]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-80]\n\t"\ + "fmla v23.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-72]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[1]; sub w5,w5,#4\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v21.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac2".s[1]; prfm pldl1keep,[x1,#64]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d10,[x4,#-64]\n\t"\ + "fmla v30.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-56]\n\t"\ + "fmla v20.4s,v9.4s,v"#ac1".s[2]; ldr x11,[x1,#-8]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-48]\n\t"\ + "fmla v26.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[2]; prfm pldl1keep,[x4,#88]\n\t"\ + "fmla v21.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-32]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-24]\n\t"\ + "fmla v27.4s,v10.4s,v"#ac3".s[2]; add x4,x4,#144\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[3]; prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[3]; cmp w5,#12\n\t"\ + "fmov v"#an2".d[1],x11; ldr d10,[x4,#-160]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-152]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[3]; prfm pldl1keep,[x4,#8]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[3]; ldr x16,[x2,#-8]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-144]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[3]; ldr x10,[x4,#-136]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[3]; ldr x11,[x3],#16\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-128]\n\t"\ + "fmov v"#an3".d[1],x16; fmov d"#an4",x11\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".4s; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".4s; ldr x11,[x3,#-8]\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".4s; prfm pldl1keep,[x3,#64]\n\t" + +#define KERNEL_M4N9_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-112]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[0]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-96]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-80]\n\t"\ + "fmla v23.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-72]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[1]; sub w5,w5,#4\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v21.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "ldr d10,[x4,#-64]\n\t"\ + "fmla v30.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-56]\n\t"\ + "fmla v20.4s,v9.4s,v"#ac1".s[2]; prfm pldl1keep,[x7]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-48]\n\t"\ + "fmla v26.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v21.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-32]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-24]\n\t"\ + "fmla v27.4s,v10.4s,v"#ac3".s[2]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "ldr d10,[x4,#-16]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[3]; prfm pldl1keep,[x9]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "fmov v10.d[1],x10\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[3]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".4s\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N9_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; add x4,x4,#36\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v20.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v21.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v23.4s,v8.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4\n\t"\ + "fmla v24.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v26.4s,v8.4s,v2.s[0]\n\t"\ + "fmla v27.4s,v9.4s,v2.s[0]; sub w5,w5,#1\n\t"\ + "ldr s10,[x4,#-4]\n\t"\ + "fmla v29.4s,v8.4s,v3.s[0]; cmp w5,#1\n\t"\ + "fmla v30.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v22.4s,v10.4s,v0.4s\n\t"\ + "fmla v25.4s,v10.4s,v1.4s\n\t"\ + "fmla v28.4s,v10.4s,v2.4s\n\t"\ + "fmla v31.4s,v10.4s,v3.4s\n\t" + + +/* m4n10 c_vec */ +/* v16 - v17 v18_comp v19_comp */ +/* v20 - v21 v22_comp v23_comp */ +/* v24 - v25 v26_comp v27_comp */ +/* v28 - v29 v30_comp v31_comp */ + +#define INIT_M4N10 \ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N10(mode) \ + UNIT_SAVE_M4N4_VR_##mode(16, 20, 24, 28) UNIT_SAVE_M4N4_VR_##mode(17, 21, 25, 29)\ + EDGE_SAVE_M4N1K4_##mode(18, 22, 26, 30) EDGE_SAVE_M4N1K4_##mode(19, 23, 27, 31) + +#define KERNEL_M4N10_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#160; ldr d9,[x4,#-144]; ldr x10,[x4,#-136]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N10_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-128]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[0]; prfm pldl1keep,[x4,#32]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-112]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-96]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[1]; sub w5,w5,#4\n\t"\ + "fmla v28.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v17.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[1]; prfm pldl1keep,[x1,#64]\n\t"\ + "fmla v25.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d10,[x4,#-80]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-72]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[2]; ldr x11,[x1,#-8]\n\t"\ + "fmla v20.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-64]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[2]; prfm pldl1keep,[x4,#96]\n\t"\ + "fmla v17.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-48]\n\t"\ + "fmla v21.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac3".s[2]; add x4,x4,#160\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[3]; cmp w5,#12\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[3]; prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d10,[x4,#-192]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-184]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[3]; ldr x16,[x2,#-8]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-176]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[3]; ldr x10,[x4,#-168]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[3]; ldr x11,[x3],#16\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmov v11.d[1],x10; ldr d8,[x4,#-160]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".4s; ldr x10,[x4,#-152]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".4s; prfm pldl1keep,[x4]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".4s\n\t"\ + "fmov v"#an3".d[1],x16; fmov d"#an4",x11\n\t"\ + "fmla v19.4s,v11.4s,v"#ac1".4s; ldr x11,[x3,#-8]\n\t"\ + "fmla v23.4s,v11.4s,v"#ac2".4s; prfm pldl1keep,[x3,#64]\n\t"\ + "fmla v27.4s,v11.4s,v"#ac3".4s\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-144]\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s; ldr x10,[x4,#-136]\n\t" + +#define KERNEL_M4N10_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-128]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[0]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-112]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-96]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[1]; sub w5,w5,#4\n\t"\ + "fmla v28.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v17.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v25.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "ldr d10,[x4,#-80]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-72]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[2]; prfm pldl1keep,[x7]\n\t"\ + "fmla v20.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-64]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v17.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-48]\n\t"\ + "fmla v21.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac3".s[2]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "ldr d10,[x4,#-32]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[3]; prfm pldl1keep,[x9]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".4s\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".4s\n\t"\ + "fmla v19.4s,v11.4s,v"#ac1".4s\n\t"\ + "fmla v23.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v27.4s,v11.4s,v"#ac3".4s\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N10_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; add x4,x4,#40\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v16.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v17.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v20.4s,v8.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4\n\t"\ + "fmla v21.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v24.4s,v8.4s,v2.s[0]\n\t"\ + "fmla v25.4s,v9.4s,v2.s[0]; sub w5,w5,#1\n\t"\ + "ldr d10,[x4,#-8]\n\t"\ + "fmla v28.4s,v8.4s,v3.s[0]; cmp w5,#1\n\t"\ + "fmla v29.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v18.4s,v0.4s,v10.s[0]\n\t"\ + "fmla v22.4s,v1.4s,v10.s[0]\n\t"\ + "fmla v26.4s,v2.4s,v10.s[0]\n\t"\ + "fmla v30.4s,v3.4s,v10.s[0]\n\t"\ + "fmla v19.4s,v0.4s,v10.s[1]\n\t"\ + "fmla v23.4s,v1.4s,v10.s[1]\n\t"\ + "fmla v27.4s,v2.4s,v10.s[1]\n\t"\ + "fmla v31.4s,v3.4s,v10.s[1]\n\t" + + +/* m4n11 c_vec */ +/* v12 - v13 v14_comp v15_comp v16_comp */ +/* v17 - v18 v19_comp v20_comp v21_comp */ +/* v22 - v23 v24_comp v25_comp v26_comp */ +/* v27 - v28 v29_comp v30_comp v31_comp */ + +#define INIT_M4N11 \ + INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27)\ + INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N11(mode) \ + UNIT_SAVE_M4N4_VR_##mode(12, 17, 22, 27) UNIT_SAVE_M4N4_VR_##mode(13, 18, 23, 28)\ + EDGE_SAVE_M4N1K4_##mode(14, 19, 24, 29) EDGE_SAVE_M4N1K4_##mode(15, 20, 25, 30)\ + EDGE_SAVE_M4N1K4_##mode(16, 21, 26, 31) + +#define KERNEL_M4N11_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#176; ldr d9,[x4,#-160]; ldr x10,[x4,#-152]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N11_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-144]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-136]\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[0]; prfm pldl1keep,[x4,#48]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-128]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-112]\n\t"\ + "fmla v17.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-104]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[1]; sub w5,w5,#4\n\t"\ + "fmla v27.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v13.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac2".s[1]; prfm pldl1keep,[x1,#64]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d10,[x4,#-96]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v12.4s,v9.4s,v"#ac1".s[2]; ldr x11,[x1,#-8]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-80]\n\t"\ + "fmla v22.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac4".s[2]; prfm pldl1keep,[x4,#112]\n\t"\ + "fmla v13.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v11.d[1],x10; ldr d8,[x4,#-64]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v23.4s,v10.4s,v"#ac3".s[2]; add x4,x4,#176\n\t"\ + "fmla v28.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v12.4s,v11.4s,v"#ac1".s[3]; cmp w5,#12\n\t"\ + "fmla v17.4s,v11.4s,v"#ac2".s[3]; prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v22.4s,v11.4s,v"#ac3".s[3]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d9,[x4,#-224]\n\t"\ + "fmla v27.4s,v11.4s,v"#ac4".s[3]; ldr x10,[x4,#-216]\n\t"\ + "fmla v13.4s,v8.4s,v"#ac1".s[3]; prfm pldl1keep,[x4]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmov v9.d[1],x10; ldr d10,[x4,#-208]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac3".s[3]; ldr x10,[x4,#-200]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[3]; ldr x16,[x2,#-8]\n\t"\ + "fmla v14.4s,v9.4s,v"#ac1".4s\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-192]\n\t"\ + "fmla v19.4s,v9.4s,v"#ac2".4s; ldr x10,[x4,#-184]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac3".4s\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".4s\n\t"\ + "fmov v11.d[1],x10; ldr d8,[x4,#-176]\n\t"\ + "fmla v15.4s,v10.4s,v"#ac1".4s; ldr x10,[x4,#-168]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v25.4s,v10.4s,v"#ac3".4s\n\t"\ + "fmov v"#an3".d[1],x16; ldr d"#an4",[x3],#16\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".4s\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".4s; prfm pldl1keep,[x3,#64]\n\t"\ + "fmla v21.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-160]\n\t"\ + "fmla v26.4s,v11.4s,v"#ac3".4s; ldr x10,[x4,#-152]\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N11_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-144]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-136]\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[0]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-128]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-112]\n\t"\ + "fmla v17.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-104]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[1]; sub w5,w5,#4\n\t"\ + "fmla v27.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v13.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "ldr d10,[x4,#-96]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v12.4s,v9.4s,v"#ac1".s[2]; prfm pldl1keep,[x7]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-80]\n\t"\ + "fmla v22.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac4".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v13.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v11.d[1],x10; ldr d8,[x4,#-64]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v23.4s,v10.4s,v"#ac3".s[2]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v8.d[1],x10\n\t"\ + "fmla v12.4s,v11.4s,v"#ac1".s[3]\n\t"\ + "fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v22.4s,v11.4s,v"#ac3".s[3]\n\t"\ + "ldr d9,[x4,#-48]\n\t"\ + "fmla v27.4s,v11.4s,v"#ac4".s[3]; ldr x10,[x4,#-40]\n\t"\ + "fmla v13.4s,v8.4s,v"#ac1".s[3]; prfm pldl1keep,[x9]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmov v9.d[1],x10; ldr d10,[x4,#-32]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac3".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "fmla v14.4s,v9.4s,v"#ac1".4s\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v19.4s,v9.4s,v"#ac2".4s; ldr x10,[x4,#-8]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac3".4s\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".4s\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v15.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v25.4s,v10.4s,v"#ac3".4s\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".4s\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".4s\n\t"\ + "fmla v21.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v11.4s,v"#ac3".4s\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N11_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; add x4,x4,#44\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v13.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v17.4s,v8.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4\n\t"\ + "fmla v18.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v22.4s,v8.4s,v2.s[0]\n\t"\ + "fmla v23.4s,v9.4s,v2.s[0]; sub w5,w5,#1\n\t"\ + "ldr d10,[x4,#-12]\n\t"\ + "fmla v27.4s,v8.4s,v3.s[0]; cmp w5,#1\n\t"\ + "fmla v28.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v14.4s,v0.4s,v10.s[0]\n\t"\ + "ldr s11,[x4,#-4]\n\t"\ + "fmla v19.4s,v1.4s,v10.s[0]\n\t"\ + "fmla v24.4s,v2.4s,v10.s[0]\n\t"\ + "fmla v29.4s,v3.4s,v10.s[0]\n\t"\ + "fmla v15.4s,v0.4s,v10.s[1]\n\t"\ + "fmla v20.4s,v1.4s,v10.s[1]\n\t"\ + "fmla v25.4s,v2.4s,v10.s[1]\n\t"\ + "fmla v30.4s,v3.4s,v10.s[1]\n\t"\ + "fmla v16.4s,v0.4s,v11.s[0]\n\t"\ + "fmla v21.4s,v1.4s,v11.s[0]\n\t"\ + "fmla v26.4s,v2.4s,v11.s[0]\n\t"\ + "fmla v31.4s,v3.4s,v11.s[0]\n\t" + + +/* m4n12 c_vec */ +/* v20 - v22 */ +/* v23 - v25 */ +/* v26 - v28 */ +/* v29 - v31 */ + +#define INIT_M4N12 \ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27)\ + INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N12(mode) \ + UNIT_SAVE_M4N4_VR_##mode(20, 23, 26, 29) UNIT_SAVE_M4N4_VR_##mode(21, 24, 27, 30)\ + UNIT_SAVE_M4N4_VR_##mode(22, 25, 28, 31) + +#define KERNEL_M4N12_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#192; ldr d9,[x4,#-176]; ldr x10,[x4,#-168]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N12_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-160]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-152]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[0]; prfm pldl1keep,[x4,#8]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-144]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-136]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".s[0]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-128]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".s[0]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[1]; prfm pldl1keep,[x1,#64]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d10,[x4,#-112]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-104]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[1]; ldr x11,[x1,#-8]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-96]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x4,#72]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-80]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-72]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[1]\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[2]; prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[2]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d10,[x4,#-64]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[2]; ldr x16,[x2,#-8]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-48]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-32]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[2]; prfm pldl1keep,[x4,#136]\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an4",[x3],#16\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[3]; sub w5,w5,#4\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[3]; prfm pldl1keep,[x3,#64]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "fmov v"#an3".d[1],x16; ldr d10,[x4,#-16]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[3]; ldr x11,[x3,#-8]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[3]; ldr x10,[x4,#8]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[3]; cmp w5,#12\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".s[3]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#16]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[3]; ldr x10,[x4,#24]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[3]; add x4,x4,#192\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".s[3]\n\t" + +#define KERNEL_M4N12_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v"#ac4".d[1],x11; fmov v9.d[1],x10\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "ldr d10,[x4,#-160]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[0]; ldr x10,[x4,#-152]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[0]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-144]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-136]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".s[0]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-128]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[0]\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".s[0]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "ldr d10,[x4,#-112]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-104]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[1]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-96]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-80]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-72]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[1]\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[2]\n\t"\ + "ldr d10,[x4,#-64]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[2]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-48]\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-32]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[2]\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v23.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "ldr d10,[x4,#-16]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac1".s[3]; sub w5,w5,#4\n\t"\ + "fmla v24.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "fmov v10.d[1],x10\n\t"\ + "fmla v27.4s,v9.4s,v"#ac3".s[3]\n\t"\ + "fmla v30.4s,v9.4s,v"#ac4".s[3]; prfm pldl1keep,[x9]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac1".s[3]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac2".s[3]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac3".s[3]\n\t"\ + "fmla v31.4s,v10.4s,v"#ac4".s[3]\n\t" + +#define KERNEL_M4N12_TL1 \ + "ldr s0,[x0],#4; ldr q8,[x4]; ldr q9,[x4,#16]\n\t"\ + "ldr q10,[x4,#32]; add x4,x4,#48\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v20.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v21.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v22.4s,v10.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v23.4s,v8.4s,v1.s[0]\n\t"\ + "fmla v24.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v25.4s,v10.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4\n\t"\ + "fmla v26.4s,v8.4s,v2.s[0]; sub w5,w5,#1\n\t"\ + "fmla v27.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v28.4s,v10.4s,v2.s[0]\n\t"\ + "cmp w5,#1\n\t"\ + "fmla v29.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v30.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v31.4s,v10.4s,v3.s[0]\n\t" + + +/* m4n13 c_vec */ +/* v16 - v18 v19_comp */ +/* v20 - v22 v23_comp */ +/* v24 - v26 v27_comp */ +/* v28 - v30 v31_comp */ + +#define INIT_M4N13 \ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N13(mode) \ + UNIT_SAVE_M4N4_VR_##mode(16, 20, 24, 28) UNIT_SAVE_M4N4_VR_##mode(17, 21, 25, 29)\ + UNIT_SAVE_M4N4_VR_##mode(18, 22, 26, 30) EDGE_SAVE_M4N1K4_##mode(19, 23, 27, 31) + +#define KERNEL_M4N13_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#208; ldr d9,[x4,#-192]; ldr x10,[x4,#-184]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N13_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-176]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-168]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[0]; prfm pldl1keep,[x4,#24]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-160]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-152]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".s[0]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-144]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".s[0]; ldr x10,[x4,#-136]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[0]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[1]; prfm pldl1keep,[x1,#64]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d10,[x4,#-128]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-120]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[1]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-112]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-104]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x4,#88]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-96]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[1]; ldr x11,[x1,#-8]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[2]; prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[2]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d10,[x4,#-80]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[2]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-64]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[2]; ldr x16,[x2,#-8]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-48]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[2]; prfm pldl1keep,[x4,#152]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an4",[x3],#16\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[3]; prfm pldl1keep,[x3,#64]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "fmov v"#an3".d[1],x16; ldr d10,[x4,#-32]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[3]; sub w5,w5,#4\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[3]; cmp w5,#12\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".s[3]\n\t"\ + "fmov v11.d[1],x10; ldr d8,[x4]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".s[3]; ldr x16,[x4,#8]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[3]; ldr x11,[x3,#-8]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[3]\n\t"\ + "ldr d9,[x4,#16]\n\t"\ + "fmla v19.4s,v11.4s,v"#ac1".4s; ldr x10,[x4,#24]\n\t"\ + "fmla v23.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmov v8.d[1],x16\n\t"\ + "fmla v27.4s,v11.4s,v"#ac3".4s; add x4,x4,#208\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N13_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-176]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-168]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[0]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-160]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-152]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".s[0]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-144]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".s[0]; ldr x10,[x4,#-136]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[0]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[0]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "ldr d10,[x4,#-128]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-120]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[1]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-112]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-104]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-96]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[1]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[2]\n\t"\ + "ldr d10,[x4,#-80]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[2]\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-64]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-48]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[2]; prfm pldl1keep,[x9]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v24.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "ldr d10,[x4,#-32]\n\t"\ + "fmla v28.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v17.4s,v9.4s,v"#ac1".s[3]; sub w5,w5,#4\n\t"\ + "fmla v21.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac3".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac1".s[3]\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v22.4s,v10.4s,v"#ac2".s[3]\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".s[3]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".s[3]\n\t"\ + "fmla v19.4s,v11.4s,v"#ac1".4s\n\t"\ + "fmla v23.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v27.4s,v11.4s,v"#ac3".4s\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N13_TL1 \ + "ldr s0,[x0],#4; ldr q8,[x4]; ldr q9,[x4,#16]\n\t"\ + "ldr q10,[x4,#32]; add x4,x4,#52\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v16.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v17.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v18.4s,v10.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v20.4s,v8.4s,v1.s[0]\n\t"\ + "fmla v21.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v22.4s,v10.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4\n\t"\ + "fmla v24.4s,v8.4s,v2.s[0]; sub w5,w5,#1\n\t"\ + "fmla v25.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v26.4s,v10.4s,v2.s[0]\n\t"\ + "ldr s11,[x4,#-4]\n\t"\ + "fmla v28.4s,v8.4s,v3.s[0]; cmp w5,#1\n\t"\ + "fmla v29.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v30.4s,v10.4s,v3.s[0]\n\t"\ + "fmla v19.4s,v0.4s,v11.4s\n\t"\ + "fmla v23.4s,v1.4s,v11.4s\n\t"\ + "fmla v27.4s,v2.4s,v11.4s\n\t"\ + "fmla v31.4s,v3.4s,v11.4s\n\t" + + +/* m4n14 c_vec */ +/* v12 - v14 v15_comp v16_comp */ +/* v17 - v19 v20_comp v21_comp */ +/* v22 - v24 v25_comp v26_comp */ +/* v27 - v29 v30_comp v31_comp */ + +#define INIT_M4N14 \ + INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27)\ + INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N14(mode) \ + UNIT_SAVE_M4N4_VR_##mode(12, 17, 22, 27) UNIT_SAVE_M4N4_VR_##mode(13, 18, 23, 28)\ + UNIT_SAVE_M4N4_VR_##mode(14, 19, 24, 29) EDGE_SAVE_M4N1K4_##mode(15, 20, 25, 30)\ + EDGE_SAVE_M4N1K4_##mode(16, 21, 26, 31) + +#define KERNEL_M4N14_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr d3,[x3],#16\n\t"\ + "ldr q8,[x4],#224; ldr d9,[x4,#-208]; ldr x10,[x4,#-200]; ldr x11,[x3,#-8]\n\t" + +#define KERNEL_M4N14_K8_L4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmov v9.d[1],x10; ldr d"#an1",[x0],#16\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[0]; prfm pldl1keep,[x0,#64]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-192]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-184]\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[0]; prfm pldl1keep,[x4,#8]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-176]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-168]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac1".s[0]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-160]\n\t"\ + "fmla v19.4s,v10.4s,v"#ac2".s[0]; ldr x10,[x4,#-152]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[0]; ldr x16,[x0,#-8]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[0]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an2",[x1],#16\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[1]; prfm pldl1keep,[x1,#64]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d10,[x4,#-144]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-136]\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[1]; ldr x11,[x1,#-8]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-128]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x4,#72]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-112]\n\t"\ + "fmla v19.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-104]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[1]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10; ldr d"#an3",[x2],#16\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[2]; prfm pldl1keep,[x2,#64]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[2]\n\t"\ + "fmov v"#an2".d[1],x11; ldr d10,[x4,#-96]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[2]; ldr x10,[x4,#-88]\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[2]; ldr x16,[x2,#-8]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-80]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d11,[x4,#-64]\n\t"\ + "fmla v19.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[2]; prfm pldl1keep,[x4,#136]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v11.d[1],x10; ldr d"#an4",[x3],#16\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[3]; prfm pldl1keep,[x3,#64]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "fmov v"#an3".d[1],x16; ldr d9,[x4,#-48]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-40]\n\t"\ + "fmla v13.4s,v11.4s,v"#ac1".s[3]; sub w5,w5,#4\n\t"\ + "fmla v18.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmov v9.d[1],x10; ldr d10,[x4,#-32]\n\t"\ + "fmla v23.4s,v11.4s,v"#ac3".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v11.4s,v"#ac4".s[3]; cmp w5,#12\n\t"\ + "fmla v14.4s,v9.4s,v"#ac1".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v19.4s,v9.4s,v"#ac2".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac3".s[3]; ldr x11,[x3,#-8]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "fmov v11.d[1],x10; ldr d8,[x4]\n\t"\ + "fmla v15.4s,v10.4s,v"#ac1".4s; ldr x16,[x4,#8]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v25.4s,v10.4s,v"#ac3".4s\n\t"\ + "ldr d9,[x4,#16]\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".4s; ldr x10,[x4,#24]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".4s; add x4,x4,#224\n\t"\ + "fmla v21.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmov v8.d[1],x16\n\t"\ + "fmla v26.4s,v11.4s,v"#ac3".4s\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N14_K8_T4(ac1, ac2, ac3, ac4) \ + "fmov v9.d[1],x10\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[0]\n\t"\ + "fmov v"#ac4".d[1],x11; ldr d10,[x4,#-192]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[0]; ldr x10,[x4,#-184]\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[0]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-176]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[0]; ldr x10,[x4,#-168]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac1".s[0]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-160]\n\t"\ + "fmla v19.4s,v10.4s,v"#ac2".s[0]; ldr x10,[x4,#-152]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[0]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[0]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[1]\n\t"\ + "ldr d10,[x4,#-144]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[1]; ldr x10,[x4,#-136]\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-128]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[1]; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac1".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-112]\n\t"\ + "fmla v19.4s,v10.4s,v"#ac2".s[1]; ldr x10,[x4,#-104]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[1]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "fmov v9.d[1],x10\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[2]\n\t"\ + "ldr d10,[x4,#-96]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[2]; ldr x10,[x4,#-88]\n\t"\ + "fmla v13.4s,v9.4s,v"#ac1".s[2]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmov v10.d[1],x10; ldr d8,[x4,#-80]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac3".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v28.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac1".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d11,[x4,#-64]\n\t"\ + "fmla v19.4s,v10.4s,v"#ac2".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v24.4s,v10.4s,v"#ac3".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v29.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v17.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".s[3]\n\t"\ + "ldr d9,[x4,#-48]\n\t"\ + "fmla v27.4s,v8.4s,v"#ac4".s[3]; ldr x10,[x4,#-40]\n\t"\ + "fmla v13.4s,v11.4s,v"#ac1".s[3]; sub w5,w5,#4\n\t"\ + "fmla v18.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmov v9.d[1],x10; ldr d10,[x4,#-32]\n\t"\ + "fmla v23.4s,v11.4s,v"#ac3".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "fmla v14.4s,v9.4s,v"#ac1".s[3]\n\t"\ + "fmov v10.d[1],x10; ldr d11,[x4,#-16]\n\t"\ + "fmla v19.4s,v9.4s,v"#ac2".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v24.4s,v9.4s,v"#ac3".s[3]\n\t"\ + "fmla v29.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "fmov v11.d[1],x10\n\t"\ + "fmla v15.4s,v10.4s,v"#ac1".4s\n\t"\ + "fmla v20.4s,v10.4s,v"#ac2".4s; prfm pldl1keep,[x9]\n\t"\ + "fmla v25.4s,v10.4s,v"#ac3".4s\n\t"\ + "fmla v30.4s,v10.4s,v"#ac4".4s\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".4s\n\t"\ + "fmla v21.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v11.4s,v"#ac3".4s\n\t"\ + "fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N14_TL1 \ + "ldr s0,[x0],#4; ldr q8,[x4]; ldr q9,[x4,#16]\n\t"\ + "ldr q10,[x4,#32]; add x4,x4,#56\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v13.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v14.4s,v10.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v17.4s,v8.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v19.4s,v10.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4\n\t"\ + "fmla v22.4s,v8.4s,v2.s[0]; sub w5,w5,#1\n\t"\ + "fmla v23.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v24.4s,v10.4s,v2.s[0]\n\t"\ + "ldr d11,[x4,#-8]\n\t"\ + "fmla v27.4s,v8.4s,v3.s[0]; cmp w5,#1\n\t"\ + "fmla v28.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v29.4s,v10.4s,v3.s[0]\n\t"\ + "fmla v15.4s,v0.4s,v11.s[0]\n\t"\ + "fmla v20.4s,v1.4s,v11.s[0]\n\t"\ + "fmla v25.4s,v2.4s,v11.s[0]\n\t"\ + "fmla v30.4s,v3.4s,v11.s[0]\n\t"\ + "fmla v16.4s,v0.4s,v11.s[1]\n\t"\ + "fmla v21.4s,v1.4s,v11.s[1]\n\t"\ + "fmla v26.4s,v2.4s,v11.s[1]\n\t"\ + "fmla v31.4s,v3.4s,v11.s[1]\n\t" + +#define FUNC_K4(ndim) \ +static inline void sgemm_skinny1_a53_m4n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_scr,\ + float * __restrict__ c_ptr, uint32_t K, uint32_t LDA, uint32_t LDC,\ + uint8_t c_rowmajor, const float * __restrict__ beta_addr) {\ + __asm__ __volatile__ (\ + "mov x0,%[a_ptr]; add x1,%[a_ptr],%w[LDA],UXTW #2\n\t"\ + "add x2,%[a_ptr],%w[LDA],UXTW #3; add x3,x1,%w[LDA],UXTW #3\n\t"\ + "add x6,x0,%w[LDA],UXTW #4; add x7,x1,%w[LDA],UXTW #4\n\t"\ + "add x8,x2,%w[LDA],UXTW #4; add x9,x3,%w[LDA],UXTW #4\n\t"\ + "mov x4,%[b_scr]; mov w5,%w[K]\n\t"\ + INIT_M4N##ndim\ + "cmp w5,#4; b.lt 4f\n\t"\ + KERNEL_M4N##ndim##_PRELOAD4\ + "cmp w5,#12; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + KERNEL_M4N##ndim##_K8_L4(0, 1, 2, 3, 4, 5, 6, 7)\ + KERNEL_M4N##ndim##_K8_L4(4, 5, 6, 7, 0, 1, 2, 3)\ + "b.ge 1b; 2:\n\t"\ + "cmp w5,#8; b.lt 3f\n\t"\ + KERNEL_M4N##ndim##_K8_L4(0, 1, 2, 3, 4, 5, 6, 7)\ + KERNEL_M4N##ndim##_K8_T4(4, 5, 6, 7)\ + "b 4f; 3:\n\t"\ + KERNEL_M4N##ndim##_K8_T4(0, 1, 2, 3)\ + "4:\n\t"\ + "cmp w5,#1; b.lt 6f\n\t"\ + "5:\n\t"\ + KERNEL_M4N##ndim##_TL1\ + "b.ge 5b; 6:\n\t"\ + INIT_SAVE\ + "cmp %w[c_rowmajor],#0; b.eq 7f\n\t"\ + SAVE_M4N##ndim(CR) "b 8f\n\t"\ + "7:\n\t"\ + SAVE_M4N##ndim(CC)\ + "8:\n\t"\ + ::[a_ptr]"r"(a_ptr), [c_ptr]"r"(c_ptr), [b_scr]"r"(b_scr),\ + [K]"r"(K), [LDA]"r"(LDA), [LDC]"r"(LDC),\ + [beta_addr]"r"(beta_addr), [c_rowmajor]"r"(c_rowmajor)\ + :"cc","memory","x0","x1","x2","x3","x4","x5","x6","x7","x8","x9",\ + "x10","x11","x12","x13","x14","x15","x16",\ + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","v12","v13",\ + "v14","v15","v16","v17","v18","v19","v20","v21","v22","v23","v24","v25",\ + "v26","v27","v28","v29","v30","v31");\ +} + +FUNC_K4(4) +FUNC_K4(5) +FUNC_K4(6) +FUNC_K4(7) +FUNC_K4(8) +FUNC_K4(9) +FUNC_K4(10) +FUNC_K4(11) +FUNC_K4(12) +FUNC_K4(13) +FUNC_K4(14) + +/* m4n15 c_vec */ +/* v14 - v16 v23_comp v24_comp v25_comp */ +/* v17 - v19 v29_comp v30_comp v31_comp */ +/* v20 - v22 v23_comp v24_comp v25_comp */ +/* v26 - v28 v29_comp v30_comp v31_comp */ + +#define INIT_M4N15 \ + INIT_4V(14, 15, 16, 17) INIT_4V(18, 19, 20, 21) INIT_4V(22, 23, 24, 25)\ + INIT_4V(26, 27, 28, 29) INIT_2V(30, 31) + +#define SAVE_M4N15(mode) \ + UNIT_SAVE_M4N4_VR_##mode(14, 17, 20, 26) UNIT_SAVE_M4N4_VR_##mode(15, 18, 21, 27)\ + UNIT_SAVE_M4N4_VR_##mode(16, 19, 22, 28) EDGE_SAVE_M4N1K2_##mode(23, 29)\ + EDGE_SAVE_M4N1K2_##mode(24, 30) EDGE_SAVE_M4N1K2_##mode(25, 31) + +#define KERNEL_M4N15_PRELOAD2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr x16,[x2],#8; ldr x11,[x3],#8\n\t"\ + "ldr q4,[x4]; ldr d5,[x4,#16]; ldr x10,[x4,#24]\n\t"\ + "add x4,x4,#120; fmov v0.d[1],x16\n\t" + +#define KERNEL_M4N15_MAIN2(ac1, ac2, an1, an2, ap1, ap2) \ + "fmov v5.d[1],x10; ldr d"#an1",[x0],#8\n\t"\ + "fmla v14.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-88]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-80]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#64]\n\t"\ + "fmla v26.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d4,[x4,#-72]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-64]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[2]; ldr x16,[x2],#8\n\t"\ + "fmla v16.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-56]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-48]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac1".s[2]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v5.d[1],x10; ldr d"#an2",[x1],#8\n\t"\ + "fmla v14.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[1]; prfm pldl1keep,[x4,#128]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac1".s[3]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d6,[x4,#-40]\n\t"\ + "fmla v26.4s,v4.4s,v"#ac2".s[3]; ldr x10,[x4,#-32]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-24]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac1".s[3]; ldr x10,[x4,#-16]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[3]; ldr x11,[x3],#8\n\t"\ + "fmla v16.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "ins v7.d[1],v7.d[0]; dup v8.2d,x10\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[1]; ldr x10,[x4,#-8]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac1".s[3]; add x4,x4,#120\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "dup v6.2d,x10; ldr d4,[x4,#-120]\n\t"\ + "fmla v23.4s,v7.4s,v"#ac1".4s; ldr x10,[x4,#-112]\n\t"\ + "fmla v29.4s,v7.4s,v"#ac2".4s; sub w5,w5,#2\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".4s\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-104]\n\t"\ + "fmla v30.4s,v8.4s,v"#ac2".4s; ldr x10,[x4,#-96]\n\t"\ + "fmla v25.4s,v6.4s,v"#ac1".4s; cmp w5,#6\n\t"\ + "fmla v31.4s,v6.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N15_TAIL2(ac1, ac2) \ + "fmov v5.d[1],x10\n\t"\ + "fmla v14.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[0]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-88]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-80]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[0]\n\t"\ + "fmla v26.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d4,[x4,#-72]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-64]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-56]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-48]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac1".s[2]; prfm pldl1keep,[x6]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v5.d[1],x10\n\t"\ + "fmla v14.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac1".s[3]\n\t"\ + "ldr d6,[x4,#-40]\n\t"\ + "fmla v26.4s,v4.4s,v"#ac2".s[3]; ldr x10,[x4,#-32]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-24]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac1".s[3]; ldr x10,[x4,#-16]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "ins v7.d[1],v7.d[0]; dup v8.2d,x10\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[1]; ldr x10,[x4,#-8]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "dup v6.2d,x10\n\t"\ + "fmla v23.4s,v7.4s,v"#ac1".4s; prfm pldl1keep,[x8]\n\t"\ + "fmla v29.4s,v7.4s,v"#ac2".4s; sub w5,w5,#2\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".4s\n\t"\ + "fmla v30.4s,v8.4s,v"#ac2".4s; prfm pldl1keep,[x9]\n\t"\ + "fmla v25.4s,v6.4s,v"#ac1".4s\n\t"\ + "fmla v31.4s,v6.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N15_FIN1 \ + "ldr s0,[x0],#4; ldr q4,[x4]; ldr q5,[x4,#16]\n\t"\ + "ldr q6,[x4,#32]; add x4,x4,#60\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v14.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v15.4s,v5.4s,v0.s[0]\n\t"\ + "fmla v16.4s,v6.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v17.4s,v4.4s,v1.s[0]; ldr w10,[x4,#-12]\n\t"\ + "fmla v18.4s,v5.4s,v1.s[0]\n\t"\ + "fmla v19.4s,v6.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4; dup v7.2d,x10\n\t"\ + "fmla v20.4s,v4.4s,v2.s[0]; ldr w11,[x4,#-8]\n\t"\ + "fmla v21.4s,v5.4s,v2.s[0]\n\t"\ + "fmla v22.4s,v6.4s,v2.s[0]\n\t"\ + "ins v0.d[1],v2.d[0]; dup v8.2d,x11\n\t"\ + "fmla v26.4s,v4.4s,v3.s[0]; ldr w16,[x4,#-4]\n\t"\ + "fmla v27.4s,v5.4s,v3.s[0]\n\t"\ + "fmla v28.4s,v6.4s,v3.s[0]\n\t"\ + "ins v1.d[1],v3.d[0]; dup v6.2d,x16\n\t"\ + "fmla v23.4s,v7.4s,v0.4s\n\t"\ + "fmla v24.4s,v8.4s,v0.4s\n\t"\ + "fmla v29.4s,v7.4s,v1.4s\n\t"\ + "fmla v30.4s,v8.4s,v1.4s\n\t"\ + "fmla v25.4s,v6.4s,v0.4s\n\t"\ + "fmla v31.4s,v6.4s,v1.4s\n\t" + + +/* m4n16 c_vec */ +/* v16 - v19 */ +/* v20 - v23 */ +/* v24 - v27 */ +/* v28 - v31 */ + +#define INIT_M4N16 \ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N16(mode) \ + UNIT_SAVE_M4N4_VR_##mode(16, 20, 24, 28) UNIT_SAVE_M4N4_VR_##mode(17, 21, 25, 29)\ + UNIT_SAVE_M4N4_VR_##mode(18, 22, 26, 30) UNIT_SAVE_M4N4_VR_##mode(19, 23, 27, 31) + +#define KERNEL_M4N16_PRELOAD2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr x16,[x2],#8; ldr x11,[x3],#8\n\t"\ + "ldr q4,[x4]; ldr d5,[x4,#16]; ldr x10,[x4,#24]\n\t"\ + "add x4,x4,#128; fmov v0.d[1],x16\n\t" + +#define KERNEL_M4N16_MAIN2(ac1, ac2, an1, an2, ap1, ap2) \ + "fmov v5.d[1],x10; ldr d"#an1",[x0],#8\n\t"\ + "fmla v16.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v24.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-96]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#64]\n\t"\ + "fmla v28.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-80]\n\t"\ + "fmla v25.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v29.4s,v5.4s,v"#ac2".s[2]; ldr x16,[x2],#8\n\t"\ + "fmla v18.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-64]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-56]\n\t"\ + "fmla v26.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v30.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d"#an2",[x1],#8\n\t"\ + "fmla v19.4s,v7.4s,v"#ac1".s[0]\n\t"\ + "fmla v23.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v27.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d5,[x4,#-48]\n\t"\ + "fmla v31.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v16.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-32]\n\t"\ + "fmla v24.4s,v4.4s,v"#ac1".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v4.4s,v"#ac2".s[3]; ldr x11,[x3],#8\n\t"\ + "fmla v17.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-16]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-8]\n\t"\ + "fmla v25.4s,v5.4s,v"#ac1".s[3]; add x4,x4,#128\n\t"\ + "fmla v29.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-128]\n\t"\ + "fmla v18.4s,v6.4s,v"#ac1".s[1]; ldr x16,[x4,#-120]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x4]\n\t"\ + "fmla v26.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "ldr d5,[x4,#-112]\n\t"\ + "fmla v30.4s,v6.4s,v"#ac2".s[3]; ldr x10,[x4,#-104]\n\t"\ + "fmla v19.4s,v7.4s,v"#ac1".s[1]; sub w5,w5,#2\n\t"\ + "fmla v23.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "fmov v4.d[1],x16\n\t"\ + "fmla v27.4s,v7.4s,v"#ac1".s[3]; cmp w5,#6\n\t"\ + "fmla v31.4s,v7.4s,v"#ac2".s[3]\n\t" + +#define KERNEL_M4N16_TAIL2(ac1, ac2) \ + "fmov v5.d[1],x10\n\t"\ + "fmla v16.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac2".s[0]\n\t"\ + "fmla v24.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-96]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v28.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-80]\n\t"\ + "fmla v25.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v29.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v18.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-64]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-56]\n\t"\ + "fmla v26.4s,v6.4s,v"#ac1".s[2]; prfm pldl1keep,[x7]\n\t"\ + "fmla v30.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v19.4s,v7.4s,v"#ac1".s[0]\n\t"\ + "fmla v23.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x8]\n\t"\ + "fmla v27.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "ldr d5,[x4,#-48]\n\t"\ + "fmla v31.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v16.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-32]\n\t"\ + "fmla v24.4s,v4.4s,v"#ac1".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v28.4s,v4.4s,v"#ac2".s[3]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-16]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-8]\n\t"\ + "fmla v25.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmla v29.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10\n\t"\ + "fmla v18.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x9]\n\t"\ + "fmla v26.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "fmla v30.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "fmla v19.4s,v7.4s,v"#ac1".s[1]; sub w5,w5,#2\n\t"\ + "fmla v23.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "fmla v27.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v31.4s,v7.4s,v"#ac2".s[3]\n\t" + +#define KERNEL_M4N16_FIN1 \ + "ldr s0,[x0],#4; ldr q4,[x4]; ldr q5,[x4,#16]\n\t"\ + "ldr q6,[x4,#32]; add x4,x4,#64\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v16.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v17.4s,v5.4s,v0.s[0]\n\t"\ + "fmla v18.4s,v6.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v20.4s,v4.4s,v1.s[0]\n\t"\ + "fmla v21.4s,v5.4s,v1.s[0]\n\t"\ + "fmla v22.4s,v6.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4; ldr d7,[x4,#-16]\n\t"\ + "fmla v24.4s,v4.4s,v2.s[0]; ldr x10,[x4,#-8]\n\t"\ + "fmla v25.4s,v5.4s,v2.s[0]\n\t"\ + "fmla v26.4s,v6.4s,v2.s[0]\n\t"\ + "fmov v7.d[1],x10\n\t"\ + "fmla v28.4s,v4.4s,v3.s[0]\n\t"\ + "fmla v29.4s,v5.4s,v3.s[0]\n\t"\ + "fmla v30.4s,v6.4s,v3.s[0]\n\t"\ + "fmla v19.4s,v7.4s,v0.s[0]\n\t"\ + "fmla v23.4s,v7.4s,v1.s[0]\n\t"\ + "fmla v27.4s,v7.4s,v2.s[0]\n\t"\ + "fmla v31.4s,v7.4s,v3.s[0]\n\t" + + +/* m4n17 c_vec */ +/* v14 - v17 v26_comp */ +/* v18 - v21 v31_comp */ +/* v22 - v25 v26_comp */ +/* v27 - v30 v31_comp */ + +#define INIT_M4N17 INIT_2V(14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N17(mode) \ + UNIT_SAVE_M4N4_VR_##mode(14, 18, 22, 27) UNIT_SAVE_M4N4_VR_##mode(15, 19, 23, 28)\ + UNIT_SAVE_M4N4_VR_##mode(16, 20, 24, 29) UNIT_SAVE_M4N4_VR_##mode(17, 21, 25, 30)\ + EDGE_SAVE_M4N1K2_##mode(26, 31) + +#define KERNEL_M4N17_PRELOAD2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr x16,[x2],#8; ldr x11,[x3],#8\n\t"\ + "ldr q4,[x4]; ldr d5,[x4,#16]; ldr x10,[x4,#24]\n\t"\ + "add x4,x4,#136; fmov v0.d[1],x16\n\t" + +#define KERNEL_M4N17_MAIN2(ac1, ac2, an1, an2, ap1, ap2) \ + "fmov v5.d[1],x10; ldr d"#an1",[x0],#8\n\t"\ + "fmla v14.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v18.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-104]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-96]\n\t"\ + "fmla v19.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#64]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-88]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-80]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[2]; ldr x16,[x2],#8\n\t"\ + "fmla v16.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-72]\n\t"\ + "fmla v20.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-64]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d"#an2",[x1],#8\n\t"\ + "fmla v17.4s,v7.4s,v"#ac1".s[0]\n\t"\ + "fmla v21.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d5,[x4,#-56]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-48]\n\t"\ + "fmla v14.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-40]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[3]; ldr x10,[x4,#-32]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[3]; prfm pldl1keep,[x4,#112]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-24]\n\t"\ + "fmla v19.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-16]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[3]; ldr x11,[x3],#8\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-8]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v20.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x4,#160]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "ins v8.d[1],v8.d[0]; ldr d4,[x4]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[3]; ldr x16,[x4,#8]\n\t"\ + "fmla v17.4s,v7.4s,v"#ac1".s[1]; add x4,x4,#136\n\t"\ + "fmla v21.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "ldr d5,[x4,#-120]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[3]; ldr x10,[x4,#-112]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[3]; sub w5,w5,#2\n\t"\ + "fmov v4.d[1],x16\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".4s; cmp w5,#6\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N17_TAIL2(ac1, ac2) \ + "fmov v5.d[1],x10\n\t"\ + "fmla v14.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v18.4s,v4.4s,v"#ac2".s[0]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-104]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-96]\n\t"\ + "fmla v19.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-88]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-80]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-72]\n\t"\ + "fmla v20.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-64]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v17.4s,v7.4s,v"#ac1".s[0]\n\t"\ + "fmla v21.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x7]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "ldr d5,[x4,#-56]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-48]\n\t"\ + "fmla v14.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-40]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[3]; ldr x10,[x4,#-32]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[3]; prfm pldl1keep,[x8]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-24]\n\t"\ + "fmla v19.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-16]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-8]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v20.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x9]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "ins v8.d[1],v8.d[0]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "fmla v17.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "fmla v21.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[3]; sub w5,w5,#2\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".4s\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N17_FIN1 \ + "ldr s0,[x0],#4; ldr q4,[x4]; ldr q5,[x4,#16]\n\t"\ + "ldr q6,[x4,#32]; add x4,x4,#68\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v14.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v15.4s,v5.4s,v0.s[0]\n\t"\ + "fmla v16.4s,v6.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v18.4s,v4.4s,v1.s[0]\n\t"\ + "fmla v19.4s,v5.4s,v1.s[0]\n\t"\ + "fmla v20.4s,v6.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4; ldr d7,[x4,#-20]\n\t"\ + "fmla v22.4s,v4.4s,v2.s[0]; ldr x10,[x4,#-12]\n\t"\ + "fmla v23.4s,v5.4s,v2.s[0]\n\t"\ + "fmla v24.4s,v6.4s,v2.s[0]\n\t"\ + "fmov v7.d[1],x10; ldr s8,[x4,#-4]\n\t"\ + "fmla v27.4s,v4.4s,v3.s[0]\n\t"\ + "fmla v28.4s,v5.4s,v3.s[0]\n\t"\ + "fmla v29.4s,v6.4s,v3.s[0]\n\t"\ + "ins v8.d[1],v8.d[0]\n\t"\ + "fmla v17.4s,v7.4s,v0.s[0]\n\t"\ + "fmla v21.4s,v7.4s,v1.s[0]\n\t"\ + "fmla v25.4s,v7.4s,v2.s[0]\n\t"\ + "ins v0.d[1],v2.d[0]; ins v1.d[1],v3.d[0]\n\t"\ + "fmla v30.4s,v7.4s,v3.s[0]\n\t"\ + "fmla v26.4s,v8.4s,v0.4s\n\t"\ + "fmla v31.4s,v8.4s,v1.4s\n\t" + + +/* m4n18 c_vec */ +/* v12 - v15 v24_comp v25_comp */ +/* v16 - v19 v30_comp v31_comp */ +/* v20 - v23 v24_comp v25_comp */ +/* v26 - v29 v30_comp v31_comp */ + +#define INIT_M4N18 INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N18(mode) \ + UNIT_SAVE_M4N4_VR_##mode(12, 16, 20, 26) UNIT_SAVE_M4N4_VR_##mode(13, 17, 21, 27)\ + UNIT_SAVE_M4N4_VR_##mode(14, 18, 22, 28) UNIT_SAVE_M4N4_VR_##mode(15, 19, 23, 29)\ + EDGE_SAVE_M4N1K2_##mode(24, 30) EDGE_SAVE_M4N1K2_##mode(25, 31) + +#define KERNEL_M4N18_PRELOAD2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr x16,[x2],#8; ldr x11,[x3],#8\n\t"\ + "ldr q4,[x4]; ldr d5,[x4,#16]; ldr x10,[x4,#24]\n\t"\ + "add x4,x4,#144; fmov v0.d[1],x16\n\t" + +#define KERNEL_M4N18_MAIN2(ac1, ac2, an1, an2, ap1, ap2) \ + "fmov v5.d[1],x10; ldr d"#an1",[x0],#8\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v16.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-112]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#64]\n\t"\ + "fmla v26.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-96]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-88]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[2]; ldr x16,[x2],#8\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-80]\n\t"\ + "fmla v18.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d"#an2",[x1],#8\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[0]\n\t"\ + "fmla v19.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v23.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d5,[x4,#-64]\n\t"\ + "fmla v29.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmla v16.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-48]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac1".s[3]; ldr x10,[x4,#-40]\n\t"\ + "fmla v26.4s,v4.4s,v"#ac2".s[3]; prfm pldl1keep,[x4,#112]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-32]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-24]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac1".s[3]; ldr x11,[x3],#8\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-16]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x4,#160]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "ins v8.d[1],v8.d[0]; ldr d9,[x4,#-8]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[1]; add x4,x4,#144\n\t"\ + "fmla v19.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "ins v9.d[1],v9.d[0]; ldr d4,[x4,#-144]\n\t"\ + "fmla v23.4s,v7.4s,v"#ac1".s[3]; ldr x10,[x4,#-136]\n\t"\ + "fmla v29.4s,v7.4s,v"#ac2".s[3]; sub w5,w5,#2\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".4s\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-128]\n\t"\ + "fmla v30.4s,v8.4s,v"#ac2".4s; ldr x10,[x4,#-120]\n\t"\ + "fmla v25.4s,v9.4s,v"#ac1".4s; cmp w5,#6\n\t"\ + "fmla v31.4s,v9.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N18_TAIL2(ac1, ac2) \ + "fmov v5.d[1],x10\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v16.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-112]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac2".s[0]\n\t"\ + "fmla v26.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-96]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-88]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-80]\n\t"\ + "fmla v18.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[0]\n\t"\ + "fmla v19.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x7]\n\t"\ + "fmla v23.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "ldr d5,[x4,#-64]\n\t"\ + "fmla v29.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmla v16.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-48]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac1".s[3]; ldr x10,[x4,#-40]\n\t"\ + "fmla v26.4s,v4.4s,v"#ac2".s[3]; prfm pldl1keep,[x8]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-32]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-24]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-16]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v18.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x9]\n\t"\ + "fmla v22.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "ins v8.d[1],v8.d[0]; ldr d9,[x4,#-8]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "fmla v19.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "ins v9.d[1],v9.d[0]\n\t"\ + "fmla v23.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v29.4s,v7.4s,v"#ac2".s[3]; sub w5,w5,#2\n\t"\ + "fmla v24.4s,v8.4s,v"#ac1".4s\n\t"\ + "fmla v30.4s,v8.4s,v"#ac2".4s\n\t"\ + "fmla v25.4s,v9.4s,v"#ac1".4s\n\t"\ + "fmla v31.4s,v9.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N18_FIN1 \ + "ldr s0,[x0],#4; ldr q4,[x4]; ldr q5,[x4,#16]\n\t"\ + "ldr q6,[x4,#32]; add x4,x4,#72\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v12.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v13.4s,v5.4s,v0.s[0]\n\t"\ + "fmla v14.4s,v6.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v16.4s,v4.4s,v1.s[0]\n\t"\ + "fmla v17.4s,v5.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v6.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4; ldr d7,[x4,#-24]\n\t"\ + "fmla v20.4s,v4.4s,v2.s[0]; ldr x10,[x4,#-16]\n\t"\ + "fmla v21.4s,v5.4s,v2.s[0]\n\t"\ + "fmla v22.4s,v6.4s,v2.s[0]\n\t"\ + "fmov v7.d[1],x10; ldr s8,[x4,#-8]\n\t"\ + "fmla v26.4s,v4.4s,v3.s[0]; ldr w10,[x4,#-4]\n\t"\ + "fmla v27.4s,v5.4s,v3.s[0]\n\t"\ + "fmla v28.4s,v6.4s,v3.s[0]\n\t"\ + "ins v8.d[1],v8.d[0]; dup v9.2d,x10\n\t"\ + "fmla v15.4s,v7.4s,v0.s[0]\n\t"\ + "fmla v19.4s,v7.4s,v1.s[0]\n\t"\ + "fmla v23.4s,v7.4s,v2.s[0]\n\t"\ + "ins v0.d[1],v2.d[0]; ins v1.d[1],v3.d[0]\n\t"\ + "fmla v29.4s,v7.4s,v3.s[0]\n\t"\ + "fmla v24.4s,v8.4s,v0.4s\n\t"\ + "fmla v30.4s,v8.4s,v1.4s\n\t"\ + "fmla v25.4s,v9.4s,v0.4s\n\t"\ + "fmla v31.4s,v9.4s,v1.4s\n\t" + + +/* m4n19 c_vec */ +/* v10 - v13 v22_comp v23_comp v24_comp */ +/* v14 - v17 v29_comp v30_comp v31_comp */ +/* v18 - v21 v22_comp v23_comp v24_comp */ +/* v25 - v28 v29_comp v30_comp v31_comp */ + +#define INIT_M4N19 \ + INIT_2V(10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N19(mode) \ + UNIT_SAVE_M4N4_VR_##mode(10, 14, 18, 25) UNIT_SAVE_M4N4_VR_##mode(11, 15, 19, 26)\ + UNIT_SAVE_M4N4_VR_##mode(12, 16, 20, 27) UNIT_SAVE_M4N4_VR_##mode(13, 17, 21, 28)\ + EDGE_SAVE_M4N1K2_##mode(22, 29) EDGE_SAVE_M4N1K2_##mode(23, 30) EDGE_SAVE_M4N1K2_##mode(24, 31) + +#define KERNEL_M4N19_PRELOAD2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr x16,[x2],#8; ldr x11,[x3],#8\n\t"\ + "ldr q4,[x4]; ldr d5,[x4,#16]; ldr x10,[x4,#24]\n\t"\ + "add x4,x4,#152; fmov v0.d[1],x16\n\t" + +#define KERNEL_M4N19_MAIN2(ac1, ac2, an1, an2, ap1, ap2) \ + "fmov v5.d[1],x10; ldr d"#an1",[x0],#8\n\t"\ + "fmla v10.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v14.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v18.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-120]\n\t"\ + "fmla v11.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-112]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#64]\n\t"\ + "fmla v25.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-104]\n\t"\ + "fmla v19.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-96]\n\t"\ + "fmla v26.4s,v5.4s,v"#ac2".s[2]; ldr x16,[x2],#8\n\t"\ + "fmla v12.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-88]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-80]\n\t"\ + "fmla v20.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v27.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-72]\n\t"\ + "fmla v13.4s,v7.4s,v"#ac1".s[0]; ldr x10,[x4,#-64]\n\t"\ + "fmla v17.4s,v7.4s,v"#ac2".s[0]\n\t"\ + "fmla v21.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v5.d[1],x10; ldr d"#an2",[x1],#8\n\t"\ + "fmla v28.4s,v7.4s,v"#ac2".s[2]\n\t"\ + "fmla v10.4s,v4.4s,v"#ac1".s[1]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v14.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d6,[x4,#-56]\n\t"\ + "fmla v18.4s,v4.4s,v"#ac1".s[3]; ldr x10,[x4,#-48]\n\t"\ + "fmla v25.4s,v4.4s,v"#ac2".s[3]; prfm pldl1keep,[x4,#112]\n\t"\ + "fmla v11.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-40]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-32]\n\t"\ + "fmla v19.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmla v26.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-24]\n\t"\ + "fmla v12.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x4,#160]\n\t"\ + "fmla v20.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "ins v8.d[1],v8.d[0]; ldr d9,[x4,#-16]\n\t"\ + "fmla v27.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "fmla v13.4s,v7.4s,v"#ac1".s[1]; ldr x11,[x3],#8\n\t"\ + "fmla v17.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "ins v9.d[1],v9.d[0]; ldr d6,[x4,#-8]\n\t"\ + "fmla v21.4s,v7.4s,v"#ac1".s[3]; add x4,x4,#152\n\t"\ + "fmla v28.4s,v7.4s,v"#ac2".s[3]; sub w5,w5,#2\n\t"\ + "fmla v22.4s,v8.4s,v"#ac1".4s\n\t"\ + "ins v6.d[1],v6.d[0]; ldr d4,[x4,#-152]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac2".4s; ldr x10,[x4,#-144]\n\t"\ + "fmla v23.4s,v9.4s,v"#ac1".4s; cmp w5,#6\n\t"\ + "fmla v30.4s,v9.4s,v"#ac2".4s\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-136]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".4s; ldr x10,[x4,#-128]\n\t"\ + "fmla v31.4s,v6.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N19_TAIL2(ac1, ac2) \ + "fmov v5.d[1],x10\n\t"\ + "fmla v10.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v14.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v18.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-120]\n\t"\ + "fmla v11.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-112]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac2".s[0]\n\t"\ + "fmla v25.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-104]\n\t"\ + "fmla v19.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-96]\n\t"\ + "fmla v26.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v12.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-88]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-80]\n\t"\ + "fmla v20.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v27.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-72]\n\t"\ + "fmla v13.4s,v7.4s,v"#ac1".s[0]; ldr x10,[x4,#-64]\n\t"\ + "fmla v17.4s,v7.4s,v"#ac2".s[0]\n\t"\ + "fmla v21.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v5.d[1],x10\n\t"\ + "fmla v28.4s,v7.4s,v"#ac2".s[2]\n\t"\ + "fmla v10.4s,v4.4s,v"#ac1".s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmla v14.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "ldr d6,[x4,#-56]\n\t"\ + "fmla v18.4s,v4.4s,v"#ac1".s[3]; ldr x10,[x4,#-48]\n\t"\ + "fmla v25.4s,v4.4s,v"#ac2".s[3]; prfm pldl1keep,[x8]\n\t"\ + "fmla v11.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-40]\n\t"\ + "fmla v15.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-32]\n\t"\ + "fmla v19.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmla v26.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-24]\n\t"\ + "fmla v12.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v16.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x9]\n\t"\ + "fmla v20.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "ins v8.d[1],v8.d[0]; ldr d9,[x4,#-16]\n\t"\ + "fmla v27.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "fmla v13.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "fmla v17.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "ins v9.d[1],v9.d[0]; ldr d6,[x4,#-8]\n\t"\ + "fmla v21.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v28.4s,v7.4s,v"#ac2".s[3]; sub w5,w5,#2\n\t"\ + "fmla v22.4s,v8.4s,v"#ac1".4s\n\t"\ + "ins v6.d[1],v6.d[0]\n\t"\ + "fmla v29.4s,v8.4s,v"#ac2".4s\n\t"\ + "fmla v23.4s,v9.4s,v"#ac1".4s\n\t"\ + "fmla v30.4s,v9.4s,v"#ac2".4s\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".4s\n\t"\ + "fmla v31.4s,v6.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N19_FIN1 \ + "ldr s0,[x0],#4; ldr q4,[x4]; ldr q5,[x4,#16]\n\t"\ + "ldr q6,[x4,#32]; add x4,x4,#76\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v10.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v11.4s,v5.4s,v0.s[0]\n\t"\ + "fmla v12.4s,v6.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4\n\t"\ + "fmla v14.4s,v4.4s,v1.s[0]\n\t"\ + "fmla v15.4s,v5.4s,v1.s[0]\n\t"\ + "fmla v16.4s,v6.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4; ldr d7,[x4,#-28]\n\t"\ + "fmla v18.4s,v4.4s,v2.s[0]; ldr x10,[x4,#-20]\n\t"\ + "fmla v19.4s,v5.4s,v2.s[0]\n\t"\ + "fmla v20.4s,v6.4s,v2.s[0]\n\t"\ + "fmov v7.d[1],x10; ldr s8,[x4,#-12]\n\t"\ + "fmla v25.4s,v4.4s,v3.s[0]; ldr w10,[x4,#-8]\n\t"\ + "fmla v26.4s,v5.4s,v3.s[0]\n\t"\ + "fmla v27.4s,v6.4s,v3.s[0]\n\t"\ + "ins v8.d[1],v8.d[0]; dup v9.2d,x10\n\t"\ + "fmla v13.4s,v7.4s,v0.s[0]; ldr w10,[x4,#-4]\n\t"\ + "fmla v17.4s,v7.4s,v1.s[0]\n\t"\ + "fmla v21.4s,v7.4s,v2.s[0]\n\t"\ + "ins v0.d[1],v2.d[0]; ins v1.d[1],v3.d[0]\n\t"\ + "fmla v28.4s,v7.4s,v3.s[0]\n\t"\ + "fmla v22.4s,v8.4s,v0.4s\n\t"\ + "fmla v29.4s,v8.4s,v1.4s\n\t"\ + "dup v6.2d,x10\n\t"\ + "fmla v23.4s,v9.4s,v0.4s\n\t"\ + "fmla v30.4s,v9.4s,v1.4s\n\t"\ + "fmla v24.4s,v6.4s,v0.4s\n\t"\ + "fmla v31.4s,v6.4s,v1.4s\n\t" + + +/* m4n20 c_vec */ +/* v12 - v16 */ +/* v17 - v21 */ +/* v22 - v26 */ +/* v27 - v31 */ + +#define INIT_M4N20 INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N20(mode) UNIT_SAVE_M4N4_VR_##mode(12, 17, 22, 27)\ + UNIT_SAVE_M4N4_VR_##mode(13, 18, 23, 28) UNIT_SAVE_M4N4_VR_##mode(14, 19, 24, 29)\ + UNIT_SAVE_M4N4_VR_##mode(15, 20, 25, 30) UNIT_SAVE_M4N4_VR_##mode(16, 21, 26, 31) + +#define KERNEL_M4N20_PRELOAD2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr x16,[x2],#8; ldr x11,[x3],#8\n\t"\ + "ldr q4,[x4]; ldr d5,[x4,#16]; ldr x10,[x4,#24]\n\t"\ + "add x4,x4,#160; fmov v0.d[1],x16\n\t" + +#define KERNEL_M4N20_MAIN2(ac1, ac2, an1, an2, ap1, ap2) \ + "fmov v5.d[1],x10; ldr d"#an1",[x0],#8\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-128]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#72]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-112]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-104]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[2]; ldr x16,[x2],#8\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-96]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d4,[x4,#-80]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#120]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-64]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmov v5.d[1],x10; ldr d"#an2",[x1],#8\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".s[2]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d6,[x4,#-48]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[3]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[3]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-32]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[1]; ldr x10,[x4,#-24]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[1]; prfm pldl1keep,[x4,#168]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-16]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[1]; add x4,x4,#160\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[1]\n\t"\ + "fmov v8.d[1],x10\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[3]; sub w5,w5,#2\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[3]; ldr x11,[x3],#8\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "ldr d4,[x4,#-160]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[1]; ldr x10,[x4,#-152]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[3]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; cmp w5,#6\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "ldr d5,[x4,#-144]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".s[3]; ldr x10,[x4,#-136]\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".s[3]\n\t" + +#define KERNEL_M4N20_TAIL2(ac1, ac2) \ + "fmov v5.d[1],x10\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-128]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[0]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-112]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-104]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-96]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[2]; prfm pldl1keep,[x7]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d4,[x4,#-80]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[0]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-64]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmov v5.d[1],x10\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "ldr d6,[x4,#-48]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[3]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[3]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-32]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[1]; ldr x10,[x4,#-24]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[1]; prfm pldl1keep,[x9]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-16]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[1]\n\t"\ + "fmov v8.d[1],x10\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[3]; sub w5,w5,#2\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[3]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".s[3]\n\t" + +#define KERNEL_M4N20_FIN1 \ + "ldr s0,[x0],#4; ldr q4,[x4]; ldr q5,[x4,#16]\n\t"\ + "ldr q6,[x4,#32]; add x4,x4,#80\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v12.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v13.4s,v5.4s,v0.s[0]\n\t"\ + "fmla v14.4s,v6.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4; ldr d7,[x4,#-32]\n\t"\ + "fmla v17.4s,v4.4s,v1.s[0]; ldr x10,[x4,#-24]\n\t"\ + "fmla v18.4s,v5.4s,v1.s[0]\n\t"\ + "fmla v19.4s,v6.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4; fmov v7.d[1],x10\n\t"\ + "fmla v22.4s,v4.4s,v2.s[0]\n\t"\ + "fmla v23.4s,v5.4s,v2.s[0]\n\t"\ + "fmla v24.4s,v6.4s,v2.s[0]\n\t"\ + "ldr d8,[x4,#-16]\n\t"\ + "fmla v27.4s,v4.4s,v3.s[0]; ldr x10,[x4,#-8]\n\t"\ + "fmla v28.4s,v5.4s,v3.s[0]\n\t"\ + "fmla v29.4s,v6.4s,v3.s[0]\n\t"\ + "fmov v8.d[1],x10\n\t"\ + "fmla v15.4s,v7.4s,v0.s[0]\n\t"\ + "fmla v20.4s,v7.4s,v1.s[0]\n\t"\ + "fmla v25.4s,v7.4s,v2.s[0]\n\t"\ + "fmla v30.4s,v7.4s,v3.s[0]\n\t"\ + "fmla v16.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v21.4s,v8.4s,v1.s[0]\n\t"\ + "fmla v26.4s,v8.4s,v2.s[0]\n\t"\ + "fmla v31.4s,v8.4s,v3.s[0]\n\t" + + +/* m4n21 c_vec */ +/* v12 - v16 v10_comp */ +/* v17 - v21 v11_comp */ +/* v22 - v26 v10_comp */ +/* v27 - v31 v11_comp */ + +#define INIT_M4N21 \ + INIT_2V(10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N21(mode) UNIT_SAVE_M4N4_VR_##mode(12, 17, 22, 27)\ + UNIT_SAVE_M4N4_VR_##mode(13, 18, 23, 28) UNIT_SAVE_M4N4_VR_##mode(14, 19, 24, 29)\ + UNIT_SAVE_M4N4_VR_##mode(15, 20, 25, 30) UNIT_SAVE_M4N4_VR_##mode(16, 21, 26, 31)\ + EDGE_SAVE_M4N1K2_##mode(10, 11) + +#define KERNEL_M4N21_PRELOAD2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr x16,[x2],#8; ldr x11,[x3],#8\n\t"\ + "ldr q4,[x4]; ldr d5,[x4,#16]; ldr x10,[x4,#24]\n\t"\ + "add x4,x4,#168; fmov v0.d[1],x16\n\t" + +#define KERNEL_M4N21_MAIN2(ac1, ac2, an1, an2, ap1, ap2) \ + "fmov v5.d[1],x10; ldr d"#an1",[x0],#8\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-136]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-128]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#64]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-120]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-112]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-104]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-96]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[2]; ldr x16,[x2],#8\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d4,[x4,#-88]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[0]; ldr x10,[x4,#-80]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#120]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-72]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-64]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmov v5.d[1],x10; ldr d"#an2",[x1],#8\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".s[2]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d6,[x4,#-56]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[1]; ldr x10,[x4,#-48]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[3]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[3]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-40]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[1]; ldr x10,[x4,#-32]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[1]; prfm pldl1keep,[x4,#176]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-24]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[3]; ldr x10,[x4,#-16]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-8]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[3]; add x4,x4,#168\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[3]; ldr x11,[x3],#8\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "ins v9.d[1],v9.d[0]; ldr d4,[x4,#-168]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[1]; ldr x10,[x4,#-160]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[3]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; sub w5,w5,#2\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "ldr d5,[x4,#-152]\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".s[3]; ldr x10,[x4,#-144]\n\t"\ + "fmla v10.4s,v9.4s,v"#ac1".4s; cmp w5,#6\n\t"\ + "fmla v11.4s,v9.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N21_TAIL2(ac1, ac2) \ + "fmov v5.d[1],x10\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[0]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-136]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-128]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-120]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-112]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-104]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-96]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v8.d[1],x10; ldr d4,[x4,#-88]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[0]; ldr x10,[x4,#-80]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x7]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-72]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-64]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[0]\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmov v5.d[1],x10\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".s[2]\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "ldr d6,[x4,#-56]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[1]; ldr x10,[x4,#-48]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[3]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[3]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-40]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[1]; ldr x10,[x4,#-32]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[1]; prfm pldl1keep,[x9]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d8,[x4,#-24]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[3]; ldr x10,[x4,#-16]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[1]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[1]\n\t"\ + "fmov v8.d[1],x10; ldr d9,[x4,#-8]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[3]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "ins v9.d[1],v9.d[0]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[3]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; sub w5,w5,#2\n\t"\ + "fmla v21.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v26.4s,v8.4s,v"#ac1".s[3]\n\t"\ + "fmla v31.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v10.4s,v9.4s,v"#ac1".4s\n\t"\ + "fmla v11.4s,v9.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N21_FIN1 \ + "ldr s0,[x0],#4; ldr q4,[x4]; ldr q5,[x4,#16]\n\t"\ + "ldr q6,[x4,#32]; add x4,x4,#84\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v12.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v13.4s,v5.4s,v0.s[0]\n\t"\ + "fmla v14.4s,v6.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4; ldr d7,[x4,#-36]\n\t"\ + "fmla v17.4s,v4.4s,v1.s[0]; ldr x10,[x4,#-28]\n\t"\ + "fmla v18.4s,v5.4s,v1.s[0]\n\t"\ + "fmla v19.4s,v6.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4; fmov v7.d[1],x10\n\t"\ + "fmla v22.4s,v4.4s,v2.s[0]\n\t"\ + "fmla v23.4s,v5.4s,v2.s[0]\n\t"\ + "fmla v24.4s,v6.4s,v2.s[0]\n\t"\ + "ldr d8,[x4,#-20]\n\t"\ + "fmla v27.4s,v4.4s,v3.s[0]; ldr x10,[x4,#-12]\n\t"\ + "fmla v28.4s,v5.4s,v3.s[0]\n\t"\ + "fmla v29.4s,v6.4s,v3.s[0]\n\t"\ + "fmov v8.d[1],x10; ldr s9,[x4,#-4]\n\t"\ + "fmla v15.4s,v7.4s,v0.s[0]\n\t"\ + "fmla v20.4s,v7.4s,v1.s[0]\n\t"\ + "fmla v25.4s,v7.4s,v2.s[0]\n\t"\ + "ins v9.d[1],v9.d[0]\n\t"\ + "fmla v30.4s,v7.4s,v3.s[0]\n\t"\ + "fmla v16.4s,v8.4s,v0.s[0]\n\t"\ + "fmla v21.4s,v8.4s,v1.s[0]\n\t"\ + "ins v0.d[1],v2.d[0]; ins v1.d[1],v3.d[0]\n\t"\ + "fmla v26.4s,v8.4s,v2.s[0]\n\t"\ + "fmla v31.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v10.4s,v9.4s,v0.4s\n\t"\ + "fmla v11.4s,v9.4s,v1.4s\n\t" + + +/* m4n22 c_vec */ +/* v12 - v16 v10_comp v8_comp */ +/* v17 - v21 v11_comp v9_comp */ +/* v22 - v26 v10_comp v8_comp */ +/* v27 - v31 v11_comp v9_comp */ + +#define INIT_M4N22 \ + INIT_4V(8, 9, 10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N22(mode) UNIT_SAVE_M4N4_VR_##mode(12, 17, 22, 27)\ + UNIT_SAVE_M4N4_VR_##mode(13, 18, 23, 28) UNIT_SAVE_M4N4_VR_##mode(14, 19, 24, 29)\ + UNIT_SAVE_M4N4_VR_##mode(15, 20, 25, 30) UNIT_SAVE_M4N4_VR_##mode(16, 21, 26, 31)\ + EDGE_SAVE_M4N1K2_##mode(10, 11) EDGE_SAVE_M4N1K2_##mode(8, 9) + +#define KERNEL_M4N22_PRELOAD2 \ + "ldr d0,[x0],#8; ldr d1,[x1],#8; ldr x16,[x2],#8; ldr x11,[x3],#8\n\t"\ + "ldr q4,[x4]; ldr d5,[x4,#16]; ldr x10,[x4,#24]\n\t"\ + "add x4,x4,#176; fmov v0.d[1],x16\n\t" + +#define KERNEL_M4N22_MAIN2(ac1, ac2, an1, an2, ap1, ap2) \ + "fmov v5.d[1],x10; ldr d"#an1",[x0],#8\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-144]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-136]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#64]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-128]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-112]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[2]; ldr x16,[x2],#8\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-96]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x4,#120]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-80]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v16.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v21.4s,v4.4s,v"#ac2".s[0]\n\t"\ + "fmov v6.d[1],x10; ldr d"#an2",[x1],#8\n\t"\ + "fmla v26.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmla v31.4s,v4.4s,v"#ac2".s[2]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v12.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmov v"#an1".d[1],x16; ldr d7,[x4,#-64]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-56]\n\t"\ + "fmla v22.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-48]\n\t"\ + "fmla v13.4s,v6.4s,v"#ac1".s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v18.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x4,#176]\n\t"\ + "fmla v23.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-32]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v14.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "fmla v19.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-16]\n\t"\ + "fmla v24.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v29.4s,v7.4s,v"#ac2".s[3]; ldr x11,[x3],#8\n\t"\ + "fmla v15.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "ins v6.d[1],v6.d[0]; ldr d7,[x4,#-8]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac2".s[1]; add x4,x4,#176\n\t"\ + "fmla v25.4s,v4.4s,v"#ac1".s[3]\n\t"\ + "fmla v30.4s,v4.4s,v"#ac2".s[3]\n\t"\ + "ins v7.d[1],v7.d[0]; ldr d4,[x4,#-176]\n\t"\ + "fmla v16.4s,v5.4s,v"#ac1".s[1]; ldr x10,[x4,#-168]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac2".s[1]; sub w5,w5,#2\n\t"\ + "fmla v26.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v31.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmla v10.4s,v6.4s,v"#ac1".4s; cmp w5,#6\n\t"\ + "fmla v11.4s,v6.4s,v"#ac2".4s\n\t"\ + "ldr d5,[x4,#-160]\n\t"\ + "fmla v8.4s,v7.4s,v"#ac1".4s; ldr x10,[x4,#-152]\n\t"\ + "fmla v9.4s,v7.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N22_TAIL2(ac1, ac2) \ + "fmov v5.d[1],x10\n\t"\ + "fmla v12.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v17.4s,v4.4s,v"#ac2".s[0]; prfm pldl1keep,[x6]\n\t"\ + "fmla v22.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmov v"#ac2".d[1],x11; ldr d6,[x4,#-144]\n\t"\ + "fmla v13.4s,v5.4s,v"#ac1".s[0]; ldr x10,[x4,#-136]\n\t"\ + "fmla v18.4s,v5.4s,v"#ac2".s[0]\n\t"\ + "fmla v27.4s,v4.4s,v"#ac2".s[2]\n\t"\ + "fmov v6.d[1],x10; ldr d7,[x4,#-128]\n\t"\ + "fmla v23.4s,v5.4s,v"#ac1".s[2]; ldr x10,[x4,#-120]\n\t"\ + "fmla v28.4s,v5.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v6.4s,v"#ac1".s[0]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-112]\n\t"\ + "fmla v19.4s,v6.4s,v"#ac2".s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v24.4s,v6.4s,v"#ac1".s[2]\n\t"\ + "fmla v29.4s,v6.4s,v"#ac2".s[2]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-96]\n\t"\ + "fmla v15.4s,v7.4s,v"#ac1".s[0]; ldr x10,[x4,#-88]\n\t"\ + "fmla v20.4s,v7.4s,v"#ac2".s[0]; prfm pldl1keep,[x7]\n\t"\ + "fmla v25.4s,v7.4s,v"#ac1".s[2]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-80]\n\t"\ + "fmla v30.4s,v7.4s,v"#ac2".s[2]; ldr x10,[x4,#-72]\n\t"\ + "fmla v16.4s,v4.4s,v"#ac1".s[0]\n\t"\ + "fmla v21.4s,v4.4s,v"#ac2".s[0]\n\t"\ + "fmov v6.d[1],x10\n\t"\ + "fmla v26.4s,v4.4s,v"#ac1".s[2]\n\t"\ + "fmla v31.4s,v4.4s,v"#ac2".s[2]; prfm pldl1keep,[x8]\n\t"\ + "fmla v12.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "ldr d7,[x4,#-64]\n\t"\ + "fmla v17.4s,v5.4s,v"#ac2".s[1]; ldr x10,[x4,#-56]\n\t"\ + "fmla v22.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmla v27.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmov v7.d[1],x10; ldr d4,[x4,#-48]\n\t"\ + "fmla v13.4s,v6.4s,v"#ac1".s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v18.4s,v6.4s,v"#ac2".s[1]; prfm pldl1keep,[x9]\n\t"\ + "fmla v23.4s,v6.4s,v"#ac1".s[3]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-32]\n\t"\ + "fmla v28.4s,v6.4s,v"#ac2".s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v14.4s,v7.4s,v"#ac1".s[1]\n\t"\ + "fmla v19.4s,v7.4s,v"#ac2".s[1]\n\t"\ + "fmov v5.d[1],x10; ldr d6,[x4,#-16]\n\t"\ + "fmla v24.4s,v7.4s,v"#ac1".s[3]\n\t"\ + "fmla v29.4s,v7.4s,v"#ac2".s[3]\n\t"\ + "fmla v15.4s,v4.4s,v"#ac1".s[1]\n\t"\ + "ins v6.d[1],v6.d[0]; ldr d7,[x4,#-8]\n\t"\ + "fmla v20.4s,v4.4s,v"#ac2".s[1]\n\t"\ + "fmla v25.4s,v4.4s,v"#ac1".s[3]\n\t"\ + "fmla v30.4s,v4.4s,v"#ac2".s[3]\n\t"\ + "ins v7.d[1],v7.d[0]\n\t"\ + "fmla v16.4s,v5.4s,v"#ac1".s[1]\n\t"\ + "fmla v21.4s,v5.4s,v"#ac2".s[1]; sub w5,w5,#2\n\t"\ + "fmla v26.4s,v5.4s,v"#ac1".s[3]\n\t"\ + "fmla v31.4s,v5.4s,v"#ac2".s[3]\n\t"\ + "fmla v10.4s,v6.4s,v"#ac1".4s\n\t"\ + "fmla v11.4s,v6.4s,v"#ac2".4s\n\t"\ + "fmla v8.4s,v7.4s,v"#ac1".4s\n\t"\ + "fmla v9.4s,v7.4s,v"#ac2".4s\n\t" + +#define KERNEL_M4N22_FIN1 \ + "ldr s0,[x0],#4; ldr q4,[x4]; ldr q5,[x4,#16]\n\t"\ + "ldr q6,[x4,#32]; add x4,x4,#88\n\t"\ + "ldr s1,[x1],#4\n\t"\ + "fmla v12.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v13.4s,v5.4s,v0.s[0]\n\t"\ + "fmla v14.4s,v6.4s,v0.s[0]\n\t"\ + "ldr s2,[x2],#4; ldr d7,[x4,#-40]\n\t"\ + "fmla v17.4s,v4.4s,v1.s[0]; ldr x10,[x4,#-32]\n\t"\ + "fmla v18.4s,v5.4s,v1.s[0]\n\t"\ + "fmla v19.4s,v6.4s,v1.s[0]\n\t"\ + "ldr s3,[x3],#4; fmov v7.d[1],x10\n\t"\ + "fmla v22.4s,v4.4s,v2.s[0]\n\t"\ + "fmla v23.4s,v5.4s,v2.s[0]\n\t"\ + "fmla v27.4s,v4.4s,v3.s[0]\n\t"\ + "ldr d4,[x4,#-24]\n\t"\ + "fmla v24.4s,v6.4s,v2.s[0]; ldr x10,[x4,#-16]\n\t"\ + "fmla v28.4s,v5.4s,v3.s[0]\n\t"\ + "fmla v29.4s,v6.4s,v3.s[0]\n\t"\ + "fmov v4.d[1],x10; ldr s5,[x4,#-8]\n\t"\ + "fmla v15.4s,v7.4s,v0.s[0]; ldr w11,[x4,#-4]\n\t"\ + "fmla v20.4s,v7.4s,v1.s[0]\n\t"\ + "fmla v25.4s,v7.4s,v2.s[0]\n\t"\ + "ins v5.d[1],v5.d[0]; dup v6.2d,x11\n\t"\ + "fmla v30.4s,v7.4s,v3.s[0]\n\t"\ + "fmla v16.4s,v4.4s,v0.s[0]\n\t"\ + "fmla v21.4s,v4.4s,v1.s[0]\n\t"\ + "ins v0.d[1],v2.d[0]; ins v1.d[1],v3.d[0]\n\t"\ + "fmla v26.4s,v4.4s,v2.s[0]\n\t"\ + "fmla v31.4s,v4.4s,v3.s[0]\n\t"\ + "fmla v10.4s,v5.4s,v0.4s\n\t"\ + "fmla v11.4s,v5.4s,v1.4s\n\t"\ + "fmla v8.4s,v6.4s,v0.4s\n\t"\ + "fmla v9.4s,v6.4s,v1.4s\n\t" + + +#define FUNC_K2(ndim) \ +static inline void sgemm_skinny1_a53_m4n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_scr,\ + float * __restrict__ c_ptr, uint32_t K, uint32_t LDA, uint32_t LDC,\ + uint8_t c_rowmajor, const float * __restrict__ beta_addr) {\ + __asm__ __volatile__ (\ + "mov x0,%[a_ptr]; add x1,%[a_ptr],%w[LDA],UXTW #2\n\t"\ + "add x2,%[a_ptr],%w[LDA],UXTW #3; add x3,x1,%w[LDA],UXTW #3\n\t"\ + "add x6,x0,%w[LDA],UXTW #4; add x7,x1,%w[LDA],UXTW #4\n\t"\ + "add x8,x2,%w[LDA],UXTW #4; add x9,x3,%w[LDA],UXTW #4\n\t"\ + "mov x4,%[b_scr]; mov w5,%w[K]\n\t"\ + INIT_M4N##ndim\ + "cmp w5,#2; b.lt 4f\n\t"\ + KERNEL_M4N##ndim##_PRELOAD2\ + "cmp w5,#6; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + KERNEL_M4N##ndim##_MAIN2(0, 1, 2, 3, 0, 1)\ + KERNEL_M4N##ndim##_MAIN2(2, 3, 0, 1, 2, 3)\ + "b.ge 1b; 2:\n\t"\ + "cmp w5,#4; b.lt 3f\n\t"\ + KERNEL_M4N##ndim##_MAIN2(0, 1, 2, 3, 0, 1)\ + KERNEL_M4N##ndim##_TAIL2(2, 3)\ + "b 4f; 3:\n\t"\ + KERNEL_M4N##ndim##_TAIL2(0, 1)\ + "4:\n\t"\ + "cmp w5,#1; b.lt 6f\n\t"\ + "5:\n\t"\ + KERNEL_M4N##ndim##_FIN1\ + "6:\n\t"\ + INIT_SAVE\ + "cmp %w[c_rowmajor],#0; b.eq 7f\n\t"\ + SAVE_M4N##ndim(CR) "b 8f\n\t"\ + "7:\n\t"\ + SAVE_M4N##ndim(CC)\ + "8:\n\t"\ + ::[a_ptr]"r"(a_ptr), [c_ptr]"r"(c_ptr), [b_scr]"r"(b_scr),\ + [K]"r"(K), [LDA]"r"(LDA), [LDC]"r"(LDC),\ + [beta_addr]"r"(beta_addr), [c_rowmajor]"r"(c_rowmajor)\ + :"cc","memory","x0","x1","x2","x3","x4","x5","x6","x7","x8","x9",\ + "x10","x11","x12","x13","x14","x15","x16",\ + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","v12","v13",\ + "v14","v15","v16","v17","v18","v19","v20","v21","v22","v23","v24","v25",\ + "v26","v27","v28","v29","v30","v31");\ +} + +FUNC_K2(15) +FUNC_K2(16) +FUNC_K2(17) +FUNC_K2(18) +FUNC_K2(19) +FUNC_K2(20) +FUNC_K2(21) +FUNC_K2(22) + +#define INIT_M4N23 INIT_4V(6, 7, 8, 9) \ + INIT_4V(10, 11, 12, 13) INIT_4V(14, 15, 16, 17)\ + INIT_4V(18, 19, 20, 21) INIT_4V(22, 23, 24, 25)\ + INIT_2V(26, 27) INIT_1V(28) + +#define SAVE_M4N23(mode) \ + UNIT_SAVE_M4N4_VC_##mode(6, 7, 8, 9) UNIT_SAVE_M4N4_VC_##mode(10, 11, 12, 13)\ + UNIT_SAVE_M4N4_VC_##mode(14, 15, 16, 17) UNIT_SAVE_M4N4_VC_##mode(18, 19, 20, 21)\ + UNIT_SAVE_M4N4_VC_##mode(22, 23, 24, 25) EDGE_SAVE_M4N1K1_##mode(26)\ + EDGE_SAVE_M4N1K1_##mode(27) EDGE_SAVE_M4N1K1_##mode(28) + +#define KERNEL_M4N23_PRELOAD2 \ + "ldr x16,[x0],#8; ldr x17,[x1],#8; ldr x19,[x2],#8; ldr x20,[x3],#8\n\t"\ + "ldr q2,[x4]; ldr q3,[x4,#16]; ldr x10,[x4,#24]; add x4,x4,#184\n\t"\ + "mov w11,w16; bfi x11,x17,#32,#32; fmov d0,x11\n\t"\ + "mov w11,w19; bfi x11,x20,#32,#32; fmov v0.d[1],x11\n\t" + +#define KERNEL_M4N23_MAIN2(ap1, ap2) \ + "fmov v3.d[1],x10; ldr d4,[x4,#-152]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-144]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]; bfxil x17,x16,#32,#32\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d1,x17\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]; ldr x16,[x0],#8\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]; bfxil x20,x19,#32,#32\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "fmov v1.d[1],x20; ldr d2,[x4,#-136]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-128]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x4,#48]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-120]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-112]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]; ldr x17,[x1],#8\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-104]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-96]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]; mov w11,w16\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]; bfi x11,x17,#32,#32\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]; ldr x19,[x2],#8\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-88]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-80]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-72]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-64]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v6.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-56]\n\t"\ + "fmla v7.4s,v1.4s,v2.s[0]; ldr x10,[x4,#-48]\n\t"\ + "fmla v8.4s,v1.4s,v2.s[1]; prfm pldl1keep,[x4,#112]\n\t"\ + "fmla v9.4s,v1.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d0,x11\n\t"\ + "fmla v10.4s,v1.4s,v2.s[3]; mov w11,w19\n\t"\ + "fmla v11.4s,v1.4s,v3.s[0]; ldr x20,[x3],#8\n\t"\ + "fmla v12.4s,v1.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-40]\n\t"\ + "fmla v13.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-32]\n\t"\ + "fmla v14.4s,v1.4s,v3.s[3]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v15.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-24]\n\t"\ + "fmla v16.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-16]\n\t"\ + "fmla v17.4s,v1.4s,v4.s[2]; bfi x11,x20,#32,#32\n\t"\ + "fmla v18.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; fmov v0.d[1],x11\n\t"\ + "fmla v19.4s,v1.4s,v2.s[0]; sub w5,w5,#2\n\t"\ + "fmla v20.4s,v1.4s,v2.s[1]; cmp w5,#6\n\t"\ + "fmla v21.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-8]\n\t"\ + "fmla v22.4s,v1.4s,v2.s[3]; prfm pldl1keep,[x4,#176]\n\t"\ + "fmla v23.4s,v1.4s,v3.s[0]; add x4,x4,#184\n\t"\ + "fmla v24.4s,v1.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-184]\n\t"\ + "fmla v25.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-176]\n\t"\ + "fmla v26.4s,v1.4s,v3.s[3]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-168]\n\t"\ + "fmla v27.4s,v1.4s,v4.s[0]; ldr x10,[x4,#-160]\n\t"\ + "fmla v28.4s,v1.4s,v4.s[1]\n\t" + +#define KERNEL_M4N23_TAIL2 \ + "fmov v3.d[1],x10; ldr d4,[x4,#-152]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-144]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]; bfxil x17,x16,#32,#32\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d1,x17\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]; bfxil x20,x19,#32,#32\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "fmov v1.d[1],x20; ldr d2,[x4,#-136]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-128]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-120]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-112]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-104]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-96]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-88]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-80]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x6]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-72]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-64]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v6.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10\n\t"\ + "fmla v7.4s,v1.4s,v2.s[0]\n\t"\ + "fmla v8.4s,v1.4s,v2.s[1]\n\t"\ + "fmla v9.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-56]\n\t"\ + "fmla v10.4s,v1.4s,v2.s[3]; ldr x10,[x4,#-48]\n\t"\ + "fmla v11.4s,v1.4s,v3.s[0]; prfm pldl1keep,[x7]\n\t"\ + "fmla v12.4s,v1.4s,v3.s[1]\n\t"\ + "fmov v4.d[1],x10; ldr d2,[x4,#-40]\n\t"\ + "fmla v13.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-32]\n\t"\ + "fmla v14.4s,v1.4s,v3.s[3]; prfm pldl1keep,[x8]\n\t"\ + "fmla v15.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-24]\n\t"\ + "fmla v16.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-16]\n\t"\ + "fmla v17.4s,v1.4s,v4.s[2]; prfm pldl1keep,[x9]\n\t"\ + "fmla v18.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10\n\t"\ + "fmla v19.4s,v1.4s,v2.s[0]; sub w5,w5,#2\n\t"\ + "fmla v20.4s,v1.4s,v2.s[1]\n\t"\ + "fmla v21.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-8]\n\t"\ + "fmla v22.4s,v1.4s,v2.s[3]\n\t"\ + "fmla v23.4s,v1.4s,v3.s[0]\n\t"\ + "fmla v24.4s,v1.4s,v3.s[1]\n\t"\ + "fmla v25.4s,v1.4s,v3.s[2]\n\t"\ + "fmla v26.4s,v1.4s,v3.s[3]\n\t"\ + "fmla v27.4s,v1.4s,v4.s[0]\n\t"\ + "fmla v28.4s,v1.4s,v4.s[1]\n\t" + +#define KERNEL_M4N23_FIN1 \ + "ldr w16,[x0],#4; ldr q2,[x4]\n\t"\ + "ldr w17,[x1],#4; ldr d3,[x4,#16]\n\t"\ + "ldr w19,[x2],#4; ldr x10,[x4,#24]\n\t"\ + "ldr w20,[x3],#4; orr x16,x16,x17,LSL #32\n\t"\ + "fmov d0,x16; orr x19,x19,x20,LSL #32; fmov v0.d[1],x19\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#32]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#40]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#48]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#56]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#64]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#72]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#80]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr w10,[x4,#88]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]; add x4,x4,#92\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t" + + +#define INIT_M4N24 INIT_4V(6, 7, 8, 9) \ + INIT_4V(10, 11, 12, 13) INIT_4V(14, 15, 16, 17)\ + INIT_4V(18, 19, 20, 21) INIT_4V(22, 23, 24, 25)\ + INIT_4V(26, 27, 28, 29) + +#define SAVE_M4N24(mode) \ + UNIT_SAVE_M4N4_VC_##mode(6, 7, 8, 9) UNIT_SAVE_M4N4_VC_##mode(10, 11, 12, 13)\ + UNIT_SAVE_M4N4_VC_##mode(14, 15, 16, 17) UNIT_SAVE_M4N4_VC_##mode(18, 19, 20, 21)\ + UNIT_SAVE_M4N4_VC_##mode(22, 23, 24, 25) UNIT_SAVE_M4N4_VC_##mode(26, 27, 28, 29) + +#define KERNEL_M4N24_PRELOAD2 \ + "ldr x16,[x0],#8; ldr x17,[x1],#8; ldr x19,[x2],#8; ldr x20,[x3],#8\n\t"\ + "ldr q2,[x4]; ldr q3,[x4,#16]; ldr x10,[x4,#24]; add x4,x4,#192\n\t"\ + "mov w11,w16; bfi x11,x17,#32,#32; fmov d0,x11\n\t"\ + "mov w11,w19; bfi x11,x20,#32,#32; fmov v0.d[1],x11\n\t" + +#define KERNEL_M4N24_MAIN2(ap1, ap2) \ + "fmov v3.d[1],x10; ldr d4,[x4,#-160]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-152]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]; bfxil x17,x16,#32,#32\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d1,x17\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]; ldr x16,[x0],#8\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]; bfxil x20,x19,#32,#32\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "fmov v1.d[1],x20; ldr d2,[x4,#-144]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-136]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x4,#48]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-128]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-120]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]; ldr x17,[x1],#8\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-112]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]; mov w11,w16\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]; bfi x11,x17,#32,#32\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]; ldr x19,[x2],#8\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-96]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-88]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-80]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-72]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-64]\n\t"\ + "fmla v6.4s,v1.4s,v2.s[0]; ldr x10,[x4,#-56]\n\t"\ + "fmla v7.4s,v1.4s,v2.s[1]; prfm pldl1keep,[x4,#112]\n\t"\ + "fmla v8.4s,v1.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d0,x11\n\t"\ + "fmla v9.4s,v1.4s,v2.s[3]; mov w11,w19\n\t"\ + "fmla v10.4s,v1.4s,v3.s[0]; ldr x20,[x3],#8\n\t"\ + "fmla v11.4s,v1.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-48]\n\t"\ + "fmla v12.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v13.4s,v1.4s,v3.s[3]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v14.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-32]\n\t"\ + "fmla v15.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-24]\n\t"\ + "fmla v16.4s,v1.4s,v4.s[2]; bfi x11,x20,#32,#32\n\t"\ + "fmla v17.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; fmov v0.d[1],x11\n\t"\ + "fmla v18.4s,v1.4s,v2.s[0]; sub w5,w5,#2\n\t"\ + "fmla v19.4s,v1.4s,v2.s[1]; cmp w5,#6\n\t"\ + "fmla v20.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-16]\n\t"\ + "fmla v21.4s,v1.4s,v2.s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v22.4s,v1.4s,v3.s[0]; prfm pldl1keep,[x4,#176]\n\t"\ + "fmla v23.4s,v1.4s,v3.s[1]\n\t"\ + "fmov v4.d[1],x10; ldr d2,[x4]\n\t"\ + "fmla v24.4s,v1.4s,v3.s[2]; ldr x10,[x4,#8]\n\t"\ + "fmla v25.4s,v1.4s,v3.s[3]; add x4,x4,#192\n\t"\ + "fmla v26.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-176]\n\t"\ + "fmla v27.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-168]\n\t"\ + "fmla v28.4s,v1.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v1.4s,v4.s[3]\n\t" + +#define KERNEL_M4N24_TAIL2 \ + "fmov v3.d[1],x10; ldr d4,[x4,#-160]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-152]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]; bfxil x17,x16,#32,#32\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d1,x17\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]; bfxil x20,x19,#32,#32\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "fmov v1.d[1],x20; ldr d2,[x4,#-144]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-136]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-128]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-120]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-112]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-104]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-96]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-88]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x6]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-80]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-72]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-64]\n\t"\ + "fmla v6.4s,v1.4s,v2.s[0]; ldr x10,[x4,#-56]\n\t"\ + "fmla v7.4s,v1.4s,v2.s[1]; prfm pldl1keep,[x7]\n\t"\ + "fmla v8.4s,v1.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v9.4s,v1.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v1.4s,v3.s[0]\n\t"\ + "fmla v11.4s,v1.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-48]\n\t"\ + "fmla v12.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-40]\n\t"\ + "fmla v13.4s,v1.4s,v3.s[3]; prfm pldl1keep,[x8]\n\t"\ + "fmla v14.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-32]\n\t"\ + "fmla v15.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-24]\n\t"\ + "fmla v16.4s,v1.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10\n\t"\ + "fmla v18.4s,v1.4s,v2.s[0]; sub w5,w5,#2\n\t"\ + "fmla v19.4s,v1.4s,v2.s[1]\n\t"\ + "fmla v20.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-16]\n\t"\ + "fmla v21.4s,v1.4s,v2.s[3]; ldr x10,[x4,#-8]\n\t"\ + "fmla v22.4s,v1.4s,v3.s[0]; prfm pldl1keep,[x9]\n\t"\ + "fmla v23.4s,v1.4s,v3.s[1]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v24.4s,v1.4s,v3.s[2]\n\t"\ + "fmla v25.4s,v1.4s,v3.s[3]\n\t"\ + "fmla v26.4s,v1.4s,v4.s[0]\n\t"\ + "fmla v27.4s,v1.4s,v4.s[1]\n\t"\ + "fmla v28.4s,v1.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v1.4s,v4.s[3]\n\t" + +#define KERNEL_M4N24_FIN1 \ + "ldr w16,[x0],#4; ldr q2,[x4]\n\t"\ + "ldr w17,[x1],#4; ldr d3,[x4,#16]\n\t"\ + "ldr w19,[x2],#4; ldr x10,[x4,#24]\n\t"\ + "ldr w20,[x3],#4; orr x16,x16,x17,LSL #32\n\t"\ + "fmov d0,x16; orr x19,x19,x20,LSL #32; fmov v0.d[1],x19\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#32]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#40]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#48]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#56]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#64]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#72]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#80]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#88]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]; add x4,x4,#96\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t" + + +#define INIT_M4N25 INIT_4V(6, 7, 8, 9) \ + INIT_4V(10, 11, 12, 13) INIT_4V(14, 15, 16, 17)\ + INIT_4V(18, 19, 20, 21) INIT_4V(22, 23, 24, 25)\ + INIT_4V(26, 27, 28, 29) INIT_1V(30) + +#define SAVE_M4N25(mode) \ + UNIT_SAVE_M4N4_VC_##mode(6, 7, 8, 9) UNIT_SAVE_M4N4_VC_##mode(10, 11, 12, 13)\ + UNIT_SAVE_M4N4_VC_##mode(14, 15, 16, 17) UNIT_SAVE_M4N4_VC_##mode(18, 19, 20, 21)\ + UNIT_SAVE_M4N4_VC_##mode(22, 23, 24, 25) UNIT_SAVE_M4N4_VC_##mode(26, 27, 28, 29)\ + EDGE_SAVE_M4N1K1_##mode(30) + +#define KERNEL_M4N25_PRELOAD2 \ + "ldr x16,[x0],#8; ldr x17,[x1],#8; ldr x19,[x2],#8; ldr x20,[x3],#8\n\t"\ + "ldr q2,[x4]; ldr q3,[x4,#16]; ldr x10,[x4,#24]; add x4,x4,#200\n\t"\ + "mov w11,w16; bfi x11,x17,#32,#32; fmov d0,x11\n\t"\ + "mov w11,w19; bfi x11,x20,#32,#32; fmov v0.d[1],x11\n\t" + +#define KERNEL_M4N25_MAIN2(ap1, ap2) \ + "fmov v3.d[1],x10; ldr d4,[x4,#-168]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-160]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]; bfxil x17,x16,#32,#32\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d1,x17\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]; ldr x16,[x0],#8\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]; bfxil x20,x19,#32,#32\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "fmov v1.d[1],x20; ldr d2,[x4,#-152]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-144]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x4,#48]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-136]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-128]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]; ldr x17,[x1],#8\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-120]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-112]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]; mov w11,w16\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]; bfi x11,x17,#32,#32\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-104]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-96]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-88]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-80]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]; ldr x19,[x2],#8\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-72]\n\t"\ + "fmla v30.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-64]\n\t"\ + "fmla v6.4s,v1.4s,v2.s[1]; prfm pldl1keep,[x4,#96]\n\t"\ + "fmla v7.4s,v1.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d0,x11\n\t"\ + "fmla v8.4s,v1.4s,v2.s[3]; mov w11,w19\n\t"\ + "fmla v9.4s,v1.4s,v3.s[0]\n\t"\ + "fmla v10.4s,v1.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-56]\n\t"\ + "fmla v11.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-48]\n\t"\ + "fmla v12.4s,v1.4s,v3.s[3]; ldr x20,[x3],#8\n\t"\ + "fmla v13.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-40]\n\t"\ + "fmla v14.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-32]\n\t"\ + "fmla v15.4s,v1.4s,v4.s[2]; sub w5,w5,#2\n\t"\ + "fmla v16.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10\n\t"\ + "fmla v17.4s,v1.4s,v2.s[0]; bfi x11,x20,#32,#32\n\t"\ + "fmla v18.4s,v1.4s,v2.s[1]; cmp w5,#6\n\t"\ + "fmla v19.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-24]; fmov v0.d[1],x11\n\t"\ + "fmla v20.4s,v1.4s,v2.s[3]; ldr x10,[x4,#-16]\n\t"\ + "fmla v21.4s,v1.4s,v3.s[0]; prfm pldl1keep,[x4,#144]\n\t"\ + "fmla v22.4s,v1.4s,v3.s[1]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-8]\n\t"\ + "fmla v23.4s,v1.4s,v3.s[2]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v24.4s,v1.4s,v3.s[3]; add x4,x4,#200\n\t"\ + "fmla v25.4s,v1.4s,v4.s[0]\n\t"\ + "ldr d2,[x4,#-200]\n\t"\ + "fmla v26.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-192]\n\t"\ + "fmla v27.4s,v1.4s,v4.s[2]; prfm pldl1keep,[x4]\n\t"\ + "fmla v28.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-184]\n\t"\ + "fmla v29.4s,v1.4s,v5.s[0]; ldr x10,[x4,#-176]\n\t"\ + "fmla v30.4s,v1.4s,v5.s[1]\n\t" + +#define KERNEL_M4N25_TAIL2 \ + "fmov v3.d[1],x10; ldr d4,[x4,#-168]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-160]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]; bfxil x17,x16,#32,#32\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d1,x17\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]; bfxil x20,x19,#32,#32\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "fmov v1.d[1],x20; ldr d2,[x4,#-152]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-144]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x6]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-136]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-128]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-120]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-112]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-104]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-96]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x7]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-88]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-80]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-72]\n\t"\ + "fmla v30.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-64]\n\t"\ + "fmla v6.4s,v1.4s,v2.s[1]; prfm pldl1keep,[x8]\n\t"\ + "fmla v7.4s,v1.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v8.4s,v1.4s,v2.s[3]\n\t"\ + "fmla v9.4s,v1.4s,v3.s[0]\n\t"\ + "fmla v10.4s,v1.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-56]\n\t"\ + "fmla v11.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-48]\n\t"\ + "fmla v12.4s,v1.4s,v3.s[3]\n\t"\ + "fmla v13.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-40]\n\t"\ + "fmla v14.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-32]\n\t"\ + "fmla v15.4s,v1.4s,v4.s[2]; sub w5,w5,#2\n\t"\ + "fmla v16.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10\n\t"\ + "fmla v17.4s,v1.4s,v2.s[0]\n\t"\ + "fmla v18.4s,v1.4s,v2.s[1]\n\t"\ + "fmla v19.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-24]\n\t"\ + "fmla v20.4s,v1.4s,v2.s[3]; ldr x10,[x4,#-16]\n\t"\ + "fmla v21.4s,v1.4s,v3.s[0]\n\t"\ + "fmla v22.4s,v1.4s,v3.s[1]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-8]\n\t"\ + "fmla v23.4s,v1.4s,v3.s[2]; prfm pldl1keep,[x9]\n\t"\ + "fmla v24.4s,v1.4s,v3.s[3]\n\t"\ + "fmla v25.4s,v1.4s,v4.s[0]\n\t"\ + "fmla v26.4s,v1.4s,v4.s[1]\n\t"\ + "fmla v27.4s,v1.4s,v4.s[2]\n\t"\ + "fmla v28.4s,v1.4s,v4.s[3]\n\t"\ + "fmla v29.4s,v1.4s,v5.s[0]\n\t"\ + "fmla v30.4s,v1.4s,v5.s[1]\n\t" + +#define KERNEL_M4N25_FIN1 \ + "ldr w16,[x0],#4; ldr q2,[x4]\n\t"\ + "ldr w17,[x1],#4; ldr d3,[x4,#16]\n\t"\ + "ldr w19,[x2],#4; ldr x10,[x4,#24]\n\t"\ + "ldr w20,[x3],#4; orr x16,x16,x17,LSL #32\n\t"\ + "fmov d0,x16; orr x19,x19,x20,LSL #32; fmov v0.d[1],x19\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#32]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#40]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#48]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#56]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#64]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#72]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#80]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#88]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]; add x4,x4,#100\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr s2,[x4,#-4]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t"\ + "fmla v30.4s,v0.4s,v2.s[0]\n\t" + + +#define INIT_M4N26 INIT_4V(6, 7, 8, 9) \ + INIT_4V(10, 11, 12, 13) INIT_4V(14, 15, 16, 17)\ + INIT_4V(18, 19, 20, 21) INIT_4V(22, 23, 24, 25)\ + INIT_4V(26, 27, 28, 29) INIT_2V(30, 31) + +#define SAVE_M4N26(mode) \ + UNIT_SAVE_M4N4_VC_##mode(6, 7, 8, 9) UNIT_SAVE_M4N4_VC_##mode(10, 11, 12, 13)\ + UNIT_SAVE_M4N4_VC_##mode(14, 15, 16, 17) UNIT_SAVE_M4N4_VC_##mode(18, 19, 20, 21)\ + UNIT_SAVE_M4N4_VC_##mode(22, 23, 24, 25) UNIT_SAVE_M4N4_VC_##mode(26, 27, 28, 29)\ + EDGE_SAVE_M4N1K1_##mode(30) EDGE_SAVE_M4N1K1_##mode(31) + +#define KERNEL_M4N26_PRELOAD2 \ + "ldr x16,[x0],#8; ldr x17,[x1],#8; ldr x19,[x2],#8; ldr x20,[x3],#8\n\t"\ + "ldr q2,[x4]; ldr q3,[x4,#16]; ldr x10,[x4,#24]; add x4,x4,#208\n\t"\ + "mov w11,w16; bfi x11,x17,#32,#32; fmov d0,x11\n\t"\ + "mov w11,w19; bfi x11,x20,#32,#32; fmov v0.d[1],x11\n\t" + +#define KERNEL_M4N26_MAIN2(ap1, ap2) \ + "fmov v3.d[1],x10; ldr d4,[x4,#-176]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-168]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]; bfxil x17,x16,#32,#32\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d1,x17\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]; ldr x16,[x0],#8\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]; bfxil x20,x19,#32,#32\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "fmov v1.d[1],x20; ldr d2,[x4,#-160]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-152]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x4,#48]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-144]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-136]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]; ldr x17,[x1],#8\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-128]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]; mov w11,w16\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]; bfi x11,x17,#32,#32\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-112]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-104]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x"#ap1",#64]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-96]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]; ldr x19,[x2],#8\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-80]\n\t"\ + "fmla v30.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v31.4s,v0.4s,v2.s[1]; prfm pldl1keep,[x4,#96]\n\t"\ + "fmla v6.4s,v1.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d0,x11\n\t"\ + "fmla v7.4s,v1.4s,v2.s[3]; mov w11,w19\n\t"\ + "fmla v8.4s,v1.4s,v3.s[0]; prfm pldl1keep,[x"#ap2",#64]\n\t"\ + "fmla v9.4s,v1.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-64]\n\t"\ + "fmla v10.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v11.4s,v1.4s,v3.s[3]; ldr x20,[x3],#8\n\t"\ + "fmla v12.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-48]\n\t"\ + "fmla v13.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v14.4s,v1.4s,v4.s[2]; sub w5,w5,#2\n\t"\ + "fmla v15.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10\n\t"\ + "fmla v16.4s,v1.4s,v2.s[0]; bfi x11,x20,#32,#32\n\t"\ + "fmla v17.4s,v1.4s,v2.s[1]; cmp w5,#6\n\t"\ + "fmla v18.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-32]; fmov v0.d[1],x11\n\t"\ + "fmla v19.4s,v1.4s,v2.s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v20.4s,v1.4s,v3.s[0]; prfm pldl1keep,[x4,#144]\n\t"\ + "fmla v21.4s,v1.4s,v3.s[1]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-16]\n\t"\ + "fmla v22.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-8]\n\t"\ + "fmla v23.4s,v1.4s,v3.s[3]; add x4,x4,#208\n\t"\ + "fmla v24.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v5.d[1],x10; ldr d2,[x4,#-208]\n\t"\ + "fmla v25.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-200]\n\t"\ + "fmla v26.4s,v1.4s,v4.s[2]; prfm pldl1keep,[x4]\n\t"\ + "fmla v27.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v2.d[1],x10\n\t"\ + "fmla v28.4s,v1.4s,v5.s[0]\n\t"\ + "fmla v29.4s,v1.4s,v5.s[1]\n\t"\ + "ldr d3,[x4,#-192]\n\t"\ + "fmla v30.4s,v1.4s,v5.s[2]; ldr x10,[x4,#-184]\n\t"\ + "fmla v31.4s,v1.4s,v5.s[3]\n\t" + +#define KERNEL_M4N26_TAIL2 \ + "fmov v3.d[1],x10; ldr d4,[x4,#-176]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-168]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]; bfxil x17,x16,#32,#32\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10; fmov d1,x17\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]; bfxil x20,x19,#32,#32\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "fmov v1.d[1],x20; ldr d2,[x4,#-160]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-152]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x6]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-144]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-136]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-128]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-120]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-112]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]; ldr x10,[x4,#-104]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]; prfm pldl1keep,[x7]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-96]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]; ldr x10,[x4,#-88]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#-80]\n\t"\ + "fmla v30.4s,v0.4s,v2.s[0]; ldr x10,[x4,#-72]\n\t"\ + "fmla v31.4s,v0.4s,v2.s[1]; prfm pldl1keep,[x8]\n\t"\ + "fmla v6.4s,v1.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v7.4s,v1.4s,v2.s[3]\n\t"\ + "fmla v8.4s,v1.4s,v3.s[0]; prfm pldl1keep,[x9]\n\t"\ + "fmla v9.4s,v1.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-64]\n\t"\ + "fmla v10.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-56]\n\t"\ + "fmla v11.4s,v1.4s,v3.s[3]\n\t"\ + "fmla v12.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#-48]\n\t"\ + "fmla v13.4s,v1.4s,v4.s[1]; ldr x10,[x4,#-40]\n\t"\ + "fmla v14.4s,v1.4s,v4.s[2]; sub w5,w5,#2\n\t"\ + "fmla v15.4s,v1.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10\n\t"\ + "fmla v16.4s,v1.4s,v2.s[0]\n\t"\ + "fmla v17.4s,v1.4s,v2.s[1]\n\t"\ + "fmla v18.4s,v1.4s,v2.s[2]\n\t"\ + "ldr d4,[x4,#-32]\n\t"\ + "fmla v19.4s,v1.4s,v2.s[3]; ldr x10,[x4,#-24]\n\t"\ + "fmla v20.4s,v1.4s,v3.s[0]\n\t"\ + "fmla v21.4s,v1.4s,v3.s[1]\n\t"\ + "fmov v4.d[1],x10; ldr d5,[x4,#-16]\n\t"\ + "fmla v22.4s,v1.4s,v3.s[2]; ldr x10,[x4,#-8]\n\t"\ + "fmla v23.4s,v1.4s,v3.s[3]\n\t"\ + "fmla v24.4s,v1.4s,v4.s[0]\n\t"\ + "fmov v5.d[1],x10\n\t"\ + "fmla v25.4s,v1.4s,v4.s[1]\n\t"\ + "fmla v26.4s,v1.4s,v4.s[2]\n\t"\ + "fmla v27.4s,v1.4s,v4.s[3]\n\t"\ + "fmla v28.4s,v1.4s,v5.s[0]\n\t"\ + "fmla v29.4s,v1.4s,v5.s[1]\n\t"\ + "fmla v30.4s,v1.4s,v5.s[2]\n\t"\ + "fmla v31.4s,v1.4s,v5.s[3]\n\t" + +#define KERNEL_M4N26_FIN1 \ + "ldr w16,[x0],#4; ldr q2,[x4]\n\t"\ + "ldr w17,[x1],#4; ldr d3,[x4,#16]\n\t"\ + "ldr w19,[x2],#4; ldr x10,[x4,#24]\n\t"\ + "ldr w20,[x3],#4; orr x16,x16,x17,LSL #32\n\t"\ + "fmov d0,x16; orr x19,x19,x20,LSL #32; fmov v0.d[1],x19\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#32]\n\t"\ + "fmla v6.4s,v0.4s,v2.s[0]; ldr x10,[x4,#40]\n\t"\ + "fmla v7.4s,v0.4s,v2.s[1]\n\t"\ + "fmla v8.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v9.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v10.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v11.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#48]\n\t"\ + "fmla v12.4s,v0.4s,v3.s[2]; ldr x10,[x4,#56]\n\t"\ + "fmla v13.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v14.4s,v0.4s,v4.s[0]\n\t"\ + "fmov v2.d[1],x10; ldr d3,[x4,#64]\n\t"\ + "fmla v15.4s,v0.4s,v4.s[1]; ldr x10,[x4,#72]\n\t"\ + "fmla v16.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v17.4s,v0.4s,v4.s[3]\n\t"\ + "fmov v3.d[1],x10; ldr d4,[x4,#80]\n\t"\ + "fmla v18.4s,v0.4s,v2.s[0]; ldr x10,[x4,#88]\n\t"\ + "fmla v19.4s,v0.4s,v2.s[1]; add x4,x4,#104\n\t"\ + "fmla v20.4s,v0.4s,v2.s[2]\n\t"\ + "fmov v4.d[1],x10\n\t"\ + "fmla v21.4s,v0.4s,v2.s[3]\n\t"\ + "fmla v22.4s,v0.4s,v3.s[0]\n\t"\ + "fmla v23.4s,v0.4s,v3.s[1]\n\t"\ + "ldr d2,[x4,#-8]\n\t"\ + "fmla v24.4s,v0.4s,v3.s[2]\n\t"\ + "fmla v25.4s,v0.4s,v3.s[3]\n\t"\ + "fmla v26.4s,v0.4s,v4.s[0]\n\t"\ + "fmla v27.4s,v0.4s,v4.s[1]\n\t"\ + "fmla v28.4s,v0.4s,v4.s[2]\n\t"\ + "fmla v29.4s,v0.4s,v4.s[3]\n\t"\ + "fmla v30.4s,v0.4s,v2.s[0]\n\t"\ + "fmla v31.4s,v0.4s,v2.s[1]\n\t" + +#define FUNC_K1(ndim) \ +static inline void sgemm_skinny1_a53_m4n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_scr,\ + float * __restrict__ c_ptr, uint32_t K, uint32_t LDA, uint32_t LDC,\ + uint8_t c_rowmajor, const float * __restrict__ beta_addr) {\ + __asm__ __volatile__ (\ + "mov x0,%[a_ptr]; add x1,%[a_ptr],%w[LDA],UXTW #2\n\t"\ + "add x2,%[a_ptr],%w[LDA],UXTW #3; add x3,x1,%w[LDA],UXTW #3\n\t"\ + "add x6,x0,%w[LDA],UXTW #4; add x7,x1,%w[LDA],UXTW #4\n\t"\ + "add x8,x2,%w[LDA],UXTW #4; add x9,x3,%w[LDA],UXTW #4\n\t"\ + "mov x4,%[b_scr]; mov w5,%w[K]\n\t"\ + INIT_M4N##ndim\ + "cmp w5,#2; b.lt 4f\n\t"\ + KERNEL_M4N##ndim##_PRELOAD2\ + "cmp w5,#6; b.lt 2f\n\t"\ + ".balign 16; 1:\n\t"\ + KERNEL_M4N##ndim##_MAIN2(0, 1)\ + KERNEL_M4N##ndim##_MAIN2(2, 3)\ + "b.ge 1b; 2:\n\t"\ + "cmp w5,#4; b.lt 3f\n\t"\ + KERNEL_M4N##ndim##_MAIN2(0, 1)\ + KERNEL_M4N##ndim##_TAIL2\ + "b 4f; 3:\n\t"\ + KERNEL_M4N##ndim##_TAIL2\ + "4:\n\t"\ + "cmp w5,#1; b.lt 6f\n\t"\ + "5:\n\t"\ + KERNEL_M4N##ndim##_FIN1\ + "6:\n\t"\ + INIT_SAVE\ + "cmp %w[c_rowmajor],#0; b.eq 7f\n\t"\ + SAVE_M4N##ndim(CR) "b 8f\n\t"\ + "7:\n\t"\ + SAVE_M4N##ndim(CC)\ + "8:\n\t"\ + ::[a_ptr]"r"(a_ptr), [c_ptr]"r"(c_ptr), [b_scr]"r"(b_scr),\ + [K]"r"(K), [LDA]"r"(LDA), [LDC]"r"(LDC),\ + [beta_addr]"r"(beta_addr), [c_rowmajor]"r"(c_rowmajor)\ + :"cc","memory","x0","x1","x2","x3","x4","x5","x6","x7","x8","x9",\ + "x10","x11","x12","x13","x14","x15","x16","x17","x19","x20",\ + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","v12","v13",\ + "v14","v15","v16","v17","v18","v19","v20","v21","v22","v23","v24","v25",\ + "v26","v27","v28","v29","v30","v31");\ +} + +FUNC_K1(23) +FUNC_K1(24) +FUNC_K1(25) +FUNC_K1(26) + +#define INIT_M1N4 \ + float32x4_t cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = vdupq_n_f32(0.0f); + +#define INIT_M1N5 INIT_M1N4 float32x4_t cq5 = cq1; + +#define INIT_M1N6 INIT_M1N5 float32x4_t cq6 = cq1; + +#define INIT_M1N7 INIT_M1N6 float32x4_t cq7 = cq1; + +#define INIT_M1N8 \ + float32x4_t cq1, cq2; cq1 = cq2 = vdupq_n_f32(0.0f); + +#define INIT_M1N9 INIT_M1N8 float32x4_t cq3 = cq1; + +#define INIT_M1N10 INIT_M1N9 float32x4_t cq4 = cq1; + +#define INIT_M1N11 INIT_M1N10 float32x4_t cq5 = cq1; + +#define INIT_M1N12 \ + float32x4_t cq1, cq2, cq3; cq1 = cq2 = cq3 = vdupq_n_f32(0.0f); + +#define INIT_M1N13 INIT_M1N12 float32x4_t cq4 = cq1; + +#define INIT_M1N14 INIT_M1N13 float32x4_t cq5 = cq2; + +#define ACC_K4M1N4 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 1);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 2);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 3); + +#define ACC_K4M1N5 ACC_K4M1N4 \ + bq1 = vld1q_f32(b_rd + 16); cq5 = vfmaq_f32(cq5, aq1, bq1); + +#define ACC_K4M1N6 ACC_K4M1N5 \ + bq2 = vld1q_f32(b_rd + 20); cq6 = vfmaq_f32(cq6, aq1, bq2); + +#define ACC_K4M1N7 ACC_K4M1N6 \ + bq3 = vld1q_f32(b_rd + 24); cq7 = vfmaq_f32(cq7, aq1, bq3); + +#define ACC_K4M1N8 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0); bq1 = vld1q_f32(b_rd + 8);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 0); bq2 = vld1q_f32(b_rd + 12);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 1); bq1 = vld1q_f32(b_rd + 16);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 1); bq2 = vld1q_f32(b_rd + 20);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 2); bq1 = vld1q_f32(b_rd + 24);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 2); bq2 = vld1q_f32(b_rd + 28);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 3);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 3); + +#define ACC_K4M1N9 ACC_K4M1N8 \ + bq1 = vld1q_f32(b_rd + 32); cq3 = vfmaq_f32(cq3, bq1, aq1); + +#define ACC_K4M1N10 ACC_K4M1N9 \ + bq2 = vld1q_f32(b_rd + 36); cq4 = vfmaq_f32(cq4, bq2, aq1); + +#define ACC_K4M1N11 ACC_K4M1N10 \ + bq1 = vld1q_f32(b_rd + 40); cq5 = vfmaq_f32(cq5, bq1, aq1); + +#define ACC_K4M1N12 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0); bq1 = vld1q_f32(b_rd + 12);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 0); bq2 = vld1q_f32(b_rd + 16);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 0); bq3 = vld1q_f32(b_rd + 20);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 1); bq1 = vld1q_f32(b_rd + 24);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 1); bq2 = vld1q_f32(b_rd + 28);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 1); bq3 = vld1q_f32(b_rd + 32);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 2); bq1 = vld1q_f32(b_rd + 36);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 2); bq2 = vld1q_f32(b_rd + 40);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 2); bq3 = vld1q_f32(b_rd + 44);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 3);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 3);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 3); + +#define ACC_K4M1N13 ACC_K4M1N12 \ + bq1 = vld1q_f32(b_rd + 48); cq4 = vfmaq_f32(cq4, bq1, aq1); + +#define ACC_K4M1N14 ACC_K4M1N13 \ + bq2 = vld1q_f32(b_rd + 52); cq5 = vfmaq_f32(cq5, bq2, aq1); + +#define REDUC_N4 \ + cq1 = vaddq_f32(cq1, cq2); cq3 = vaddq_f32(cq3, cq4);\ + cq1 = vaddq_f32(cq1, cq3); + +#define REDUC_N5 REDUC_N4 \ + float32x2_t cd1 = vadd_f32(vget_low_f32(cq5), vget_high_f32(cq5));\ + float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + +#define REDUC_N6 REDUC_N5 \ + float32x2_t cd2 = vadd_f32(vget_low_f32(cq6), vget_high_f32(cq6));\ + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1); + +#define REDUC_N7 REDUC_N6 \ + float32x2_t cd3 = vadd_f32(vget_low_f32(cq7), vget_high_f32(cq7));\ + float cs3 = vget_lane_f32(cd3, 0) + vget_lane_f32(cd3, 1); + +#define REDUC_N8 {} + +#define REDUC_N9 \ + float32x2_t cd1 = vadd_f32(vget_low_f32(cq3), vget_high_f32(cq3));\ + float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + +#define REDUC_N10 REDUC_N9 \ + float32x2_t cd2 = vadd_f32(vget_low_f32(cq4), vget_high_f32(cq4));\ + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1); + +#define REDUC_N11 REDUC_N10 \ + float32x2_t cd3 = vadd_f32(vget_low_f32(cq5), vget_high_f32(cq5));\ + float cs3 = vget_lane_f32(cd3, 0) + vget_lane_f32(cd3, 1); + +#define REDUC_N12 {} + +#define REDUC_N13 \ + float32x2_t cd1 = vadd_f32(vget_low_f32(cq4), vget_high_f32(cq4));\ + float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + +#define REDUC_N14 REDUC_N13 \ + float32x2_t cd2 = vadd_f32(vget_low_f32(cq5), vget_high_f32(cq5));\ + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1); + +#define ACC_K1M1N4 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1); + +#define ACC_K1M1N5 ACC_K1M1N4 cs1 += as1 * b_rd[4]; + +#define ACC_K1M1N6 ACC_K1M1N5 cs2 += as1 * b_rd[5]; + +#define ACC_K1M1N7 ACC_K1M1N6 cs3 += as1 * b_rd[6]; + +#define ACC_K1M1N8 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1);\ + cq2 = vfmaq_n_f32(cq2, bq2, as1); + +#define ACC_K1M1N9 ACC_K1M1N8 cs1 += as1 * b_rd[8]; + +#define ACC_K1M1N10 ACC_K1M1N9 cs2 += as1 * b_rd[9]; + +#define ACC_K1M1N11 ACC_K1M1N10 cs3 += as1 * b_rd[10]; + +#define ACC_K1M1N12 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1);\ + cq2 = vfmaq_n_f32(cq2, bq2, as1);\ + cq3 = vfmaq_n_f32(cq3, bq3, as1); + +#define ACC_K1M1N13 ACC_K1M1N12 cs1 += as1 * b_rd[12]; + +#define ACC_K1M1N14 ACC_K1M1N13 cs2 += as1 * b_rd[13]; + +#define UNIT_SAVE_M1N4_CC(cq1) \ + c_ptr[0] = c_ptr[0] * beta + vgetq_lane_f32(cq1, 0);\ + c_ptr[LDC] = c_ptr[LDC] * beta + vgetq_lane_f32(cq1, 1);\ + c_ptr += LDC * 2;\ + c_ptr[0] = c_ptr[0] * beta + vgetq_lane_f32(cq1, 2);\ + c_ptr[LDC] = c_ptr[LDC] * beta + vgetq_lane_f32(cq1, 3);\ + c_ptr += LDC * 2; + +#define UNIT_SAVE_M1N4_CR(cq1) \ + cq1 = vfmaq_n_f32(cq1, vld1q_f32(c_ptr), beta);\ + vst1q_f32(c_ptr, cq1); c_ptr += 4; + +#define UNIT_SAVE_M1N1_CC(cs1) \ + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr += LDC; + +#define UNIT_SAVE_M1N1_CR(cs1) \ + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr++; + +#define SAVE_M1N4(mode) UNIT_SAVE_M1N4_##mode(cq1) + +#define SAVE_M1N5(mode) SAVE_M1N4(mode) UNIT_SAVE_M1N1_##mode(cs1) + +#define SAVE_M1N6(mode) SAVE_M1N5(mode) UNIT_SAVE_M1N1_##mode(cs2) + +#define SAVE_M1N7(mode) SAVE_M1N6(mode) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N8(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2) + +#define SAVE_M1N9(mode) SAVE_M1N8(mode) UNIT_SAVE_M1N1_##mode(cs1) + +#define SAVE_M1N10(mode) SAVE_M1N9(mode) UNIT_SAVE_M1N1_##mode(cs2) + +#define SAVE_M1N11(mode) SAVE_M1N10(mode) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N12(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2) UNIT_SAVE_M1N4_##mode(cq3) + +#define SAVE_M1N13(mode) SAVE_M1N12(mode) UNIT_SAVE_M1N1_##mode(cs1) + +#define SAVE_M1N14(mode) SAVE_M1N13(mode) UNIT_SAVE_M1N1_##mode(cs2) + +#define FUNC_EDGE_K4(ndim) \ +static inline void sgemm_skinny1_a53_m1n##ndim(\ + const float * __restrict__ a_rd, const float * __restrict__ b_rd,\ + float * __restrict__ c_ptr, uint32_t k_left, uint32_t LDC,\ + uint8_t c_rowmajor, float beta) {\ + INIT_M1N##ndim\ + for (; k_left > 3; k_left -= 4) {\ + ACC_K4M1N##ndim b_rd += ndim * 4;\ + }\ + REDUC_N##ndim\ + for (; k_left > 0; k_left--) {\ + ACC_K1M1N##ndim b_rd += ndim;\ + }\ + if (c_rowmajor == 0) {\ + SAVE_M1N##ndim(CC)\ + } else {\ + SAVE_M1N##ndim(CR)\ + }\ +} + +FUNC_EDGE_K4(4) +FUNC_EDGE_K4(5) +FUNC_EDGE_K4(6) +FUNC_EDGE_K4(7) +FUNC_EDGE_K4(8) +FUNC_EDGE_K4(9) +FUNC_EDGE_K4(10) +FUNC_EDGE_K4(11) +FUNC_EDGE_K4(12) +FUNC_EDGE_K4(13) +FUNC_EDGE_K4(14) + +#define INIT_M1N15 \ + float32x4_t cq1, cq2, cq3, cq4, cq5, cq6;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = vdupq_n_f32(0.0f);\ + float32x2_t cd1, cd2, cd3;\ + cd1 = cd2 = cd3 = vdup_n_f32(0.0f); + +#define INIT_M1N16 \ + float32x4_t cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = vdupq_n_f32(0.0f); + +#define INIT_M1N17 INIT_M1N16 float32x2_t cd1 = vdup_n_f32(0.0f); + +#define INIT_M1N18 INIT_M1N17 float32x2_t cd2 = vdup_n_f32(0.0f); + +#define INIT_M1N19 INIT_M1N18 float32x2_t cd3 = vdup_n_f32(0.0f); + +#define INIT_M1N20 INIT_M1N16 float32x4_t cq5 = vdupq_n_f32(0.0f); + +#define INIT_M1N21 INIT_M1N20 float32x2_t cd1 = vdup_n_f32(0.0f); + +#define INIT_M1N22 INIT_M1N21 float32x2_t cd2 = vdup_n_f32(0.0f); + +#define ACC_M1N15K2 \ + float32x2_t ad1 = vld1_f32(a_rd); a_rd += 2;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + cq1 = vfmaq_lane_f32(cq1, bq1, ad1, 0); bq1 = vld1q_f32(b_rd + 12);\ + cq2 = vfmaq_lane_f32(cq2, bq2, ad1, 0); bq2 = vld1q_f32(b_rd + 16);\ + cq3 = vfmaq_lane_f32(cq3, bq3, ad1, 0); bq3 = vld1q_f32(b_rd + 20);\ + cq4 = vfmaq_lane_f32(cq4, bq1, ad1, 1); float32x2_t bd1 = vld1_f32(b_rd + 24);\ + cq5 = vfmaq_lane_f32(cq5, bq2, ad1, 1); float32x2_t bd2 = vld1_f32(b_rd + 26);\ + cq6 = vfmaq_lane_f32(cq6, bq3, ad1, 1); float32x2_t bd3 = vld1_f32(b_rd + 28);\ + cd1 = vfma_f32(cd1, ad1, bd1);\ + cd2 = vfma_f32(cd2, ad1, bd2);\ + cd3 = vfma_f32(cd3, ad1, bd3); + +#define ACC_M1N16K2 \ + float32x2_t ad1 = vld1_f32(a_rd); a_rd += 2;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + cq1 = vfmaq_lane_f32(cq1, bq1, ad1, 0); bq1 = vld1q_f32(b_rd + 16);\ + cq2 = vfmaq_lane_f32(cq2, bq2, ad1, 0); bq2 = vld1q_f32(b_rd + 20);\ + cq3 = vfmaq_lane_f32(cq3, bq3, ad1, 0); bq3 = vld1q_f32(b_rd + 24);\ + cq4 = vfmaq_lane_f32(cq4, bq4, ad1, 0); bq4 = vld1q_f32(b_rd + 28);\ + cq1 = vfmaq_lane_f32(cq1, bq1, ad1, 1);\ + cq2 = vfmaq_lane_f32(cq2, bq2, ad1, 1);\ + cq3 = vfmaq_lane_f32(cq3, bq3, ad1, 1);\ + cq4 = vfmaq_lane_f32(cq4, bq4, ad1, 1); + +#define ACC_M1N17K2 ACC_M1N16K2 \ + float32x2_t bd1 = vld1_f32(b_rd + 32);\ + cd1 = vfma_f32(cd1, ad1, bd1); + +#define ACC_M1N18K2 ACC_M1N17K2 \ + bd1 = vld1_f32(b_rd + 34); cd2 = vfma_f32(cd2, ad1, bd1); + +#define ACC_M1N19K2 ACC_M1N18K2 \ + bd1 = vld1_f32(b_rd + 36); cd3 = vfma_f32(cd3, ad1, bd1); + +#define ACC_M1N20K2 \ + float32x2_t ad1 = vld1_f32(a_rd); a_rd += 2;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + cq1 = vfmaq_lane_f32(cq1, bq1, ad1, 0); bq1 = vld1q_f32(b_rd + 20);\ + cq2 = vfmaq_lane_f32(cq2, bq2, ad1, 0); bq2 = vld1q_f32(b_rd + 24);\ + cq3 = vfmaq_lane_f32(cq3, bq3, ad1, 0); bq3 = vld1q_f32(b_rd + 28);\ + cq4 = vfmaq_lane_f32(cq4, bq4, ad1, 0); bq4 = vld1q_f32(b_rd + 32);\ + cq5 = vfmaq_lane_f32(cq5, bq5, ad1, 0); bq5 = vld1q_f32(b_rd + 36);\ + cq1 = vfmaq_lane_f32(cq1, bq1, ad1, 1);\ + cq2 = vfmaq_lane_f32(cq2, bq2, ad1, 1);\ + cq3 = vfmaq_lane_f32(cq3, bq3, ad1, 1);\ + cq4 = vfmaq_lane_f32(cq4, bq4, ad1, 1);\ + cq5 = vfmaq_lane_f32(cq5, bq5, ad1, 1); + +#define ACC_M1N21K2 ACC_M1N20K2 \ + float32x2_t bd1 = vld1_f32(b_rd + 40); cd1 = vfma_f32(cd1, ad1, bd1); + +#define ACC_M1N22K2 ACC_M1N21K2 \ + float32x2_t bd2 = vld1_f32(b_rd + 42); cd2 = vfma_f32(cd2, ad1, bd2); + +#define REDUC_M1N15 \ + cq1 = vaddq_f32(cq1, cq4); cq2 = vaddq_f32(cq2, cq5); cq3 = vaddq_f32(cq3, cq6);\ + float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1);\ + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1);\ + float cs3 = vget_lane_f32(cd3, 0) + vget_lane_f32(cd3, 1); + +#define REDUC_M1N16 {} + +#define REDUC_M1N17 float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + +#define REDUC_M1N18 REDUC_M1N17 \ + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1); + +#define REDUC_M1N19 REDUC_M1N18 \ + float cs3 = vget_lane_f32(cd3, 0) + vget_lane_f32(cd3, 1); + +#define REDUC_M1N20 {} + +#define REDUC_M1N21 float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + +#define REDUC_M1N22 REDUC_M1N21 \ + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1); + +#define ACC_M1N15K1 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1); float bs1 = b_rd[12];\ + cq2 = vfmaq_n_f32(cq2, bq2, as1); float bs2 = b_rd[13];\ + cq3 = vfmaq_n_f32(cq3, bq3, as1); float bs3 = b_rd[14];\ + cs1 += as1 * bs1; cs2 += as1 * bs2; cs3 += as1 * bs3; + +#define ACC_M1N16K1 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1); cq2 = vfmaq_n_f32(cq2, bq2, as1);\ + cq3 = vfmaq_n_f32(cq3, bq3, as1); cq4 = vfmaq_n_f32(cq4, bq4, as1); + +#define ACC_M1N17K1 ACC_M1N16K1 cs1 += as1 * b_rd[16]; + +#define ACC_M1N18K1 ACC_M1N17K1 cs2 += as1 * b_rd[17]; + +#define ACC_M1N19K1 ACC_M1N18K1 cs3 += as1 * b_rd[18]; + +#define ACC_M1N20K1 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1); cq2 = vfmaq_n_f32(cq2, bq2, as1);\ + cq3 = vfmaq_n_f32(cq3, bq3, as1); cq4 = vfmaq_n_f32(cq4, bq4, as1);\ + cq5 = vfmaq_n_f32(cq5, bq5, as1); + +#define ACC_M1N21K1 ACC_M1N20K1 cs1 += as1 * b_rd[20]; + +#define ACC_M1N22K1 ACC_M1N21K1 cs2 += as1 * b_rd[21]; + +#define SAVE_M1N15(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2) UNIT_SAVE_M1N4_##mode(cq3)\ + UNIT_SAVE_M1N1_##mode(cs1) UNIT_SAVE_M1N1_##mode(cs2) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N16(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2)\ + UNIT_SAVE_M1N4_##mode(cq3) UNIT_SAVE_M1N4_##mode(cq4) + +#define SAVE_M1N17(mode) SAVE_M1N16(mode) UNIT_SAVE_M1N1_##mode(cs1) + +#define SAVE_M1N18(mode) SAVE_M1N17(mode) UNIT_SAVE_M1N1_##mode(cs2) + +#define SAVE_M1N19(mode) SAVE_M1N18(mode) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N20(mode) SAVE_M1N16(mode) UNIT_SAVE_M1N4_##mode(cq5) + +#define SAVE_M1N21(mode) SAVE_M1N20(mode) UNIT_SAVE_M1N1_##mode(cs1) + +#define SAVE_M1N22(mode) SAVE_M1N21(mode) UNIT_SAVE_M1N1_##mode(cs2) + +#define FUNC_EDGE_K2(ndim) \ +static inline void sgemm_skinny1_a53_m1n##ndim(\ + const float * __restrict__ a_rd, const float * __restrict__ b_rd,\ + float * __restrict__ c_ptr, uint32_t k_left, uint32_t LDC,\ + uint8_t c_rowmajor, float beta) {\ + INIT_M1N##ndim\ + for (; k_left > 1; k_left -= 2) {\ + ACC_M1N##ndim##K2 b_rd += ndim * 2;\ + }\ + REDUC_M1N##ndim\ + for (; k_left > 0; k_left--) {\ + ACC_M1N##ndim##K1 b_rd += ndim;\ + }\ + if (c_rowmajor == 0) {\ + SAVE_M1N##ndim(CC)\ + } else {\ + SAVE_M1N##ndim(CR)\ + }\ +} + +FUNC_EDGE_K2(15) +FUNC_EDGE_K2(16) +FUNC_EDGE_K2(17) +FUNC_EDGE_K2(18) +FUNC_EDGE_K2(19) +FUNC_EDGE_K2(20) +FUNC_EDGE_K2(21) +FUNC_EDGE_K2(22) + +#define INIT_M1N23 INIT_M1N20 \ + float cs1 = 0.0f, cs2 = 0.0f, cs3 = 0.0f; + +#define INIT_M1N24 INIT_M1N20 float32x4_t cq6 = vdupq_n_f32(0.0f); + +#define INIT_M1N25 INIT_M1N24 float cs1 = 0.0f; + +#define INIT_M1N26 INIT_M1N25 float cs2 = 0.0f; + +#define ACC_M1N23K1 ACC_M1N20K1 \ + cs1 += as1 * b_rd[20]; cs2 += as1 * b_rd[21]; cs3 += as1 * b_rd[22]; + +#define ACC_M1N24K1 ACC_M1N20K1 \ + float32x4_t bq6 = vld1q_f32(b_rd + 20);\ + cq6 = vfmaq_n_f32(cq6, bq6, as1); + +#define ACC_M1N25K1 ACC_M1N24K1 cs1 += as1 * b_rd[24]; + +#define ACC_M1N26K1 ACC_M1N25K1 cs2 += as1 * b_rd[25]; + +#define SAVE_M1N23(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2) UNIT_SAVE_M1N4_##mode(cq3)\ + UNIT_SAVE_M1N4_##mode(cq4) UNIT_SAVE_M1N4_##mode(cq5)\ + UNIT_SAVE_M1N1_##mode(cs1) UNIT_SAVE_M1N1_##mode(cs2) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N24(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2) UNIT_SAVE_M1N4_##mode(cq3)\ + UNIT_SAVE_M1N4_##mode(cq4) UNIT_SAVE_M1N4_##mode(cq5) UNIT_SAVE_M1N4_##mode(cq6) + +#define SAVE_M1N25(mode) SAVE_M1N24(mode) UNIT_SAVE_M1N1_##mode(cs1) + +#define SAVE_M1N26(mode) SAVE_M1N25(mode) UNIT_SAVE_M1N1_##mode(cs2) + +#define FUNC_EDGE_K1(ndim) \ +static inline void sgemm_skinny1_a53_m1n##ndim(\ + const float * __restrict__ a_rd, const float * __restrict__ b_rd,\ + float * __restrict__ c_ptr, uint32_t k_left, uint32_t LDC,\ + uint8_t c_rowmajor, float beta) {\ + INIT_M1N##ndim\ + for (; k_left > 0; k_left--) {\ + ACC_M1N##ndim##K1 b_rd += ndim;\ + }\ + if (c_rowmajor == 0) {\ + SAVE_M1N##ndim(CC)\ + } else {\ + SAVE_M1N##ndim(CR)\ + }\ +} + +FUNC_EDGE_K1(23) +FUNC_EDGE_K1(24) +FUNC_EDGE_K1(25) +FUNC_EDGE_K1(26) + +#endif diff --git a/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA7x.h b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA7x.h new file mode 100644 index 0000000..016ce73 --- /dev/null +++ b/include/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA7x.h @@ -0,0 +1,2556 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +#ifndef INCLUDE_A7X_KERNEL +#define INCLUDE_A7X_KERNEL + +/* x0 - x3 for a_ptrs */ +/* x4 for b_ptr, x5 for k_left */ +/* x6 - x9 for a_pref */ +/* x12 - x15 for c_tmp */ + +#define INIT_1V(c1) "movi v"#c1".16b,#0\n\t" + +#define INIT_2V(c1, c2) INIT_1V(c1) INIT_1V(c2) + +#define INIT_4V(c1, c2, c3, c4) INIT_2V(c1, c2) INIT_2V(c3, c4) + +#define INIT_SAVE \ + "ldr s0,[%[beta_addr]]; mov x12,%[c_ptr]\n\t"\ + "add x13,%[c_ptr],%w[LDC],UXTW #2; add x14,%[c_ptr],%w[LDC],UXTW #3\n\t"\ + "add x15,x13,%w[LDC],UXTW #3\n\t" + +#define UNIT_SAVE_M4N4_CC(c1, c2, c3, c4) \ + "trn1 v1.4s,v"#c1".4s,v"#c2".4s; trn1 v2.4s,v"#c3".4s,v"#c4".4s\n\t"\ + "trn2 v3.4s,v"#c1".4s,v"#c2".4s; trn2 v4.4s,v"#c3".4s,v"#c4".4s\n\t"\ + "trn1 v"#c1".2d,v1.2d,v2.2d; trn1 v"#c2".2d,v3.2d,v4.2d\n\t"\ + "trn2 v"#c3".2d,v1.2d,v2.2d; trn2 v"#c4".2d,v3.2d,v4.2d\n\t"\ + "ldr q1,[x12]; ldr q2,[x13]; ldr q3,[x14]; ldr q4,[x15]\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]; fmla v"#c2".4s,v2.4s,v0.s[0]\n\t"\ + "fmla v"#c3".4s,v3.4s,v0.s[0]; fmla v"#c4".4s,v4.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12]; prfm pstl2keep,[x12,#32]; add x12,x12,%w[LDC],UXTW #4\n\t"\ + "str q"#c2",[x13]; prfm pstl2keep,[x13,#32]; add x13,x13,%w[LDC],UXTW #4\n\t"\ + "str q"#c3",[x14]; prfm pstl2keep,[x14,#32]; add x14,x14,%w[LDC],UXTW #4\n\t"\ + "str q"#c4",[x15]; prfm pstl2keep,[x15,#32]; add x15,x15,%w[LDC],UXTW #4\n\t" + +#define EDGE_SAVE_M4N1_CC(c1, c2, c3, c4) \ + "ldr q1,[x12]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c2".4s\n\t"\ + "faddp v"#c3".4s,v"#c3".4s,v"#c4".4s\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c3".4s\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12]; prfm pstl2keep,[x12,#32]\n\t"\ + "add x12,x12,%w[LDC],UXTW #2\n\t" + +#define UNIT_SAVE_M4N4_CR(c1, c2, c3, c4) \ + "ldr q1,[x12]; ldr q2,[x13]; ldr q3,[x14]; ldr q4,[x15]\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]; fmla v"#c2".4s,v2.4s,v0.s[0]\n\t"\ + "fmla v"#c3".4s,v3.4s,v0.s[0]; fmla v"#c4".4s,v4.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12],#16; str q"#c2",[x13],#16\n\t"\ + "str q"#c3",[x14],#16; str q"#c4",[x15],#16\n\t" + +#define EDGE_SAVE_M4N1_CR(c1, c2, c3, c4) \ + "ldr s1,[x12]; ldr s2,[x13]; ldr s3,[x14]; ldr s4,[x15]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c2".4s\n\t"\ + "ins v1.s[1],v2.s[0]; ins v3.s[1],v4.s[0]\n\t"\ + "faddp v"#c3".4s,v"#c3".4s,v"#c4".4s\n\t"\ + "ins v1.d[1],v3.d[0]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c3".4s\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]\n\t"\ + "st1 {v"#c1".s}[0],[x12],#4; st1 {v"#c1".s}[1],[x13],#4\n\t"\ + "st1 {v"#c1".s}[2],[x14],#4; st1 {v"#c1".s}[3],[x15],#4\n\t" + +#define FUNC_M4(ndim) \ +static inline void sgemm_skinny1_a7x_m4n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_scr,\ + float * __restrict__ c_ptr, uint32_t K, uint32_t LDA, uint32_t LDC,\ + uint8_t c_rowmajor, const float * __restrict__ beta_addr) {\ + __asm__ __volatile__ (\ + "mov x0,%[a_ptr]; add x1,%[a_ptr],%w[LDA],UXTW #2\n\t"\ + "add x2,%[a_ptr],%w[LDA],UXTW #3; add x3,x1,%w[LDA],UXTW #3\n\t"\ + "add x6,x0,%w[LDA],UXTW #4; add x7,x1,%w[LDA],UXTW #4\n\t"\ + "add x8,x2,%w[LDA],UXTW #4; add x9,x3,%w[LDA],UXTW #4\n\t"\ + "mov x4,%[b_scr]; mov w5,%w[K]\n\t"\ + INIT_M4N##ndim\ + "cmp w5,#4; b.lt 4f\n\t"\ + KERNEL_M4N##ndim##_PRELOAD4\ + "cmp w5,#20; b.lt 1f\n\t"\ + ".balign 16; 9:\n\t"\ + "prfm pldl2keep,[x6]; add x6,x6,#64\n\t"\ + KERNEL_M4N##ndim##_MAIN4(0, 1, 2, 3, 4, 5, 6, 7)\ + "prfm pldl2keep,[x7]; add x7,x7,#64\n\t"\ + KERNEL_M4N##ndim##_MAIN4(4, 5, 6, 7, 0, 1, 2, 3)\ + "prfm pldl2keep,[x8]; add x8,x8,#64\n\t"\ + KERNEL_M4N##ndim##_MAIN4(0, 1, 2, 3, 4, 5, 6, 7)\ + "prfm pldl2keep,[x9]; add x9,x9,#64\n\t"\ + KERNEL_M4N##ndim##_MAIN4(4, 5, 6, 7, 0, 1, 2, 3)\ + "cmp w5,#20; b.ge 9b; 1:\n\t"\ + "cmp w5,#12; b.lt 2f\n\t"\ + KERNEL_M4N##ndim##_MAIN4(0, 1, 2, 3, 4, 5, 6, 7)\ + KERNEL_M4N##ndim##_MAIN4(4, 5, 6, 7, 0, 1, 2, 3)\ + "2:\n\t"\ + "cmp w5,#8; b.lt 3f\n\t"\ + KERNEL_M4N##ndim##_MAIN4(0, 1, 2, 3, 4, 5, 6, 7)\ + KERNEL_M4N##ndim##_TAIL4(4, 5, 6, 7)\ + "b 4f; 3:\n\t"\ + KERNEL_M4N##ndim##_TAIL4(0, 1, 2, 3)\ + "4:\n\t"\ + "cmp w5,#1; b.lt 6f\n\t"\ + "5:\n\t"\ + KERNEL_M4N##ndim##_TL1 "b.gt 5b\n\t"\ + "6:\n\t"\ + INIT_SAVE\ + "cmp %w[c_rowmajor],#0; b.eq 7f\n\t"\ + SAVE_M4N##ndim(CR) "b 8f\n\t"\ + "7:\n\t"\ + SAVE_M4N##ndim(CC)\ + "8:\n\t"\ + ::[a_ptr]"r"(a_ptr), [b_scr]"r"(b_scr), [c_ptr]"r"(c_ptr),\ + [K]"r"(K), [LDA]"r"(LDA), [LDC]"r"(LDC),\ + [beta_addr]"r"(beta_addr), [c_rowmajor]"r"(c_rowmajor)\ + :"cc","memory","x0","x1","x2","x3","x4","x5","x6","x7","x8","x9",\ + "x12","x13","x14","x15",\ + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","v12","v13",\ + "v14","v15","v16","v17","v18","v19","v20","v21","v22","v23","v24","v25",\ + "v26","v27","v28","v29","v30","v31");\ +} + +#define UNIT_SAVE_M3N4_CC(c1, c2, c3) \ + "ldr d1,[x12]; ldr s2,[x12,#8]\n\t"\ + "ldr d3,[x13]; ldr s4,[x13,#8]\n\t"\ + "trn1 v5.4s,v"#c1".4s,v"#c2".4s; trn2 v"#c2".4s,v"#c1".4s,v"#c2".4s\n\t"\ + "mov v6.8b,v"#c2".8b; mov v"#c1".16b,v5.16b\n\t"\ + "fmla v5.2s,v1.2s,v0.s[0]; fmla v6.2s,v3.2s,v0.s[0]\n\t"\ + "fmov s1,s"#c3"; ins v3.s[0],v"#c3".s[1]\n\t"\ + "fmla s1,s2,v0.s[0]; fmla s3,s4,v0.s[0]\n\t"\ + "str d5,[x12]; str s1,[x12,#8]; prfm pstl2keep,[x12,#24]\n\t"\ + "add x12,x12,%w[LDC],UXTW #4\n\t"\ + "str d6,[x13]; str s3,[x13,#8]; prfm pstl2keep,[x13,#24]\n\t"\ + "add x13,x13,%w[LDC],UXTW #4\n\t"\ + "ldr d1,[x14]; ldr s2,[x14,#8]\n\t"\ + "ldr d3,[x15]; ldr s4,[x15,#8]\n\t"\ + "ins v"#c1".d[0],v"#c1".d[1]; ins v"#c2".d[0],v"#c2".d[1]\n\t"\ + "ins v5.s[0],v"#c3".s[2]; ins v6.s[0],v"#c3".s[3]\n\t"\ + "fmla v"#c1".2s,v1.2s,v0.s[0]; fmla v"#c2".2s,v3.2s,v0.s[0]\n\t"\ + "fmla s5,s2,v0.s[0]; fmla s6,s4,v0.s[0]\n\t"\ + "str d"#c1",[x14]; str s5,[x14,#8]; prfm pstl2keep,[x14,#24]\n\t"\ + "add x14,x14,%w[LDC],UXTW #4\n\t"\ + "str d"#c2",[x15]; str s6,[x15,#8]; prfm pstl2keep,[x15,#24]\n\t"\ + "add x15,x15,%w[LDC],UXTW #4\n\t" + +#define UNIT_SAVE_M3N4_CR(c1, c2, c3) \ + "ldr q1,[x12]; ldr q2,[x13]; ldr q3,[x14]\n\t"\ + "fmla v"#c1".4s,v1.4s,v0.s[0]\n\t"\ + "fmla v"#c2".4s,v2.4s,v0.s[0]\n\t"\ + "fmla v"#c3".4s,v3.4s,v0.s[0]\n\t"\ + "str q"#c1",[x12],#16; str q"#c2",[x13],#16; str q"#c3",[x14],#16\n\t" + +#define EDGE_SAVE_M3N1_CC(c1, c2, c3) \ + "ldr d1,[x12]; ldr s2,[x12,#8]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c2".4s\n\t"\ + "faddp v"#c3".4s,v"#c3".4s,v"#c3".4s\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c1".4s\n\t"\ + "faddp s"#c3",v"#c3".2s\n\t"\ + "fmla v"#c1".2s,v1.2s,v0.s[0]; fmla s"#c3",s2,v0.s[0]\n\t"\ + "str d"#c1",[x12]; str s"#c3",[x12,#8]\n\t"\ + "prfm pstl2keep,[x12,#24]\n\t"\ + "add x12,x12,%w[LDC],UXTW #2\n\t" + +#define EDGE_SAVE_M3N1_CR(c1, c2, c3) \ + "ldr s1,[x12]; ldr s2,[x13]; ldr s3,[x14]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c2".4s\n\t"\ + "faddp v"#c3".4s,v"#c3".4s,v"#c3".4s\n\t"\ + "ins v1.s[1],v2.s[0]\n\t"\ + "faddp v"#c1".4s,v"#c1".4s,v"#c1".4s\n\t"\ + "faddp s"#c3",v"#c3".2s\n\t"\ + "fmla v"#c1".2s,v1.2s,v0.s[0]; fmla s"#c3",s3,v0.s[0]\n\t"\ + "st1 {v"#c1".s}[0],[x12],#4; st1 {v"#c1".s}[1],[x13],#4\n\t"\ + "str s"#c3",[x14],#4\n\t" + +#define FUNC_M3(ndim) \ +static inline void sgemm_skinny1_a7x_m3n##ndim(\ + const float * __restrict__ a_ptr, const float * __restrict__ b_scr,\ + float * __restrict__ c_ptr, uint32_t K, uint32_t LDA, uint32_t LDC,\ + uint8_t c_rowmajor, const float * __restrict__ beta_addr) {\ + __asm__ __volatile__ (\ + "mov x0,%[a_ptr]; add x1,%[a_ptr],%w[LDA],UXTW #2\n\t"\ + "add x2,%[a_ptr],%w[LDA],UXTW #3\n\t"\ + "add x6,x1,%w[LDA],UXTW #3; add x7,x0,%w[LDA],UXTW #4\n\t"\ + "add x8,x1,%w[LDA],UXTW #4\n\t"\ + "mov x4,%[b_scr]; mov w5,%w[K]\n\t"\ + INIT_M3N##ndim\ + "cmp w5,#4; b.lt 4f\n\t"\ + KERNEL_M3N##ndim##_PRELOAD4\ + "cmp w5,#20; b.lt 1f\n\t"\ + ".balign 16; 9:\n\t"\ + KERNEL_M3N##ndim##_MAIN4(0, 1, 2, 3, 4, 5)\ + "prfm pldl2keep,[x6]; add x6,x6,#64\n\t"\ + KERNEL_M3N##ndim##_MAIN4(3, 4, 5, 0, 1, 2)\ + "prfm pldl2keep,[x7]; add x7,x7,#64\n\t"\ + KERNEL_M3N##ndim##_MAIN4(0, 1, 2, 3, 4, 5)\ + "prfm pldl2keep,[x8]; add x8,x8,#64\n\t"\ + KERNEL_M3N##ndim##_MAIN4(3, 4, 5, 0, 1, 2)\ + "cmp w5,#20; b.ge 9b; 1:\n\t"\ + "cmp w5,#12; b.lt 2f\n\t"\ + KERNEL_M3N##ndim##_MAIN4(0, 1, 2, 3, 4, 5)\ + KERNEL_M3N##ndim##_MAIN4(3, 4, 5, 0, 1, 2)\ + "2:\n\t"\ + "cmp w5,#8; b.lt 3f\n\t"\ + KERNEL_M3N##ndim##_MAIN4(0, 1, 2, 3, 4, 5)\ + KERNEL_M3N##ndim##_TAIL4(3, 4, 5)\ + "b 4f; 3:\n\t"\ + KERNEL_M3N##ndim##_TAIL4(0, 1, 2)\ + "4:\n\t"\ + "cmp w5,#1; b.lt 6f\n\t"\ + "5:\n\t"\ + KERNEL_M3N##ndim##_TL1 "b.gt 5b\n\t"\ + "6:\n\t"\ + INIT_SAVE\ + "cmp %w[c_rowmajor],#0; b.eq 7f\n\t"\ + SAVE_M3N##ndim(CR) "b 8f\n\t"\ + "7:\n\t"\ + SAVE_M3N##ndim(CC)\ + "8:\n\t"\ + ::[a_ptr]"r"(a_ptr), [b_scr]"r"(b_scr), [c_ptr]"r"(c_ptr),\ + [K]"r"(K), [LDA]"r"(LDA), [LDC]"r"(LDC),\ + [beta_addr]"r"(beta_addr), [c_rowmajor]"r"(c_rowmajor)\ + :"cc","memory","x0","x1","x2","x4","x5","x6","x7","x8",\ + "x12","x13","x14","x15",\ + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10","v11","v12","v13",\ + "v14","v15","v16","v17","v18","v19","v20","v21","v22","v23","v24","v25",\ + "v26","v27","v28","v29","v30","v31");\ +} + + +#define INIT_M4N4 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19) + +#define SAVE_M4N4(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15) + +#define KERNEL_M4N4_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]; ldr q11,[x4,#48]\n\t"\ + "add x4,x4,#64\n\t" + +#define KERNEL_M4N4_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "ldr q9,[x4,#16]; sub w5,w5,#4\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; fmla v13.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[2]; fmla v15.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#32]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "ldr q11,[x4,#48]; add x4,x4,#64\n\t" + +#define KERNEL_M4N4_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; fmla v13.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[2]; fmla v15.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "sub w5,w5,#4; prfm pldl2keep,[x9]\n\t"\ + "fadd v12.4s,v12.4s,v16.4s; fadd v13.4s,v13.4s,v17.4s\n\t"\ + "fadd v14.4s,v14.4s,v18.4s; fadd v15.4s,v15.4s,v19.4s\n\t" + +#define KERNEL_M4N4_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t" + + +#define INIT_M4N5 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) + +#define SAVE_M4N5(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15)\ + EDGE_SAVE_M4N1_##mode(20, 21, 22, 23) + +#define KERNEL_M4N5_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]; ldr q11,[x4,#48]\n\t"\ + "ldr q24,[x4,#64]; add x4,x4,#80\n\t" + +#define KERNEL_M4N5_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "ldr q9,[x4,#16]; sub w5,w5,#4\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; fmla v13.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[2]; fmla v15.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#32]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "ldr q11,[x4,#48]\n\t"\ + "fmla v20.4s,v24.4s,v"#ac1".4s; fmla v21.4s,v24.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v24.4s,v"#ac3".4s; fmla v23.4s,v24.4s,v"#ac4".4s\n\t"\ + "ldr q24,[x4,#64]; add x4,x4,#80\n\t" + +#define KERNEL_M4N5_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; fmla v13.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[2]; fmla v15.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "prfm pldl2keep,[x9]\n\t"\ + "fmla v20.4s,v24.4s,v"#ac1".4s; fmla v21.4s,v24.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v24.4s,v"#ac3".4s; fmla v23.4s,v24.4s,v"#ac4".4s\n\t"\ + "sub w5,w5,#4\n\t"\ + "fadd v12.4s,v12.4s,v16.4s; fadd v13.4s,v13.4s,v17.4s\n\t"\ + "fadd v14.4s,v14.4s,v18.4s; fadd v15.4s,v15.4s,v19.4s\n\t" + +#define KERNEL_M4N5_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16; ldr s9,[x4],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v20.4s,v9.4s,v0.4s; fmla v21.4s,v9.4s,v1.4s\n\t"\ + "fmla v22.4s,v9.4s,v2.4s; fmla v23.4s,v9.4s,v3.4s\n\t" + + +#define INIT_M4N6 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27) + +#define SAVE_M4N6(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15)\ + EDGE_SAVE_M4N1_##mode(20, 21, 22, 23) EDGE_SAVE_M4N1_##mode(24, 25, 26, 27) + +#define KERNEL_M4N6_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]; ldr q11,[x4,#48]\n\t"\ + "ldr q28,[x4,#64]; ldr q29,[x4,#80]; add x4,x4,#96\n\t" + +#define KERNEL_M4N6_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "ldr q9,[x4,#16]; sub w5,w5,#4\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; fmla v13.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[2]; fmla v15.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#32]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "ldr q11,[x4,#48]\n\t"\ + "fmla v20.4s,v28.4s,v"#ac1".4s; fmla v21.4s,v28.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v28.4s,v"#ac3".4s; fmla v23.4s,v28.4s,v"#ac4".4s\n\t"\ + "ldr q28,[x4,#64]\n\t"\ + "fmla v24.4s,v29.4s,v"#ac1".4s; fmla v25.4s,v29.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v29.4s,v"#ac3".4s; fmla v27.4s,v29.4s,v"#ac4".4s\n\t"\ + "ldr q29,[x4,#80]; add x4,x4,#96\n\t" + +#define KERNEL_M4N6_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; fmla v13.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[2]; fmla v15.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "prfm pldl2keep,[x9]\n\t"\ + "fmla v20.4s,v28.4s,v"#ac1".4s; fmla v21.4s,v28.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v28.4s,v"#ac3".4s; fmla v23.4s,v28.4s,v"#ac4".4s\n\t"\ + "fmla v24.4s,v29.4s,v"#ac1".4s; fmla v25.4s,v29.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v29.4s,v"#ac3".4s; fmla v27.4s,v29.4s,v"#ac4".4s\n\t"\ + "sub w5,w5,#4\n\t"\ + "fadd v12.4s,v12.4s,v16.4s; fadd v13.4s,v13.4s,v17.4s\n\t"\ + "fadd v14.4s,v14.4s,v18.4s; fadd v15.4s,v15.4s,v19.4s\n\t" + +#define KERNEL_M4N6_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16; ldr s9,[x4],#4; ldr s10,[x4],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v20.4s,v9.4s,v0.4s; fmla v21.4s,v9.4s,v1.4s\n\t"\ + "fmla v22.4s,v9.4s,v2.4s; fmla v23.4s,v9.4s,v3.4s\n\t"\ + "fmla v24.4s,v10.4s,v0.4s; fmla v25.4s,v10.4s,v1.4s\n\t"\ + "fmla v26.4s,v10.4s,v2.4s; fmla v27.4s,v10.4s,v3.4s\n\t" + + +#define INIT_M4N7 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N7(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15)\ + EDGE_SAVE_M4N1_##mode(20, 21, 22, 23) EDGE_SAVE_M4N1_##mode(24, 25, 26, 27)\ + EDGE_SAVE_M4N1_##mode(28, 29, 30, 31) + +#define KERNEL_M4N7_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]; ldr q11,[x4,#48]\n\t"\ + "add x4,x4,#112\n\t" + +#define KERNEL_M4N7_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-48]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "ldr q9,[x4,#-32]; sub w5,w5,#4\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; fmla v13.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[2]; fmla v15.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#-16]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".4s; fmla v21.4s,v8.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".4s; fmla v23.4s,v8.4s,v"#ac4".4s\n\t"\ + "ldr q8,[x4],#112\n\t"\ + "fmla v24.4s,v9.4s,v"#ac1".4s; fmla v25.4s,v9.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v9.4s,v"#ac3".4s; fmla v27.4s,v9.4s,v"#ac4".4s\n\t"\ + "ldr q9,[x4,#-96]\n\t"\ + "fmla v28.4s,v10.4s,v"#ac1".4s; fmla v29.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v30.4s,v10.4s,v"#ac3".4s; fmla v31.4s,v10.4s,v"#ac4".4s\n\t"\ + "ldr q10,[x4,#-80]; ldr q11,[x4,#-64]\n\t" + +#define KERNEL_M4N7_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-48]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "ldr q9,[x4,#-32]; sub w5,w5,#4\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[2]; fmla v13.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[2]; fmla v15.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#-16]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "prfm pldl2keep,[x9]\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "fmla v20.4s,v8.4s,v"#ac1".4s; fmla v21.4s,v8.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v8.4s,v"#ac3".4s; fmla v23.4s,v8.4s,v"#ac4".4s\n\t"\ + "fmla v24.4s,v9.4s,v"#ac1".4s; fmla v25.4s,v9.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v9.4s,v"#ac3".4s; fmla v27.4s,v9.4s,v"#ac4".4s\n\t"\ + "fmla v28.4s,v10.4s,v"#ac1".4s; fmla v29.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v30.4s,v10.4s,v"#ac3".4s; fmla v31.4s,v10.4s,v"#ac4".4s\n\t"\ + "fadd v12.4s,v12.4s,v16.4s; fadd v13.4s,v13.4s,v17.4s\n\t"\ + "fadd v14.4s,v14.4s,v18.4s; fadd v15.4s,v15.4s,v19.4s\n\t" + +#define KERNEL_M4N7_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16; ldr s9,[x4],#4; ldr s10,[x4],#4; ldr s11,[x4],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v20.4s,v9.4s,v0.4s; fmla v21.4s,v9.4s,v1.4s\n\t"\ + "fmla v22.4s,v9.4s,v2.4s; fmla v23.4s,v9.4s,v3.4s\n\t"\ + "fmla v24.4s,v10.4s,v0.4s; fmla v25.4s,v10.4s,v1.4s\n\t"\ + "fmla v26.4s,v10.4s,v2.4s; fmla v27.4s,v10.4s,v3.4s\n\t"\ + "fmla v28.4s,v11.4s,v0.4s; fmla v29.4s,v11.4s,v1.4s\n\t"\ + "fmla v30.4s,v11.4s,v2.4s; fmla v31.4s,v11.4s,v3.4s\n\t" + + +#define INIT_M4N8 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19) + +#define SAVE_M4N8(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15)\ + UNIT_SAVE_M4N4_##mode(16, 17, 18, 19) + +#define KERNEL_M4N8_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]; ldr q11,[x4,#48]\n\t"\ + "add x4,x4,#128\n\t" + +#define KERNEL_M4N8_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-64]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-48]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]; fmla v13.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[1]; fmla v15.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-32]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[1]; fmla v17.4s,v11.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[1]; fmla v19.4s,v11.4s,v"#ac4".s[1]\n\t"\ + "ldr q11,[x4,#-16]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[2]; fmla v13.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[2]; fmla v15.4s,v8.4s,v"#ac4".s[2]\n\t"\ + "ldr q8,[x4],#128\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[2]; fmla v17.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[2]; fmla v19.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-112]; sub w5,w5,#4\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[3]; fmla v13.4s,v10.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[3]; fmla v15.4s,v10.4s,v"#ac4".s[3]\n\t"\ + "ldr q10,[x4,#-96]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "ldr q11,[x4,#-80]\n\t" + +#define KERNEL_M4N8_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-64]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-48]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]; fmla v13.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[1]; fmla v15.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-32]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[1]; fmla v17.4s,v11.4s,v"#ac2".s[1]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[1]; fmla v19.4s,v11.4s,v"#ac4".s[1]\n\t"\ + "ldr q11,[x4,#-16]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[2]; fmla v13.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[2]; fmla v15.4s,v8.4s,v"#ac4".s[2]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[2]; fmla v17.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[2]; fmla v19.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "sub w5,w5,#4\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[3]; fmla v13.4s,v10.4s,v"#ac2".s[3]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[3]; fmla v15.4s,v10.4s,v"#ac4".s[3]\n\t"\ + "fmla v16.4s,v11.4s,v"#ac1".s[3]; fmla v17.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "prfm pldl2keep,[x9]\n\t"\ + "fmla v18.4s,v11.4s,v"#ac3".s[3]; fmla v19.4s,v11.4s,v"#ac4".s[3]\n\t" + +#define KERNEL_M4N8_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16; ldr q9,[x4],#16\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v16.4s,v9.4s,v0.s[0]; fmla v17.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v9.4s,v2.s[0]; fmla v19.4s,v9.4s,v3.s[0]\n\t" + + +#define INIT_M4N9 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) + +#define SAVE_M4N9(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15)\ + UNIT_SAVE_M4N4_##mode(16, 17, 18, 19) EDGE_SAVE_M4N1_##mode(20, 21, 22, 23) + +#define KERNEL_M4N9_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]\n\t"\ + "add x4,x4,#144\n\t" + +#define KERNEL_M4N9_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-96]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-80]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]; fmla v13.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[1]; fmla v15.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-64]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; fmla v17.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac3".s[1]; fmla v19.4s,v8.4s,v"#ac4".s[1]\n\t"\ + "ldr q8,[x4,#-48]\n\t"\ + "fmla v12.4s,v9.4s,v"#ac1".s[2]; fmla v13.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v9.4s,v"#ac3".s[2]; fmla v15.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-32]\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[2]; fmla v17.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac3".s[2]; fmla v19.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#-16]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[3]; fmla v13.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[3]; fmla v15.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "ldr q8,[x4],#144\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[3]; fmla v17.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "sub w5,w5,#4\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[3]; fmla v19.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "ldr q9,[x4,#-128]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".4s; fmla v21.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".4s; fmla v23.4s,v10.4s,v"#ac4".4s\n\t"\ + "ldr q10,[x4,#-112]\n\t" + +#define KERNEL_M4N9_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-96]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-80]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]; fmla v13.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[1]; fmla v15.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-64]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; fmla v17.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac3".s[1]; fmla v19.4s,v8.4s,v"#ac4".s[1]\n\t"\ + "ldr q8,[x4,#-48]\n\t"\ + "fmla v12.4s,v9.4s,v"#ac1".s[2]; fmla v13.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v9.4s,v"#ac3".s[2]; fmla v15.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-32]\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[2]; fmla v17.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac3".s[2]; fmla v19.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#-16]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[3]; fmla v13.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[3]; fmla v15.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[3]; fmla v17.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "prfm pldl2keep,[x9]; sub w5,w5,#4\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[3]; fmla v19.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".4s; fmla v21.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".4s; fmla v23.4s,v10.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N9_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16; ldr q9,[x4],#16; ldr s10,[x4],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v16.4s,v9.4s,v0.s[0]; fmla v17.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v9.4s,v2.s[0]; fmla v19.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v20.4s,v10.4s,v0.4s; fmla v21.4s,v10.4s,v1.4s\n\t"\ + "fmla v22.4s,v10.4s,v2.4s; fmla v23.4s,v10.4s,v3.4s\n\t" + + +#define INIT_M4N10 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27) + +#define SAVE_M4N10(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15)\ + UNIT_SAVE_M4N4_##mode(16, 17, 18, 19)\ + EDGE_SAVE_M4N1_##mode(20, 21, 22, 23) EDGE_SAVE_M4N1_##mode(24, 25, 26, 27) + +#define KERNEL_M4N10_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]\n\t"\ + "add x4,x4,#160\n\t" + +#define KERNEL_M4N10_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-112]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-96]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]; fmla v13.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[1]; fmla v15.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-80]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; fmla v17.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac3".s[1]; fmla v19.4s,v8.4s,v"#ac4".s[1]\n\t"\ + "ldr q8,[x4,#-64]\n\t"\ + "fmla v12.4s,v9.4s,v"#ac1".s[2]; fmla v13.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v9.4s,v"#ac3".s[2]; fmla v15.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-48]\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[2]; fmla v17.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac3".s[2]; fmla v19.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#-32]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[3]; fmla v13.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[3]; fmla v15.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "ldr q11,[x4,#-16]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[3]; fmla v17.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "sub w5,w5,#4\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[3]; fmla v19.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "ldr q8,[x4],#160\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".4s; fmla v21.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".4s; fmla v23.4s,v10.4s,v"#ac4".4s\n\t"\ + "ldr q9,[x4,#-144]\n\t"\ + "fmla v24.4s,v11.4s,v"#ac1".4s; fmla v25.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v11.4s,v"#ac3".4s; fmla v27.4s,v11.4s,v"#ac4".4s\n\t"\ + "ldr q10,[x4,#-128]\n\t" + +#define KERNEL_M4N10_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-112]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-96]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]; fmla v13.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[1]; fmla v15.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-80]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; fmla v17.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac3".s[1]; fmla v19.4s,v8.4s,v"#ac4".s[1]\n\t"\ + "ldr q8,[x4,#-64]\n\t"\ + "fmla v12.4s,v9.4s,v"#ac1".s[2]; fmla v13.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v9.4s,v"#ac3".s[2]; fmla v15.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-48]\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[2]; fmla v17.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac3".s[2]; fmla v19.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#-32]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[3]; fmla v13.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[3]; fmla v15.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "ldr q11,[x4,#-16]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[3]; fmla v17.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "prfm pldl2keep,[x9]; sub w5,w5,#4\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[3]; fmla v19.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".4s; fmla v21.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".4s; fmla v23.4s,v10.4s,v"#ac4".4s\n\t"\ + "fmla v24.4s,v11.4s,v"#ac1".4s; fmla v25.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v11.4s,v"#ac3".4s; fmla v27.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N10_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16; ldr q9,[x4],#16; ldr s10,[x4],#4; ldr s11,[x4],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v16.4s,v9.4s,v0.s[0]; fmla v17.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v9.4s,v2.s[0]; fmla v19.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v20.4s,v10.4s,v0.4s; fmla v21.4s,v10.4s,v1.4s\n\t"\ + "fmla v22.4s,v10.4s,v2.4s; fmla v23.4s,v10.4s,v3.4s\n\t"\ + "fmla v24.4s,v11.4s,v0.4s; fmla v25.4s,v11.4s,v1.4s\n\t"\ + "fmla v26.4s,v11.4s,v2.4s; fmla v27.4s,v11.4s,v3.4s\n\t" + + +#define INIT_M4N11 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M4N11(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15)\ + UNIT_SAVE_M4N4_##mode(16, 17, 18, 19)\ + EDGE_SAVE_M4N1_##mode(20, 21, 22, 23) EDGE_SAVE_M4N1_##mode(24, 25, 26, 27)\ + EDGE_SAVE_M4N1_##mode(28, 29, 30, 31) + +#define KERNEL_M4N11_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]\n\t"\ + "add x4,x4,#176\n\t" + +#define KERNEL_M4N11_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-128]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-112]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]; fmla v13.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[1]; fmla v15.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-96]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; fmla v17.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac3".s[1]; fmla v19.4s,v8.4s,v"#ac4".s[1]\n\t"\ + "ldr q11,[x4,#-80]\n\t"\ + "fmla v12.4s,v9.4s,v"#ac1".s[2]; fmla v13.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v9.4s,v"#ac3".s[2]; fmla v15.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q8,[x4,#-64]\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[2]; fmla v17.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac3".s[2]; fmla v19.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-48]\n\t"\ + "fmla v12.4s,v11.4s,v"#ac1".s[3]; fmla v13.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v14.4s,v11.4s,v"#ac3".s[3]; fmla v15.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "ldr q10,[x4,#-32]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[3]; fmla v17.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "sub w5,w5,#4\n\t"\ + "fmla v18.4s,v8.4s,v"#ac3".s[3]; fmla v19.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "ldr q11,[x4,#-16]\n\t"\ + "fmla v20.4s,v9.4s,v"#ac1".4s; fmla v21.4s,v9.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v9.4s,v"#ac3".4s; fmla v23.4s,v9.4s,v"#ac4".4s\n\t"\ + "ldr q8,[x4],#176\n\t"\ + "fmla v24.4s,v10.4s,v"#ac1".4s; fmla v25.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".4s; fmla v27.4s,v10.4s,v"#ac4".4s\n\t"\ + "ldr q9,[x4,#-160]\n\t"\ + "fmla v28.4s,v11.4s,v"#ac1".4s; fmla v29.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v30.4s,v11.4s,v"#ac3".4s; fmla v31.4s,v11.4s,v"#ac4".4s\n\t"\ + "ldr q10,[x4,#-144]\n\t" + +#define KERNEL_M4N11_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-128]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-112]\n\t"\ + "fmla v12.4s,v10.4s,v"#ac1".s[1]; fmla v13.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "fmla v14.4s,v10.4s,v"#ac3".s[1]; fmla v15.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-96]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[1]; fmla v17.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + "fmla v18.4s,v8.4s,v"#ac3".s[1]; fmla v19.4s,v8.4s,v"#ac4".s[1]\n\t"\ + "ldr q11,[x4,#-80]\n\t"\ + "fmla v12.4s,v9.4s,v"#ac1".s[2]; fmla v13.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v9.4s,v"#ac3".s[2]; fmla v15.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q8,[x4,#-64]\n\t"\ + "fmla v16.4s,v10.4s,v"#ac1".s[2]; fmla v17.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + "fmla v18.4s,v10.4s,v"#ac3".s[2]; fmla v19.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-48]\n\t"\ + "fmla v12.4s,v11.4s,v"#ac1".s[3]; fmla v13.4s,v11.4s,v"#ac2".s[3]\n\t"\ + "fmla v14.4s,v11.4s,v"#ac3".s[3]; fmla v15.4s,v11.4s,v"#ac4".s[3]\n\t"\ + "ldr q10,[x4,#-32]\n\t"\ + "fmla v16.4s,v8.4s,v"#ac1".s[3]; fmla v17.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "prfm pldl2keep,[x9]; sub w5,w5,#4\n\t"\ + "fmla v18.4s,v8.4s,v"#ac3".s[3]; fmla v19.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "ldr q11,[x4,#-16]\n\t"\ + "fmla v20.4s,v9.4s,v"#ac1".4s; fmla v21.4s,v9.4s,v"#ac2".4s\n\t"\ + "fmla v22.4s,v9.4s,v"#ac3".4s; fmla v23.4s,v9.4s,v"#ac4".4s\n\t"\ + "fmla v24.4s,v10.4s,v"#ac1".4s; fmla v25.4s,v10.4s,v"#ac2".4s\n\t"\ + "fmla v26.4s,v10.4s,v"#ac3".4s; fmla v27.4s,v10.4s,v"#ac4".4s\n\t"\ + "fmla v28.4s,v11.4s,v"#ac1".4s; fmla v29.4s,v11.4s,v"#ac2".4s\n\t"\ + "fmla v30.4s,v11.4s,v"#ac3".4s; fmla v31.4s,v11.4s,v"#ac4".4s\n\t" + +#define KERNEL_M4N11_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16; ldr q9,[x4],#16; ldr d10,[x4],#8; ldr s11,[x4],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v16.4s,v9.4s,v0.s[0]; fmla v17.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v9.4s,v2.s[0]; fmla v19.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v20.4s,v0.4s,v10.s[0]; fmla v21.4s,v1.4s,v10.s[0]\n\t"\ + "fmla v22.4s,v2.4s,v10.s[0]; fmla v23.4s,v3.4s,v10.s[0]\n\t"\ + "fmla v24.4s,v0.4s,v10.s[1]; fmla v25.4s,v1.4s,v10.s[1]\n\t"\ + "fmla v26.4s,v2.4s,v10.s[1]; fmla v27.4s,v3.4s,v10.s[1]\n\t"\ + "fmla v28.4s,v0.4s,v11.s[0]; fmla v29.4s,v1.4s,v11.s[0]\n\t"\ + "fmla v30.4s,v2.4s,v11.s[0]; fmla v31.4s,v3.4s,v11.s[0]\n\t" + + +#define INIT_M4N12 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) + +#define SAVE_M4N12(mode) UNIT_SAVE_M4N4_##mode(12, 13, 14, 15)\ + UNIT_SAVE_M4N4_##mode(16, 17, 18, 19) UNIT_SAVE_M4N4_##mode(20, 21, 22, 23) + +#define KERNEL_M4N12_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16; ldr q3,[x3],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]\n\t"\ + "add x4,x4,#192\n\t" + +#define KERNEL_M4N12_MAIN4(ac1, ac2, ac3, ac4, an1, an2, an3, an4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-144]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-128]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[0]; fmla v21.4s,v10.4s,v"#ac2".s[0]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[0]; fmla v23.4s,v10.4s,v"#ac4".s[0]\n\t"\ + "ldr q10,[x4,#-112]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[1]; fmla v13.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[1]; fmla v15.4s,v8.4s,v"#ac4".s[1]\n\t"\ + "ldr q8,[x4,#-96]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "ldr q9,[x4,#-80]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[1]; fmla v21.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[1]; fmla v23.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-64]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[2]; fmla v13.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[2]; fmla v15.4s,v8.4s,v"#ac4".s[2]\n\t"\ + "ldr q8,[x4,#-48]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[2]; fmla v17.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[2]; fmla v19.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-32]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[2]; fmla v21.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[2]; fmla v23.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#-16]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[3]; fmla v13.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "ldr q"#an4",[x3],#16\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[3]; fmla v15.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "ldr q8,[x4],#192\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[3]; fmla v17.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "sub w5,w5,#4\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[3]; fmla v19.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "ldr q9,[x4,#-176]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[3]; fmla v21.4s,v10.4s,v"#ac2".s[3]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[3]; fmla v23.4s,v10.4s,v"#ac4".s[3]\n\t"\ + "ldr q10,[x4,#-160]\n\t" + +#define KERNEL_M4N12_TAIL4(ac1, ac2, ac3, ac4) \ + "fmla v12.4s,v8.4s,v"#ac1".s[0]; fmla v13.4s,v8.4s,v"#ac2".s[0]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[0]; fmla v15.4s,v8.4s,v"#ac4".s[0]\n\t"\ + "ldr q8,[x4,#-144]; prfm pldl2keep,[x6]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[0]; fmla v17.4s,v9.4s,v"#ac2".s[0]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[0]; fmla v19.4s,v9.4s,v"#ac4".s[0]\n\t"\ + "ldr q9,[x4,#-128]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[0]; fmla v21.4s,v10.4s,v"#ac2".s[0]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[0]; fmla v23.4s,v10.4s,v"#ac4".s[0]\n\t"\ + "ldr q10,[x4,#-112]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[1]; fmla v13.4s,v8.4s,v"#ac2".s[1]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[1]; fmla v15.4s,v8.4s,v"#ac4".s[1]\n\t"\ + "ldr q8,[x4,#-96]; prfm pldl2keep,[x7]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[1]; fmla v17.4s,v9.4s,v"#ac2".s[1]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[1]; fmla v19.4s,v9.4s,v"#ac4".s[1]\n\t"\ + "ldr q9,[x4,#-80]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[1]; fmla v21.4s,v10.4s,v"#ac2".s[1]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[1]; fmla v23.4s,v10.4s,v"#ac4".s[1]\n\t"\ + "ldr q10,[x4,#-64]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[2]; fmla v13.4s,v8.4s,v"#ac2".s[2]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[2]; fmla v15.4s,v8.4s,v"#ac4".s[2]\n\t"\ + "ldr q8,[x4,#-48]; prfm pldl2keep,[x8]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[2]; fmla v17.4s,v9.4s,v"#ac2".s[2]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[2]; fmla v19.4s,v9.4s,v"#ac4".s[2]\n\t"\ + "ldr q9,[x4,#-32]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[2]; fmla v21.4s,v10.4s,v"#ac2".s[2]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[2]; fmla v23.4s,v10.4s,v"#ac4".s[2]\n\t"\ + "ldr q10,[x4,#-16]\n\t"\ + "fmla v12.4s,v8.4s,v"#ac1".s[3]; fmla v13.4s,v8.4s,v"#ac2".s[3]\n\t"\ + "fmla v14.4s,v8.4s,v"#ac3".s[3]; fmla v15.4s,v8.4s,v"#ac4".s[3]\n\t"\ + "fmla v16.4s,v9.4s,v"#ac1".s[3]; fmla v17.4s,v9.4s,v"#ac2".s[3]\n\t"\ + "sub w5,w5,#4; prfm pldl2keep,[x9]\n\t"\ + "fmla v18.4s,v9.4s,v"#ac3".s[3]; fmla v19.4s,v9.4s,v"#ac4".s[3]\n\t"\ + "fmla v20.4s,v10.4s,v"#ac1".s[3]; fmla v21.4s,v10.4s,v"#ac2".s[3]\n\t"\ + "fmla v22.4s,v10.4s,v"#ac3".s[3]; fmla v23.4s,v10.4s,v"#ac4".s[3]\n\t" + +#define KERNEL_M4N12_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4; ldr s3,[x3],#4\n\t"\ + "ldr q8,[x4],#16; ldr q9,[x4],#16; ldr q10,[x4],#16\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v8.4s,v3.s[0]\n\t"\ + "fmla v16.4s,v9.4s,v0.s[0]; fmla v17.4s,v9.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v9.4s,v2.s[0]; fmla v19.4s,v9.4s,v3.s[0]\n\t"\ + "fmla v20.4s,v10.4s,v0.s[0]; fmla v21.4s,v10.4s,v1.s[0]\n\t"\ + "fmla v22.4s,v10.4s,v2.s[0]; fmla v23.4s,v10.4s,v3.s[0]\n\t" + +FUNC_M4(4) +FUNC_M4(5) +FUNC_M4(6) +FUNC_M4(7) +FUNC_M4(8) +FUNC_M4(9) +FUNC_M4(10) +FUNC_M4(11) +FUNC_M4(12) + +#define FMA_M3N4(c1, c2, c3, a1, a2, a3, b1, k) \ + "fmla v"#c1".4s,v"#b1".4s,v"#a1".s["#k"]\n\t"\ + "fmla v"#c2".4s,v"#b1".4s,v"#a2".s["#k"]\n\t"\ + "fmla v"#c3".4s,v"#b1".4s,v"#a3".s["#k"]\n\t" + +#define FMA_M3N1(c1, c2, c3, a1, a2, a3, b1) \ + "fmla v"#c1".4s,v"#b1".4s,v"#a1".4s\n\t"\ + "fmla v"#c2".4s,v"#b1".4s,v"#a2".4s\n\t"\ + "fmla v"#c3".4s,v"#b1".4s,v"#a3".4s\n\t" + + +#define INIT_M3N13 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) + +#define SAVE_M3N13(mode) UNIT_SAVE_M3N4_##mode(12, 13, 14)\ + UNIT_SAVE_M3N4_##mode(15, 16, 17) UNIT_SAVE_M3N4_##mode(18, 19, 20)\ + EDGE_SAVE_M3N1_##mode(21, 22, 23) + +#define KERNEL_M3N13_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]\n\t"\ + "add x4,x4,#208\n\t" + +#define KERNEL_M3N13_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-160]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 0) "ldr q9,[x4,#-144]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 0) "ldr q10,[x4,#-128]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-112]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 1) "ldr q9,[x4,#-96]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 1) "ldr q10,[x4,#-80]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 2) "ldr q8,[x4,#-64]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 2) "ldr q9,[x4,#-48]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 2) "ldr q10,[x4,#-32]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 3) "ldr q11,[x4,#-16]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 3) "ldr q8,[x4],#208\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 3) "ldr q9,[x4,#-192]\n\t"\ + FMA_M3N1(21, 22, 23, ac1, ac2, ac3, 11) "ldr q10,[x4,#-176]\n\t" + +#define KERNEL_M3N13_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-160]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 0) "ldr q9,[x4,#-144]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 0) "ldr q10,[x4,#-128]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-112]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 1) "ldr q9,[x4,#-96]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 1) "ldr q10,[x4,#-80]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 2) "ldr q8,[x4,#-64]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 2) "ldr q9,[x4,#-48]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 2) "ldr q10,[x4,#-32]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 3) "ldr q11,[x4,#-16]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 3)\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 3)\ + FMA_M3N1(21, 22, 23, ac1, ac2, ac3, 11) + +#define KERNEL_M3N13_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q8,[x4],#16; ldr q9,[x4],#16; ldr q10,[x4],#16; ldr s11,[x4],#4\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v16.4s,v9.4s,v1.s[0]; fmla v17.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v18.4s,v10.4s,v0.s[0]; fmla v19.4s,v10.4s,v1.s[0]\n\t"\ + "fmla v20.4s,v10.4s,v2.s[0]; fmla v21.4s,v0.4s,v11.s[0]\n\t"\ + "fmla v22.4s,v1.4s,v11.s[0]; fmla v23.4s,v2.4s,v11.s[0]\n\t" + + +#define INIT_M3N14 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) INIT_2V(24, 25) INIT_1V(26) + +#define SAVE_M3N14(mode) UNIT_SAVE_M3N4_##mode(12, 13, 14)\ + UNIT_SAVE_M3N4_##mode(15, 16, 17) UNIT_SAVE_M3N4_##mode(18, 19, 20)\ + EDGE_SAVE_M3N1_##mode(21, 22, 23) EDGE_SAVE_M3N1_##mode(24, 25, 26) + +#define KERNEL_M3N14_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]\n\t"\ + "add x4,x4,#224\n\t" + +#define KERNEL_M3N14_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-176]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 0) "ldr q9,[x4,#-160]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 0) "ldr q10,[x4,#-144]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-128]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 1) "ldr q9,[x4,#-112]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 1) "ldr q10,[x4,#-96]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 2) "ldr q11,[x4,#-80]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 2) "ldr q8,[x4,#-64]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 2) "ldr q9,[x4,#-48]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 11, 3) "ldr q10,[x4,#-32]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 8, 3) "ldr q11,[x4,#-16]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 9, 3) "ldr q8,[x4],#224\n\t"\ + FMA_M3N1(21, 22, 23, ac1, ac2, ac3, 10) "ldr q9,[x4,#-208]\n\t"\ + FMA_M3N1(24, 25, 26, ac1, ac2, ac3, 11) "ldr q10,[x4,#-192]\n\t" + +#define KERNEL_M3N14_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-176]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 0) "ldr q9,[x4,#-160]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 0) "ldr q10,[x4,#-144]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-128]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 1) "ldr q9,[x4,#-112]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 1) "ldr q10,[x4,#-96]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 2) "ldr q11,[x4,#-80]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 2) "ldr q8,[x4,#-64]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 2) "ldr q9,[x4,#-48]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 11, 3) "ldr q10,[x4,#-32]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 8, 3) "ldr q11,[x4,#-16]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 9, 3)\ + FMA_M3N1(21, 22, 23, ac1, ac2, ac3, 10)\ + FMA_M3N1(24, 25, 26, ac1, ac2, ac3, 11) + +#define KERNEL_M3N14_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q8,[x4],#16; ldr q9,[x4],#16; ldr q10,[x4],#16; ldr d11,[x4],#8\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v16.4s,v9.4s,v1.s[0]; fmla v17.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v18.4s,v10.4s,v0.s[0]; fmla v19.4s,v10.4s,v1.s[0]\n\t"\ + "fmla v20.4s,v10.4s,v2.s[0]; fmla v21.4s,v0.4s,v11.s[0]\n\t"\ + "fmla v22.4s,v1.4s,v11.s[0]; fmla v23.4s,v2.4s,v11.s[0]\n\t"\ + "fmla v24.4s,v0.4s,v11.s[1]; fmla v25.4s,v1.4s,v11.s[1]\n\t"\ + "fmla v26.4s,v2.4s,v11.s[1]\n\t" + + +#define INIT_M3N15 INIT_4V(12, 13, 14, 15) INIT_4V(16, 17, 18, 19)\ + INIT_4V(20, 21, 22, 23) INIT_4V(24, 25, 26, 27) INIT_2V(28, 29) + +#define SAVE_M3N15(mode) UNIT_SAVE_M3N4_##mode(12, 13, 14)\ + UNIT_SAVE_M3N4_##mode(15, 16, 17) UNIT_SAVE_M3N4_##mode(18, 19, 20)\ + EDGE_SAVE_M3N1_##mode(21, 22, 23) EDGE_SAVE_M3N1_##mode(24, 25, 26)\ + EDGE_SAVE_M3N1_##mode(27, 28, 29) + +#define KERNEL_M3N15_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q8,[x4]; ldr q9,[x4,#16]; ldr q10,[x4,#32]\n\t"\ + "add x4,x4,#240\n\t" + +#define KERNEL_M3N15_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-192]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 0) "ldr q9,[x4,#-176]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 0) "ldr q10,[x4,#-160]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 1) "ldr q11,[x4,#-144]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 1) "ldr q8,[x4,#-128]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 1) "ldr q9,[x4,#-112]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 11, 2) "ldr q10,[x4,#-96]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 8, 2) "ldr q11,[x4,#-80]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 9, 2) "ldr q8,[x4,#-64]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 10, 3) "ldr q9,[x4,#-48]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 11, 3) "ldr q10,[x4,#-32]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 8, 3) "ldr q11,[x4,#-16]\n\t"\ + FMA_M3N1(21, 22, 23, ac1, ac2, ac3, 9) "ldr q8,[x4],#240\n\t"\ + FMA_M3N1(24, 25, 26, ac1, ac2, ac3, 10) "ldr q9,[x4,#-224]\n\t"\ + FMA_M3N1(27, 28, 29, ac1, ac2, ac3, 11) "ldr q10,[x4,#-208]\n\t" + +#define KERNEL_M3N15_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-192]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 0) "ldr q9,[x4,#-176]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 0) "ldr q10,[x4,#-160]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 8, 1) "ldr q11,[x4,#-144]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 9, 1) "ldr q8,[x4,#-128]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 10, 1) "ldr q9,[x4,#-112]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 11, 2) "ldr q10,[x4,#-96]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 8, 2) "ldr q11,[x4,#-80]\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 9, 2) "ldr q8,[x4,#-64]\n\t"\ + FMA_M3N4(12, 13, 14, ac1, ac2, ac3, 10, 3) "ldr q9,[x4,#-48]\n\t"\ + FMA_M3N4(15, 16, 17, ac1, ac2, ac3, 11, 3) "ldr q10,[x4,#-32]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(18, 19, 20, ac1, ac2, ac3, 8, 3) "ldr q11,[x4,#-16]\n\t"\ + FMA_M3N1(21, 22, 23, ac1, ac2, ac3, 9)\ + FMA_M3N1(24, 25, 26, ac1, ac2, ac3, 10)\ + FMA_M3N1(27, 28, 29, ac1, ac2, ac3, 11) + +#define KERNEL_M3N15_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q8,[x4],#16; ldr q9,[x4],#16; ldr q10,[x4],#16; ldr d11,[x4],#8\n\t"\ + "fmla v12.4s,v8.4s,v0.s[0]; fmla v13.4s,v8.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v14.4s,v8.4s,v2.s[0]; fmla v15.4s,v9.4s,v0.s[0]\n\t"\ + "ldr s8,[x4],#4\n\t"\ + "fmla v16.4s,v9.4s,v1.s[0]; fmla v17.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v18.4s,v10.4s,v0.s[0]; fmla v19.4s,v10.4s,v1.s[0]\n\t"\ + "fmla v20.4s,v10.4s,v2.s[0]; fmla v21.4s,v0.4s,v11.s[0]\n\t"\ + "fmla v22.4s,v1.4s,v11.s[0]; fmla v23.4s,v2.4s,v11.s[0]\n\t"\ + "fmla v24.4s,v0.4s,v11.s[1]; fmla v25.4s,v1.4s,v11.s[1]\n\t"\ + "fmla v26.4s,v2.4s,v11.s[1]; fmla v27.4s,v0.4s,v8.s[0]\n\t"\ + "fmla v28.4s,v1.4s,v8.s[0]; fmla v29.4s,v2.4s,v8.s[0]\n\t" + + +#define INIT_M3N16 INIT_4V(10, 11, 12, 13) INIT_4V(14, 15, 16, 17)\ + INIT_4V(18, 19, 20, 21) + +#define SAVE_M3N16(mode) UNIT_SAVE_M3N4_##mode(10, 11, 12)\ + UNIT_SAVE_M3N4_##mode(13, 14, 15) UNIT_SAVE_M3N4_##mode(16, 17, 18)\ + UNIT_SAVE_M3N4_##mode(19, 20, 21) + +#define KERNEL_M3N16_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]; ldr q8,[x4,#32]\n\t"\ + "add x4,x4,#256\n\t" + +#define KERNEL_M3N16_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-208]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#-192]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-176]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-144]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-128]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-112]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-96]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 8, 2) "ldr q8,[x4,#-80]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 8, 2) "ldr q8,[x4,#-32]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 3) "ldr q9,[x4,#-16]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 3) "ldr q6,[x4]; add x4,x4,#256\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 3) "ldr q7,[x4,#-240]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 9, 3) "ldr q8,[x4,#-224]\n\t" + +#define KERNEL_M3N16_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-208]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#-192]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-176]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-144]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-128]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-112]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-96]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 8, 2) "ldr q8,[x4,#-80]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 8, 2) "ldr q8,[x4,#-32]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 3) "ldr q9,[x4,#-16]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 3)\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 3)\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 9, 3) + +#define KERNEL_M3N16_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16; ldr q8,[x4],#16; ldr q9,[x4],#16\n\t"\ + "fmla v10.4s,v6.4s,v0.s[0]; fmla v11.4s,v6.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v12.4s,v6.4s,v2.s[0]; fmla v13.4s,v7.4s,v0.s[0]\n\t"\ + "fmla v14.4s,v7.4s,v1.s[0]; fmla v15.4s,v7.4s,v2.s[0]\n\t"\ + "fmla v16.4s,v8.4s,v0.s[0]; fmla v17.4s,v8.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v8.4s,v2.s[0]; fmla v19.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v20.4s,v9.4s,v1.s[0]; fmla v21.4s,v9.4s,v2.s[0]\n\t" + + +#define INIT_M3N17 INIT_4V(10, 11, 12, 13) INIT_4V(14, 15, 16, 17)\ + INIT_4V(18, 19, 20, 21) INIT_2V(22, 23) INIT_1V(24) + +#define SAVE_M3N17(mode) UNIT_SAVE_M3N4_##mode(10, 11, 12)\ + UNIT_SAVE_M3N4_##mode(13, 14, 15) UNIT_SAVE_M3N4_##mode(16, 17, 18)\ + UNIT_SAVE_M3N4_##mode(19, 20, 21) EDGE_SAVE_M3N1_##mode(22, 23, 24) + +#define KERNEL_M3N17_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]; ldr q8,[x4,#32]\n\t"\ + "add x4,x4,#272\n\t" + +#define KERNEL_M3N17_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-224]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#-208]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-192]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-176]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-160]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-144]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 8, 2) "ldr q8,[x4,#-96]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 6, 2) "ldr q9,[x4,#-80]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 7, 2) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 8, 2) "ldr q7,[x4,#-48]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 9, 3) "ldr q8,[x4,#-32]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 6, 3) "ldr q9,[x4,#-16]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 7, 3) "ldr q6,[x4]; add x4,x4,#272\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 8, 3) "ldr q7,[x4,#-256]\n\t"\ + FMA_M3N1(22, 23, 24, ac1, ac2, ac3, 9) "ldr q8,[x4,#-240]\n\t" + +#define KERNEL_M3N17_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-224]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#-208]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-192]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-176]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-160]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-144]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 8, 2) "ldr q8,[x4,#-96]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 6, 2) "ldr q9,[x4,#-80]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 7, 2) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 8, 2) "ldr q7,[x4,#-48]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 9, 3) "ldr q8,[x4,#-32]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 6, 3) "ldr q9,[x4,#-16]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 7, 3)\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 8, 3)\ + FMA_M3N1(22, 23, 24, ac1, ac2, ac3, 9) + +#define KERNEL_M3N17_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16; ldr q8,[x4],#16; ldr q9,[x4],#16\n\t"\ + "fmla v10.4s,v6.4s,v0.s[0]; fmla v11.4s,v6.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v12.4s,v6.4s,v2.s[0]; fmla v13.4s,v7.4s,v0.s[0]; ldr s6,[x4],#4\n\t"\ + "fmla v14.4s,v7.4s,v1.s[0]; fmla v15.4s,v7.4s,v2.s[0]\n\t"\ + "fmla v16.4s,v8.4s,v0.s[0]; fmla v17.4s,v8.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v8.4s,v2.s[0]; fmla v19.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v20.4s,v9.4s,v1.s[0]; fmla v21.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v22.4s,v0.4s,v6.s[0]; fmla v23.4s,v1.4s,v6.s[0]\n\t"\ + "fmla v24.4s,v2.4s,v6.s[0]\n\t" + + +#define INIT_M3N18 INIT_4V(10, 11, 12, 13) INIT_4V(14, 15, 16, 17)\ + INIT_4V(18, 19, 20, 21) INIT_4V(22, 23, 24, 25) INIT_2V(26, 27) + +#define SAVE_M3N18(mode) UNIT_SAVE_M3N4_##mode(10, 11, 12)\ + UNIT_SAVE_M3N4_##mode(13, 14, 15) UNIT_SAVE_M3N4_##mode(16, 17, 18)\ + UNIT_SAVE_M3N4_##mode(19, 20, 21) EDGE_SAVE_M3N1_##mode(22, 23, 24)\ + EDGE_SAVE_M3N1_##mode(25, 26, 27) + +#define KERNEL_M3N18_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]; ldr q8,[x4,#32]\n\t"\ + "add x4,x4,#288\n\t" + +#define KERNEL_M3N18_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-240]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#-224]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-208]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-176]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-160]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 1) "ldr q9,[x4,#-144]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 1) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 8, 2) "ldr q7,[x4,#-112]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 9, 2) "ldr q8,[x4,#-96]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 2) "ldr q9,[x4,#-80]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 2) "ldr q6,[x4,#-64]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 8, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 9, 3) "ldr q8,[x4,#-32]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 3) "ldr q9,[x4,#-16]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 3) "ldr q6,[x4]\n\t"\ + FMA_M3N1(22, 23, 24, ac1, ac2, ac3, 8) "ldr q7,[x4,#16]\n\t"\ + FMA_M3N1(25, 26, 27, ac1, ac2, ac3, 9) "ldr q8,[x4,#32]; add x4,x4,#288\n\t" + +#define KERNEL_M3N18_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-240]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#-224]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-208]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-176]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 1) "ldr q8,[x4,#-160]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 1) "ldr q9,[x4,#-144]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 1) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 8, 2) "ldr q7,[x4,#-112]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 9, 2) "ldr q8,[x4,#-96]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 2) "ldr q9,[x4,#-80]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 2) "ldr q6,[x4,#-64]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 8, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 9, 3) "ldr q8,[x4,#-32]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 6, 3) "ldr q9,[x4,#-16]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 7, 3)\ + FMA_M3N1(22, 23, 24, ac1, ac2, ac3, 8)\ + FMA_M3N1(25, 26, 27, ac1, ac2, ac3, 9) + +#define KERNEL_M3N18_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16; ldr q8,[x4],#16; ldr q9,[x4],#16\n\t"\ + "fmla v10.4s,v6.4s,v0.s[0]; fmla v11.4s,v6.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v12.4s,v6.4s,v2.s[0]; fmla v13.4s,v7.4s,v0.s[0]; ldr d6,[x4],#8\n\t"\ + "fmla v14.4s,v7.4s,v1.s[0]; fmla v15.4s,v7.4s,v2.s[0]\n\t"\ + "fmla v16.4s,v8.4s,v0.s[0]; fmla v17.4s,v8.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v8.4s,v2.s[0]; fmla v19.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v20.4s,v9.4s,v1.s[0]; fmla v21.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v22.4s,v0.4s,v6.s[0]; fmla v23.4s,v1.4s,v6.s[0]\n\t"\ + "fmla v24.4s,v2.4s,v6.s[0]; fmla v25.4s,v0.4s,v6.s[1]\n\t"\ + "fmla v26.4s,v1.4s,v6.s[1]; fmla v27.4s,v2.4s,v6.s[1]\n\t" + + +#define INIT_M3N19 INIT_4V(10, 11, 12, 13) INIT_4V(14, 15, 16, 17)\ + INIT_4V(18, 19, 20, 21) INIT_4V(22, 23, 24, 25)\ + INIT_4V(26, 27, 28, 29) INIT_1V(30) + +#define SAVE_M3N19(mode) UNIT_SAVE_M3N4_##mode(10, 11, 12)\ + UNIT_SAVE_M3N4_##mode(13, 14, 15) UNIT_SAVE_M3N4_##mode(16, 17, 18)\ + UNIT_SAVE_M3N4_##mode(19, 20, 21) EDGE_SAVE_M3N1_##mode(22, 23, 24)\ + EDGE_SAVE_M3N1_##mode(25, 26, 27) EDGE_SAVE_M3N1_##mode(28, 29, 30) + +#define KERNEL_M3N19_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]; ldr q8,[x4,#32]\n\t"\ + "add x4,x4,#304\n\t" + +#define KERNEL_M3N19_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-256]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#-240]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-224]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 0) "ldr q9,[x4,#-208]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 1) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 1) "ldr q7,[x4,#-176]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 9, 1) "ldr q8,[x4,#-160]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 1) "ldr q9,[x4,#-144]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 2) "ldr q6,[x4,#-128]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 2) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 9, 2) "ldr q8,[x4,#-96]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 2) "ldr q9,[x4,#-80]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 3) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 9, 3) "ldr q8,[x4,#-32]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 3) "ldr q9,[x4,#-16]\n\t"\ + FMA_M3N1(22, 23, 24, ac1, ac2, ac3, 7) "ldr q6,[x4]\n\t"\ + FMA_M3N1(25, 26, 27, ac1, ac2, ac3, 8) "ldr q7,[x4,#16]\n\t"\ + FMA_M3N1(28, 29, 30, ac1, ac2, ac3, 9) "ldr q8,[x4,#32]; add x4,x4,#304\n\t" + +#define KERNEL_M3N19_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#-256]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#-240]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 8, 0) "ldr q8,[x4,#-224]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 0) "ldr q9,[x4,#-208]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 1) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 1) "ldr q7,[x4,#-176]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 9, 1) "ldr q8,[x4,#-160]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 1) "ldr q9,[x4,#-144]\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 2) "ldr q6,[x4,#-128]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 2) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 9, 2) "ldr q8,[x4,#-96]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 2) "ldr q9,[x4,#-80]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(10, 11, 12, ac1, ac2, ac3, 7, 3) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(13, 14, 15, ac1, ac2, ac3, 8, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(16, 17, 18, ac1, ac2, ac3, 9, 3) "ldr q8,[x4,#-32]\n\t"\ + FMA_M3N4(19, 20, 21, ac1, ac2, ac3, 6, 3) "ldr q9,[x4,#-16]\n\t"\ + FMA_M3N1(22, 23, 24, ac1, ac2, ac3, 7)\ + FMA_M3N1(25, 26, 27, ac1, ac2, ac3, 8)\ + FMA_M3N1(28, 29, 30, ac1, ac2, ac3, 9) + +#define KERNEL_M3N19_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16; ldr q8,[x4],#16; ldr q9,[x4],#16\n\t"\ + "fmla v10.4s,v6.4s,v0.s[0]; fmla v11.4s,v6.4s,v1.s[0]; subs w5,w5,#1\n\t"\ + "fmla v12.4s,v6.4s,v2.s[0]; fmla v13.4s,v7.4s,v0.s[0]; ldr d6,[x4],#8\n\t"\ + "fmla v14.4s,v7.4s,v1.s[0]; fmla v15.4s,v7.4s,v2.s[0]; ldr s7,[x4],#4\n\t"\ + "fmla v16.4s,v8.4s,v0.s[0]; fmla v17.4s,v8.4s,v1.s[0]\n\t"\ + "fmla v18.4s,v8.4s,v2.s[0]; fmla v19.4s,v9.4s,v0.s[0]\n\t"\ + "fmla v20.4s,v9.4s,v1.s[0]; fmla v21.4s,v9.4s,v2.s[0]\n\t"\ + "fmla v22.4s,v0.4s,v6.s[0]; fmla v23.4s,v1.4s,v6.s[0]\n\t"\ + "fmla v24.4s,v2.4s,v6.s[0]; fmla v25.4s,v0.4s,v6.s[1]\n\t"\ + "fmla v26.4s,v1.4s,v6.s[1]; fmla v27.4s,v2.4s,v6.s[1]\n\t"\ + "fmla v28.4s,v0.4s,v7.s[0]; fmla v29.4s,v1.4s,v7.s[0]\n\t"\ + "fmla v30.4s,v2.4s,v7.s[0]\n\t" + + +#define INIT_M3N20 INIT_4V(8, 9, 10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_2V(20, 21) INIT_1V(22) + +#define SAVE_M3N20(mode) UNIT_SAVE_M3N4_##mode(8, 9, 10)\ + UNIT_SAVE_M3N4_##mode(11, 12, 13) UNIT_SAVE_M3N4_##mode(14, 15, 16)\ + UNIT_SAVE_M3N4_##mode(17, 18, 19) UNIT_SAVE_M3N4_##mode(20, 21, 22) + +#define KERNEL_M3N20_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N20_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + "add x4,x4,#320\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-208]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-176]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-144]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-112]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-96]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-80]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-48]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-32]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 3) "ldr q6,[x4]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N20_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + "add x4,x4,#320\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-208]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-176]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-144]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-112]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-96]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-80]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-48]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-32]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 3)\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 3) + +#define KERNEL_M3N20_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16\n\t"\ + FMA_M3N4(8, 9, 10, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(11, 12, 13, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(14, 15, 16, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(17, 18, 19, 0, 1, 2, 7, 0) "subs w5,w5,#1\n\t"\ + FMA_M3N4(20, 21, 22, 0, 1, 2, 6, 0) + + +#define INIT_M3N21 INIT_4V(8, 9, 10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23) INIT_2V(24, 25) + +#define SAVE_M3N21(mode) UNIT_SAVE_M3N4_##mode(8, 9, 10)\ + UNIT_SAVE_M3N4_##mode(11, 12, 13) UNIT_SAVE_M3N4_##mode(14, 15, 16)\ + UNIT_SAVE_M3N4_##mode(17, 18, 19) UNIT_SAVE_M3N4_##mode(20, 21, 22)\ + EDGE_SAVE_M3N1_##mode(23, 24, 25) + +#define KERNEL_M3N21_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N21_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + "add x4,x4,#336\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-224]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-208]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-192]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-176]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-160]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-144]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-128]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-112]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-96]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-80]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-64]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-32]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 3) "ldr q6,[x4]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(23, 24, 25, ac1, ac2, ac3, 7) "ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N21_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + "add x4,x4,#336\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-224]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-208]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-192]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-176]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-160]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-144]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-128]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-112]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-96]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-80]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-64]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-32]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 3)\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(23, 24, 25, ac1, ac2, ac3, 7) + +#define KERNEL_M3N21_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16\n\t"\ + FMA_M3N4(8, 9, 10, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(11, 12, 13, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(14, 15, 16, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(17, 18, 19, 0, 1, 2, 7, 0) "ldr s7,[x4],#4\n\t"\ + FMA_M3N4(20, 21, 22, 0, 1, 2, 6, 0) "subs w5,w5,#1\n\t"\ + "fmla v23.4s,v0.4s,v7.s[0]; fmla v24.4s,v1.4s,v7.s[0]\n\t"\ + "fmla v25.4s,v2.4s,v7.s[0]\n\t" + + +#define INIT_M3N22 INIT_4V(8, 9, 10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_1V(28) + +#define SAVE_M3N22(mode) UNIT_SAVE_M3N4_##mode(8, 9, 10)\ + UNIT_SAVE_M3N4_##mode(11, 12, 13) UNIT_SAVE_M3N4_##mode(14, 15, 16)\ + UNIT_SAVE_M3N4_##mode(17, 18, 19) UNIT_SAVE_M3N4_##mode(20, 21, 22)\ + EDGE_SAVE_M3N1_##mode(23, 24, 25) EDGE_SAVE_M3N1_##mode(26, 27, 28)\ + +#define KERNEL_M3N22_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N22_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + "add x4,x4,#352\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-240]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-224]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-208]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-176]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-144]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-96]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-80]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-32]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(23, 24, 25, ac1, ac2, ac3, 6) "ldr q6,[x4]\n\t"\ + FMA_M3N1(26, 27, 28, ac1, ac2, ac3, 7) "ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N22_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + "add x4,x4,#352\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-240]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-224]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-208]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-176]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-144]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-96]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-80]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-32]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(23, 24, 25, ac1, ac2, ac3, 6)\ + FMA_M3N1(26, 27, 28, ac1, ac2, ac3, 7) + +#define KERNEL_M3N22_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16\n\t"\ + FMA_M3N4(8, 9, 10, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(11, 12, 13, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(14, 15, 16, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(17, 18, 19, 0, 1, 2, 7, 0) "ldr d7,[x4],#8\n\t"\ + FMA_M3N4(20, 21, 22, 0, 1, 2, 6, 0) "subs w5,w5,#1\n\t"\ + "fmla v23.4s,v0.4s,v7.s[0]; fmla v24.4s,v1.4s,v7.s[0]\n\t"\ + "fmla v25.4s,v2.4s,v7.s[0]; fmla v26.4s,v0.4s,v7.s[1]\n\t"\ + "fmla v27.4s,v1.4s,v7.s[1]; fmla v28.4s,v2.4s,v7.s[1]\n\t" + + +#define INIT_M3N23 INIT_4V(8, 9, 10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M3N23(mode) UNIT_SAVE_M3N4_##mode(8, 9, 10)\ + UNIT_SAVE_M3N4_##mode(11, 12, 13) UNIT_SAVE_M3N4_##mode(14, 15, 16)\ + UNIT_SAVE_M3N4_##mode(17, 18, 19) UNIT_SAVE_M3N4_##mode(20, 21, 22)\ + EDGE_SAVE_M3N1_##mode(23, 24, 25) EDGE_SAVE_M3N1_##mode(26, 27, 28)\ + EDGE_SAVE_M3N1_##mode(29, 30, 31) + +#define KERNEL_M3N23_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N23_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + "add x4,x4,#368\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-256]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-240]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-224]\n\t"\ + "ldr q"#an2",[x1],#16\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-208]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-192]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-176]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-160]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-144]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-128]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-96]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-80]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-48]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-32]\n\t"\ + FMA_M3N1(23, 24, 25, ac1, ac2, ac3, 6) "ldr q6,[x4]\n\t"\ + FMA_M3N1(26, 27, 28, ac1, ac2, ac3, 7) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(29, 30, 31, ac1, ac2, ac3, 7) "ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N23_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + "add x4,x4,#368\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-256]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-240]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-224]\n\t"\ + "prfm pldl2keep,[x7]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-208]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-192]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-176]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-160]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-144]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-128]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-96]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-80]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-48]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-32]\n\t"\ + FMA_M3N1(23, 24, 25, ac1, ac2, ac3, 6)\ + FMA_M3N1(26, 27, 28, ac1, ac2, ac3, 7) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(29, 30, 31, ac1, ac2, ac3, 7) + +#define KERNEL_M3N23_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16\n\t"\ + FMA_M3N4(8, 9, 10, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(11, 12, 13, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(14, 15, 16, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(17, 18, 19, 0, 1, 2, 7, 0) "ldr d7,[x4],#8\n\t"\ + FMA_M3N4(20, 21, 22, 0, 1, 2, 6, 0) "ldr s6,[x4],#4\n\t"\ + "fmla v23.4s,v0.4s,v7.s[0]; fmla v24.4s,v1.4s,v7.s[0]; subs w5,w5,#1\n\t"\ + "fmla v25.4s,v2.4s,v7.s[0]; fmla v26.4s,v0.4s,v7.s[1]\n\t"\ + "fmla v27.4s,v1.4s,v7.s[1]; fmla v28.4s,v2.4s,v7.s[1]\n\t"\ + "fmla v29.4s,v0.4s,v6.s[0]; fmla v30.4s,v1.4s,v6.s[0]\n\t"\ + "fmla v31.4s,v2.4s,v6.s[0]\n\t" + + +#define INIT_M3N24 INIT_4V(8, 9, 10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23) INIT_2V(24, 25) + +#define SAVE_M3N24(mode) UNIT_SAVE_M3N4_##mode(8, 9, 10)\ + UNIT_SAVE_M3N4_##mode(11, 12, 13) UNIT_SAVE_M3N4_##mode(14, 15, 16)\ + UNIT_SAVE_M3N4_##mode(17, 18, 19) UNIT_SAVE_M3N4_##mode(20, 21, 22)\ + UNIT_SAVE_M3N4_##mode(23, 24, 25) + +#define KERNEL_M3N24_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N24_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#128]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#144]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#160]\n\t"\ + "ldr q"#an2",[x1],#16; add x4,x4,#384\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-208]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-176]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-144]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-96]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-80]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-64]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-32]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 3) "ldr q6,[x4]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N24_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#128]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#144]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#160]\n\t"\ + "prfm pldl2keep,[x7]; add x4,x4,#384\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-208]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-176]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-144]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-96]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-80]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-64]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-32]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 3)\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 3) + +#define KERNEL_M3N24_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16\n\t"\ + FMA_M3N4(8, 9, 10, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(11, 12, 13, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(14, 15, 16, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(17, 18, 19, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(20, 21, 22, 0, 1, 2, 6, 0) "subs w5,w5,#1\n\t"\ + FMA_M3N4(23, 24, 25, 0, 1, 2, 7, 0) + + +#define INIT_M3N25 INIT_4V(8, 9, 10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_1V(28) + +#define SAVE_M3N25(mode) UNIT_SAVE_M3N4_##mode(8, 9, 10)\ + UNIT_SAVE_M3N4_##mode(11, 12, 13) UNIT_SAVE_M3N4_##mode(14, 15, 16)\ + UNIT_SAVE_M3N4_##mode(17, 18, 19) UNIT_SAVE_M3N4_##mode(20, 21, 22)\ + UNIT_SAVE_M3N4_##mode(23, 24, 25) EDGE_SAVE_M3N1_##mode(26, 27, 28) + +#define KERNEL_M3N25_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N25_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#128]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#144]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#160]\n\t"\ + "ldr q"#an2",[x1],#16; add x4,x4,#400\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-224]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-208]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-192]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-176]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-160]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-144]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-128]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-112]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-96]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-80]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-64]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-48]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-32]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 3) "ldr q6,[x4]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(26, 27, 28, ac1, ac2, ac3, 7) "ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N25_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#128]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#144]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#160]\n\t"\ + "prfm pldl2keep,[x7]; add x4,x4,#400\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-224]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-208]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-192]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-176]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-160]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-144]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-128]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-112]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-96]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-80]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-64]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-48]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-32]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 3)\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(26, 27, 28, ac1, ac2, ac3, 7) + +#define KERNEL_M3N25_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16\n\t"\ + FMA_M3N4(8, 9, 10, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(11, 12, 13, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(14, 15, 16, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(17, 18, 19, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(20, 21, 22, 0, 1, 2, 6, 0) "ldr s6,[x4],#4\n\t"\ + FMA_M3N4(23, 24, 25, 0, 1, 2, 7, 0) "subs w5,w5,#1\n\t"\ + "fmla v26.4s,v0.4s,v6.s[0]; fmla v27.4s,v1.4s,v6.s[0]\n\t"\ + "fmla v28.4s,v2.4s,v6.s[0]\n\t" + + +#define INIT_M3N26 INIT_4V(8, 9, 10, 11) INIT_4V(12, 13, 14, 15)\ + INIT_4V(16, 17, 18, 19) INIT_4V(20, 21, 22, 23)\ + INIT_4V(24, 25, 26, 27) INIT_4V(28, 29, 30, 31) + +#define SAVE_M3N26(mode) UNIT_SAVE_M3N4_##mode(8, 9, 10)\ + UNIT_SAVE_M3N4_##mode(11, 12, 13) UNIT_SAVE_M3N4_##mode(14, 15, 16)\ + UNIT_SAVE_M3N4_##mode(17, 18, 19) UNIT_SAVE_M3N4_##mode(20, 21, 22)\ + UNIT_SAVE_M3N4_##mode(23, 24, 25) EDGE_SAVE_M3N1_##mode(26, 27, 28)\ + EDGE_SAVE_M3N1_##mode(29, 30, 31) + +#define KERNEL_M3N26_PRELOAD4 \ + "ldr q0,[x0],#16; ldr q1,[x1],#16; ldr q2,[x2],#16\n\t"\ + "ldr q6,[x4]; ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N26_MAIN4(ac1, ac2, ac3, an1, an2, an3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "ldr q"#an1",[x0],#16\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#128]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#144]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#160]\n\t"\ + "ldr q"#an2",[x1],#16; add x4,x4,#416\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-240]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-224]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-208]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-176]\n\t"\ + "ldr q"#an3",[x2],#16\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-144]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-96]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-80]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-32]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(26, 27, 28, ac1, ac2, ac3, 6) "ldr q6,[x4]\n\t"\ + FMA_M3N1(29, 30, 31, ac1, ac2, ac3, 7) "ldr q7,[x4,#16]\n\t" + +#define KERNEL_M3N26_TAIL4(ac1, ac2, ac3) \ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#32]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#48]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#80]\n\t"\ + "prfm pldl2keep,[x6]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 0) "ldr q6,[x4,#96]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 0) "ldr q7,[x4,#112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#128]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#144]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#160]\n\t"\ + "prfm pldl2keep,[x7]; add x4,x4,#416\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-240]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 1) "ldr q6,[x4,#-224]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 1) "ldr q7,[x4,#-208]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-192]\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-176]\n\t"\ + "prfm pldl2keep,[x8]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-160]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-144]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 2) "ldr q6,[x4,#-128]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 2) "ldr q7,[x4,#-112]\n\t"\ + FMA_M3N4(8, 9, 10, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-96]\n\t"\ + "sub w5,w5,#4\n\t"\ + FMA_M3N4(11, 12, 13, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-80]\n\t"\ + FMA_M3N4(14, 15, 16, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-64]\n\t"\ + FMA_M3N4(17, 18, 19, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-48]\n\t"\ + FMA_M3N4(20, 21, 22, ac1, ac2, ac3, 6, 3) "ldr q6,[x4,#-32]\n\t"\ + FMA_M3N4(23, 24, 25, ac1, ac2, ac3, 7, 3) "ldr q7,[x4,#-16]\n\t"\ + FMA_M3N1(26, 27, 28, ac1, ac2, ac3, 6)\ + FMA_M3N1(29, 30, 31, ac1, ac2, ac3, 7) + +#define KERNEL_M3N26_TL1 \ + "ldr s0,[x0],#4; ldr s1,[x1],#4; ldr s2,[x2],#4\n\t"\ + "ldr q6,[x4],#16; ldr q7,[x4],#16\n\t"\ + FMA_M3N4(8, 9, 10, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(11, 12, 13, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(14, 15, 16, 0, 1, 2, 6, 0) "ldr q6,[x4],#16\n\t"\ + FMA_M3N4(17, 18, 19, 0, 1, 2, 7, 0) "ldr q7,[x4],#16\n\t"\ + FMA_M3N4(20, 21, 22, 0, 1, 2, 6, 0) "ldr d6,[x4],#8\n\t"\ + FMA_M3N4(23, 24, 25, 0, 1, 2, 7, 0) "subs w5,w5,#1\n\t"\ + "fmla v26.4s,v0.4s,v6.s[0]; fmla v27.4s,v1.4s,v6.s[0]\n\t"\ + "fmla v28.4s,v2.4s,v6.s[0]; fmla v29.4s,v0.4s,v6.s[1]\n\t"\ + "fmla v30.4s,v1.4s,v6.s[1]; fmla v31.4s,v2.4s,v6.s[1]\n\t" + +FUNC_M3(13) +FUNC_M3(14) +FUNC_M3(15) +FUNC_M3(16) +FUNC_M3(17) +FUNC_M3(18) +FUNC_M3(19) +FUNC_M3(20) +FUNC_M3(21) +FUNC_M3(22) +FUNC_M3(23) +FUNC_M3(24) +FUNC_M3(25) +FUNC_M3(26) + + +#define INIT_M1N4 \ + float32x4_t cq1, cq2, cq3, cq4;\ + cq1 = cq2 = cq3 = cq4 = vdupq_n_f32(0.0f); + +#define INIT_M1N5 INIT_M1N4 float32x4_t cq5 = vdupq_n_f32(0.0f); +#define INIT_M1N6 INIT_M1N5 float32x4_t cq6 = vdupq_n_f32(0.0f); +#define INIT_M1N7 INIT_M1N6 float32x4_t cq7 = vdupq_n_f32(0.0f); + +#define INIT_M1N8 \ + float32x4_t cq1, cq2, cq3, cq4, cq5, cq6, cq7, cq8;\ + cq1 = cq2 = cq3 = cq4 = vdupq_n_f32(0.0f);\ + cq5 = cq6 = cq7 = cq8 = vdupq_n_f32(0.0f); + +#define INIT_M1N9 INIT_M1N8 float32x4_t cq9 = vdupq_n_f32(0.0f); +#define INIT_M1N10 INIT_M1N9 float32x4_t cq10 = vdupq_n_f32(0.0f); +#define INIT_M1N11 INIT_M1N10 float32x4_t cq11 = vdupq_n_f32(0.0f); + +#define INIT_M1N12 \ + float32x4_t cq1, cq2, cq3, cq4, cq5, cq6;\ + float32x4_t cq7, cq8, cq9, cq10, cq11, cq12;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = vdupq_n_f32(0.0f);\ + cq7 = cq8 = cq9 = cq10 = cq11 = cq12 = vdupq_n_f32(0.0f); + +#define INIT_M1N13 INIT_M1N12 float32x4_t cq13 = vdupq_n_f32(0.0f); +#define INIT_M1N14 INIT_M1N13 float32x4_t cq14 = vdupq_n_f32(0.0f); +#define INIT_M1N15 INIT_M1N14 float32x4_t cq15 = vdupq_n_f32(0.0f); + +#define INIT_M1N16 \ + float32x4_t cq1, cq2, cq3, cq4, cq5, cq6, cq7, cq8;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = cq7 = cq8 = vdupq_n_f32(0.0f);\ + +#define INIT_M1N17 INIT_M1N16 float32x4_t cq9 = vdupq_n_f32(0.0f); +#define INIT_M1N18 INIT_M1N17 float32x4_t cq10 = vdupq_n_f32(0.0f); +#define INIT_M1N19 INIT_M1N18 float32x4_t cq11 = vdupq_n_f32(0.0f); + +#define INIT_M1N20 \ + float32x4_t cq1, cq2, cq3, cq4, cq5, cq6, cq7, cq8, cq9, cq10;\ + cq1 = cq2 = cq3 = cq4 = cq5 = vdupq_n_f32(0.0f);\ + cq6 = cq7 = cq8 = cq9 = cq10 = vdupq_n_f32(0.0f); + +#define INIT_M1N21 INIT_M1N20 float32x4_t cq11 = vdupq_n_f32(0.0f); +#define INIT_M1N22 INIT_M1N21 float32x4_t cq12 = vdupq_n_f32(0.0f); +#define INIT_M1N23 INIT_M1N22 float32x4_t cq13 = vdupq_n_f32(0.0f); + +#define INIT_M1N24 \ + float32x4_t cq1, cq2, cq3, cq4, cq5, cq6;\ + float32x4_t cq7, cq8, cq9, cq10, cq11, cq12;\ + cq1 = cq2 = cq3 = cq4 = cq5 = cq6 = vdupq_n_f32(0.0f);\ + cq7 = cq8 = cq9 = cq10 = cq11 = cq12 = vdupq_n_f32(0.0f); + +#define INIT_M1N25 INIT_M1N24 float32x4_t cq13 = vdupq_n_f32(0.0f); +#define INIT_M1N26 INIT_M1N25 float32x4_t cq14 = vdupq_n_f32(0.0f); + +#define ACC_K4M1N4 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 1);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 2);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 3); + +#define UNIT_ACC_K4M1N1(q_no, off) \ + float32x4_t bq##q_no = vld1q_f32(b_rd + off);\ + cq##q_no = vfmaq_f32(cq##q_no, bq##q_no, aq1); + +#define ACC_K4M1N5 ACC_K4M1N4 UNIT_ACC_K4M1N1(5, 16) +#define ACC_K4M1N6 ACC_K4M1N5 UNIT_ACC_K4M1N1(6, 20) +#define ACC_K4M1N7 ACC_K4M1N6 UNIT_ACC_K4M1N1(7, 24) + +#define ACC_K4M1N8 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + float32x4_t bq6 = vld1q_f32(b_rd + 20);\ + float32x4_t bq7 = vld1q_f32(b_rd + 24);\ + float32x4_t bq8 = vld1q_f32(b_rd + 28);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 0);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 1);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 1);\ + cq5 = vfmaq_laneq_f32(cq5, bq5, aq1, 2);\ + cq6 = vfmaq_laneq_f32(cq6, bq6, aq1, 2);\ + cq7 = vfmaq_laneq_f32(cq7, bq7, aq1, 3);\ + cq8 = vfmaq_laneq_f32(cq8, bq8, aq1, 3); + +#define ACC_K4M1N9 ACC_K4M1N8 UNIT_ACC_K4M1N1(9, 32) +#define ACC_K4M1N10 ACC_K4M1N9 UNIT_ACC_K4M1N1(10, 36) +#define ACC_K4M1N11 ACC_K4M1N10 UNIT_ACC_K4M1N1(11, 40) + +#define ACC_K4M1N12 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + float32x4_t bq6 = vld1q_f32(b_rd + 20);\ + float32x4_t bq7 = vld1q_f32(b_rd + 24);\ + float32x4_t bq8 = vld1q_f32(b_rd + 28);\ + float32x4_t bq9 = vld1q_f32(b_rd + 32);\ + float32x4_t bq10 = vld1q_f32(b_rd + 36);\ + float32x4_t bq11 = vld1q_f32(b_rd + 40);\ + float32x4_t bq12 = vld1q_f32(b_rd + 44);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 0);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 0);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 1);\ + cq5 = vfmaq_laneq_f32(cq5, bq5, aq1, 1);\ + cq6 = vfmaq_laneq_f32(cq6, bq6, aq1, 1);\ + cq7 = vfmaq_laneq_f32(cq7, bq7, aq1, 2);\ + cq8 = vfmaq_laneq_f32(cq8, bq8, aq1, 2);\ + cq9 = vfmaq_laneq_f32(cq9, bq9, aq1, 2);\ + cq10 = vfmaq_laneq_f32(cq10, bq10, aq1, 3);\ + cq11 = vfmaq_laneq_f32(cq11, bq11, aq1, 3);\ + cq12 = vfmaq_laneq_f32(cq12, bq12, aq1, 3); + +#define ACC_K4M1N13 ACC_K4M1N12 UNIT_ACC_K4M1N1(13, 48) +#define ACC_K4M1N14 ACC_K4M1N13 UNIT_ACC_K4M1N1(14, 52) +#define ACC_K4M1N15 ACC_K4M1N14 UNIT_ACC_K4M1N1(15, 56) + +#define ACC_K4M1N16 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + float32x4_t bq6 = vld1q_f32(b_rd + 20);\ + float32x4_t bq7 = vld1q_f32(b_rd + 24);\ + float32x4_t bq8 = vld1q_f32(b_rd + 28);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0); bq1 = vld1q_f32(b_rd + 32);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 0); bq2 = vld1q_f32(b_rd + 36);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 0); bq3 = vld1q_f32(b_rd + 40);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 0); bq4 = vld1q_f32(b_rd + 44);\ + cq5 = vfmaq_laneq_f32(cq5, bq5, aq1, 1); bq5 = vld1q_f32(b_rd + 48);\ + cq6 = vfmaq_laneq_f32(cq6, bq6, aq1, 1); bq6 = vld1q_f32(b_rd + 52);\ + cq7 = vfmaq_laneq_f32(cq7, bq7, aq1, 1); bq7 = vld1q_f32(b_rd + 56);\ + cq8 = vfmaq_laneq_f32(cq8, bq8, aq1, 1); bq8 = vld1q_f32(b_rd + 60);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 2);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 2);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 2);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 2);\ + cq5 = vfmaq_laneq_f32(cq5, bq5, aq1, 3);\ + cq6 = vfmaq_laneq_f32(cq6, bq6, aq1, 3);\ + cq7 = vfmaq_laneq_f32(cq7, bq7, aq1, 3);\ + cq8 = vfmaq_laneq_f32(cq8, bq8, aq1, 3); + +#define ACC_K4M1N17 ACC_K4M1N16 UNIT_ACC_K4M1N1(9, 64) +#define ACC_K4M1N18 ACC_K4M1N17 UNIT_ACC_K4M1N1(10, 68) +#define ACC_K4M1N19 ACC_K4M1N18 UNIT_ACC_K4M1N1(11, 72) + +#define ACC_K4M1N20 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + float32x4_t bq6 = vld1q_f32(b_rd + 20);\ + float32x4_t bq7 = vld1q_f32(b_rd + 24);\ + float32x4_t bq8 = vld1q_f32(b_rd + 28);\ + float32x4_t bq9 = vld1q_f32(b_rd + 32);\ + float32x4_t bq10 = vld1q_f32(b_rd + 36);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0); bq1 = vld1q_f32(b_rd + 40);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 0); bq2 = vld1q_f32(b_rd + 44);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 0); bq3 = vld1q_f32(b_rd + 48);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 0); bq4 = vld1q_f32(b_rd + 52);\ + cq5 = vfmaq_laneq_f32(cq5, bq5, aq1, 0); bq5 = vld1q_f32(b_rd + 56);\ + cq6 = vfmaq_laneq_f32(cq6, bq6, aq1, 1); bq6 = vld1q_f32(b_rd + 60);\ + cq7 = vfmaq_laneq_f32(cq7, bq7, aq1, 1); bq7 = vld1q_f32(b_rd + 64);\ + cq8 = vfmaq_laneq_f32(cq8, bq8, aq1, 1); bq8 = vld1q_f32(b_rd + 68);\ + cq9 = vfmaq_laneq_f32(cq9, bq9, aq1, 1); bq9 = vld1q_f32(b_rd + 72);\ + cq10 = vfmaq_laneq_f32(cq10, bq10, aq1, 1); bq10 = vld1q_f32(b_rd + 76);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 2);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 2);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 2);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 2);\ + cq5 = vfmaq_laneq_f32(cq5, bq5, aq1, 2);\ + cq6 = vfmaq_laneq_f32(cq6, bq6, aq1, 3);\ + cq7 = vfmaq_laneq_f32(cq7, bq7, aq1, 3);\ + cq8 = vfmaq_laneq_f32(cq8, bq8, aq1, 3);\ + cq9 = vfmaq_laneq_f32(cq9, bq9, aq1, 3);\ + cq10 = vfmaq_laneq_f32(cq10, bq10, aq1, 3); + +#define ACC_K4M1N21 ACC_K4M1N20 UNIT_ACC_K4M1N1(11, 80) +#define ACC_K4M1N22 ACC_K4M1N21 UNIT_ACC_K4M1N1(12, 84) +#define ACC_K4M1N23 ACC_K4M1N22 UNIT_ACC_K4M1N1(13, 88) + +#define ACC_K4M1N24 \ + float32x4_t aq1 = vld1q_f32(a_rd); a_rd += 4;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + float32x4_t bq6 = vld1q_f32(b_rd + 20);\ + float32x4_t bq7 = vld1q_f32(b_rd + 24);\ + float32x4_t bq8 = vld1q_f32(b_rd + 28);\ + float32x4_t bq9 = vld1q_f32(b_rd + 32);\ + float32x4_t bq10 = vld1q_f32(b_rd + 36);\ + float32x4_t bq11 = vld1q_f32(b_rd + 40);\ + float32x4_t bq12 = vld1q_f32(b_rd + 44);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 0); bq1 = vld1q_f32(b_rd + 48);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 0); bq2 = vld1q_f32(b_rd + 52);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 0); bq3 = vld1q_f32(b_rd + 56);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 0); bq4 = vld1q_f32(b_rd + 60);\ + cq5 = vfmaq_laneq_f32(cq5, bq5, aq1, 0); bq5 = vld1q_f32(b_rd + 64);\ + cq6 = vfmaq_laneq_f32(cq6, bq6, aq1, 0); bq6 = vld1q_f32(b_rd + 68);\ + cq7 = vfmaq_laneq_f32(cq7, bq7, aq1, 1); bq7 = vld1q_f32(b_rd + 72);\ + cq8 = vfmaq_laneq_f32(cq8, bq8, aq1, 1); bq8 = vld1q_f32(b_rd + 76);\ + cq9 = vfmaq_laneq_f32(cq9, bq9, aq1, 1); bq9 = vld1q_f32(b_rd + 80);\ + cq10 = vfmaq_laneq_f32(cq10, bq10, aq1, 1); bq10 = vld1q_f32(b_rd + 84);\ + cq11 = vfmaq_laneq_f32(cq11, bq11, aq1, 1); bq11 = vld1q_f32(b_rd + 88);\ + cq12 = vfmaq_laneq_f32(cq12, bq12, aq1, 1); bq12 = vld1q_f32(b_rd + 92);\ + cq1 = vfmaq_laneq_f32(cq1, bq1, aq1, 2);\ + cq2 = vfmaq_laneq_f32(cq2, bq2, aq1, 2);\ + cq3 = vfmaq_laneq_f32(cq3, bq3, aq1, 2);\ + cq4 = vfmaq_laneq_f32(cq4, bq4, aq1, 2);\ + cq5 = vfmaq_laneq_f32(cq5, bq5, aq1, 2);\ + cq6 = vfmaq_laneq_f32(cq6, bq6, aq1, 2);\ + cq7 = vfmaq_laneq_f32(cq7, bq7, aq1, 3);\ + cq8 = vfmaq_laneq_f32(cq8, bq8, aq1, 3);\ + cq9 = vfmaq_laneq_f32(cq9, bq9, aq1, 3);\ + cq10 = vfmaq_laneq_f32(cq10, bq10, aq1, 3);\ + cq11 = vfmaq_laneq_f32(cq11, bq11, aq1, 3);\ + cq12 = vfmaq_laneq_f32(cq12, bq12, aq1, 3); + +#define ACC_K4M1N25 ACC_K4M1N24 UNIT_ACC_K4M1N1(13, 96) +#define ACC_K4M1N26 ACC_K4M1N25 UNIT_ACC_K4M1N1(14, 100) + +#define REDUC_M1N4 \ + cq1 = vaddq_f32(cq1, cq2); cq3 = vaddq_f32(cq3, cq4);\ + cq1 = vaddq_f32(cq1, cq3); + +#define UNIT_REDUC_1V(q_no, s_no) \ + float32x2_t cd##s_no = vadd_f32(vget_low_f32(cq##q_no),\ + vget_high_f32(cq##q_no));\ + float cs##s_no = vget_lane_f32(cd##s_no, 0) + vget_lane_f32(cd##s_no, 1); + +#define REDUC_M1N5 REDUC_M1N4 UNIT_REDUC_1V(5, 1) +#define REDUC_M1N6 REDUC_M1N5 UNIT_REDUC_1V(6, 2) +#define REDUC_M1N7 REDUC_M1N6 UNIT_REDUC_1V(7, 3) + +#define REDUC_M1N8 \ + cq1 = vaddq_f32(cq1, cq3); cq2 = vaddq_f32(cq2, cq4);\ + cq5 = vaddq_f32(cq5, cq7); cq6 = vaddq_f32(cq6, cq8);\ + cq1 = vaddq_f32(cq1, cq5); cq2 = vaddq_f32(cq2, cq6); + +#define REDUC_M1N9 REDUC_M1N8 UNIT_REDUC_1V(9, 1) +#define REDUC_M1N10 REDUC_M1N9 UNIT_REDUC_1V(10, 2) +#define REDUC_M1N11 REDUC_M1N10 UNIT_REDUC_1V(11, 3) + +#define REDUC_M1N12 \ + cq1 = vaddq_f32(cq1, cq4); cq2 = vaddq_f32(cq2, cq5);\ + cq3 = vaddq_f32(cq3, cq6); cq7 = vaddq_f32(cq7, cq10);\ + cq8 = vaddq_f32(cq8, cq11); cq9 = vaddq_f32(cq9, cq12);\ + cq1 = vaddq_f32(cq1, cq7); cq2 = vaddq_f32(cq2, cq8);\ + cq3 = vaddq_f32(cq3, cq9); + +#define REDUC_M1N13 REDUC_M1N12 UNIT_REDUC_1V(13, 1) +#define REDUC_M1N14 REDUC_M1N13 UNIT_REDUC_1V(14, 2) +#define REDUC_M1N15 REDUC_M1N14 UNIT_REDUC_1V(15, 3) + +#define REDUC_M1N16 \ + cq1 = vaddq_f32(cq1, cq5); cq2 = vaddq_f32(cq2, cq6);\ + cq3 = vaddq_f32(cq3, cq7); cq4 = vaddq_f32(cq4, cq8); + +#define REDUC_M1N17 REDUC_M1N16 UNIT_REDUC_1V(9, 1) +#define REDUC_M1N18 REDUC_M1N17 UNIT_REDUC_1V(10, 2) +#define REDUC_M1N19 REDUC_M1N18 UNIT_REDUC_1V(11, 3) + +#define REDUC_M1N20 \ + cq1 = vaddq_f32(cq1, cq6); cq2 = vaddq_f32(cq2, cq7);\ + cq3 = vaddq_f32(cq3, cq8); cq4 = vaddq_f32(cq4, cq9);\ + cq5 = vaddq_f32(cq5, cq10); + +#define REDUC_M1N21 REDUC_M1N20 UNIT_REDUC_1V(11, 1) +#define REDUC_M1N22 REDUC_M1N21 UNIT_REDUC_1V(12, 2) +#define REDUC_M1N23 REDUC_M1N22 UNIT_REDUC_1V(13, 3) + +#define REDUC_M1N24 \ + cq1 = vaddq_f32(cq1, cq7); cq2 = vaddq_f32(cq2, cq8);\ + cq3 = vaddq_f32(cq3, cq9); cq4 = vaddq_f32(cq4, cq10);\ + cq5 = vaddq_f32(cq5, cq11); cq6 = vaddq_f32(cq6, cq12); + +#define REDUC_M1N25 REDUC_M1N24 UNIT_REDUC_1V(13, 1) +#define REDUC_M1N26 REDUC_M1N25 UNIT_REDUC_1V(14, 2) + +#define ACC_K1M1N4 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1); + +#define ACC_K1M1N5 ACC_K1M1N4 cs1 += as1 * b_rd[4]; +#define ACC_K1M1N6 ACC_K1M1N5 cs2 += as1 * b_rd[5]; +#define ACC_K1M1N7 ACC_K1M1N6 cs3 += as1 * b_rd[6]; + +#define ACC_K1M1N8 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1);\ + cq2 = vfmaq_n_f32(cq2, bq2, as1); + +#define ACC_K1M1N9 ACC_K1M1N8 cs1 += as1 * b_rd[8]; +#define ACC_K1M1N10 ACC_K1M1N9 cs2 += as1 * b_rd[9]; +#define ACC_K1M1N11 ACC_K1M1N10 cs3 += as1 * b_rd[10]; + +#define ACC_K1M1N12 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1);\ + cq2 = vfmaq_n_f32(cq2, bq2, as1);\ + cq3 = vfmaq_n_f32(cq3, bq3, as1); + +#define ACC_K1M1N13 ACC_K1M1N12 cs1 += as1 * b_rd[12]; +#define ACC_K1M1N14 ACC_K1M1N13 cs2 += as1 * b_rd[13]; +#define ACC_K1M1N15 ACC_K1M1N14 cs3 += as1 * b_rd[14]; + +#define ACC_K1M1N16 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1);\ + cq2 = vfmaq_n_f32(cq2, bq2, as1);\ + cq3 = vfmaq_n_f32(cq3, bq3, as1);\ + cq4 = vfmaq_n_f32(cq4, bq4, as1); + +#define ACC_K1M1N17 ACC_K1M1N16 cs1 += as1 * b_rd[16]; +#define ACC_K1M1N18 ACC_K1M1N17 cs2 += as1 * b_rd[17]; +#define ACC_K1M1N19 ACC_K1M1N18 cs3 += as1 * b_rd[18]; + +#define ACC_K1M1N20 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1);\ + cq2 = vfmaq_n_f32(cq2, bq2, as1);\ + cq3 = vfmaq_n_f32(cq3, bq3, as1);\ + cq4 = vfmaq_n_f32(cq4, bq4, as1);\ + cq5 = vfmaq_n_f32(cq5, bq5, as1); + +#define ACC_K1M1N21 ACC_K1M1N20 cs1 += as1 * b_rd[20]; +#define ACC_K1M1N22 ACC_K1M1N21 cs2 += as1 * b_rd[21]; +#define ACC_K1M1N23 ACC_K1M1N22 cs3 += as1 * b_rd[22]; + +#define ACC_K1M1N24 \ + float as1 = *a_rd++;\ + float32x4_t bq1 = vld1q_f32(b_rd);\ + float32x4_t bq2 = vld1q_f32(b_rd + 4);\ + float32x4_t bq3 = vld1q_f32(b_rd + 8);\ + float32x4_t bq4 = vld1q_f32(b_rd + 12);\ + float32x4_t bq5 = vld1q_f32(b_rd + 16);\ + float32x4_t bq6 = vld1q_f32(b_rd + 20);\ + cq1 = vfmaq_n_f32(cq1, bq1, as1);\ + cq2 = vfmaq_n_f32(cq2, bq2, as1);\ + cq3 = vfmaq_n_f32(cq3, bq3, as1);\ + cq4 = vfmaq_n_f32(cq4, bq4, as1);\ + cq5 = vfmaq_n_f32(cq5, bq5, as1);\ + cq6 = vfmaq_n_f32(cq6, bq6, as1); + +#define ACC_K1M1N25 ACC_K1M1N24 cs1 += as1 * b_rd[24]; +#define ACC_K1M1N26 ACC_K1M1N25 cs2 += as1 * b_rd[25]; + +#define UNIT_SAVE_M1N4_CC(cq1) \ + c_ptr[0] = c_ptr[0] * beta + vgetq_lane_f32(cq1, 0);\ + c_ptr[LDC] = c_ptr[LDC] * beta + vgetq_lane_f32(cq1, 1);\ + c_ptr += LDC * 2;\ + c_ptr[0] = c_ptr[0] * beta + vgetq_lane_f32(cq1, 2);\ + c_ptr[LDC] = c_ptr[LDC] * beta + vgetq_lane_f32(cq1, 3);\ + c_ptr += LDC * 2; + +#define UNIT_SAVE_M1N4_CR(cq1) \ + cq1 = vfmaq_n_f32(cq1, vld1q_f32(c_ptr), beta);\ + vst1q_f32(c_ptr, cq1); c_ptr += 4; + +#define UNIT_SAVE_M1N1_CC(cs1) \ + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr += LDC; + +#define UNIT_SAVE_M1N1_CR(cs1) \ + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr++; + +#define SAVE_M1N4(mode) UNIT_SAVE_M1N4_##mode(cq1) + +#define SAVE_M1N5(mode) SAVE_M1N4(mode) UNIT_SAVE_M1N1_##mode(cs1) +#define SAVE_M1N6(mode) SAVE_M1N5(mode) UNIT_SAVE_M1N1_##mode(cs2) +#define SAVE_M1N7(mode) SAVE_M1N6(mode) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N8(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2) + +#define SAVE_M1N9(mode) SAVE_M1N8(mode) UNIT_SAVE_M1N1_##mode(cs1) +#define SAVE_M1N10(mode) SAVE_M1N9(mode) UNIT_SAVE_M1N1_##mode(cs2) +#define SAVE_M1N11(mode) SAVE_M1N10(mode) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N12(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2) UNIT_SAVE_M1N4_##mode(cq3) + +#define SAVE_M1N13(mode) SAVE_M1N12(mode) UNIT_SAVE_M1N1_##mode(cs1) +#define SAVE_M1N14(mode) SAVE_M1N13(mode) UNIT_SAVE_M1N1_##mode(cs2) +#define SAVE_M1N15(mode) SAVE_M1N14(mode) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N16(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2)\ + UNIT_SAVE_M1N4_##mode(cq3) UNIT_SAVE_M1N4_##mode(cq4) + +#define SAVE_M1N17(mode) SAVE_M1N16(mode) UNIT_SAVE_M1N1_##mode(cs1) +#define SAVE_M1N18(mode) SAVE_M1N17(mode) UNIT_SAVE_M1N1_##mode(cs2) +#define SAVE_M1N19(mode) SAVE_M1N18(mode) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N20(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2)\ + UNIT_SAVE_M1N4_##mode(cq3) UNIT_SAVE_M1N4_##mode(cq4) UNIT_SAVE_M1N4_##mode(cq5) + +#define SAVE_M1N21(mode) SAVE_M1N20(mode) UNIT_SAVE_M1N1_##mode(cs1) +#define SAVE_M1N22(mode) SAVE_M1N21(mode) UNIT_SAVE_M1N1_##mode(cs2) +#define SAVE_M1N23(mode) SAVE_M1N22(mode) UNIT_SAVE_M1N1_##mode(cs3) + +#define SAVE_M1N24(mode) \ + UNIT_SAVE_M1N4_##mode(cq1) UNIT_SAVE_M1N4_##mode(cq2) UNIT_SAVE_M1N4_##mode(cq3)\ + UNIT_SAVE_M1N4_##mode(cq4) UNIT_SAVE_M1N4_##mode(cq5) UNIT_SAVE_M1N4_##mode(cq6) + +#define SAVE_M1N25(mode) SAVE_M1N24(mode) UNIT_SAVE_M1N1_##mode(cs1) +#define SAVE_M1N26(mode) SAVE_M1N25(mode) UNIT_SAVE_M1N1_##mode(cs2) + +#define FUNC_M1(ndim) \ +static inline void sgemm_skinny1_a7x_m1n##ndim(\ + const float * __restrict__ a_rd, const float * __restrict__ b_rd,\ + float * __restrict__ c_ptr, uint32_t k_left, uint32_t LDC,\ + uint8_t c_rowmajor, float beta) {\ + INIT_M1N##ndim\ + for (; k_left > 3; k_left -= 4) {\ + ACC_K4M1N##ndim\ + b_rd += 4 * ndim;\ + }\ + REDUC_M1N##ndim\ + for (; k_left > 0; k_left--) {\ + ACC_K1M1N##ndim\ + b_rd += ndim;\ + }\ + if (c_rowmajor == 0) {\ + SAVE_M1N##ndim(CC)\ + } else {\ + SAVE_M1N##ndim(CR)\ + }\ +} + +FUNC_M1(4) +FUNC_M1(5) +FUNC_M1(6) +FUNC_M1(7) +FUNC_M1(8) +FUNC_M1(9) +FUNC_M1(10) +FUNC_M1(11) +FUNC_M1(12) +FUNC_M1(13) +FUNC_M1(14) +FUNC_M1(15) +FUNC_M1(16) +FUNC_M1(17) +FUNC_M1(18) +FUNC_M1(19) +FUNC_M1(20) +FUNC_M1(21) +FUNC_M1(22) +FUNC_M1(23) +FUNC_M1(24) +FUNC_M1(25) +FUNC_M1(26) + +#endif diff --git a/src/arm_neon/ARMCompareAndSwap.c b/src/arm_neon/ARMCompareAndSwap.c new file mode 100644 index 0000000..4dcf12f --- /dev/null +++ b/src/arm_neon/ARMCompareAndSwap.c @@ -0,0 +1,112 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include +#include + +#if __aarch64__ +/* detect ARMv8 CAS support */ +static __thread uint8_t blas_arm64_cas_type = 0; +static __thread uint8_t blas_arm64_cas_init = 0; + +#ifndef HWCAP_ATOMICS +#define HWCAP_ATOMICS (1 << 8) +#endif + +static uint8_t blas_arm64_get_cas_support() { + if (!blas_arm64_cas_init) { + blas_arm64_cas_type = (getauxval(AT_HWCAP) & HWCAP_ATOMICS) ? + 1 : 0; + blas_arm64_cas_init = 1; + } + return blas_arm64_cas_type; +} + +#endif + +uint32_t atomicCAS_U32(uint32_t comp, uint32_t write, uint32_t *dst) { +#if __aarch64__ + if (blas_arm64_get_cas_support()) { + uint32_t tmp = comp; + __asm__ __volatile__( + "cas %w0,%w1,[%2]\n\t" + :"+r"(tmp):"r"(write),"r"(dst):"cc","memory"); + return tmp; + } else { + register uint32_t tmp __asm("w0"); + register uint32_t comp_asm __asm("w2") = comp; + register uint32_t write_asm __asm("w3") = write; + register uint32_t *dst_asm __asm("x4") = dst; + __asm__ __volatile__( + "1:\n\t" + "ldxr %w0,[%x3]; cmp %w0,%w1; bne 2f; stxr w1,%w2,[%x3]\n\t" + "cmp w1,#0; bne 1b\n\t" + "2:\n\t" + :"+r"(tmp):"r"(comp_asm),"r"(write_asm),"r"(dst_asm):"x1","cc","memory"); + return tmp; + } +#else + register uint32_t tmp __asm("r0"); + register uint32_t comp_asm __asm("r2") = comp; + register uint32_t write_asm __asm("r3") = write; + register uint32_t *dst_asm __asm("r4") = dst; + __asm__ __volatile__( + "1:\n\t" + "ldrex %0,[%3]; cmp %0,%1; bne 2f; strex r1,%2,[%3]\n\t" + "cmp r1,#0; bne 1b\n\t" + "2:\n\t" + :"+r"(tmp):"r"(comp_asm),"r"(write_asm),"r"(dst_asm):"r1","cc","memory"); + return tmp; +#endif +} + +uint64_t atomicCAS_U64(uint64_t comp, uint64_t write, uint64_t *dst) { + uint64_t tmp; +#if __aarch64__ + if (blas_arm64_get_cas_support()) { + tmp = comp; + __asm__ __volatile__( + "cas %x0,%x1,[%2]\n\t" + :"+r"(tmp):"r"(write),"r"(dst):"cc","memory"); + } else { + __asm__ __volatile__( + "mov x2,%x1; mov x4,%x2\n\t" + "1:\n\t" + "ldxr %x0,[%x3]; cmp %x0,x2; bne 2f; stxr w6,x4,[%x3]\n\t" + "cmp w6,#0; bne 1b\n\t" + "2:\n\t" + :"+r"(tmp):"r"(comp),"r"(write),"r"(dst):"x2","x4","w6","cc","memory"); + } +#else + uint64_t *comp_addr = ∁ + uint64_t *write_loc = &write; + uint64_t *tmp_addr = &tmp; + __asm__ __volatile__( + "ldr r2,[%0]; ldr r3,[%0,#4]; ldr r4,[%1]; ldr r5,[%1,#4]\n\t" + "1:\n\t" + "ldrexd r0,r1,[%2]; cmp r0,r2; bne 2f\n\t" + "cmp r1,r3; bne 2f; strexd r6,r4,r5,[%2]\n\t" + "cmp r6,#0; bne 1b\n\t" + "2:\n\t" + "str r0,[%3]; str r1,[%3,#4]\n\t" + ::"r"(comp_addr),"r"(write_loc),"r"(dst),"r"(tmp_addr) + :"r0","r1","r2","r3","r4","r5","r6","cc","memory"); +#endif + return tmp; +} + diff --git a/src/arm_neon/ARMCpuType.c b/src/arm_neon/ARMCpuType.c new file mode 100644 index 0000000..6d58056 --- /dev/null +++ b/src/arm_neon/ARMCpuType.c @@ -0,0 +1,451 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/ARMCpuType.h" +#include +#include +#include +#include +#include +#include +#include +#ifndef _POSIX_C_SOURCE +#define _POSIX_C_SOURCE 200809L +#endif +#include +#include + +#define MAX_CPU_COUNT 20 + +struct ARM_CpuType { + bool m_init; + uint8_t m_cpuType[MAX_CPU_COUNT]; +}; + +static struct ARM_CpuType blas_arm_cpu_type = {false, {0}}; + +static pthread_mutex_t blas_arm_get_cpu_type_lock + = PTHREAD_MUTEX_INITIALIZER; + +static bool is_hex(char test) { + if (test >= 48 && test <= 57) return true; //0-9 + else if (test >= 65 && test <= 70) return true; //A-F + else if (test >= 97 && test <= 102) return true; //a-f + else return false; +} + +static uint16_t hex2num(char test) { + if (test >= 48 && test <= 57) return (test - 48); //0-9 + else if (test >= 65 && test <= 70) return (test - 55); //A-F + else if (test >= 97 && test <= 102) return (test - 87); //a-f + else return 0; +} + +static uint16_t extract_id(const char *line_header, + unsigned int header_start, unsigned int size) { + + unsigned int header_read = header_start; + /* find the first colon */ + for (; header_read < size; ++header_read) { + if (line_header[header_read] == ':') { + header_read++; + break; + } + } + /* skip backspace chars after the colon */ + for (; header_read < size; ++header_read) { + if (line_header[header_read] != ' ') break; + } + /* detect 0x or 0X header */ + bool hex_id = false; + if (header_read + 2 < size) { + if (line_header[header_read] == '0' && + (line_header[header_read + 1] == 'x' || line_header[header_read + 1] == 'X')) { + hex_id = true; + header_read += 2; + } + } + /* read number */ + uint16_t id = 0; + if (hex_id) { + for (; header_read < size; ++header_read) { + char test = line_header[header_read]; + if (!is_hex(test)) break; + id = id * 16 + hex2num(test); + } + } else {//decimal + for (; header_read < size; ++header_read) { + char test = line_header[header_read]; + if (test < 48 || test > 57) break; + id = id * 10 + (test - 48); + } + } + return id; +} + +/* parse_midr: get CPU model information from MIDR bits */ +static uint8_t parse_midr(uint32_t midr) { + + uint8_t cputype = 0; //0 = generic + uint32_t implementer = midr >> 24; + uint32_t part = (midr >> 4) & 0xFFF; + uint32_t variant = (midr >> 20) & 0xF; + if (implementer == 0x41) { //0x41 == ARM + if (part == 0xD03) cputype = 53; //Cortex-A53 + else if (part == 0xD04) cputype = 35; //Cortex-A35 + else if (part == 0xD05) { + if (variant > 0) cputype = 55; //Cortex-A55 + else cputype = 53; //dual-issue ability of Cortex-A55r0 is limited + } + } + else if (implementer == 0x51) { //0x51 == Qualcomm + if (part == 0x803 || part == 0x801) cputype = 53; + if (part == 0x805) cputype = 55; + } + return cputype; +} + +/* MIDR: Main ID Register in ARM processor */ +/* direct access of MIDR is not possible in user mode without kernel modules */ +/* however the system (Linux/Android) reads MIDR and stores its info to /proc/cpuinfo */ +/* so we can assemble the bits of MIDR from the informations in /proc/cpuinfo */ +static int read_midr(uint32_t *midr, uint8_t midr_size) { + + FILE *fp = fopen("/proc/cpuinfo", "r"); + if (fp == NULL) { + return -1; //file open failed + } + + unsigned char num_cpu_detected = 0; + unsigned char num_cpu_part_parsed = 0; + unsigned char num_cpu_vendor_parsed = 0; + + char buffer[300], line_header[30]; + unsigned int header_read = 0, buffer_read = 0; + bool continue_find_endline = false, line_fill = false; + size_t bytes_read = 0; + unsigned int cpuid = 0; + do { + bytes_read = fread(buffer, 1, sizeof(buffer), fp); + if (ferror(fp)) { + fclose(fp); + return -2; //error during file read + } + for (buffer_read = 0; buffer_read < bytes_read; ) { + if (continue_find_endline) { + for (; buffer_read < bytes_read; ++buffer_read) { + if (buffer[buffer_read] == '\n') { + continue_find_endline = false; + buffer_read++; + break; + } + } + } + for (; buffer_read < bytes_read; ++buffer_read) { + if (header_read == sizeof(line_header) || buffer[buffer_read] == '\n') { + line_fill = true; + break; + } + line_header[header_read] = buffer[buffer_read]; header_read++; + } + if (line_fill) { + for (; header_read < sizeof(line_header); ++header_read) { + line_header[header_read] = '\0'; + } + /* extract MIDR information from /proc/cpuinfo */ + /* "CPU implementer : " */ + /* "CPU variant : " */ + /* "CPU architecture: " */ + /* "CPU part : " */ + /* "CPU revision : */ + if (line_header[0] == 'C' && line_header[1] == 'P' && line_header[2] == 'U' + && cpuid < midr_size) { + + for (header_read = 3; header_read < sizeof(line_header); ++header_read) { + if (line_header[header_read] != ' ') break; + } + bool skip_detection = false; + /* extract architecture (MIDR[16:19]) */ + if (header_read + 12 < sizeof(line_header)) { + if (line_header[header_read] == 'a' && line_header[header_read + 1] == 'r' + && line_header[header_read + 2] == 'c' && line_header[header_read + 3] == 'h' + && line_header[header_read + 4] == 'i' && line_header[header_read + 5] == 't') { + + skip_detection = true; + header_read += 12; + midr[cpuid] |= + ((uint32_t)extract_id(line_header, header_read, sizeof(line_header)) << 16); + } + } + /* extract revision (MIDR[0:3]) */ + if (!skip_detection && header_read + 8 < sizeof(line_header)) { + if (line_header[header_read] == 'r' && line_header[header_read + 1] == 'e' + && line_header[header_read + 2] == 'v' && line_header[header_read + 3] == 'i' + && line_header[header_read + 4] == 's' && line_header[header_read + 5] == 'i') { + + skip_detection = true; + header_read += 8; + midr[cpuid] |= + ((uint32_t)extract_id(line_header, header_read, sizeof(line_header))); + } + } + /* extract variant (MIDR[20:23]) */ + if (!skip_detection && header_read + 7 < sizeof(line_header)) { + if (line_header[header_read] == 'v' && line_header[header_read + 1] == 'a' + && line_header[header_read + 2] == 'r' && line_header[header_read + 3] == 'i' + && line_header[header_read + 4] == 'a' && line_header[header_read + 5] == 'n') { + + skip_detection = true; + header_read += 7; + midr[cpuid] |= + ((uint32_t)extract_id(line_header, header_read, sizeof(line_header)) << 20); + } + } + /* extract implementer (MIDR[24:31]) */ + if (!skip_detection && header_read + 11 < sizeof(line_header)) { + if (line_header[header_read] == 'i' && line_header[header_read + 1] == 'm' + && line_header[header_read + 2] == 'p' && line_header[header_read + 3] == 'l' + && line_header[header_read + 4] == 'e' && line_header[header_read + 5] == 'm') { + + skip_detection = true; + header_read += 11; + midr[cpuid] |= + ((uint32_t)extract_id(line_header, header_read, sizeof(line_header))) << 24; + num_cpu_vendor_parsed++; + } + } + /* extract part number (MIDR[4:15]) */ + if (!skip_detection && header_read + 4 < sizeof(line_header)) { + if (line_header[header_read] == 'p' && line_header[header_read + 1] == 'a' + && line_header[header_read + 2] == 'r' && line_header[header_read + 3] == 't') { + + skip_detection = true; + header_read += 4; + midr[cpuid] |= + ((uint32_t)extract_id(line_header, header_read, sizeof(line_header))) << 4; + num_cpu_part_parsed++; + } + } + } + /* read processor id from /proc/cpuinfo */ + /* "processor : " */ + if (line_header[0] == 'p' && line_header[1] == 'r' && line_header[2] == 'o' + && line_header[3] == 'c' && line_header[4] == 'e' && line_header[5] == 's' + && line_header[6] == 's' && line_header[7] == 'o' && line_header[8] == 'r') { + + header_read = 9; + cpuid = extract_id(line_header, header_read, sizeof(line_header)); + if (cpuid < midr_size) midr[cpuid] = 0; + num_cpu_detected++; + } + line_fill = false; + header_read = 0; + } + for (; buffer_read < bytes_read; ++buffer_read) { + continue_find_endline = true; + if (buffer[buffer_read] == '\n') { + continue_find_endline = false; + buffer_read++; + break; + } + } + } + } while(bytes_read == sizeof(buffer)); + + fclose(fp); + + /* on some platforms the Linux kernel is buggy, + * info from /proc/cpuinfo lack some fields. */ + if (num_cpu_detected != num_cpu_part_parsed) return -3; + if (num_cpu_detected != num_cpu_vendor_parsed) return -3; + return num_cpu_detected; +} + +static char cpu_uevent[40] = "/sys/devices/system/cpu/cpu"; + +static uint8_t get_cputype_from_uevent(uint8_t cpuid) { + /* first form the file path */ + uint8_t digits[8]; + uint8_t n_digits = 0; + uint8_t tmp = cpuid; + do { + digits[n_digits] = tmp % 10; + tmp /= 10; + n_digits++; + } while (tmp > 0); + for (uint8_t i = 0; i < n_digits; ++i) { + cpu_uevent[27 + i] = digits[n_digits - i - 1] + 48; + } + uint8_t tail_pos = 27 + n_digits; + cpu_uevent[tail_pos] = '/'; + cpu_uevent[tail_pos + 1] = 'u'; + cpu_uevent[tail_pos + 2] = 'e'; + cpu_uevent[tail_pos + 3] = 'v'; + cpu_uevent[tail_pos + 4] = 'e'; + cpu_uevent[tail_pos + 5] = 'n'; + cpu_uevent[tail_pos + 6] = 't'; + cpu_uevent[tail_pos + 7] = '\0'; + /* then open the file */ + FILE *fp = fopen(cpu_uevent, "r"); + if (fp == NULL) { + return 0; //file open failed + } + unsigned char buffer[100]; + fread(buffer, 1, sizeof(buffer), fp); + if (ferror(fp)) { + return 0; //error during read + } + uint8_t cputype = 0; + /* search for patterns like "OF_COMPATIBLE_0=arm,cortex-a72" */ + for (uint8_t i = 0; i < sizeof(buffer) - 40; ++i) { + if (buffer[i] == 'O' && buffer[i + 1] == 'F' && buffer[i + 2] == '_') { + i += 3; + if (buffer[i] == 'C' && buffer[i + 1] == 'O' && buffer[i + 2] == 'M') { + i += 3; + if (buffer[i] == 'P' && buffer[i + 1] == 'A' && buffer[i + 2] == 'T') { + i += 10; + if (buffer[i] == 'a' && buffer[i + 1] == 'r' && buffer[i + 2] == 'm') { + i += 4; + if (buffer[i] == 'c' && buffer[i + 1] == 'o' && buffer[i + 2] == 'r') { + i += 5; + if (buffer[i] == 'x' && buffer[i + 1] == '-' && buffer[i + 2] == 'a') { + char tmp = buffer[i + 3]; + if (tmp >= 48 && tmp <= 57) cputype = tmp - 48; + tmp = buffer[i + 4]; + if (tmp >= 48 && tmp <= 57) cputype = cputype * 10 + (tmp - 48); + break; + } + } + } + } + } + } + } + return cputype; +} + +uint8_t blas_arm_get_cpu_type(uint8_t cpuid) { + if (cpuid >= MAX_CPU_COUNT) return 0; + if (!blas_arm_cpu_type.m_init) { + int acc_lock = pthread_mutex_lock(&blas_arm_get_cpu_type_lock); + if (acc_lock != 0) return 0; + if (!blas_arm_cpu_type.m_init) { + uint32_t midr[MAX_CPU_COUNT]; + for (int cpupos = 0; cpupos < MAX_CPU_COUNT; ++cpupos) { + midr[cpupos] = 0; + } + int midr_read_status = read_midr(midr, MAX_CPU_COUNT); + if (midr_read_status > MAX_CPU_COUNT) midr_read_status = MAX_CPU_COUNT; + if (midr_read_status >= 0) { + for (int cpupos = 0; cpupos < midr_read_status; ++cpupos) { + blas_arm_cpu_type.m_cpuType[cpupos] = + parse_midr(midr[cpupos]); + } + } else { + for (int cpupos = 0; cpupos < MAX_CPU_COUNT; ++cpupos) { + blas_arm_cpu_type.m_cpuType[cpupos] = + get_cputype_from_uevent(cpupos); + } + } + blas_arm_cpu_type.m_init = true; + } + pthread_mutex_unlock(&blas_arm_get_cpu_type_lock); + } + return blas_arm_cpu_type.m_cpuType[cpuid]; +} + +static __thread uint8_t blas_arm_fp16_type = 0; +static __thread uint8_t blas_arm_fp16_init = 0; + +#ifndef HWCAP_ASIMDHP +#define HWCAP_ASIMDHP (1 << 10) +#endif +#ifndef HWCAP_FPHP +#define HWCAP_FPHP (1 << 9) +#endif + +uint8_t blas_arm_get_fp16_support() { + if (!blas_arm_fp16_init) { + unsigned long hwcap = getauxval(AT_HWCAP); +#if __aarch64__ + blas_arm_fp16_type = + ((hwcap & HWCAP_ASIMDHP) && (hwcap & HWCAP_FPHP)) ? 2 : 1; +#else + blas_arm_fp16_type = + ((hwcap & HWCAP_VFPv4) && (hwcap & HWCAP_NEON)) ? 1 : 0; +#endif + blas_arm_fp16_init = 1; + } + return blas_arm_fp16_type; +} + +#if __aarch64__ +#define GEMM_DEFAULT_I8I32_INST 1 +#else +#define GEMM_DEFAULT_I8I32_INST 0 +#endif + +static uint8_t blas_arm_i8i32_type = GEMM_DEFAULT_I8I32_INST + 1; +static uint8_t blas_arm_i8i32_init = 0; +static pthread_mutex_t blas_arm_set_int_lock + = PTHREAD_MUTEX_INITIALIZER; +static jmp_buf i8i32_ret_env; +static pthread_t int_tid; + +static void i8i32gemm_sigill_handler(int sigill) { + if (pthread_equal(int_tid, pthread_self()) != 0) { + blas_arm_i8i32_type = GEMM_DEFAULT_I8I32_INST; + longjmp(i8i32_ret_env, 1); + } else { + _Exit(EXIT_FAILURE); + } +} + +static void test_i8i32() { +#if __aarch64__ + __asm__ __volatile__("sdot v1.4s,v0.16b,v2.4b[0]":::"v0","v1","v2"); +#else + __asm__ __volatile__("vmlal.s16 q1,d0,d1[0]":::"q0","q1"); +#endif +} + +uint8_t blas_arm_get_i8i32_support() { + if (!blas_arm_i8i32_init) { + int acc_lock = pthread_mutex_lock(&blas_arm_set_int_lock); + if (acc_lock != 0) return GEMM_DEFAULT_I8I32_INST; + if (!blas_arm_i8i32_init) { + struct sigaction i8i32_act, old_act; + memset(&i8i32_act, '\0', sizeof(i8i32_act)); + i8i32_act.sa_handler = &i8i32gemm_sigill_handler; + int_tid = pthread_self(); + if (setjmp(i8i32_ret_env)) { + sigaction(SIGILL, &old_act, NULL); + blas_arm_i8i32_init = 1; + pthread_mutex_unlock(&blas_arm_set_int_lock); + return GEMM_DEFAULT_I8I32_INST; + } + __asm__ __volatile__("dsb sy":::"memory"); + sigaction(SIGILL, &i8i32_act, &old_act); + test_i8i32(); + sigaction(SIGILL, &old_act, NULL); + blas_arm_i8i32_init = 1; + } + pthread_mutex_unlock(&blas_arm_set_int_lock); + } + return blas_arm_i8i32_type; +} + diff --git a/src/neon_armv7a/Bias.c b/src/neon_armv7a/Bias.c new file mode 100644 index 0000000..e77da7c --- /dev/null +++ b/src/neon_armv7a/Bias.c @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/NeonBias.h" +#include "arm_neon/NeonSum.h" + +NEON_BIAS(float, float32x4_t, f32, 4, mla) + +NEON_BIAS(int32_t, int32x4_t, s32, 4, mla) + +NEON_I8I32_SUM(u, uint) + +NEON_I16_SUMSQUARE(s, int) + diff --git a/src/neon_armv7a/Layer.c b/src/neon_armv7a/Layer.c new file mode 100644 index 0000000..c64cabd --- /dev/null +++ b/src/neon_armv7a/Layer.c @@ -0,0 +1,24 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv7a/SgemmDriver.h" +#include "neon_armv7a/Bias.h" +#include "common/CommonLayer.h" +#include + +SIMPLE_FC_FUNC(sgemm, float, float, float) + diff --git a/src/neon_armv7a/Quant.c b/src/neon_armv7a/Quant.c new file mode 100644 index 0000000..3835294 --- /dev/null +++ b/src/neon_armv7a/Quant.c @@ -0,0 +1,52 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "common/CommonQuant.h" +#include "arm_neon/NeonQuant.h" + +NEON_FIND_EXTREME(float32_t, f32, float32x2_t, float32x4_t, 2) + +QUANTIZE_ASYMMETRIC(32, 8) + +QUANTIZE_SYMMETRIC(32, 8) + +QUANTIZE_ASYMMETRIC(32, 16) + +QUANTIZE_SYMMETRIC(32, 16) + +void dequantize_symmetric_f32_s32(const int32_t *src, float32_t *dst, + float32_t scale, uint32_t size) { + + inline_dequant_cvt_f32_s32(dst, src, scale, size); +} + +NEON_FIND_EXTREME(int32_t, s32, int32x2_t, int32x4_t, 2) + +NEON_FIND_EXTREME(int16_t, s16, int16x4_t, int16x8_t, 4) + +REQUANTIZE_ASYMMETRIC_MULHI(float, 32, 8, 64) + +REQUANTIZE_SYMMETRIC_MULHI(float, 32, 8, 64) + +REQUANTIZE_ASYMMETRIC_MULHI(float, 32, 16, 64) + +REQUANTIZE_SYMMETRIC_MULHI(float, 32, 16, 64) + +REQUANTIZE_ASYMMETRIC_MULHI(float, 16, 8, 32) + +REQUANTIZE_SYMMETRIC_MULHI(float, 16, 8, 32) + diff --git a/src/neon_armv7a/S8S32GemmDriver.c b/src/neon_armv7a/S8S32GemmDriver.c new file mode 100644 index 0000000..6089838 --- /dev/null +++ b/src/neon_armv7a/S8S32GemmDriver.c @@ -0,0 +1,43 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv7a/S8S32MlaGemmDriver.h" +#include "arm_neon/ARMCpuType.h" + +int s8s32gemm_serial(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, int32_t beta_inp) { + + if (blas_arm_get_i8i32_support() == 0) { + return 2; + } + return s8s32mlagemm_serial(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp); +} + +int s8s32gemm(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, + int32_t beta_inp, uint32_t num_threads) { + + if (blas_arm_get_i8i32_support() == 0) { + return 2; + } + return s8s32mlagemm(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp, num_threads); +} + diff --git a/src/neon_armv7a/S8S32MlaGemmCopy.c b/src/neon_armv7a/S8S32MlaGemmCopy.c new file mode 100644 index 0000000..5e7edb8 --- /dev/null +++ b/src/neon_armv7a/S8S32MlaGemmCopy.c @@ -0,0 +1,30 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "common/CommonCopy.h" +#include "arm_neon/NeonI8I32MlaGemmCopy.h" + +GENERIC_NCOPY_FUNC(s8s32mlagemm, int8_t, int16_t, 6) +GENERIC_NCOPY_FUNC(s8s32mlagemm, int8_t, int16_t, 8) + +GENERIC_TCOPY_FUNC(s8s32mlagemm, int8_t, int16_t, 6) +GENERIC_TCOPY_FUNC(s8s32mlagemm, int8_t, int16_t, 8) + diff --git a/src/neon_armv7a/S8S32MlaGemmDriver.c b/src/neon_armv7a/S8S32MlaGemmDriver.c new file mode 100644 index 0000000..7b09908 --- /dev/null +++ b/src/neon_armv7a/S8S32MlaGemmDriver.c @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv7a/S8S32MlaGemmCopy.h" +#include "neon_armv7a/S8S32MlaGemmKernel.h" +#include "neon_armv7a/S8S32MlaGemmSkinnyGer.h" +#include "neon_armv7a/S8S32MlaGemmSkinnyDot.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +GEMM_PARALLEL_FUNC(s8s32mlagemm, int8_t, int16_t, int8_t, int16_t, int32_t, + 6, 8, 4, 4, 4, 4) + diff --git a/src/neon_armv7a/S8S32MlaGemmKernel.c b/src/neon_armv7a/S8S32MlaGemmKernel.c new file mode 100644 index 0000000..bc0dcf1 --- /dev/null +++ b/src/neon_armv7a/S8S32MlaGemmKernel.c @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "common/CommonKernel.h" +#include "neon_armv7a/I8I32MlaGemmKernel.h" + +DUALPACK_KERNEL_FUNC_LM(s8s32mlagemm, int16_t, int16_t, int32_t, 6, 8) +DUALPACK_KERNEL_FUNC_LN(s8s32mlagemm, int16_t, int16_t, int32_t, 8, 6) + diff --git a/src/neon_armv7a/S8S32MlaGemmSkinnyDot.c b/src/neon_armv7a/S8S32MlaGemmSkinnyDot.c new file mode 100644 index 0000000..2380dc3 --- /dev/null +++ b/src/neon_armv7a/S8S32MlaGemmSkinnyDot.c @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "arm_neon/NeonI8I32MlaGemmSkinnyDot.h" +#include "common/CommonSkinnyDot.h" + +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 1, 15, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 2, 15, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 3, 15, 3, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 4, 15, 3, 131072, int8_t, int8_t) \ No newline at end of file diff --git a/src/neon_armv7a/S8S32MlaGemmSkinnyGer.c b/src/neon_armv7a/S8S32MlaGemmSkinnyGer.c new file mode 100644 index 0000000..0ae4d36 --- /dev/null +++ b/src/neon_armv7a/S8S32MlaGemmSkinnyGer.c @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "arm_neon/NeonI8I32MlaGemmSkinnyGer.h" + +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 1, 5, 5, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 2, 5, 5, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 3, 5, 5, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 4, 5, 5, 8192, int8_t, int8_t) + diff --git a/src/neon_armv7a/SgemmCopy.c b/src/neon_armv7a/SgemmCopy.c new file mode 100644 index 0000000..58929e6 --- /dev/null +++ b/src/neon_armv7a/SgemmCopy.c @@ -0,0 +1,31 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "common/CommonCopy.h" +#include "arm_neon/NeonSgemmCopy.h" + +#define NCOPY_float_float(unroll) NCOPY_UNROLL_##unroll + +GENERIC_NCOPY_FUNC(sgemm, float, float, 6) +GENERIC_NCOPY_FUNC(sgemm, float, float, 8) + +#define TCOPY_UNIT_float_float(src_ptr, dst_ptr, dst_offset, num_elements) \ + TCOPY_UNIT_##num_elements(src_ptr, dst_ptr, dst_offset) + +GENERIC_TCOPY_FUNC(sgemm, float, float, 6) +GENERIC_TCOPY_FUNC(sgemm, float, float, 8) + diff --git a/src/neon_armv7a/SgemmDriver.c b/src/neon_armv7a/SgemmDriver.c new file mode 100644 index 0000000..1fa8a15 --- /dev/null +++ b/src/neon_armv7a/SgemmDriver.c @@ -0,0 +1,26 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv7a/SgemmKernel.h" +#include "neon_armv7a/SgemmCopy.h" +#include "neon_armv7a/SgemmSkinnyGer.h" +#include "neon_armv7a/SgemmSkinnyDot.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +GEMM_PARALLEL_FUNC(sgemm, float, float, float, float, float, 6, 8, 8, 8, 8, 8) + diff --git a/src/neon_armv7a/SgemmKernel.c b/src/neon_armv7a/SgemmKernel.c new file mode 100644 index 0000000..1b01135 --- /dev/null +++ b/src/neon_armv7a/SgemmKernel.c @@ -0,0 +1,328 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "common/CommonKernel.h" +#include "arm_neon/NeonSgemmKernel.h" + +#define NEON_SGEMM_KERNEL_M6N8_PRELOAD_A53 \ + "vldr d0,[%13]; vldr d1,[%13,#8]; add %13,%13,#24\n\t"\ + "vldr d4,[%14]; vldr d5,[%14,#8]; ldr r2,[%14,#16]; ldr r3,[%14,#20]\n\t"\ + "add %14,%14,#32\n\t" + +#define NEON_SGEMM_KERNEL_M6N8_MAIN2_A53 \ + "vldr d3,[%13,#-8]; vmov d2,d1\n\t"\ + "vmla.f32 %q0,q0,d4[0]; ldr r0,[%14]\n\t"\ + "vmla.f32 %q1,q0,d4[1]; ldr r1,[%14,#4]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d7,[%14,#-8]; vmov d6,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]; ldr r2,[%13]\n\t"\ + "vmla.f32 %q4,q2,d3[0]; ldr r3,[%13,#4]\n\t"\ + "vmla.f32 %q5,q2,d3[1]\n\t"\ + "vldr d5,[%14,#8]; vmov d4,r0,r1\n\t"\ + "vmla.f32 %q6,q3,d0[0]; add %13,%13,#48\n\t"\ + "vmla.f32 %q7,q3,d0[1]; add %14,%14,#64\n\t"\ + "vmla.f32 %q8,q1,d6[0]; pld [%13,#128]\n\t"\ + "vldr d1,[%13,#-40]; vmov d0,r2,r3\n\t"\ + "vmla.f32 %q9,q1,d6[1]; ldr r2,[%14,#-48]\n\t"\ + "vmla.f32 %q10,q1,d7[0]; ldr r3,[%14,#-44]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t"\ + "vldr d3,[%13,#-32]; vmov d2,d1\n\t"\ + "vmla.f32 %q0,q0,d4[0]; ldr r0,[%14,#-32]\n\t"\ + "vmla.f32 %q1,q0,d4[1]; ldr r1,[%14,#-28]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d7,[%14,#-40]; vmov d6,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]; ldr r2,[%13,#-24]\n\t"\ + "vmla.f32 %q4,q2,d3[0]; ldr r3,[%13,#-20]\n\t"\ + "vmla.f32 %q5,q2,d3[1]\n\t"\ + "vldr d5,[%14,#-24]; vmov d4,r0,r1\n\t"\ + "vmla.f32 %q6,q3,d0[0]; sub %12,%12,#2\n\t"\ + "vmla.f32 %q7,q3,d0[1]; cmp %12,#2\n\t"\ + "vmla.f32 %q8,q1,d6[0]; pld [%14,#192]\n\t"\ + "vldr d1,[%13,#-16]; vmov d0,r2,r3\n\t"\ + "vmla.f32 %q9,q1,d6[1]; ldr r2,[%14,#-16]\n\t"\ + "vmla.f32 %q10,q1,d7[0]; ldr r3,[%14,#-12]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t" + +#define NEON_SGEMM_KERNEL_M6N8_TAIL2_A53 \ + "vldr d3,[%13,#-8]; vmov d2,d1\n\t"\ + "vmla.f32 %q0,q0,d4[0]; ldr r0,[%14]\n\t"\ + "vmla.f32 %q1,q0,d4[1]; ldr r1,[%14,#4]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d7,[%14,#-8]; vmov d6,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]; ldr r2,[%13]\n\t"\ + "vmla.f32 %q4,q2,d3[0]; ldr r3,[%13,#4]\n\t"\ + "vmla.f32 %q5,q2,d3[1]\n\t"\ + "vldr d5,[%14,#8]; vmov d4,r0,r1\n\t"\ + "vmla.f32 %q6,q3,d0[0]; add %13,%13,#24\n\t"\ + "vmla.f32 %q7,q3,d0[1]; add %14,%14,#32\n\t"\ + "vmla.f32 %q8,q1,d6[0]\n\t"\ + "vldr d1,[%13,#-16]; vmov d0,r2,r3\n\t"\ + "vmla.f32 %q9,q1,d6[1]; ldr r2,[%14,#-16]\n\t"\ + "vmla.f32 %q10,q1,d7[0]; ldr r3,[%14,#-12]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t"\ + "vldr d3,[%13,#-8]; vmov d2,d1\n\t"\ + "vmla.f32 %q0,q0,d4[0]\n\t"\ + "vmla.f32 %q1,q0,d4[1]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d7,[%14,#-8]; vmov d6,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]\n\t"\ + "vmla.f32 %q4,q2,d3[0]\n\t"\ + "vmla.f32 %q5,q2,d3[1]\n\t"\ + "vmla.f32 %q6,q3,d0[0]\n\t"\ + "vmla.f32 %q7,q3,d0[1]\n\t"\ + "vmla.f32 %q8,q1,d6[0]\n\t"\ + "vmla.f32 %q9,q1,d6[1]\n\t"\ + "vmla.f32 %q10,q1,d7[0]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t" + +#define NEON_SGEMM_KERNEL_M6N8_TAIL1_A53 \ + "vldr d3,[%13,#-8]; vmov d2,d1\n\t"\ + "vmla.f32 %q0,q0,d4[0]\n\t"\ + "vmla.f32 %q1,q0,d4[1]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d7,[%14,#-8]; vmov d6,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]\n\t"\ + "vmla.f32 %q4,q2,d3[0]\n\t"\ + "vmla.f32 %q5,q2,d3[1]\n\t"\ + "vmla.f32 %q6,q3,d0[0]\n\t"\ + "vmla.f32 %q7,q3,d0[1]\n\t"\ + "vmla.f32 %q8,q1,d6[0]\n\t"\ + "vmla.f32 %q9,q1,d6[1]\n\t"\ + "vmla.f32 %q10,q1,d7[0]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t" + +#define NEON_SGEMM_SAVE_M6N8_ASM \ + float32x4x2_t ct1 = vzipq_f32(cq05, cq06);\ + float32x2_t cd1 = vget_low_f32(ct1.val[0]);\ + float32x2_t cd2 = vget_high_f32(ct1.val[0]);\ +\ + cq01 = vmlaq_n_f32(cq01, vld1q_f32(c_tmp), beta);\ + cd1 = vmla_n_f32(cd1, vld1_f32(c_tmp + 4), beta);\ + cq02 = vmlaq_n_f32(cq02, vld1q_f32(c_tmp + ldc), beta);\ + cd2 = vmla_n_f32(cd2, vld1_f32(c_tmp + ldc + 4), beta);\ +\ + vst1q_f32(c_tmp, cq01); vst1_f32(c_tmp + 4, cd1); c_tmp += ldc;\ + vst1q_f32(c_tmp, cq02); vst1_f32(c_tmp + 4, cd2); c_tmp += ldc;\ + cd1 = vget_low_f32(ct1.val[1]);\ + cd2 = vget_high_f32(ct1.val[1]);\ +\ + cq03 = vmlaq_n_f32(cq03, vld1q_f32(c_tmp), beta);\ + cd1 = vmla_n_f32(cd1, vld1_f32(c_tmp + 4), beta);\ + cq04 = vmlaq_n_f32(cq04, vld1q_f32(c_tmp + ldc), beta);\ + cd2 = vmla_n_f32(cd2, vld1_f32(c_tmp + ldc + 4), beta);\ +\ + vst1q_f32(c_tmp, cq03); vst1_f32(c_tmp + 4, cd1); c_tmp += ldc;\ + vst1q_f32(c_tmp, cq04); vst1_f32(c_tmp + 4, cd2); c_tmp += ldc;\ + ct1 = vzipq_f32(cq07, cq08);\ + cd1 = vget_low_f32(ct1.val[0]);\ + cd2 = vget_high_f32(ct1.val[0]);\ +\ + cd1 = vmla_n_f32(cd1, vld1_f32(c_tmp), beta);\ + cq09 = vmlaq_n_f32(cq09, vld1q_f32(c_tmp + 2), beta);\ + cd2 = vmla_n_f32(cd2, vld1_f32(c_tmp + ldc), beta);\ + cq10 = vmlaq_n_f32(cq10, vld1q_f32(c_tmp + ldc + 2), beta);\ +\ + vst1_f32(c_tmp, cd1); vst1q_f32(c_tmp + 2, cq09); c_tmp += ldc;\ + vst1_f32(c_tmp, cd2); vst1q_f32(c_tmp + 2, cq10); c_tmp += ldc;\ + cd1 = vget_low_f32(ct1.val[1]);\ + cd2 = vget_high_f32(ct1.val[1]);\ +\ + cd1 = vmla_n_f32(cd1, vld1_f32(c_tmp), beta);\ + cq11 = vmlaq_n_f32(cq11, vld1q_f32(c_tmp + 2), beta);\ + cd2 = vmla_n_f32(cd2, vld1_f32(c_tmp + ldc), beta);\ + cq12 = vmlaq_n_f32(cq12, vld1q_f32(c_tmp + ldc + 2), beta);\ +\ + vst1_f32(c_tmp, cd1); vst1q_f32(c_tmp + 2, cq11); c_tmp += ldc;\ + vst1_f32(c_tmp, cd2); vst1q_f32(c_tmp + 2, cq12); + +#define NEON_SGEMM_KERNEL_M8N6_PRELOAD_A53 \ + "vldr d0,[%13]; vldr d1,[%13,#8]\n\t"\ + "ldr r2,[%13,#16]; ldr r3,[%13,#20]; add %13,%13,#32\n\t"\ + "vldr d4,[%14]; vldr d5,[%14,#8]; add %14,%14,#24\n\t" + +#define NEON_SGEMM_KERNEL_M8N6_MAIN2_A53 \ + "vldr d7,[%14,#-8]; vmov d6,d5\n\t"\ + "vmla.f32 %q0,q0,d4[0]; ldr r0,[%13]\n\t"\ + "vmla.f32 %q1,q0,d4[1]; ldr r1,[%13,#4]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d3,[%13,#-8]; vmov d2,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]; ldr r2,[%14]\n\t"\ + "vmla.f32 %q4,q0,d7[0]; ldr r3,[%14,#4]\n\t"\ + "vmla.f32 %q5,q0,d7[1]\n\t"\ + "vldr d1,[%13,#8]; vmov d0,r0,r1\n\t"\ + "vmla.f32 %q6,q1,d4[0]; add %13,%13,#64\n\t"\ + "vmla.f32 %q7,q1,d4[1]; add %14,%14,#48\n\t"\ + "vmla.f32 %q8,q1,d6[0]; pld [%13,#192]\n\t"\ + "vldr d5,[%14,#-40]; vmov d4,r2,r3\n\t"\ + "vmla.f32 %q9,q1,d6[1]; ldr r2,[%13,#-48]\n\t"\ + "vmla.f32 %q10,q1,d7[0]; ldr r3,[%13,#-44]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t"\ + "vldr d7,[%14,#-32]; vmov d6,d5\n\t"\ + "vmla.f32 %q0,q0,d4[0]; ldr r0,[%13,#-32]\n\t"\ + "vmla.f32 %q1,q0,d4[1]; ldr r1,[%13,#-28]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d3,[%13,#-40]; vmov d2,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]; ldr r2,[%14,#-24]\n\t"\ + "vmla.f32 %q4,q0,d7[0]; ldr r3,[%14,#-20]\n\t"\ + "vmla.f32 %q5,q0,d7[1]\n\t"\ + "vldr d1,[%13,#-24]; vmov d0,r0,r1\n\t"\ + "vmla.f32 %q6,q1,d4[0]; sub %12,%12,#2\n\t"\ + "vmla.f32 %q7,q1,d4[1]; cmp %12,#2\n\t"\ + "vmla.f32 %q8,q1,d6[0]; pld [%14,#128]\n\t"\ + "vldr d5,[%14,#-16]; vmov d4,r2,r3\n\t"\ + "vmla.f32 %q9,q1,d6[1]; ldr r2,[%13,#-16]\n\t"\ + "vmla.f32 %q10,q1,d7[0]; ldr r3,[%13,#-12]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t" + +#define NEON_SGEMM_KERNEL_M8N6_TAIL2_A53 \ + "vldr d7,[%14,#-8]; vmov d6,d5\n\t"\ + "vmla.f32 %q0,q0,d4[0]; ldr r0,[%13]\n\t"\ + "vmla.f32 %q1,q0,d4[1]; ldr r1,[%13,#4]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d3,[%13,#-8]; vmov d2,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]; ldr r2,[%14]\n\t"\ + "vmla.f32 %q4,q0,d7[0]; ldr r3,[%14,#4]\n\t"\ + "vmla.f32 %q5,q0,d7[1]\n\t"\ + "vldr d1,[%13,#8]; vmov d0,r0,r1\n\t"\ + "vmla.f32 %q6,q1,d4[0]; add %13,%13,#32\n\t"\ + "vmla.f32 %q7,q1,d4[1]; add %14,%14,#24\n\t"\ + "vmla.f32 %q8,q1,d6[0]\n\t"\ + "vldr d5,[%14,#-16]; vmov d4,r2,r3\n\t"\ + "vmla.f32 %q9,q1,d6[1]; ldr r2,[%13,#-16]\n\t"\ + "vmla.f32 %q10,q1,d7[0]; ldr r3,[%13,#-12]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t"\ + "vldr d7,[%14,#-8]; vmov d6,d5\n\t"\ + "vmla.f32 %q0,q0,d4[0]\n\t"\ + "vmla.f32 %q1,q0,d4[1]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d3,[%13,#-8]; vmov d2,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]\n\t"\ + "vmla.f32 %q4,q0,d7[0]\n\t"\ + "vmla.f32 %q5,q0,d7[1]\n\t"\ + "vmla.f32 %q6,q1,d4[0]\n\t"\ + "vmla.f32 %q7,q1,d4[1]\n\t"\ + "vmla.f32 %q8,q1,d6[0]\n\t"\ + "vmla.f32 %q9,q1,d6[1]\n\t"\ + "vmla.f32 %q10,q1,d7[0]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t" + +#define NEON_SGEMM_KERNEL_M8N6_TAIL1_A53 \ + "vldr d7,[%14,#-8]; vmov d6,d5\n\t"\ + "vmla.f32 %q0,q0,d4[0]\n\t"\ + "vmla.f32 %q1,q0,d4[1]\n\t"\ + "vmla.f32 %q2,q0,d5[0]\n\t"\ + "vldr d3,[%13,#-8]; vmov d2,r2,r3\n\t"\ + "vmla.f32 %q3,q0,d5[1]\n\t"\ + "vmla.f32 %q4,q0,d7[0]\n\t"\ + "vmla.f32 %q5,q0,d7[1]\n\t"\ + "vmla.f32 %q6,q1,d4[0]\n\t"\ + "vmla.f32 %q7,q1,d4[1]\n\t"\ + "vmla.f32 %q8,q1,d6[0]\n\t"\ + "vmla.f32 %q9,q1,d6[1]\n\t"\ + "vmla.f32 %q10,q1,d7[0]\n\t"\ + "vmla.f32 %q11,q1,d7[1]\n\t" + +#define NEON_SGEMM_SAVE_M8N6_ASM \ +\ + cq01 = vmlaq_n_f32(cq01, vld1q_f32(c_tmp), beta);\ + cq07 = vmlaq_n_f32(cq07, vld1q_f32(c_tmp + 4), beta);\ + cq02 = vmlaq_n_f32(cq02, vld1q_f32(c_tmp + ldc), beta);\ + cq08 = vmlaq_n_f32(cq08, vld1q_f32(c_tmp + ldc + 4), beta);\ +\ + vst1q_f32(c_tmp, cq01); vst1q_f32(c_tmp + 4, cq07); c_tmp += ldc;\ + vst1q_f32(c_tmp, cq02); vst1q_f32(c_tmp + 4, cq08); c_tmp += ldc;\ +\ + cq03 = vmlaq_n_f32(cq03, vld1q_f32(c_tmp), beta);\ + cq09 = vmlaq_n_f32(cq09, vld1q_f32(c_tmp + 4), beta);\ + cq04 = vmlaq_n_f32(cq04, vld1q_f32(c_tmp + ldc), beta);\ + cq10 = vmlaq_n_f32(cq10, vld1q_f32(c_tmp + ldc + 4), beta);\ +\ + vst1q_f32(c_tmp, cq03); vst1q_f32(c_tmp + 4, cq09); c_tmp += ldc;\ + vst1q_f32(c_tmp, cq04); vst1q_f32(c_tmp + 4, cq10); c_tmp += ldc;\ +\ + cq05 = vmlaq_n_f32(cq05, vld1q_f32(c_tmp), beta);\ + cq11 = vmlaq_n_f32(cq11, vld1q_f32(c_tmp + 4), beta);\ + cq06 = vmlaq_n_f32(cq06, vld1q_f32(c_tmp + ldc), beta);\ + cq12 = vmlaq_n_f32(cq12, vld1q_f32(c_tmp + ldc + 4), beta);\ +\ + vst1q_f32(c_tmp, cq05); vst1q_f32(c_tmp + 4, cq11); c_tmp += ldc;\ + vst1q_f32(c_tmp, cq06); vst1q_f32(c_tmp + 4, cq12); + +#define PREF_C_1_LANE(n, mdim) \ + pref_c(c_pref); pref_c(c_pref + mdim - 1); c_pref += ldc; +#define PREF_C(mdim, ndim) \ + MACRO_EXPANSION_##ndim(VOID_BASE, PREF_C_1_LANE, mdim) + +#define NEON_SGEMM_ASM(mdim, ndim, cputype) {\ + float *c_pref = c_ptr; PREF_C(mdim, ndim)\ + register float32x4_t cq01 __asm("q4");\ + register float32x4_t cq02 __asm("q5");\ + register float32x4_t cq03 __asm("q6");\ + register float32x4_t cq04 __asm("q7");\ + register float32x4_t cq05 __asm("q8");\ + register float32x4_t cq06 __asm("q9");\ + register float32x4_t cq07 __asm("q10");\ + register float32x4_t cq08 __asm("q11");\ + register float32x4_t cq09 __asm("q12");\ + register float32x4_t cq10 __asm("q13");\ + register float32x4_t cq11 __asm("q14");\ + register float32x4_t cq12 __asm("q15");\ + const float *a_ptr, *b_ptr;\ + uint32_t k_left;\ + b_ptr = b_head;\ + a_ptr = a_head;\ + k_left = K;\ + __asm__ __volatile__ (\ + "vmov.i8 %q0,#0; vmov.i8 %q1,#0; vmov %q2,%q0; vmov %q3,%q1\n\t"\ + "vmov %q4,%q0; vmov %q5,%q1; vmov %q6,%q0; vmov %q7,%q1\n\t"\ + "vmov %q8,%q0; vmov %q9,%q1; vmov %q10,%q0; vmov %q11,%q1\n\t"\ + "cmp %12,#0; beq 4f\n\t"\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim##_PRELOAD_##cputype\ + "cmp %12,#2; ble 2f\n\t"\ + ".balign 16\n\t"\ + "1:\n\t"\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim##_MAIN2_##cputype "bgt 1b\n\t"\ + "2:\n\t"\ + "cmp %12,#2; bne 3f\n\t"\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim##_TAIL2_##cputype "b 4f\n\t"\ + "3:\n\t"\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim##_TAIL1_##cputype\ + "4:\n\t"\ + :"=w"(cq01),"=w"(cq02),"=w"(cq03),"=w"(cq04),"=w"(cq05),"=w"(cq06),\ + "=w"(cq07),"=w"(cq08),"=w"(cq09),"=w"(cq10),"=w"(cq11),"=w"(cq12),\ + "+r"(k_left),"+r"(a_ptr),"+r"(b_ptr)\ + ::"d0","d1","d2","d3","d4","d5","d6","d7",\ + "r0","r1","r2","r3","cc","memory");\ + float *c_tmp = c_ptr;\ + NEON_SGEMM_SAVE_M##mdim##N##ndim##_ASM\ +} + +static inline void inline_dualpack_gemm_afloat_bfloat_cfloat_m6_n8( + const float *a_head, const float *b_head, float *c_ptr, + uint32_t K, float beta, uint32_t ldc) { + NEON_SGEMM_ASM(6, 8, A53) +} + +static inline void inline_dualpack_gemm_afloat_bfloat_cfloat_m8_n6( + const float *a_head, const float *b_head, float *c_ptr, + uint32_t K, float beta, uint32_t ldc) { + NEON_SGEMM_ASM(8, 6, A53) +} + +DUALPACK_KERNEL_FUNC_LM(sgemm, float, float, float, 6, 8) +DUALPACK_KERNEL_FUNC_LN(sgemm, float, float, float, 8, 6) + diff --git a/src/neon_armv7a/SgemmSkinnyDot.c b/src/neon_armv7a/SgemmSkinnyDot.c new file mode 100644 index 0000000..40c12d1 --- /dev/null +++ b/src/neon_armv7a/SgemmSkinnyDot.c @@ -0,0 +1,495 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/ARMCpuType.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonSkinnyDot.h" +#include + +typedef float sgemm_skinnydot_ascalar; +typedef float sgemm_skinnydot_bscalar; +typedef float sgemm_skinnydot_cscalar; + +static inline void inline_sgemm_arowmajor_bskinny_m4n1(const float *a_ptr1, + const float *b_ptr, float *c_ptr, uint32_t k_inc, uint32_t LDK, uint32_t LDM, + float beta, bool c_rowmajor) { + + float32x2_t cd1, cd2, cd3, cd4, cd5, cd6, cd7, cd8; + const float *a_ptr2 = a_ptr1 + LDK; + const float *a_ptr3 = a_ptr1 + LDK * 2; + const float *a_ptr4 = a_ptr2 + LDK * 2; + const float *a_pref = a_ptr4 + LDK; + const uint32_t pref_inc = LDK > k_inc ? (LDK - k_inc) * sizeof(float) : 0; + uint32_t k_left = k_inc; + __asm__ __volatile__( + "mov r0,#0\n\t" + "vmov.i8 %[cd1],#0; vmov.i8 %[cd2],#0\n\t" + "vmov.i8 %[cd3],#0; vmov.i8 %[cd4],#0\n\t" + "vmov.i8 %[cd5],#0; vmov.i8 %[cd6],#0\n\t" + "vmov.i8 %[cd7],#0; vmov.i8 %[cd8],#0\n\t" + "cmp %[k_left],#4; blt 3f\n\t" + "vldr d2,[%[a_ptr1]]; vldr d6,[%[a_ptr1],#8]; add %[a_ptr1],%[a_ptr1],#16\n\t" + "vldr d3,[%[a_ptr2]]; vldr d7,[%[a_ptr2],#8]; add %[a_ptr2],%[a_ptr2],#16\n\t" + "vldr d4,[%[a_ptr3]]; vldr d8,[%[a_ptr3],#8]; add %[a_ptr3],%[a_ptr3],#16\n\t" + "vldr d5,[%[a_ptr4]]; vldr d9,[%[a_ptr4],#8]; add %[a_ptr4],%[a_ptr4],#16\n\t" + "vldm %[b_ptr]!,{d0,d1}\n\t" + "cmp %[k_left],#8; blt 2f\n\t" + ".balign 16; 1:\n\t" + "pld [%[a_pref]]; add %[a_pref],%[a_pref],#64; add r0,r0,#16\n\t" + "vmla.f32 %[cd1],d2,d0; vldr d2,[%[a_ptr1]]\n\t" + "cmp r0,%[k_inc]\n\t" + "vmla.f32 %[cd2],d3,d0; vldr d3,[%[a_ptr2]]\n\t" + "addgt %[a_pref],%[a_pref],%[pref_inc]\n\t" + "vmla.f32 %[cd3],d4,d0; vldr d4,[%[a_ptr3]]\n\t" + "movgt r0,#0\n\t" + "vmla.f32 %[cd4],d5,d0; vldr d5,[%[a_ptr4]]\n\t" + "vldr d0,[%[b_ptr]]; sub %[k_left],%[k_left],#4\n\t" + "vmla.f32 %[cd5],d6,d1; vldr d6,[%[a_ptr1],#8]\n\t" + "add %[a_ptr1],%[a_ptr1],#16\n\t" + "vmla.f32 %[cd6],d7,d1; vldr d7,[%[a_ptr2],#8]\n\t" + "add %[a_ptr2],%[a_ptr2],#16; cmp %[k_left],#8\n\t" + "vmla.f32 %[cd7],d8,d1; vldr d8,[%[a_ptr3],#8]\n\t" + "add %[a_ptr3],%[a_ptr3],#16\n\t" + "vmla.f32 %[cd8],d9,d1; vldr d9,[%[a_ptr4],#8]\n\t" + "add %[a_ptr4],%[a_ptr4],#16\n\t" + "vldr d1,[%[b_ptr],#8]; add %[b_ptr],%[b_ptr],#16; bge 1b\n\t" + "2:\n\t" + "vmla.f32 %[cd1],d2,d0; vmla.f32 %[cd2],d3,d0\n\t" + "vmla.f32 %[cd3],d4,d0; vmla.f32 %[cd4],d5,d0\n\t" + "sub %[k_left],%[k_left],#4\n\t" + "vmla.f32 %[cd5],d6,d1; vmla.f32 %[cd6],d7,d1\n\t" + "vmla.f32 %[cd7],d8,d1; vmla.f32 %[cd8],d9,d1\n\t" + "3:\n\t" + :[cd1]"=w"(cd1), [cd2]"=w"(cd2), [cd3]"=w"(cd3), [cd4]"=w"(cd4), + [cd5]"=w"(cd5), [cd6]"=w"(cd6), [cd7]"=w"(cd7), [cd8]"=w"(cd8), + [a_ptr1]"+r"(a_ptr1), [a_ptr2]"+r"(a_ptr2), [a_ptr3]"+r"(a_ptr3), + [a_ptr4]"+r"(a_ptr4), [b_ptr]"+r"(b_ptr), + [k_left]"+r"(k_left), [a_pref]"+r"(a_pref) + :[pref_inc]"r"(pref_inc), [k_inc]"r"(k_inc) + :"d0","d1","d2","d3","d4","d5","d6","d7","d8","d9", + "r0","cc","memory"); + + cd1 = vadd_f32(cd1, cd5); cd2 = vadd_f32(cd2, cd6); + cd3 = vadd_f32(cd3, cd7); cd4 = vadd_f32(cd4, cd8); + float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1); + float cs3 = vget_lane_f32(cd3, 0) + vget_lane_f32(cd3, 1); + float cs4 = vget_lane_f32(cd4, 0) + vget_lane_f32(cd4, 1); + for (; k_left > 0; k_left--) { + float bs1 = *b_ptr; b_ptr++; + cs1 += (*a_ptr1) * bs1; a_ptr1++; + cs2 += (*a_ptr2) * bs1; a_ptr2++; + cs3 += (*a_ptr3) * bs1; a_ptr3++; + cs4 += (*a_ptr4) * bs1; a_ptr4++; + } + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr[1] = c_ptr[1] * beta + cs2; + c_ptr[2] = c_ptr[2] * beta + cs3; c_ptr[3] = c_ptr[3] * beta + cs4; +} + +static inline void inline_sgemm_arowmajor_bskinny_m1n1(const float *a_ptr, + const float *b_ptr, float *c_ptr, uint32_t k_left, uint32_t LDK, uint32_t LDM, + float beta, bool c_rowmajor) { + + float32x4_t cq1; + __asm__ __volatile__( + "vmov.i8 d16,#0; vmov.i8 d17,#0\n\t" + "vmov d18,d16; vmov d19,d17\n\t" + "vmov d20,d16; vmov d21,d17\n\t" + "vmov d22,d16; vmov d23,d17\n\t" + "cmp %[K],#16; blt 4f\n\t" + "pld [%[a_ptr],#256]\n\t" + "add %[a_ptr],%[a_ptr],#64; add %[b_ptr],%[b_ptr],#64\n\t" + "vldr d24,[%[a_ptr],#-64]; vldr d8,[%[b_ptr],#-64]\n\t" + "vldr d25,[%[a_ptr],#-56]; vldr d9,[%[b_ptr],#-56]\n\t" + "vldr d26,[%[a_ptr],#-48]; vldr d10,[%[b_ptr],#-48]\n\t" + "vldr d27,[%[a_ptr],#-40]; vldr d11,[%[b_ptr],#-40]\n\t" + "vldr d28,[%[a_ptr],#-32]; vldr d12,[%[b_ptr],#-32]\n\t" + "vldr d29,[%[a_ptr],#-24]; vldr d13,[%[b_ptr],#-24]\n\t" + "vldr d30,[%[a_ptr],#-16]; vldr d14,[%[b_ptr],#-16]\n\t" + "vldr d31,[%[a_ptr],#-8]; vldr d15,[%[b_ptr],#-8]\n\t" + "cmp %[K],#32; blt 3f\n\t" + "2:\n\t" + "pld [%[a_ptr],#256]\n\t" + "add %[a_ptr],%[a_ptr],#64; add %[b_ptr],%[b_ptr],#64\n\t" + "vmla.f32 d16,d24,d8; vldr d24,[%[a_ptr],#-64]; vldr d8,[%[b_ptr],#-64]\n\t" + "vmla.f32 d17,d25,d9; vldr d25,[%[a_ptr],#-56]; vldr d9,[%[b_ptr],#-56]\n\t" + "vmla.f32 d18,d26,d10; vldr d26,[%[a_ptr],#-48]; vldr d10,[%[b_ptr],#-48]\n\t" + "vmla.f32 d19,d27,d11; vldr d27,[%[a_ptr],#-40]; vldr d11,[%[b_ptr],#-40]\n\t" + "sub %[K],%[K],#16\n\t" + "vmla.f32 d20,d28,d12; vldr d28,[%[a_ptr],#-32]; vldr d12,[%[b_ptr],#-32]\n\t" + "vmla.f32 d21,d29,d13; vldr d29,[%[a_ptr],#-24]; vldr d13,[%[b_ptr],#-24]\n\t" + "cmp %[K],#32\n\t" + "vmla.f32 d22,d30,d14; vldr d30,[%[a_ptr],#-16]; vldr d14,[%[b_ptr],#-16]\n\t" + "vmla.f32 d23,d31,d15; vldr d31,[%[a_ptr],#-8]; vldr d15,[%[b_ptr],#-8]\n\t" + "bge 2b\n\t" + "3:\n\t" + "vmla.f32 d16,d24,d8; vmla.f32 d17,d25,d9\n\t" + "vmla.f32 d18,d26,d10; vmla.f32 d19,d27,d11; sub %[K],%[K],#16\n\t" + "vmla.f32 d20,d28,d12; vmla.f32 d21,d29,d13\n\t" + "vmla.f32 d22,d30,d14; vmla.f32 d23,d31,d15\n\t" + "4:\n\t" + "vadd.f32 d16,d16,d20; vadd.f32 d17,d17,d21\n\t" + "vadd.f32 d18,d18,d22; vadd.f32 d19,d19,d23\n\t" + "cmp %[K],#8; blt 5f; add %[a_ptr],%[a_ptr],#32; add %[b_ptr],%[b_ptr],#32\n\t" + "vldr d24,[%[a_ptr],#-32]; vldr d8,[%[b_ptr],#-32]; vmla.f32 d16,d24,d8\n\t" + "vldr d25,[%[a_ptr],#-24]; vldr d9,[%[b_ptr],#-24]; vmla.f32 d17,d25,d9\n\t" + "sub %[K],%[K],#8\n\t" + "vldr d26,[%[a_ptr],#-16]; vldr d10,[%[b_ptr],#-16]; vmla.f32 d18,d26,d10\n\t" + "vldr d27,[%[a_ptr],#-8]; vldr d11,[%[b_ptr],#-8]; vmla.f32 d19,d27,d11\n\t" + "5:\n\t" + "vadd.f32 %e[cq1],d16,d17; vadd.f32 %f[cq1],d18,d19\n\t" + "cmp %[K],#4; blt 6f\n\t" + "add %[a_ptr],%[a_ptr],#16; add %[b_ptr],%[b_ptr],#16\n\t" + "vldr d24,[%[a_ptr],#-16]; vldr d8,[%[b_ptr],#-16]; vmla.f32 %e[cq1],d24,d8\n\t" + "sub %[K],%[K],#4\n\t" + "vldr d25,[%[a_ptr],#-8]; vldr d9,[%[b_ptr],#-8]; vmla.f32 %f[cq1],d25,d9\n\t" + "6:\n\t" + :[cq1]"=w"(cq1), [a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr), [K]"+r"(k_left) + ::"cc","memory","q12","q13","q14","q15", + "q4","q5","q6","q7","q8","q9","q10","q11"); + + float32x2_t cd1 = vadd_f32(vget_low_f32(cq1), vget_high_f32(cq1)); + if (k_left > 1) { + float32x2_t ad1 = vld1_f32(a_ptr); a_ptr += 2; + float32x2_t bd1 = vld1_f32(b_ptr); b_ptr += 2; + cd1 = vmla_f32(cd1, ad1, bd1); + k_left -= 2; + } + + float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + if (k_left > 0) { + cs1 += a_ptr[0] * b_ptr[0]; + } + c_ptr[0] = c_ptr[0] * beta + cs1; +} + +/* k_mask = 7 */ +static inline void inline_sgemm_arowmajor_bskinny_m4n2(const float *a_ptr1, + const float *b_ptr, float *c_ptr, uint32_t k_inc, uint32_t LDK, uint32_t LDM, + float beta, bool c_rowmajor) { + + const float *a_ptr2 = a_ptr1 + LDK; + const float *a_ptr3 = a_ptr1 + LDK * 2; + const float *a_ptr4 = a_ptr1 + LDK * 3; + const float *a_pref = a_ptr1 + LDK * 4; + uint32_t k_left = k_inc; + const uint32_t pref_inc = LDK > k_inc ? (LDK - k_inc) * sizeof(float) : 0; + float32x4_t cq1, cq2, cq3, cq4, cq5, cq6, cq7, cq8; + __asm__ __volatile__( + "mov r0,#0\n\t" + "vmov.i8 %q[cq1],#0; vmov.i8 %q[cq2],#0\n\t" + "vmov.i8 %q[cq3],#0; vmov.i8 %q[cq4],#0\n\t" + "vmov.i8 %q[cq5],#0; vmov.i8 %q[cq6],#0\n\t" + "vmov.i8 %q[cq7],#0; vmov.i8 %q[cq8],#0\n\t" + "cmp %[k_left],#4; blt 3f\n\t" + "vldm %[a_ptr1]!,{q2}; vldm %[a_ptr2]!,{q3}\n\t" + "vldm %[a_ptr3]!,{q4}; vldm %[a_ptr4]!,{q5}\n\t" + "vldm %[b_ptr]!,{q0}; vldm %[b_ptr]!,{q1}\n\t" + "cmp %[k_left],#8; blt 2f\n\t" + ".balign 16; 1:\n\t" + "pld [%[a_pref]]; add %[a_pref],%[a_pref],#64; add r0,r0,#16\n\t" + "vmla.f32 %q[cq1],q2,q0; cmp r0,%[k_inc]\n\t" + "vmla.f32 %q[cq5],q2,q1; vldm %[a_ptr1]!,{q2}\n\t" + "vmla.f32 %q[cq2],q3,q0; addgt %[a_pref],%[a_pref],%[pref_inc]\n\t" + "vmla.f32 %q[cq6],q3,q1; vldm %[a_ptr2]!,{q3}\n\t" + "sub %[k_left],%[k_left],#4\n\t" + "vmla.f32 %q[cq3],q4,q0; movgt r0,#0\n\t" + "vmla.f32 %q[cq7],q4,q1; vldm %[a_ptr3]!,{q4}\n\t" + "vmla.f32 %q[cq4],q5,q0; cmp %[k_left],#8\n\t" + "vmla.f32 %q[cq8],q5,q1; vldm %[a_ptr4]!,{q5}\n\t" + "vldm %[b_ptr]!,{q0}; vldm %[b_ptr]!,{q1}; bge 1b\n\t" + "2:\n\t" + "vmla.f32 %q[cq1],q2,q0; vmla.f32 %q[cq5],q2,q1\n\t" + "vmla.f32 %q[cq2],q3,q0; vmla.f32 %q[cq6],q3,q1\n\t" + "vmla.f32 %q[cq3],q4,q0; vmla.f32 %q[cq7],q4,q1\n\t" + "vmla.f32 %q[cq4],q5,q0; vmla.f32 %q[cq8],q5,q1\n\t" + "sub %[k_left],%[k_left],#4\n\t" + "3:\n\t" + :[cq1]"=w"(cq1), [cq2]"=w"(cq2), [cq3]"=w"(cq3), [cq4]"=w"(cq4), + [cq5]"=w"(cq5), [cq6]"=w"(cq6), [cq7]"=w"(cq7), [cq8]"=w"(cq8), + [k_left]"+r"(k_left), [a_pref]"+r"(a_pref), [b_ptr]"+r"(b_ptr), + [a_ptr1]"+r"(a_ptr1), [a_ptr2]"+r"(a_ptr2), + [a_ptr3]"+r"(a_ptr3), [a_ptr4]"+r"(a_ptr4) + :[pref_inc]"r"(pref_inc), [k_inc]"r"(k_inc) + :"d0","d1","d2","d3","d4","d5","d6","d7","d8","d9","d10","d11", + "r0","cc","memory"); + + float32x2_t cd1 = vadd_f32(vget_low_f32(cq1), vget_high_f32(cq1)); + float32x2_t cd2 = vadd_f32(vget_low_f32(cq2), vget_high_f32(cq2)); + float32x2_t cd3 = vadd_f32(vget_low_f32(cq3), vget_high_f32(cq3)); + float32x2_t cd4 = vadd_f32(vget_low_f32(cq4), vget_high_f32(cq4)); + float32x2_t cd5 = vadd_f32(vget_low_f32(cq5), vget_high_f32(cq5)); + float32x2_t cd6 = vadd_f32(vget_low_f32(cq6), vget_high_f32(cq6)); + float32x2_t cd7 = vadd_f32(vget_low_f32(cq7), vget_high_f32(cq7)); + float32x2_t cd8 = vadd_f32(vget_low_f32(cq8), vget_high_f32(cq8)); + if (k_left >= 2) { + float32x2_t bd1 = vld1_f32(b_ptr); + float32x2_t bd2 = vld1_f32(b_ptr + 2); b_ptr += 4; + float32x2_t ad1 = vld1_f32(a_ptr1); a_ptr1 += 2; + float32x2_t ad2 = vld1_f32(a_ptr2); a_ptr2 += 2; + float32x2_t ad3 = vld1_f32(a_ptr3); a_ptr3 += 2; + float32x2_t ad4 = vld1_f32(a_ptr4); a_ptr4 += 2; + cd1 = vmla_f32(cd1, ad1, bd1); + cd2 = vmla_f32(cd2, ad2, bd1); + cd3 = vmla_f32(cd3, ad3, bd1); + cd4 = vmla_f32(cd4, ad4, bd1); + cd5 = vmla_f32(cd5, ad1, bd2); + cd6 = vmla_f32(cd6, ad2, bd2); + cd7 = vmla_f32(cd7, ad3, bd2); + cd8 = vmla_f32(cd8, ad4, bd2); + k_left -= 2; + } + float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1); + float cs3 = vget_lane_f32(cd3, 0) + vget_lane_f32(cd3, 1); + float cs4 = vget_lane_f32(cd4, 0) + vget_lane_f32(cd4, 1); + float cs5 = vget_lane_f32(cd5, 0) + vget_lane_f32(cd5, 1); + float cs6 = vget_lane_f32(cd6, 0) + vget_lane_f32(cd6, 1); + float cs7 = vget_lane_f32(cd7, 0) + vget_lane_f32(cd7, 1); + float cs8 = vget_lane_f32(cd8, 0) + vget_lane_f32(cd8, 1); + if (k_left > 0) { + float bs1 = b_ptr[0]; + float bs2 = b_ptr[1]; + float as1 = *a_ptr1; + float as2 = *a_ptr2; + float as3 = *a_ptr3; + float as4 = *a_ptr4; + cs1 += as1 * bs1; cs2 += as2 * bs1; + cs3 += as3 * bs1; cs4 += as4 * bs1; + cs5 += as1 * bs2; cs6 += as2 * bs2; + cs7 += as3 * bs2; cs8 += as4 * bs2; + } + if (c_rowmajor) { + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr[1] = c_ptr[1] * beta + cs5; + c_ptr[2] = c_ptr[2] * beta + cs2; c_ptr[3] = c_ptr[3] * beta + cs6; + c_ptr[4] = c_ptr[4] * beta + cs3; c_ptr[5] = c_ptr[5] * beta + cs7; + c_ptr[6] = c_ptr[6] * beta + cs4; c_ptr[7] = c_ptr[7] * beta + cs8; + } else { + c_ptr[0] = c_ptr[0] * beta + cs1; c_ptr[1] = c_ptr[1] * beta + cs2; + c_ptr[2] = c_ptr[2] * beta + cs3; c_ptr[3] = c_ptr[3] * beta + cs4; + c_ptr += LDM; + c_ptr[0] = c_ptr[0] * beta + cs5; c_ptr[1] = c_ptr[1] * beta + cs6; + c_ptr[2] = c_ptr[2] * beta + cs7; c_ptr[3] = c_ptr[3] * beta + cs8; + } +} + +static inline void inline_sgemm_arowmajor_bskinny_m1n2(const float *a_ptr, + const float *b_ptr, float *c_ptr, uint32_t k_left, uint32_t LDK, uint32_t LDM, + float beta, bool c_rowmajor) { + + register float32x4_t cq1 __asm("q8"); + __asm__ __volatile__( + "vmov.i8 %q[cq1],#0; vmov.i8 q9,#0\n\t" + "vmov.i8 q10,#0; vmov.i8 q11,#0\n\t" + "cmp %[k_left],#16; blt 4f\n\t" + "pld [%[a_ptr],#256]\n\t" + "vldm %[a_ptr]!,{q12,q13,q14,q15}\n\t" + "vldm %[b_ptr]!,{q0,q1,q2,q3}\n\t" + "vldm %[b_ptr]!,{q4,q5,q6,q7}\n\t" + "cmp %[k_left],#32; blt 3f\n\t" + ".balign 16; 2:\n\t" + "pld [%[a_ptr],#256]\n\t" + "vmla.f32 %q[cq1],q12,q0; vldm %[b_ptr]!,{q0}\n\t" + "vmla.f32 q10,q12,q1; vldm %[b_ptr]!,{q1}; vldm %[a_ptr]!,{q12}\n\t" + "vmla.f32 q9,q13,q2; vldm %[b_ptr]!,{q2}\n\t" + "vmla.f32 q11,q13,q3; vldm %[b_ptr]!,{q3}; vldm %[a_ptr]!,{q13}\n\t" + "sub %[k_left],%[k_left],#16\n\t" + "vmla.f32 %q[cq1],q14,q4; vldm %[b_ptr]!,{q4}\n\t" + "vmla.f32 q10,q14,q5; vldm %[b_ptr]!,{q5}; vldm %[a_ptr]!,{q14}\n\t" + "cmp %[k_left],#32\n\t" + "vmla.f32 q9,q15,q6; vldm %[b_ptr]!,{q6}\n\t" + "vmla.f32 q11,q15,q7; vldm %[b_ptr]!,{q7}; vldm %[a_ptr]!,{q15}\n\t" + "bge 2b\n\t" + "3:\n\t" + "vmla.f32 %q[cq1],q12,q0; vmla.f32 q10,q12,q1; sub %[k_left],%[k_left],#16\n\t" + "vmla.f32 q9,q13,q2; vmla.f32 q11,q13,q3\n\t" + "vmla.f32 %q[cq1],q14,q4; vmla.f32 q10,q14,q5\n\t" + "vmla.f32 q9,q15,q6; vmla.f32 q11,q15,q7\n\t" + "4:\n\t" + "cmp %[k_left],#8; blt 5f\n\t" + "vldm %[a_ptr]!,{q12}; vldm %[b_ptr]!,{q0,q1}\n\t" + "vldm %[a_ptr]!,{q13}; vldm %[b_ptr]!,{q2,q3}\n\t" + "vmla.f32 %q[cq1],q12,q0; vmla.f32 q10,q12,q1\n\t" + "sub %[k_left],%[k_left],#8\n\t" + "vmla.f32 q9,q13,q2; vmla.f32 q11,q13,q3\n\t" + "5:\n\t" + "vadd.f32 %q[cq1],%q[cq1],q9; vadd.f32 q10,q10,q11\n\t" + "cmp %[k_left],#4; blt 6f\n\t" + "vldm %[a_ptr]!,{q12}; vldm %[b_ptr]!,{q4}; vldm %[b_ptr]!,{q0}\n\t" + "vmla.f32 %q[cq1],q12,q4; vmla.f32 q10,q12,q0\n\t" + "sub %[k_left],%[k_left],#4\n\t" + "6:\n\t" + "vadd.f32 %e[cq1],%e[cq1],%f[cq1]; vadd.f32 %f[cq1],d20,d21\n\t" + "cmp %[k_left],#2; blt 7f\n\t" + "vld1.32 {d24},[%[a_ptr]]!\n\t" + "vld1.32 {d8},[%[b_ptr]]!; vld1.32 {d0},[%[b_ptr]]!\n\t" + "vmla.f32 %e[cq1],d24,d8; vmla.f32 %f[cq1],d24,d0\n\t" + "sub %[k_left],%[k_left],#2\n\t" + "7:\n\t" + :[cq1]"=w"(cq1), [a_ptr]"+r"(a_ptr), + [k_left]"+r"(k_left), [b_ptr]"+r"(b_ptr) + ::"cc","memory","q0","q1","q2","q3","q4","q5","q6","q7", + "q9","q10","q11","q12","q13","q14","q15"); + + float32x2_t cd1 = vpadd_f32(vget_low_f32(cq1), vget_high_f32(cq1)); + if (k_left > 0) { + float as1 = *a_ptr; + float32x2_t bd1 = vld1_f32(b_ptr); + cd1 = vmla_n_f32(cd1, bd1, as1); + } + + if (c_rowmajor) { + cd1 = vmla_n_f32(cd1, vld1_f32(c_ptr), beta); + vst1_f32(c_ptr, cd1); + } else { + c_ptr[0] = c_ptr[0] * beta + vget_lane_f32(cd1, 0); + c_ptr[LDM] = c_ptr[LDM] * beta + vget_lane_f32(cd1, 1); + } +} + +static inline bool unroll_test_m4n1(uint32_t M, uint32_t K) { + return K <= 512; +} + +static inline bool unroll_test_m1n1(uint32_t M, uint32_t K) { + return true; +} + +static inline bool unroll_test_m4n2(uint32_t M, uint32_t K) { + return K <= 512; +} + +static inline bool unroll_test_m1n2(uint32_t M, uint32_t K) { + return true; +} + +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(sgemm, 1, 5, 5, 32768, float, float, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(sgemm, 2, 7, 5, 32768, float, float, unroll_test) + +typedef float sgemm_skinnydot_avec1; +typedef float sgemm_skinnydot_bvec1; +typedef float sgemm_skinnydot_cvec1; + +typedef float32x2_t sgemm_skinnydot_avec2; +typedef float32x2_t sgemm_skinnydot_bvec2; +typedef float32x2_t sgemm_skinnydot_cvec2; + +typedef float32x4_t sgemm_skinnydot_avec4; +typedef float32x4_t sgemm_skinnydot_bvec4; +typedef float32x4_t sgemm_skinnydot_cvec4; + +typedef float32x4x2_t sgemm_skinnydot_avec8; +typedef float32x4x2_t sgemm_skinnydot_bvec8; +typedef float32x4x2_t sgemm_skinnydot_cvec8; + +GEMM_SKINNY_DOT_CALC_UNIT(sgemm, 8) { + float32x4x2_t ret; + ret.val[0] = vmlaq_f32(c_vec.val[0], a_vec.val[0], b_vec.val[0]); + ret.val[1] = vmlaq_f32(c_vec.val[1], a_vec.val[1], b_vec.val[1]); + return ret; +} + +GEMM_SKINNY_DOT_CALC_UNIT(sgemm, 4) { + return vmlaq_f32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_DOT_CALC_UNIT(sgemm, 2) { + return vmla_f32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_DOT_CALC_UNIT(sgemm, 1) { + return c_vec + a_vec * b_vec; +} + +GEMM_SKINNY_DOT_LOADA_UNIT(sgemm, 8) { + __asm__("pld [%0,#96]"::"r"(a_ptr):); + float32x4x2_t ret; + ret.val[0] = vld1q_f32(a_ptr); + ret.val[1] = vld1q_f32(a_ptr + 4); + return ret; +} + +GEMM_SKINNY_DOT_LOADA_UNIT(sgemm, 4) { + __asm__("pld [%0,#80]"::"r"(a_ptr):); + return vld1q_f32(a_ptr); +} + +GEMM_SKINNY_DOT_LOADA_UNIT(sgemm, 2) { + __asm__("pld [%0,#72]"::"r"(a_ptr):); + return vld1_f32(a_ptr); +} + +GEMM_SKINNY_DOT_LOADA_UNIT(sgemm, 1) { + return *a_ptr; +} + +GEMM_SKINNY_DOT_LOADB_UNIT(sgemm, 8) { + float32x4x2_t ret; + ret.val[0] = vld1q_f32(b_ptr); + ret.val[1] = vld1q_f32(b_ptr + 4); + return ret; +} + +GEMM_SKINNY_DOT_LOADB_UNIT(sgemm, 4) { + return vld1q_f32(b_ptr); +} + +GEMM_SKINNY_DOT_LOADB_UNIT(sgemm, 2) { + return vld1_f32(b_ptr); +} + +GEMM_SKINNY_DOT_LOADB_UNIT(sgemm, 1) { + return *b_ptr; +} + +GEMM_SKINNY_DOT_REDUC_UNIT(sgemm, 8, 4) { + return vaddq_f32(c_vec.val[0], c_vec.val[1]); +} + +GEMM_SKINNY_DOT_REDUC_UNIT(sgemm, 4, 2) { + return vadd_f32(vget_low_f32(c_vec), vget_high_f32(c_vec)); +} + +GEMM_SKINNY_DOT_REDUC_UNIT(sgemm, 2, 1) { + return vget_lane_f32(c_vec, 0) + vget_lane_f32(c_vec, 1); +} + +GEMM_SKINNY_DOT_INITC_UNIT(sgemm, 8) { + float32x4x2_t ret; + ret.val[0] = vdupq_n_f32(0); + ret.val[1] = vdupq_n_f32(0); + return ret; +} + +GEMM_SKINNY_DOT_INITC_UNIT(sgemm, 4) { + return vdupq_n_f32(0); +} + +GEMM_SKINNY_DOT_INITC_UNIT(sgemm, 2) { + return vdup_n_f32(0); +} + +GEMM_SKINNY_DOT_INITC_UNIT(sgemm, 1) { + return 0; +} + +GEMM_SKINNY_DOT_PARALLEL_FUNC(sgemm, 3, 3, 7, 32768, float, float) +GEMM_SKINNY_DOT_PARALLEL_FUNC(sgemm, 4, 3, 7, 32768, float, float) +GEMM_SKINNY_DOT_PARALLEL_FUNC(sgemm, 5, 3, 7, 32768, float, float) +GEMM_SKINNY_DOT_PARALLEL_FUNC(sgemm, 6, 3, 7, 32768, float, float) +GEMM_SKINNY_DOT_PARALLEL_FUNC(sgemm, 7, 3, 3, 32768, float, float) +GEMM_SKINNY_DOT_PARALLEL_FUNC(sgemm, 8, 3, 3, 32768, float, float) diff --git a/src/neon_armv7a/SgemmSkinnyGer.c b/src/neon_armv7a/SgemmSkinnyGer.c new file mode 100644 index 0000000..a051de7 --- /dev/null +++ b/src/neon_armv7a/SgemmSkinnyGer.c @@ -0,0 +1,280 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonSkinnyGer.h" + +#include + +typedef float sgemm_skinnyger_ascalar; +typedef float sgemm_skinnyger_bscalar; +typedef float sgemm_skinnyger_cscalar; + +typedef float sgemm_skinnyger_avec1; +typedef float sgemm_skinnyger_bvec1; +typedef float sgemm_skinnyger_cvec1; + +typedef float32x2_t sgemm_skinnyger_avec2; +typedef float32x2_t sgemm_skinnyger_bvec2; +typedef float32x2_t sgemm_skinnyger_cvec2; + +typedef float32x4_t sgemm_skinnyger_avec4; +typedef float32x4_t sgemm_skinnyger_bvec4; +typedef float32x4_t sgemm_skinnyger_cvec4; + +typedef float32x4x2_t sgemm_skinnyger_avec8; +typedef float32x4x2_t sgemm_skinnyger_bvec8; +typedef float32x4x2_t sgemm_skinnyger_cvec8; + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 1) { + float32x4x2_t ret; + ret.val[0] = vmlaq_lane_f32(c_vec.val[0], a_vec.val[0], vget_low_f32(b_vec), 0); + ret.val[1] = vmlaq_lane_f32(c_vec.val[1], a_vec.val[1], vget_low_f32(b_vec), 0); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 2) { + float32x4x2_t ret; + ret.val[0] = vmlaq_lane_f32(c_vec.val[0], a_vec.val[0], vget_low_f32(b_vec), 1); + ret.val[1] = vmlaq_lane_f32(c_vec.val[1], a_vec.val[1], vget_low_f32(b_vec), 1); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 3) { + float32x4x2_t ret; + ret.val[0] = vmlaq_lane_f32(c_vec.val[0], a_vec.val[0], vget_high_f32(b_vec), 0); + ret.val[1] = vmlaq_lane_f32(c_vec.val[1], a_vec.val[1], vget_high_f32(b_vec), 0); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 4) { + float32x4x2_t ret; + ret.val[0] = vmlaq_lane_f32(c_vec.val[0], a_vec.val[0], vget_high_f32(b_vec), 1); + ret.val[1] = vmlaq_lane_f32(c_vec.val[1], a_vec.val[1], vget_high_f32(b_vec), 1); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 2, 1) { + float32x4x2_t ret; + ret.val[0] = vmlaq_lane_f32(c_vec.val[0], a_vec.val[0], b_vec, 0); + ret.val[1] = vmlaq_lane_f32(c_vec.val[1], a_vec.val[1], b_vec, 0); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 2, 2) { + float32x4x2_t ret; + ret.val[0] = vmlaq_lane_f32(c_vec.val[0], a_vec.val[0], b_vec, 1); + ret.val[1] = vmlaq_lane_f32(c_vec.val[1], a_vec.val[1], b_vec, 1); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 1, 1) { + float32x4x2_t ret; + ret.val[0] = vmlaq_n_f32(c_vec.val[0], a_vec.val[0], b_vec); + ret.val[1] = vmlaq_n_f32(c_vec.val[1], a_vec.val[1], b_vec); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 4, 1) { + return vmlaq_lane_f32(c_vec, a_vec, vget_low_f32(b_vec), 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 4, 2) { + return vmlaq_lane_f32(c_vec, a_vec, vget_low_f32(b_vec), 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 4, 3) { + return vmlaq_lane_f32(c_vec, a_vec, vget_high_f32(b_vec), 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 4, 4) { + return vmlaq_lane_f32(c_vec, a_vec, vget_high_f32(b_vec), 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 2, 1) { + return vmlaq_lane_f32(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 2, 2) { + return vmlaq_lane_f32(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 1, 1) { + return vmlaq_n_f32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 4, 1) { + return vmla_lane_f32(c_vec, a_vec, vget_low_f32(b_vec), 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 4, 2) { + return vmla_lane_f32(c_vec, a_vec, vget_low_f32(b_vec), 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 4, 3) { + return vmla_lane_f32(c_vec, a_vec, vget_high_f32(b_vec), 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 4, 4) { + return vmla_lane_f32(c_vec, a_vec, vget_high_f32(b_vec), 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 2, 1) { + return vmla_lane_f32(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 2, 2) { + return vmla_lane_f32(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 1, 1) { + return vmla_n_f32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 4, 1) { + return c_vec + a_vec * vgetq_lane_f32(b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 4, 2) { + return c_vec + a_vec * vgetq_lane_f32(b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 4, 3) { + return c_vec + a_vec * vgetq_lane_f32(b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 4, 4) { + return c_vec + a_vec * vgetq_lane_f32(b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 2, 1) { + return c_vec + a_vec * vget_lane_f32(b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 2, 2) { + return c_vec + a_vec * vget_lane_f32(b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 1, 1) { + return a_vec * b_vec + c_vec; +} + +GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 8) { + float32x4x2_t ret; + ret.val[0] = vld1q_f32(a_ptr); + ret.val[1] = vld1q_f32(a_ptr + 4); + __asm__("pld [%0,#96]"::"r"(a_ptr):); + return ret; +} + +GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 4) { + __asm__("pld [%0,#80]"::"r"(a_ptr):); + return vld1q_f32(a_ptr); +} + +GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 2) { + __asm__("pld [%0,#72]"::"r"(a_ptr):); + return vld1_f32(a_ptr); +} + +GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 1) { + return *a_ptr; +} + +GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 8) { + float32x4x2_t ret; + ret.val[0] = vld1q_f32(c_ptr); + ret.val[1] = vld1q_f32(c_ptr + 4); + return ret; +} + +GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 4) { + return vld1q_f32(c_ptr); +} + +GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 2) { + return vld1_f32(c_ptr); +} + +GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 1) { + return *c_ptr; +} + +GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 8) { + vst1q_f32(c_ptr, c_vec.val[0]); + vst1q_f32(c_ptr + 4, c_vec.val[1]); +} + +GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 4) { + vst1q_f32(c_ptr, c_vec); +} + +GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 2) { + vst1_f32(c_ptr, c_vec); +} + +GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 1) { + *c_ptr = c_vec; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(sgemm, 4) { + float32x4_t ret = vdupq_n_f32(0); + float b1 = *b_ptr; b_ptr += ldb; + float b2 = *b_ptr; b_ptr += ldb; + float b3 = *b_ptr; b_ptr += ldb; + float b4 = *b_ptr; + ret = vsetq_lane_f32(b1, ret, 0); + ret = vsetq_lane_f32(b2, ret, 1); + ret = vsetq_lane_f32(b3, ret, 2); + ret = vsetq_lane_f32(b4, ret, 3); + return ret; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(sgemm, 2) { + float32x2_t ret = vdup_n_f32(0); + float b1 = *b_ptr; + float b2 = b_ptr[ldb]; + ret = vset_lane_f32(b1, ret, 0); + ret = vset_lane_f32(b2, ret, 1); + return ret; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(sgemm, 1) { + return *b_ptr; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(sgemm, 4) { + return vld1q_f32(b_ptr); +} + +GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(sgemm, 2) { + return vld1_f32(b_ptr); +} + +GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(sgemm, 1) { + return *b_ptr; +} + +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 1, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 2, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 3, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 4, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 5, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 6, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 7, 7, 3, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 8, 7, 3, 8192, float, float) + diff --git a/src/neon_armv7a/U8U32GemmDriver.c b/src/neon_armv7a/U8U32GemmDriver.c new file mode 100644 index 0000000..f67fda7 --- /dev/null +++ b/src/neon_armv7a/U8U32GemmDriver.c @@ -0,0 +1,42 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv7a/U8U32MlaGemmDriver.h" +#include "arm_neon/ARMCpuType.h" + +int u8u32gemm_serial(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, uint32_t beta_inp) { + + if (blas_arm_get_i8i32_support() == 0) { + return 2; + } + return u8u32mlagemm_serial(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp); +} + +int u8u32gemm(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, + uint32_t beta_inp, uint32_t num_threads) { + + if (blas_arm_get_i8i32_support() == 0) { + return 2; + } + return u8u32mlagemm(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp, num_threads); +} diff --git a/src/neon_armv7a/U8U32MlaGemmCopy.c b/src/neon_armv7a/U8U32MlaGemmCopy.c new file mode 100644 index 0000000..be09b7f --- /dev/null +++ b/src/neon_armv7a/U8U32MlaGemmCopy.c @@ -0,0 +1,30 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "common/CommonCopy.h" +#include "arm_neon/NeonI8I32MlaGemmCopy.h" + +GENERIC_NCOPY_FUNC(u8u32mlagemm, uint8_t, uint16_t, 6) +GENERIC_NCOPY_FUNC(u8u32mlagemm, uint8_t, uint16_t, 8) + +GENERIC_TCOPY_FUNC(u8u32mlagemm, uint8_t, uint16_t, 6) +GENERIC_TCOPY_FUNC(u8u32mlagemm, uint8_t, uint16_t, 8) + diff --git a/src/neon_armv7a/U8U32MlaGemmDriver.c b/src/neon_armv7a/U8U32MlaGemmDriver.c new file mode 100644 index 0000000..071a898 --- /dev/null +++ b/src/neon_armv7a/U8U32MlaGemmDriver.c @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv7a/U8U32MlaGemmCopy.h" +#include "neon_armv7a/U8U32MlaGemmKernel.h" +#include "neon_armv7a/U8U32MlaGemmSkinnyDot.h" +#include "neon_armv7a/U8U32MlaGemmSkinnyGer.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +GEMM_PARALLEL_FUNC(u8u32mlagemm, uint8_t, uint16_t, uint8_t, uint16_t, uint32_t, + 6, 8, 4, 4, 4, 4) + diff --git a/src/neon_armv7a/U8U32MlaGemmKernel.c b/src/neon_armv7a/U8U32MlaGemmKernel.c new file mode 100644 index 0000000..5638fd9 --- /dev/null +++ b/src/neon_armv7a/U8U32MlaGemmKernel.c @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "common/CommonKernel.h" +#include "neon_armv7a/I8I32MlaGemmKernel.h" + +DUALPACK_KERNEL_FUNC_LM(u8u32mlagemm, uint16_t, uint16_t, uint32_t, 6, 8) +DUALPACK_KERNEL_FUNC_LN(u8u32mlagemm, uint16_t, uint16_t, uint32_t, 8, 6) + diff --git a/src/neon_armv7a/U8U32MlaGemmSkinnyDot.c b/src/neon_armv7a/U8U32MlaGemmSkinnyDot.c new file mode 100644 index 0000000..36e1c25 --- /dev/null +++ b/src/neon_armv7a/U8U32MlaGemmSkinnyDot.c @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "arm_neon/NeonI8I32MlaGemmSkinnyDot.h" +#include "common/CommonSkinnyDot.h" + +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 1, 15, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 2, 15, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 3, 15, 3, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 4, 15, 3, 131072, uint8_t, uint8_t) \ No newline at end of file diff --git a/src/neon_armv7a/U8U32MlaGemmSkinnyGer.c b/src/neon_armv7a/U8U32MlaGemmSkinnyGer.c new file mode 100644 index 0000000..8bc28a0 --- /dev/null +++ b/src/neon_armv7a/U8U32MlaGemmSkinnyGer.c @@ -0,0 +1,29 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "arm_neon/NeonI8I32MlaGemmSkinnyGer.h" + +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 1, 5, 5, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 2, 5, 5, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 3, 5, 5, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 4, 5, 5, 8192, uint8_t, uint8_t) + diff --git a/src/neon_armv8a/Bias.c b/src/neon_armv8a/Bias.c new file mode 100644 index 0000000..08519e8 --- /dev/null +++ b/src/neon_armv8a/Bias.c @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/NeonBias.h" +#include "arm_neon/NeonSum.h" + +NEON_BIAS(float, float32x4_t, f32, 4, fma) + +NEON_BIAS(int32_t, int32x4_t, s32, 4, mla) + +NEON_I8I32_SUM(u, uint) + +NEON_I16_SUMSQUARE(s, int) + diff --git a/src/neon_armv8a/HgemmDriver.c b/src/neon_armv8a/HgemmDriver.c new file mode 100644 index 0000000..3c31786 --- /dev/null +++ b/src/neon_armv8a/HgemmDriver.c @@ -0,0 +1,28 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/HgemmKernel.h" +#include "neon_armv8a/HgemmCopy.h" +#include "neon_armv8a/HgemmSkinnyDot.h" +#include "neon_armv8a/HgemmSkinnyGer.h" +#include "arm_neon/ARMCpuType.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +GEMM_PARALLEL_FUNC(hgemm, float16_t, float16_t, float16_t, float16_t, float16_t, + 8, 16, 12, 12, 12, 12, || blas_arm_get_fp16_support() < 2) + diff --git a/src/neon_armv8a/Layer.c b/src/neon_armv8a/Layer.c new file mode 100644 index 0000000..1d8bee2 --- /dev/null +++ b/src/neon_armv8a/Layer.c @@ -0,0 +1,24 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/SgemmDriver.h" +#include "neon_armv8a/Bias.h" +#include "common/CommonLayer.h" +#include + +SIMPLE_FC_FUNC(sgemm, float, float, float) + diff --git a/src/neon_armv8a/Quant.c b/src/neon_armv8a/Quant.c new file mode 100644 index 0000000..3835294 --- /dev/null +++ b/src/neon_armv8a/Quant.c @@ -0,0 +1,52 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "common/CommonQuant.h" +#include "arm_neon/NeonQuant.h" + +NEON_FIND_EXTREME(float32_t, f32, float32x2_t, float32x4_t, 2) + +QUANTIZE_ASYMMETRIC(32, 8) + +QUANTIZE_SYMMETRIC(32, 8) + +QUANTIZE_ASYMMETRIC(32, 16) + +QUANTIZE_SYMMETRIC(32, 16) + +void dequantize_symmetric_f32_s32(const int32_t *src, float32_t *dst, + float32_t scale, uint32_t size) { + + inline_dequant_cvt_f32_s32(dst, src, scale, size); +} + +NEON_FIND_EXTREME(int32_t, s32, int32x2_t, int32x4_t, 2) + +NEON_FIND_EXTREME(int16_t, s16, int16x4_t, int16x8_t, 4) + +REQUANTIZE_ASYMMETRIC_MULHI(float, 32, 8, 64) + +REQUANTIZE_SYMMETRIC_MULHI(float, 32, 8, 64) + +REQUANTIZE_ASYMMETRIC_MULHI(float, 32, 16, 64) + +REQUANTIZE_SYMMETRIC_MULHI(float, 32, 16, 64) + +REQUANTIZE_ASYMMETRIC_MULHI(float, 16, 8, 32) + +REQUANTIZE_SYMMETRIC_MULHI(float, 16, 8, 32) + diff --git a/src/neon_armv8a/S8S32DotGemmDriver.c b/src/neon_armv8a/S8S32DotGemmDriver.c new file mode 100644 index 0000000..9d77d1f --- /dev/null +++ b/src/neon_armv8a/S8S32DotGemmDriver.c @@ -0,0 +1,36 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/S8S32DotGemmCopy.h" +#include "neon_armv8a/S8S32DotGemmKernel.h" +#include "neon_armv8a/S8S32DotGemmSkinnyDot.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +#ifdef SCRATCH_K_CORD +#undef SCRATCH_K_CORD +#define SCRATCH_K_CORD(k) ((k) >> 2) +#endif + +#ifdef GEMM_D_K +#undef GEMM_D_K +#define GEMM_D_K 768 +#endif + +GEMM_PARALLEL_FUNC(s8s32dotgemm, int8_t, int32_t, int8_t, int32_t, int32_t, + 8, 12, 12, 12, 0, 0) + diff --git a/src/neon_armv8a/S8S32GemmDriver.c b/src/neon_armv8a/S8S32GemmDriver.c new file mode 100644 index 0000000..7dbbf90 --- /dev/null +++ b/src/neon_armv8a/S8S32GemmDriver.c @@ -0,0 +1,48 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/S8S32MlaGemmDriver.h" +#include "neon_armv8a/S8S32DotGemmDriver.h" +#include "arm_neon/ARMCpuType.h" + +int s8s32gemm_serial(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, int32_t beta_inp) { + + if (blas_arm_get_i8i32_support() == 2) { + return s8s32dotgemm_serial(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp); + } else { + return s8s32mlagemm_serial(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp); + } +} + +int s8s32gemm(int a_rowmajor, int b_rowmajor, + const int8_t *A, const int8_t *B, + int32_t *C, uint32_t M, uint32_t N, uint32_t K, + int32_t beta_inp, uint32_t num_threads) { + + if (blas_arm_get_i8i32_support() == 2) { + return s8s32dotgemm(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp, num_threads); + } else { + return s8s32mlagemm(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp, num_threads); + } +} + diff --git a/src/neon_armv8a/S8S32MlaGemmCopy.c b/src/neon_armv8a/S8S32MlaGemmCopy.c new file mode 100644 index 0000000..42c9451 --- /dev/null +++ b/src/neon_armv8a/S8S32MlaGemmCopy.c @@ -0,0 +1,30 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "common/CommonCopy.h" +#include "arm_neon/NeonI8I32MlaGemmCopy.h" + +GENERIC_NCOPY_FUNC(s8s32mlagemm, int8_t, int16_t, 8) +GENERIC_NCOPY_FUNC(s8s32mlagemm, int8_t, int16_t, 12) + +GENERIC_TCOPY_FUNC(s8s32mlagemm, int8_t, int16_t, 8) +GENERIC_TCOPY_FUNC(s8s32mlagemm, int8_t, int16_t, 12) + diff --git a/src/neon_armv8a/S8S32MlaGemmDriver.c b/src/neon_armv8a/S8S32MlaGemmDriver.c new file mode 100644 index 0000000..e52bd40 --- /dev/null +++ b/src/neon_armv8a/S8S32MlaGemmDriver.c @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/S8S32MlaGemmCopy.h" +#include "neon_armv8a/S8S32MlaGemmKernel.h" +#include "neon_armv8a/S8S32MlaGemmSkinnyGer.h" +#include "neon_armv8a/S8S32MlaGemmSkinnyDot.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +GEMM_PARALLEL_FUNC(s8s32mlagemm, int8_t, int16_t, int8_t, int16_t, int32_t, + 8, 12, 8, 8, 8, 8) + diff --git a/src/neon_armv8a/S8S32MlaGemmKernel.c b/src/neon_armv8a/S8S32MlaGemmKernel.c new file mode 100644 index 0000000..110f834 --- /dev/null +++ b/src/neon_armv8a/S8S32MlaGemmKernel.c @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "common/CommonKernel.h" +#include "neon_armv8a/I8I32MlaGemmKernel.h" + +DUALPACK_KERNEL_FUNC_LM(s8s32mlagemm, int16_t, int16_t, int32_t, 8, 12) +DUALPACK_KERNEL_FUNC_LN(s8s32mlagemm, int16_t, int16_t, int32_t, 12, 8) + diff --git a/src/neon_armv8a/S8S32MlaGemmSkinnyDot.c b/src/neon_armv8a/S8S32MlaGemmSkinnyDot.c new file mode 100644 index 0000000..3d45eac --- /dev/null +++ b/src/neon_armv8a/S8S32MlaGemmSkinnyDot.c @@ -0,0 +1,34 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "neon_armv8a/I8I32MlaGemmSkinnyDot.h" +#include "common/CommonSkinnyDot.h" + +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(s8s32mlagemm, 1, 31, 5, 131072, int8_t, int8_t, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(s8s32mlagemm, 2, 31, 5, 131072, int8_t, int8_t, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(s8s32mlagemm, 3, 31, 5, 131072, int8_t, int8_t, unroll_test) + +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 4, 15, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 5, 15, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 6, 15, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 7, 15, 3, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32mlagemm, 8, 15, 3, 131072, int8_t, int8_t) diff --git a/src/neon_armv8a/S8S32MlaGemmSkinnyGer.c b/src/neon_armv8a/S8S32MlaGemmSkinnyGer.c new file mode 100644 index 0000000..c8a739b --- /dev/null +++ b/src/neon_armv8a/S8S32MlaGemmSkinnyGer.c @@ -0,0 +1,32 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "arm_neon/NeonI8I32MlaGemmSkinnyGer.h" + +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 1, 5, 29, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 2, 5, 29, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 3, 5, 29, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 4, 5, 29, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 5, 5, 13, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 6, 5, 13, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 7, 5, 13, 8192, int8_t, int8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(s8s32mlagemm, 8, 5, 13, 8192, int8_t, int8_t) diff --git a/src/neon_armv8a/SgemmCopy.c b/src/neon_armv8a/SgemmCopy.c new file mode 100644 index 0000000..fe6308b --- /dev/null +++ b/src/neon_armv8a/SgemmCopy.c @@ -0,0 +1,30 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "common/CommonCopy.h" +#include "arm_neon/NeonSgemmCopy.h" + +#define NCOPY_float_float(unroll) NCOPY_UNROLL_##unroll + +GENERIC_NCOPY_FUNC(sgemm, float, float, 8) +GENERIC_NCOPY_FUNC(sgemm, float, float, 12) + +#define TCOPY_UNIT_float_float(src_ptr, dst_ptr, dst_offset, num_elements) \ + TCOPY_UNIT_##num_elements(src_ptr, dst_ptr, dst_offset) + +GENERIC_TCOPY_FUNC(sgemm, float, float, 8) +GENERIC_TCOPY_FUNC(sgemm, float, float, 12) diff --git a/src/neon_armv8a/SgemmDriver.c b/src/neon_armv8a/SgemmDriver.c new file mode 100644 index 0000000..09ed573 --- /dev/null +++ b/src/neon_armv8a/SgemmDriver.c @@ -0,0 +1,26 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/SgemmKernel.h" +#include "neon_armv8a/SgemmCopy.h" +#include "neon_armv8a/SgemmSkinnyDot.h" +#include "neon_armv8a/SgemmSkinnyGer.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +GEMM_PARALLEL_FUNC(sgemm, float, float, float, float, float, 8, 12, 50, 50, 12, 12) + diff --git a/src/neon_armv8a/SgemmKernel.c b/src/neon_armv8a/SgemmKernel.c new file mode 100644 index 0000000..7e9bc94 --- /dev/null +++ b/src/neon_armv8a/SgemmKernel.c @@ -0,0 +1,1071 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif +#include "common/CommonKernel.h" +#include "arm_neon/NeonSgemmKernel.h" +#include "arm_neon/ARMCpuType.h" +#include + +#define NEON_SGEMM_KERNEL_M8N12_PRELOAD_A53 \ + "ldr q0,[%25]; add %25,%25,#32\n\t"\ + "ldr q3,[%26]; ldr d5,[%26,#16]; ldr x0,[%26,#24]; add %26,%26,#48\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_MAIN2_A53 \ + "fmov v5.d[1],x0; ldr d7,[%26,#-16]\n\t"\ + "fmla %0.4s,v0.4s,v3.s[0]; ldr x0,[%26,#-8]\n\t"\ + "fmla %2.4s,v0.4s,v3.s[1]; fmla %4.4s,v0.4s,v3.s[2]\n\t"\ + "fmov v7.d[1],x0; ldr d2,[%25,#-16]\n\t"\ + "fmla %6.4s,v0.4s,v3.s[3]; ldr x0,[%25,#-8]\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]; fmla %16.4s,v0.4s,v7.s[0]\n\t"\ + "fmov v2.d[1],x0; ldr d4,[%26]\n\t"\ + "fmla %18.4s,v0.4s,v7.s[1]; ldr x0,[%26,#8]\n\t"\ + "fmla %20.4s,v0.4s,v7.s[2]; fmla %22.4s,v0.4s,v7.s[3]\n\t"\ + "fmov v4.d[1],x0; ldr d1,[%25]\n\t"\ + "fmla %17.4s,v2.4s,v7.s[0]; ldr x0,[%25,#8]\n\t"\ + "fmla %19.4s,v2.4s,v7.s[1]; fmla %21.4s,v2.4s,v7.s[2]\n\t"\ + "fmov v1.d[1],x0; ldr d6,[%26,#16]\n\t"\ + "fmla %23.4s,v2.4s,v7.s[3]; ldr x0,[%26,#24]\n\t"\ + "fmla %9.4s,v2.4s,v5.s[0]; fmla %1.4s,v2.4s,v3.s[0]\n\t"\ + "fmov v6.d[1],x0; ldr d7,[%26,#32]\n\t"\ + "fmla %3.4s,v2.4s,v3.s[1]; ldr x0,[%26,#40]\n\t"\ + "fmla %5.4s,v2.4s,v3.s[2]; fmla %7.4s,v2.4s,v3.s[3]\n\t"\ + "fmov v7.d[1],x0; ldr d3,[%26,#48]\n\t"\ + "fmla %11.4s,v2.4s,v5.s[1]; ldr x0,[%26,#56]\n\t"\ + "fmla %13.4s,v2.4s,v5.s[2]; fmla %15.4s,v2.4s,v5.s[3]\n\t"\ + "fmov v3.d[1],x0; ldr d2,[%25,#16]\n\t"\ + "fmla %10.4s,v0.4s,v5.s[1]; ldr x0,[%25,#24]\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]; fmla %14.4s,v0.4s,v5.s[3]\n\t"\ + "fmov v2.d[1],x0; ldr d0,[%25,#32]\n\t"\ + "fmla %0.4s,v1.4s,v4.s[0]; ldr x0,[%25,#40]\n\t"\ + "fmla %2.4s,v1.4s,v4.s[1]; fmla %4.4s,v1.4s,v4.s[2]\n\t"\ + "fmov v0.d[1],x0; ldr d5,[%26,#64]\n\t"\ + "fmla %6.4s,v1.4s,v4.s[3]; ldr x0,[%26,#72]\n\t"\ + "fmla %8.4s,v1.4s,v6.s[0]; fmla %10.4s,v1.4s,v6.s[1]\n\t"\ + "add %25,%25,#64\n\t"\ + "fmla %12.4s,v1.4s,v6.s[2]\n\t"\ + "fmla %14.4s,v1.4s,v6.s[3]; fmla %16.4s,v1.4s,v7.s[0]\n\t"\ + "add %26,%26,#96\n\t"\ + "fmla %18.4s,v1.4s,v7.s[1]\n\t"\ + "fmla %20.4s,v1.4s,v7.s[2]; fmla %22.4s,v1.4s,v7.s[3]\n\t"\ + "prfm pldl1keep,[%25,#128]\n\t"\ + "fmla %1.4s,v2.4s,v4.s[0]\n\t"\ + "fmla %3.4s,v2.4s,v4.s[1]; fmla %5.4s,v2.4s,v4.s[2]\n\t"\ + "prfm pldl1keep,[%26,#192]\n\t"\ + "fmla %7.4s,v2.4s,v4.s[3]\n\t"\ + "fmla %9.4s,v2.4s,v6.s[0]; fmla %11.4s,v2.4s,v6.s[1]\n\t"\ + "sub %w24,%w24,#2\n\t"\ + "fmla %13.4s,v2.4s,v6.s[2]\n\t"\ + "fmla %15.4s,v2.4s,v6.s[3]; fmla %17.4s,v2.4s,v7.s[0]\n\t"\ + "cmp %w24,#2; prfm pldl1keep,[%26,#240]\n\t"\ + "fmla %19.4s,v2.4s,v7.s[1]\n\t"\ + "fmla %21.4s,v2.4s,v7.s[2]; fmla %23.4s,v2.4s,v7.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_TAIL2_A53 \ + "fmov v5.d[1],x0; ldr d7,[%26,#-16]\n\t"\ + "fmla %0.4s,v0.4s,v3.s[0]; ldr x0,[%26,#-8]\n\t"\ + "fmla %2.4s,v0.4s,v3.s[1]; fmla %4.4s,v0.4s,v3.s[2]\n\t"\ + "fmov v7.d[1],x0; ldr d2,[%25,#-16]\n\t"\ + "fmla %6.4s,v0.4s,v3.s[3]; ldr x0,[%25,#-8]\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]; fmla %16.4s,v0.4s,v7.s[0]\n\t"\ + "fmov v2.d[1],x0; ldr d4,[%26]\n\t"\ + "fmla %18.4s,v0.4s,v7.s[1]; ldr x0,[%26,#8]\n\t"\ + "fmla %20.4s,v0.4s,v7.s[2]; fmla %22.4s,v0.4s,v7.s[3]\n\t"\ + "fmov v4.d[1],x0; ldr d1,[%25]\n\t"\ + "fmla %17.4s,v2.4s,v7.s[0]; ldr x0,[%25,#8]\n\t"\ + "fmla %19.4s,v2.4s,v7.s[1]; fmla %21.4s,v2.4s,v7.s[2]\n\t"\ + "fmov v1.d[1],x0; ldr d6,[%26,#16]\n\t"\ + "fmla %23.4s,v2.4s,v7.s[3]; ldr x0,[%26,#24]\n\t"\ + "fmla %9.4s,v2.4s,v5.s[0]; fmla %1.4s,v2.4s,v3.s[0]\n\t"\ + "fmov v6.d[1],x0; ldr d7,[%26,#32]\n\t"\ + "fmla %3.4s,v2.4s,v3.s[1]; ldr x0,[%26,#40]\n\t"\ + "fmla %5.4s,v2.4s,v3.s[2]; fmla %7.4s,v2.4s,v3.s[3]\n\t"\ + "fmov v7.d[1],x0\n\t"\ + "fmla %11.4s,v2.4s,v5.s[1]\n\t"\ + "fmla %13.4s,v2.4s,v5.s[2]; fmla %15.4s,v2.4s,v5.s[3]\n\t"\ + "ldr d2,[%25,#16]\n\t"\ + "fmla %10.4s,v0.4s,v5.s[1]; ldr x0,[%25,#24]\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]; fmla %14.4s,v0.4s,v5.s[3]\n\t"\ + "fmov v2.d[1],x0\n\t"\ + "fmla %0.4s,v1.4s,v4.s[0]\n\t"\ + "fmla %2.4s,v1.4s,v4.s[1]; fmla %4.4s,v1.4s,v4.s[2]\n\t"\ + "fmla %6.4s,v1.4s,v4.s[3]\n\t"\ + "fmla %8.4s,v1.4s,v6.s[0]; fmla %10.4s,v1.4s,v6.s[1]\n\t"\ + "add %25,%25,#32\n\t"\ + "fmla %12.4s,v1.4s,v6.s[2]\n\t"\ + "fmla %14.4s,v1.4s,v6.s[3]; fmla %16.4s,v1.4s,v7.s[0]\n\t"\ + "add %26,%26,#48\n\t"\ + "fmla %18.4s,v1.4s,v7.s[1]\n\t"\ + "fmla %20.4s,v1.4s,v7.s[2]; fmla %22.4s,v1.4s,v7.s[3]\n\t"\ + "fmla %1.4s,v2.4s,v4.s[0]\n\t"\ + "fmla %3.4s,v2.4s,v4.s[1]; fmla %5.4s,v2.4s,v4.s[2]\n\t"\ + "fmla %7.4s,v2.4s,v4.s[3]\n\t"\ + "fmla %9.4s,v2.4s,v6.s[0]; fmla %11.4s,v2.4s,v6.s[1]\n\t"\ + "fmla %13.4s,v2.4s,v6.s[2]\n\t"\ + "fmla %15.4s,v2.4s,v6.s[3]; fmla %17.4s,v2.4s,v7.s[0]\n\t"\ + "fmla %19.4s,v2.4s,v7.s[1]\n\t"\ + "fmla %21.4s,v2.4s,v7.s[2]; fmla %23.4s,v2.4s,v7.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_TAIL1_A53 \ + "fmov v5.d[1],x0; ldr d7,[%26,#-16]\n\t"\ + "fmla %0.4s,v0.4s,v3.s[0]; ldr x0,[%26,#-8]\n\t"\ + "fmla %2.4s,v0.4s,v3.s[1]; fmla %4.4s,v0.4s,v3.s[2]\n\t"\ + "fmov v7.d[1],x0; ldr d2,[%25,#-16]\n\t"\ + "fmla %6.4s,v0.4s,v3.s[3]; ldr x0,[%25,#-8]\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]; fmla %16.4s,v0.4s,v7.s[0]\n\t"\ + "fmov v2.d[1],x0\n\t"\ + "fmla %18.4s,v0.4s,v7.s[1]\n\t"\ + "fmla %20.4s,v0.4s,v7.s[2]; fmla %22.4s,v0.4s,v7.s[3]\n\t"\ + "fmla %17.4s,v2.4s,v7.s[0]\n\t"\ + "fmla %19.4s,v2.4s,v7.s[1]; fmla %21.4s,v2.4s,v7.s[2]\n\t"\ + "fmla %23.4s,v2.4s,v7.s[3]\n\t"\ + "fmla %9.4s,v2.4s,v5.s[0]; fmla %1.4s,v2.4s,v3.s[0]\n\t"\ + "fmla %3.4s,v2.4s,v3.s[1]\n\t"\ + "fmla %5.4s,v2.4s,v3.s[2]; fmla %7.4s,v2.4s,v3.s[3]\n\t"\ + "fmla %11.4s,v2.4s,v5.s[1]\n\t"\ + "fmla %13.4s,v2.4s,v5.s[2]; fmla %15.4s,v2.4s,v5.s[3]\n\t"\ + "fmla %10.4s,v0.4s,v5.s[1]\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]; fmla %14.4s,v0.4s,v5.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_PRELOAD_A55 \ + "ldr q0,[%25]; ldr q1,[%25,#16]; add %25,%25,#32\n\t"\ + "ldr q4,[%26]; ldr d5,[%26,#16]; ldr x1,[%26,#24]; add %26,%26,#48\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_MAIN2_A55 \ + "fmla %0.4s,v0.4s,v4.s[0]; ldr d2,[%25]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[1]; ldr x0,[%25,#8]\n\t"\ + "fmla %4.4s,v0.4s,v4.s[2]\n\t"\ + "fmla %6.4s,v0.4s,v4.s[3]; fmov v5.d[1],x1\n\t"\ + "fmla %1.4s,v1.4s,v4.s[0]; ldr d6,[%26,#-16]\n\t"\ + "fmla %3.4s,v1.4s,v4.s[1]; ldr x1,[%26,#-8]\n\t"\ + "fmla %5.4s,v1.4s,v4.s[2]\n\t"\ + "fmla %7.4s,v1.4s,v4.s[3]; fmov v2.d[1],x0\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]; ldr d3,[%25,#16]\n\t"\ + "fmla %10.4s,v0.4s,v5.s[1]; ldr x0,[%25,#24]\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]\n\t"\ + "fmla %14.4s,v0.4s,v5.s[3]; fmov v6.d[1],x1\n\t"\ + "fmla %9.4s,v1.4s,v5.s[0]; ldr d4,[%26]\n\t"\ + "fmla %11.4s,v1.4s,v5.s[1]; ldr x1,[%26,#8]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[2]\n\t"\ + "fmla %15.4s,v1.4s,v5.s[3]; fmov v3.d[1],x0\n\t"\ + "fmla %16.4s,v0.4s,v6.s[0]; ldr d5,[%26,#16]\n\t"\ + "fmla %18.4s,v0.4s,v6.s[1]; ldr x0,[%26,#24]\n\t"\ + "fmla %20.4s,v0.4s,v6.s[2]\n\t"\ + "fmla %22.4s,v0.4s,v6.s[3]; fmov v4.d[1],x1\n\t"\ + "fmla %17.4s,v1.4s,v6.s[0]; add %25,%25,#64\n\t"\ + "fmla %19.4s,v1.4s,v6.s[1]; add %26,%26,#96\n\t"\ + "fmla %21.4s,v1.4s,v6.s[2]\n\t"\ + "fmla %23.4s,v1.4s,v6.s[3]\n\t"\ + "fmla %0.4s,v2.4s,v4.s[0]; ldr d0,[%25,#-32]\n\t"\ + "fmla %2.4s,v2.4s,v4.s[1]; ldr x1,[%25,#-24]\n\t"\ + "fmla %4.4s,v2.4s,v4.s[2]\n\t"\ + "fmla %6.4s,v2.4s,v4.s[3]; fmov v5.d[1],x0\n\t"\ + "fmla %1.4s,v3.4s,v4.s[0]; ldr d6,[%26,#-64]\n\t"\ + "fmla %3.4s,v3.4s,v4.s[1]; ldr x0,[%26,#-56]\n\t"\ + "fmla %5.4s,v3.4s,v4.s[2]\n\t"\ + "fmla %7.4s,v3.4s,v4.s[3]; fmov v0.d[1],x1\n\t"\ + "fmla %8.4s,v2.4s,v5.s[0]; ldr d1,[%25,#-16]\n\t"\ + "fmla %10.4s,v2.4s,v5.s[1]; ldr x1,[%25,#-8]\n\t"\ + "fmla %12.4s,v2.4s,v5.s[2]\n\t"\ + "fmla %14.4s,v2.4s,v5.s[3]; fmov v6.d[1],x0\n\t"\ + "fmla %9.4s,v3.4s,v5.s[0]; ldr d4,[%26,#-48]\n\t"\ + "fmla %11.4s,v3.4s,v5.s[1]; ldr x0,[%26,#-40]\n\t"\ + "fmla %13.4s,v3.4s,v5.s[2]\n\t"\ + "fmla %15.4s,v3.4s,v5.s[3]; fmov v1.d[1],x1\n\t"\ + "fmla %16.4s,v2.4s,v6.s[0]; ldr d5,[%26,#-32]\n\t"\ + "fmla %18.4s,v2.4s,v6.s[1]; ldr x1,[%26,#-24]\n\t"\ + "fmla %20.4s,v2.4s,v6.s[2]\n\t"\ + "fmla %22.4s,v2.4s,v6.s[3]; fmov v4.d[1],x0\n\t"\ + "fmla %17.4s,v3.4s,v6.s[0]\n\t"\ + "fmla %19.4s,v3.4s,v6.s[1]; sub %w24,%w24,#2\n\t"\ + "fmla %21.4s,v3.4s,v6.s[2]; cmp %w24,#2\n\t"\ + "fmla %23.4s,v3.4s,v6.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_TAIL2_A55 \ + "fmla %0.4s,v0.4s,v4.s[0]; ldr d2,[%25]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[1]; ldr x0,[%25,#8]\n\t"\ + "fmla %4.4s,v0.4s,v4.s[2]\n\t"\ + "fmla %6.4s,v0.4s,v4.s[3]; fmov v5.d[1],x1\n\t"\ + "fmla %1.4s,v1.4s,v4.s[0]; ldr d6,[%26,#-16]\n\t"\ + "fmla %3.4s,v1.4s,v4.s[1]; ldr x1,[%26,#-8]\n\t"\ + "fmla %5.4s,v1.4s,v4.s[2]\n\t"\ + "fmla %7.4s,v1.4s,v4.s[3]; fmov v2.d[1],x0\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]; ldr d3,[%25,#16]\n\t"\ + "fmla %10.4s,v0.4s,v5.s[1]; ldr x0,[%25,#24]\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]\n\t"\ + "fmla %14.4s,v0.4s,v5.s[3]; fmov v6.d[1],x1\n\t"\ + "fmla %9.4s,v1.4s,v5.s[0]; ldr d4,[%26]\n\t"\ + "fmla %11.4s,v1.4s,v5.s[1]; ldr x1,[%26,#8]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[2]\n\t"\ + "fmla %15.4s,v1.4s,v5.s[3]; fmov v3.d[1],x0\n\t"\ + "fmla %16.4s,v0.4s,v6.s[0]; ldr d5,[%26,#16]\n\t"\ + "fmla %18.4s,v0.4s,v6.s[1]; ldr x0,[%26,#24]\n\t"\ + "fmla %20.4s,v0.4s,v6.s[2]\n\t"\ + "fmla %22.4s,v0.4s,v6.s[3]; fmov v4.d[1],x1\n\t"\ + "fmla %17.4s,v1.4s,v6.s[0]; add %25,%25,#32\n\t"\ + "fmla %19.4s,v1.4s,v6.s[1]; add %26,%26,#48\n\t"\ + "fmla %21.4s,v1.4s,v6.s[2]\n\t"\ + "fmla %23.4s,v1.4s,v6.s[3]\n\t"\ + "fmla %0.4s,v2.4s,v4.s[0]\n\t"\ + "fmla %2.4s,v2.4s,v4.s[1]\n\t"\ + "fmla %4.4s,v2.4s,v4.s[2]\n\t"\ + "fmla %6.4s,v2.4s,v4.s[3]; fmov v5.d[1],x0\n\t"\ + "fmla %1.4s,v3.4s,v4.s[0]; ldr d6,[%26,#-16]\n\t"\ + "fmla %3.4s,v3.4s,v4.s[1]; ldr x0,[%26,#-8]\n\t"\ + "fmla %5.4s,v3.4s,v4.s[2]\n\t"\ + "fmla %7.4s,v3.4s,v4.s[3]\n\t"\ + "fmla %8.4s,v2.4s,v5.s[0]\n\t"\ + "fmla %10.4s,v2.4s,v5.s[1]\n\t"\ + "fmla %12.4s,v2.4s,v5.s[2]\n\t"\ + "fmla %14.4s,v2.4s,v5.s[3]; fmov v6.d[1],x0\n\t"\ + "fmla %9.4s,v3.4s,v5.s[0]\n\t"\ + "fmla %11.4s,v3.4s,v5.s[1]\n\t"\ + "fmla %13.4s,v3.4s,v5.s[2]\n\t"\ + "fmla %15.4s,v3.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v6.s[0]\n\t"\ + "fmla %18.4s,v2.4s,v6.s[1]\n\t"\ + "fmla %20.4s,v2.4s,v6.s[2]\n\t"\ + "fmla %22.4s,v2.4s,v6.s[3]\n\t"\ + "fmla %17.4s,v3.4s,v6.s[0]\n\t"\ + "fmla %19.4s,v3.4s,v6.s[1]\n\t"\ + "fmla %21.4s,v3.4s,v6.s[2]\n\t"\ + "fmla %23.4s,v3.4s,v6.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_TAIL1_A55 \ + "fmla %0.4s,v0.4s,v4.s[0]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[1]\n\t"\ + "fmla %4.4s,v0.4s,v4.s[2]\n\t"\ + "fmla %6.4s,v0.4s,v4.s[3]; fmov v5.d[1],x1\n\t"\ + "fmla %1.4s,v1.4s,v4.s[0]; ldr d6,[%26,#-16]\n\t"\ + "fmla %3.4s,v1.4s,v4.s[1]; ldr x1,[%26,#-8]\n\t"\ + "fmla %5.4s,v1.4s,v4.s[2]\n\t"\ + "fmla %7.4s,v1.4s,v4.s[3]\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]\n\t"\ + "fmla %10.4s,v0.4s,v5.s[1]\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]\n\t"\ + "fmla %14.4s,v0.4s,v5.s[3]; fmov v6.d[1],x1\n\t"\ + "fmla %9.4s,v1.4s,v5.s[0]\n\t"\ + "fmla %11.4s,v1.4s,v5.s[1]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[2]\n\t"\ + "fmla %15.4s,v1.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v0.4s,v6.s[0]\n\t"\ + "fmla %18.4s,v0.4s,v6.s[1]\n\t"\ + "fmla %20.4s,v0.4s,v6.s[2]\n\t"\ + "fmla %22.4s,v0.4s,v6.s[3]\n\t"\ + "fmla %17.4s,v1.4s,v6.s[0]\n\t"\ + "fmla %19.4s,v1.4s,v6.s[1]\n\t"\ + "fmla %21.4s,v1.4s,v6.s[2]\n\t"\ + "fmla %23.4s,v1.4s,v6.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_PRELOAD_A72 \ + "ldr q0,[%25]; ldr q1,[%25,#16]; add %25,%25,#32\n\t"\ + "ldr q4,[%26]; ldr q5,[%26,#16]; add %26,%26,#48\n\t"\ + +#define NEON_SGEMM_KERNEL_M8N12_MAIN2_A72 \ + "fmla %0.4s,v0.4s,v4.s[0]; fmla %2.4s,v0.4s,v4.s[1]; ldr q6,[%26,#-16]\n\t"\ + "fmla %4.4s,v0.4s,v4.s[2]; fmla %6.4s,v0.4s,v4.s[3]\n\t"\ + "fmla %1.4s,v1.4s,v4.s[0]; fmla %3.4s,v1.4s,v4.s[1]; ldr q2,[%25],#64\n\t"\ + "fmla %5.4s,v1.4s,v4.s[2]; fmla %7.4s,v1.4s,v4.s[3]\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]; fmla %10.4s,v0.4s,v5.s[1]; ldr q4,[%26],#96\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]; fmla %14.4s,v0.4s,v5.s[3]\n\t"\ + "fmla %9.4s,v1.4s,v5.s[0]; fmla %11.4s,v1.4s,v5.s[1]; ldr q3,[%25,#-48]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[2]; fmla %15.4s,v1.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v0.4s,v6.s[0]; fmla %18.4s,v0.4s,v6.s[1]; ldr q5,[%26,#-80]\n\t"\ + "fmla %20.4s,v0.4s,v6.s[2]; fmla %22.4s,v0.4s,v6.s[3]\n\t"\ + "fmla %17.4s,v1.4s,v6.s[0]; fmla %19.4s,v1.4s,v6.s[1]; sub %w24,%w24,#2\n\t"\ + "fmla %21.4s,v1.4s,v6.s[2]; fmla %23.4s,v1.4s,v6.s[3]\n\t"\ + "fmla %0.4s,v2.4s,v4.s[0]; fmla %2.4s,v2.4s,v4.s[1]; ldr q6,[%26,#-64]\n\t"\ + "fmla %4.4s,v2.4s,v4.s[2]; fmla %6.4s,v2.4s,v4.s[3]\n\t"\ + "fmla %1.4s,v3.4s,v4.s[0]; fmla %3.4s,v3.4s,v4.s[1]; ldr q0,[%25,#-32]\n\t"\ + "fmla %5.4s,v3.4s,v4.s[2]; fmla %7.4s,v3.4s,v4.s[3]\n\t"\ + "fmla %8.4s,v2.4s,v5.s[0]; fmla %10.4s,v2.4s,v5.s[1]; ldr q4,[%26,#-48]\n\t"\ + "fmla %12.4s,v2.4s,v5.s[2]; fmla %14.4s,v2.4s,v5.s[3]\n\t"\ + "fmla %9.4s,v3.4s,v5.s[0]; fmla %11.4s,v3.4s,v5.s[1]; ldr q1,[%25,#-16]\n\t"\ + "fmla %13.4s,v3.4s,v5.s[2]; fmla %15.4s,v3.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v6.s[0]; fmla %18.4s,v2.4s,v6.s[1]; ldr q5,[%26,#-32]\n\t"\ + "fmla %20.4s,v2.4s,v6.s[2]; fmla %22.4s,v2.4s,v6.s[3]\n\t"\ + "fmla %17.4s,v3.4s,v6.s[0]; fmla %19.4s,v3.4s,v6.s[1]; cmp %w24,#2\n\t"\ + "fmla %21.4s,v3.4s,v6.s[2]; fmla %23.4s,v3.4s,v6.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_TAIL2_A72 \ + "fmla %0.4s,v0.4s,v4.s[0]; fmla %2.4s,v0.4s,v4.s[1]; ldr q6,[%26,#-16]\n\t"\ + "fmla %4.4s,v0.4s,v4.s[2]; fmla %6.4s,v0.4s,v4.s[3]\n\t"\ + "fmla %1.4s,v1.4s,v4.s[0]; fmla %3.4s,v1.4s,v4.s[1]; ldr q2,[%25],#32\n\t"\ + "fmla %5.4s,v1.4s,v4.s[2]; fmla %7.4s,v1.4s,v4.s[3]\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]; fmla %10.4s,v0.4s,v5.s[1]; ldr q4,[%26],#48\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]; fmla %14.4s,v0.4s,v5.s[3]\n\t"\ + "fmla %9.4s,v1.4s,v5.s[0]; fmla %11.4s,v1.4s,v5.s[1]; ldr q3,[%25,#-16]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[2]; fmla %15.4s,v1.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v0.4s,v6.s[0]; fmla %18.4s,v0.4s,v6.s[1]; ldr q5,[%26,#-32]\n\t"\ + "fmla %20.4s,v0.4s,v6.s[2]; fmla %22.4s,v0.4s,v6.s[3]\n\t"\ + "fmla %17.4s,v1.4s,v6.s[0]; fmla %19.4s,v1.4s,v6.s[1]\n\t"\ + "fmla %21.4s,v1.4s,v6.s[2]; fmla %23.4s,v1.4s,v6.s[3]\n\t"\ + "fmla %0.4s,v2.4s,v4.s[0]; fmla %2.4s,v2.4s,v4.s[1]; ldr q6,[%26,#-16]\n\t"\ + "fmla %4.4s,v2.4s,v4.s[2]; fmla %6.4s,v2.4s,v4.s[3]\n\t"\ + "fmla %1.4s,v3.4s,v4.s[0]; fmla %3.4s,v3.4s,v4.s[1]\n\t"\ + "fmla %5.4s,v3.4s,v4.s[2]; fmla %7.4s,v3.4s,v4.s[3]\n\t"\ + "fmla %8.4s,v2.4s,v5.s[0]; fmla %10.4s,v2.4s,v5.s[1]\n\t"\ + "fmla %12.4s,v2.4s,v5.s[2]; fmla %14.4s,v2.4s,v5.s[3]\n\t"\ + "fmla %9.4s,v3.4s,v5.s[0]; fmla %11.4s,v3.4s,v5.s[1]\n\t"\ + "fmla %13.4s,v3.4s,v5.s[2]; fmla %15.4s,v3.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v6.s[0]; fmla %18.4s,v2.4s,v6.s[1]\n\t"\ + "fmla %20.4s,v2.4s,v6.s[2]; fmla %22.4s,v2.4s,v6.s[3]\n\t"\ + "fmla %17.4s,v3.4s,v6.s[0]; fmla %19.4s,v3.4s,v6.s[1]\n\t"\ + "fmla %21.4s,v3.4s,v6.s[2]; fmla %23.4s,v3.4s,v6.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M8N12_TAIL1_A72 \ + "fmla %0.4s,v0.4s,v4.s[0]; fmla %2.4s,v0.4s,v4.s[1]; ldr q6,[%26,#-16]\n\t"\ + "fmla %4.4s,v0.4s,v4.s[2]; fmla %6.4s,v0.4s,v4.s[3]\n\t"\ + "fmla %1.4s,v1.4s,v4.s[0]; fmla %3.4s,v1.4s,v4.s[1]\n\t"\ + "fmla %5.4s,v1.4s,v4.s[2]; fmla %7.4s,v1.4s,v4.s[3]\n\t"\ + "fmla %8.4s,v0.4s,v5.s[0]; fmla %10.4s,v0.4s,v5.s[1]\n\t"\ + "fmla %12.4s,v0.4s,v5.s[2]; fmla %14.4s,v0.4s,v5.s[3]\n\t"\ + "fmla %9.4s,v1.4s,v5.s[0]; fmla %11.4s,v1.4s,v5.s[1]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[2]; fmla %15.4s,v1.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v0.4s,v6.s[0]; fmla %18.4s,v0.4s,v6.s[1]\n\t"\ + "fmla %20.4s,v0.4s,v6.s[2]; fmla %22.4s,v0.4s,v6.s[3]\n\t"\ + "fmla %17.4s,v1.4s,v6.s[0]; fmla %19.4s,v1.4s,v6.s[1]\n\t"\ + "fmla %21.4s,v1.4s,v6.s[2]; fmla %23.4s,v1.4s,v6.s[3]\n\t" + +#define NEON_SGEMM_SAVE_M8N3_UNIT(cq1, cq2, cq3, cq4, cq5, cq6) \ + ct1 = vld1q_f32(c_tmp1); ct2 = vld1q_f32(c_tmp1 + 4);\ + ct3 = vld1q_f32(c_tmp2); ct4 = vld1q_f32(c_tmp2 + 4);\ + ct5 = vld1q_f32(c_tmp3); ct6 = vld1q_f32(c_tmp3 + 4);\ + cq1 = vfmaq_n_f32(cq1, ct1, beta); cq2 = vfmaq_n_f32(cq2, ct2, beta);\ + cq3 = vfmaq_n_f32(cq3, ct3, beta); cq4 = vfmaq_n_f32(cq4, ct4, beta);\ + cq5 = vfmaq_n_f32(cq5, ct5, beta); cq6 = vfmaq_n_f32(cq6, ct6, beta);\ + vst1q_f32(c_tmp1, cq1); vst1q_f32(c_tmp1 + 4, cq2); c_tmp1 += ldc3;\ + vst1q_f32(c_tmp2, cq3); vst1q_f32(c_tmp2 + 4, cq4); c_tmp2 += ldc3;\ + vst1q_f32(c_tmp3, cq5); vst1q_f32(c_tmp3 + 4, cq6); c_tmp3 += ldc3; + +#define NEON_SGEMM_SAVE_M8N12_ASM1 \ + float *c_tmp1 = c_ptr;\ + float *c_tmp2 = c_ptr + ldc;\ + float *c_tmp3 = c_ptr + ldc * 2;\ + uint32_t ldc3 = ldc * 3;\ + float32x4_t ct1, ct2, ct3, ct4, ct5, ct6;\ + NEON_SGEMM_SAVE_M8N3_UNIT(cq01, cq02, cq03, cq04, cq05, cq06)\ + NEON_SGEMM_SAVE_M8N3_UNIT(cq07, cq08, cq09, cq10, cq11, cq12)\ + NEON_SGEMM_SAVE_M8N3_UNIT(cq13, cq14, cq15, cq16, cq17, cq18)\ + NEON_SGEMM_SAVE_M8N3_UNIT(cq19, cq20, cq21, cq22, cq23, cq24) + +#define NEON_SGEMM_KERNEL_M12N8_PRELOAD_A53 \ + "ldr q5,[%26]; add %26,%26,#32\n\t"\ + "ldr q0,[%25]; ldr d2,[%25,#16]; ldr x0,[%25,#24]; add %25,%25,#48\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_MAIN2_A53 \ + "fmov v2.d[1],x0; ldr d4,[%25,#-16]\n\t"\ + "fmla %0.4s,v0.4s,v5.s[0]; ldr x0,[%25,#-8]\n\t"\ + "fmla %1.4s,v0.4s,v5.s[1]; fmla %2.4s,v0.4s,v5.s[2]\n\t"\ + "fmov v4.d[1],x0; ldr d7,[%26,#-16]\n\t"\ + "fmla %3.4s,v0.4s,v5.s[3]; ldr x0,[%26,#-8]\n\t"\ + "fmla %8.4s,v2.4s,v5.s[0]; fmla %16.4s,v4.4s,v5.s[0]\n\t"\ + "fmov v7.d[1],x0; ldr d6,[%26]\n\t"\ + "fmla %17.4s,v4.4s,v5.s[1]; ldr x0,[%26,#8]\n\t"\ + "fmla %18.4s,v4.4s,v5.s[2]; fmla %20.4s,v4.4s,v7.s[0]\n\t"\ + "fmov v6.d[1],x0; ldr d1,[%25]\n\t"\ + "fmla %21.4s,v4.4s,v7.s[1]; ldr x0,[%25,#8]\n\t"\ + "fmla %22.4s,v4.4s,v7.s[2]; fmla %23.4s,v4.4s,v7.s[3]\n\t"\ + "fmov v1.d[1],x0; ldr d3,[%25,#16]\n\t"\ + "fmla %19.4s,v4.4s,v5.s[3]; ldr x0,[%25,#24]\n\t"\ + "fmla %4.4s,v0.4s,v7.s[0]; fmla %5.4s,v0.4s,v7.s[1]\n\t"\ + "fmov v3.d[1],x0; ldr d4,[%25,#32]\n\t"\ + "fmla %6.4s,v0.4s,v7.s[2]; ldr x0,[%25,#40]\n\t"\ + "fmla %7.4s,v0.4s,v7.s[3]; fmla %12.4s,v2.4s,v7.s[0]\n\t"\ + "fmov v4.d[1],x0; ldr d0,[%25,#48]\n\t"\ + "fmla %13.4s,v2.4s,v7.s[1]; ldr x0,[%25,#56]\n\t"\ + "fmla %14.4s,v2.4s,v7.s[2]; fmla %15.4s,v2.4s,v7.s[3]\n\t"\ + "fmov v0.d[1],x0; ldr d7,[%26,#16]\n\t"\ + "fmla %9.4s,v2.4s,v5.s[1]; ldr x0,[%26,#24]\n\t"\ + "fmla %10.4s,v2.4s,v5.s[2]; fmla %11.4s,v2.4s,v5.s[3]\n\t"\ + "fmov v7.d[1],x0; ldr d5,[%26,#32]\n\t"\ + "fmla %0.4s,v1.4s,v6.s[0]; ldr x0,[%26,#40]\n\t"\ + "fmla %1.4s,v1.4s,v6.s[1]; fmla %2.4s,v1.4s,v6.s[2]\n\t"\ + "fmov v5.d[1],x0; ldr d2,[%25,#64]\n\t"\ + "fmla %3.4s,v1.4s,v6.s[3]; ldr x0,[%25,#72]\n\t"\ + "fmla %4.4s,v1.4s,v7.s[0]; fmla %5.4s,v1.4s,v7.s[1]\n\t"\ + "add %25,%25,#96\n\t"\ + "fmla %6.4s,v1.4s,v7.s[2]\n\t"\ + "fmla %7.4s,v1.4s,v7.s[3]; fmla %8.4s,v3.4s,v6.s[0]\n\t"\ + "prfm pldl1keep,[%25,#192]\n\t"\ + "fmla %9.4s,v3.4s,v6.s[1]\n\t"\ + "fmla %10.4s,v3.4s,v6.s[2]; fmla %11.4s,v3.4s,v6.s[3]\n\t"\ + "add %26,%26,#64\n\t"\ + "fmla %12.4s,v3.4s,v7.s[0]\n\t"\ + "fmla %13.4s,v3.4s,v7.s[1]; fmla %14.4s,v3.4s,v7.s[2]\n\t"\ + "prfm pldl1keep,[%26,#128]\n\t"\ + "fmla %15.4s,v3.4s,v7.s[3]\n\t"\ + "fmla %16.4s,v4.4s,v6.s[0]; fmla %17.4s,v4.4s,v6.s[1]\n\t"\ + "sub %w24,%w24,#2\n\t"\ + "fmla %18.4s,v4.4s,v6.s[2]\n\t"\ + "fmla %19.4s,v4.4s,v6.s[3]; fmla %20.4s,v4.4s,v7.s[0]\n\t"\ + "cmp %w24,#2; prfm pldl1keep,[%25,#240]\n\t"\ + "fmla %21.4s,v4.4s,v7.s[1]\n\t"\ + "fmla %22.4s,v4.4s,v7.s[2]; fmla %23.4s,v4.4s,v7.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_TAIL2_A53 \ + "fmov v2.d[1],x0; ldr d4,[%25,#-16]\n\t"\ + "fmla %0.4s,v0.4s,v5.s[0]; ldr x0,[%25,#-8]\n\t"\ + "fmla %1.4s,v0.4s,v5.s[1]; fmla %2.4s,v0.4s,v5.s[2]\n\t"\ + "fmov v4.d[1],x0; ldr d7,[%26,#-16]\n\t"\ + "fmla %3.4s,v0.4s,v5.s[3]; ldr x0,[%26,#-8]\n\t"\ + "fmla %8.4s,v2.4s,v5.s[0]; fmla %16.4s,v4.4s,v5.s[0]\n\t"\ + "fmov v7.d[1],x0; ldr d6,[%26]\n\t"\ + "fmla %17.4s,v4.4s,v5.s[1]; ldr x0,[%26,#8]\n\t"\ + "fmla %18.4s,v4.4s,v5.s[2]; fmla %20.4s,v4.4s,v7.s[0]\n\t"\ + "fmov v6.d[1],x0; ldr d1,[%25]\n\t"\ + "fmla %21.4s,v4.4s,v7.s[1]; ldr x0,[%25,#8]\n\t"\ + "fmla %22.4s,v4.4s,v7.s[2]; fmla %23.4s,v4.4s,v7.s[3]\n\t"\ + "fmov v1.d[1],x0; ldr d3,[%25,#16]\n\t"\ + "fmla %19.4s,v4.4s,v5.s[3]; ldr x0,[%25,#24]\n\t"\ + "fmla %4.4s,v0.4s,v7.s[0]; fmla %5.4s,v0.4s,v7.s[1]\n\t"\ + "fmov v3.d[1],x0; ldr d4,[%25,#32]\n\t"\ + "fmla %6.4s,v0.4s,v7.s[2]; ldr x0,[%25,#40]\n\t"\ + "fmla %7.4s,v0.4s,v7.s[3]; fmla %12.4s,v2.4s,v7.s[0]\n\t"\ + "fmov v4.d[1],x0\n\t"\ + "fmla %13.4s,v2.4s,v7.s[1]\n\t"\ + "fmla %14.4s,v2.4s,v7.s[2]; fmla %15.4s,v2.4s,v7.s[3]\n\t"\ + "ldr d7,[%26,#16]\n\t"\ + "fmla %9.4s,v2.4s,v5.s[1]; ldr x0,[%26,#24]\n\t"\ + "fmla %10.4s,v2.4s,v5.s[2]; fmla %11.4s,v2.4s,v5.s[3]\n\t"\ + "fmov v7.d[1],x0\n\t"\ + "fmla %0.4s,v1.4s,v6.s[0]\n\t"\ + "fmla %1.4s,v1.4s,v6.s[1]; fmla %2.4s,v1.4s,v6.s[2]\n\t"\ + "fmla %3.4s,v1.4s,v6.s[3]\n\t"\ + "fmla %4.4s,v1.4s,v7.s[0]; fmla %5.4s,v1.4s,v7.s[1]\n\t"\ + "add %25,%25,#48\n\t"\ + "fmla %6.4s,v1.4s,v7.s[2]\n\t"\ + "fmla %7.4s,v1.4s,v7.s[3]; fmla %8.4s,v3.4s,v6.s[0]\n\t"\ + "add %26,%26,#32\n\t"\ + "fmla %9.4s,v3.4s,v6.s[1]\n\t"\ + "fmla %10.4s,v3.4s,v6.s[2]; fmla %11.4s,v3.4s,v6.s[3]\n\t"\ + "fmla %12.4s,v3.4s,v7.s[0]\n\t"\ + "fmla %13.4s,v3.4s,v7.s[1]; fmla %14.4s,v3.4s,v7.s[2]\n\t"\ + "fmla %15.4s,v3.4s,v7.s[3]\n\t"\ + "fmla %16.4s,v4.4s,v6.s[0]; fmla %17.4s,v4.4s,v6.s[1]\n\t"\ + "fmla %18.4s,v4.4s,v6.s[2]\n\t"\ + "fmla %19.4s,v4.4s,v6.s[3]; fmla %20.4s,v4.4s,v7.s[0]\n\t"\ + "fmla %21.4s,v4.4s,v7.s[1]\n\t"\ + "fmla %22.4s,v4.4s,v7.s[2]; fmla %23.4s,v4.4s,v7.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_TAIL1_A53 \ + "fmov v2.d[1],x0; ldr d4,[%25,#-16]\n\t"\ + "fmla %0.4s,v0.4s,v5.s[0]; ldr x0,[%25,#-8]\n\t"\ + "fmla %1.4s,v0.4s,v5.s[1]; fmla %2.4s,v0.4s,v5.s[2]\n\t"\ + "fmov v4.d[1],x0; ldr d7,[%26,#-16]\n\t"\ + "fmla %3.4s,v0.4s,v5.s[3]; ldr x0,[%26,#-8]\n\t"\ + "fmla %8.4s,v2.4s,v5.s[0]; fmla %16.4s,v4.4s,v5.s[0]\n\t"\ + "fmov v7.d[1],x0\n\t"\ + "fmla %17.4s,v4.4s,v5.s[1]\n\t"\ + "fmla %18.4s,v4.4s,v5.s[2]; fmla %20.4s,v4.4s,v7.s[0]\n\t"\ + "fmla %21.4s,v4.4s,v7.s[1]\n\t"\ + "fmla %22.4s,v4.4s,v7.s[2]; fmla %23.4s,v4.4s,v7.s[3]\n\t"\ + "fmla %19.4s,v4.4s,v5.s[3]\n\t"\ + "fmla %4.4s,v0.4s,v7.s[0]; fmla %5.4s,v0.4s,v7.s[1]\n\t"\ + "fmla %6.4s,v0.4s,v7.s[2]\n\t"\ + "fmla %7.4s,v0.4s,v7.s[3]; fmla %12.4s,v2.4s,v7.s[0]\n\t"\ + "fmla %13.4s,v2.4s,v7.s[1]\n\t"\ + "fmla %14.4s,v2.4s,v7.s[2]; fmla %15.4s,v2.4s,v7.s[3]\n\t"\ + "fmla %9.4s,v2.4s,v5.s[1]\n\t"\ + "fmla %10.4s,v2.4s,v5.s[2]; fmla %11.4s,v2.4s,v5.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_PRELOAD_A55 \ + "ldr q4,[%26]; ldr q5,[%26,#16]; add %26,%26,#32\n\t"\ + "ldr q0,[%25]; ldr d1,[%25,#16]; ldr x1,[%25,#24]; add %25,%25,#48\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_MAIN2_A55 \ + "fmla %0.4s,v0.4s,v4.s[0]; ldr d6,[%26]\n\t"\ + "fmla %1.4s,v0.4s,v4.s[1]; ldr x0,[%26,#8]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[2]\n\t"\ + "fmla %3.4s,v0.4s,v4.s[3]; fmov v1.d[1],x1\n\t"\ + "fmla %4.4s,v0.4s,v5.s[0]; ldr d2,[%25,#-16]\n\t"\ + "fmla %5.4s,v0.4s,v5.s[1]; ldr x1,[%25,#-8]\n\t"\ + "fmla %6.4s,v0.4s,v5.s[2]\n\t"\ + "fmla %7.4s,v0.4s,v5.s[3]; fmov v6.d[1],x0\n\t"\ + "fmla %8.4s,v1.4s,v4.s[0]; ldr d7,[%26,#16]\n\t"\ + "fmla %9.4s,v1.4s,v4.s[1]; ldr x0,[%26,#24]\n\t"\ + "fmla %10.4s,v1.4s,v4.s[2]\n\t"\ + "fmla %11.4s,v1.4s,v4.s[3]; fmov v2.d[1],x1\n\t"\ + "fmla %12.4s,v1.4s,v5.s[0]; ldr d0,[%25]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[1]; ldr x1,[%25,#8]\n\t"\ + "fmla %14.4s,v1.4s,v5.s[2]\n\t"\ + "fmla %15.4s,v1.4s,v5.s[3]; fmov v7.d[1],x0\n\t"\ + "fmla %16.4s,v2.4s,v4.s[0]; ldr d1,[%25,#16]\n\t"\ + "fmla %17.4s,v2.4s,v4.s[1]; ldr x0,[%25,#24]\n\t"\ + "fmla %18.4s,v2.4s,v4.s[2]\n\t"\ + "fmla %19.4s,v2.4s,v4.s[3]; fmov v0.d[1],x1\n\t"\ + "fmla %20.4s,v2.4s,v5.s[0]; add %25,%25,#96\n\t"\ + "fmla %21.4s,v2.4s,v5.s[1]; add %26,%26,#64\n\t"\ + "fmla %22.4s,v2.4s,v5.s[2]\n\t"\ + "fmla %23.4s,v2.4s,v5.s[3]\n\t"\ + "fmla %0.4s,v0.4s,v6.s[0]; ldr d4,[%26,#-32]\n\t"\ + "fmla %1.4s,v0.4s,v6.s[1]; ldr x1,[%26,#-24]\n\t"\ + "fmla %2.4s,v0.4s,v6.s[2]\n\t"\ + "fmla %3.4s,v0.4s,v6.s[3]; fmov v1.d[1],x0\n\t"\ + "fmla %4.4s,v0.4s,v7.s[0]; ldr d2,[%25,#-64]\n\t"\ + "fmla %5.4s,v0.4s,v7.s[1]; ldr x0,[%25,#-56]\n\t"\ + "fmla %6.4s,v0.4s,v7.s[2]\n\t"\ + "fmla %7.4s,v0.4s,v7.s[3]; fmov v4.d[1],x1\n\t"\ + "fmla %8.4s,v1.4s,v6.s[0]; ldr d5,[%26,#-16]\n\t"\ + "fmla %9.4s,v1.4s,v6.s[1]; ldr x1,[%26,#-8]\n\t"\ + "fmla %10.4s,v1.4s,v6.s[2]\n\t"\ + "fmla %11.4s,v1.4s,v6.s[3]; fmov v2.d[1],x0\n\t"\ + "fmla %12.4s,v1.4s,v7.s[0]; ldr d0,[%25,#-48]\n\t"\ + "fmla %13.4s,v1.4s,v7.s[1]; ldr x0,[%25,#-40]\n\t"\ + "fmla %14.4s,v1.4s,v7.s[2]\n\t"\ + "fmla %15.4s,v1.4s,v7.s[3]; fmov v5.d[1],x1\n\t"\ + "fmla %16.4s,v2.4s,v6.s[0]; ldr d1,[%25,#-32]\n\t"\ + "fmla %17.4s,v2.4s,v6.s[1]; ldr x1,[%25,#-24]\n\t"\ + "fmla %18.4s,v2.4s,v6.s[2]\n\t"\ + "fmla %19.4s,v2.4s,v6.s[3]; fmov v0.d[1],x0\n\t"\ + "fmla %20.4s,v2.4s,v7.s[0]\n\t"\ + "fmla %21.4s,v2.4s,v7.s[1]; sub %w24,%w24,#2\n\t"\ + "fmla %22.4s,v2.4s,v7.s[2]; cmp %w24,#2\n\t"\ + "fmla %23.4s,v2.4s,v7.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_TAIL2_A55 \ + "fmla %0.4s,v0.4s,v4.s[0]; ldr d6,[%26]\n\t"\ + "fmla %1.4s,v0.4s,v4.s[1]; ldr x0,[%26,#8]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[2]\n\t"\ + "fmla %3.4s,v0.4s,v4.s[3]; fmov v1.d[1],x1\n\t"\ + "fmla %4.4s,v0.4s,v5.s[0]; ldr d2,[%25,#-16]\n\t"\ + "fmla %5.4s,v0.4s,v5.s[1]; ldr x1,[%25,#-8]\n\t"\ + "fmla %6.4s,v0.4s,v5.s[2]\n\t"\ + "fmla %7.4s,v0.4s,v5.s[3]; fmov v6.d[1],x0\n\t"\ + "fmla %8.4s,v1.4s,v4.s[0]; ldr d7,[%26,#16]\n\t"\ + "fmla %9.4s,v1.4s,v4.s[1]; ldr x0,[%26,#24]\n\t"\ + "fmla %10.4s,v1.4s,v4.s[2]\n\t"\ + "fmla %11.4s,v1.4s,v4.s[3]; fmov v2.d[1],x1\n\t"\ + "fmla %12.4s,v1.4s,v5.s[0]; ldr d0,[%25]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[1]; ldr x1,[%25,#8]\n\t"\ + "fmla %14.4s,v1.4s,v5.s[2]\n\t"\ + "fmla %15.4s,v1.4s,v5.s[3]; fmov v7.d[1],x0\n\t"\ + "fmla %16.4s,v2.4s,v4.s[0]; ldr d1,[%25,#16]\n\t"\ + "fmla %17.4s,v2.4s,v4.s[1]; ldr x0,[%25,#24]\n\t"\ + "fmla %18.4s,v2.4s,v4.s[2]\n\t"\ + "fmla %19.4s,v2.4s,v4.s[3]; fmov v0.d[1],x1\n\t"\ + "fmla %20.4s,v2.4s,v5.s[0]; add %25,%25,#48\n\t"\ + "fmla %21.4s,v2.4s,v5.s[1]; add %26,%26,#32\n\t"\ + "fmla %22.4s,v2.4s,v5.s[2]\n\t"\ + "fmla %23.4s,v2.4s,v5.s[3]\n\t"\ + "fmla %0.4s,v0.4s,v6.s[0]\n\t"\ + "fmla %1.4s,v0.4s,v6.s[1]\n\t"\ + "fmla %2.4s,v0.4s,v6.s[2]\n\t"\ + "fmla %3.4s,v0.4s,v6.s[3]; fmov v1.d[1],x0\n\t"\ + "fmla %4.4s,v0.4s,v7.s[0]; ldr d2,[%25,#-16]\n\t"\ + "fmla %5.4s,v0.4s,v7.s[1]; ldr x0,[%25,#-8]\n\t"\ + "fmla %6.4s,v0.4s,v7.s[2]\n\t"\ + "fmla %7.4s,v0.4s,v7.s[3]\n\t"\ + "fmla %8.4s,v1.4s,v6.s[0]\n\t"\ + "fmla %9.4s,v1.4s,v6.s[1]\n\t"\ + "fmla %10.4s,v1.4s,v6.s[2]\n\t"\ + "fmla %11.4s,v1.4s,v6.s[3]; fmov v2.d[1],x0\n\t"\ + "fmla %12.4s,v1.4s,v7.s[0]\n\t"\ + "fmla %13.4s,v1.4s,v7.s[1]\n\t"\ + "fmla %14.4s,v1.4s,v7.s[2]\n\t"\ + "fmla %15.4s,v1.4s,v7.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v6.s[0]\n\t"\ + "fmla %17.4s,v2.4s,v6.s[1]\n\t"\ + "fmla %18.4s,v2.4s,v6.s[2]\n\t"\ + "fmla %19.4s,v2.4s,v6.s[3]\n\t"\ + "fmla %20.4s,v2.4s,v7.s[0]\n\t"\ + "fmla %21.4s,v2.4s,v7.s[1]\n\t"\ + "fmla %22.4s,v2.4s,v7.s[2]\n\t"\ + "fmla %23.4s,v2.4s,v7.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_TAIL1_A55 \ + "fmla %0.4s,v0.4s,v4.s[0]\n\t"\ + "fmla %1.4s,v0.4s,v4.s[1]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[2]\n\t"\ + "fmla %3.4s,v0.4s,v4.s[3]; fmov v1.d[1],x1\n\t"\ + "fmla %4.4s,v0.4s,v5.s[0]; ldr d2,[%25,#-16]\n\t"\ + "fmla %5.4s,v0.4s,v5.s[1]; ldr x1,[%25,#-8]\n\t"\ + "fmla %6.4s,v0.4s,v5.s[2]\n\t"\ + "fmla %7.4s,v0.4s,v5.s[3]\n\t"\ + "fmla %8.4s,v1.4s,v4.s[0]\n\t"\ + "fmla %9.4s,v1.4s,v4.s[1]\n\t"\ + "fmla %10.4s,v1.4s,v4.s[2]\n\t"\ + "fmla %11.4s,v1.4s,v4.s[3]; fmov v2.d[1],x1\n\t"\ + "fmla %12.4s,v1.4s,v5.s[0]\n\t"\ + "fmla %13.4s,v1.4s,v5.s[1]\n\t"\ + "fmla %14.4s,v1.4s,v5.s[2]\n\t"\ + "fmla %15.4s,v1.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v4.s[0]\n\t"\ + "fmla %17.4s,v2.4s,v4.s[1]\n\t"\ + "fmla %18.4s,v2.4s,v4.s[2]\n\t"\ + "fmla %19.4s,v2.4s,v4.s[3]\n\t"\ + "fmla %20.4s,v2.4s,v5.s[0]\n\t"\ + "fmla %21.4s,v2.4s,v5.s[1]\n\t"\ + "fmla %22.4s,v2.4s,v5.s[2]\n\t"\ + "fmla %23.4s,v2.4s,v5.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_PRELOAD_A72 \ + "ldr q0,[%25]; ldr q1,[%25,#16]; add %25,%25,#48\n\t"\ + "ldr q4,[%26]; ldr q5,[%26,#16]; add %26,%26,#32\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_MAIN2_A72 \ + "fmla %0.4s,v0.4s,v4.s[0]; fmla %1.4s,v0.4s,v4.s[1]; ldr q2,[%25,#-16]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[2]; fmla %3.4s,v0.4s,v4.s[3]\n\t"\ + "fmla %4.4s,v0.4s,v5.s[0]; fmla %5.4s,v0.4s,v5.s[1]; ldr q6,[%26],#64\n\t"\ + "fmla %6.4s,v0.4s,v5.s[2]; fmla %7.4s,v0.4s,v5.s[3]\n\t"\ + "fmla %8.4s,v1.4s,v4.s[0]; fmla %9.4s,v1.4s,v4.s[1]; ldr q0,[%25],#96\n\t"\ + "fmla %10.4s,v1.4s,v4.s[2]; fmla %11.4s,v1.4s,v4.s[3]\n\t"\ + "fmla %12.4s,v1.4s,v5.s[0]; fmla %13.4s,v1.4s,v5.s[1]; ldr q7,[%26,#-48]\n\t"\ + "fmla %14.4s,v1.4s,v5.s[2]; fmla %15.4s,v1.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v4.s[0]; fmla %17.4s,v2.4s,v4.s[1]; ldr q1,[%25,#-80]\n\t"\ + "fmla %18.4s,v2.4s,v4.s[2]; fmla %19.4s,v2.4s,v4.s[3]\n\t"\ + "fmla %20.4s,v2.4s,v5.s[0]; fmla %21.4s,v2.4s,v5.s[1]; sub %w24,%w24,#2\n\t"\ + "fmla %22.4s,v2.4s,v5.s[2]; fmla %23.4s,v2.4s,v5.s[3]\n\t"\ + "fmla %0.4s,v0.4s,v6.s[0]; fmla %1.4s,v0.4s,v6.s[1]; ldr q2,[%25,#-64]\n\t"\ + "fmla %2.4s,v0.4s,v6.s[2]; fmla %3.4s,v0.4s,v6.s[3]\n\t"\ + "fmla %4.4s,v0.4s,v7.s[0]; fmla %5.4s,v0.4s,v7.s[1]; ldr q4,[%26,#-32]\n\t"\ + "fmla %6.4s,v0.4s,v7.s[2]; fmla %7.4s,v0.4s,v7.s[3]\n\t"\ + "fmla %8.4s,v1.4s,v6.s[0]; fmla %9.4s,v1.4s,v6.s[1]; ldr q0,[%25,#-48]\n\t"\ + "fmla %10.4s,v1.4s,v6.s[2]; fmla %11.4s,v1.4s,v6.s[3]\n\t"\ + "fmla %12.4s,v1.4s,v7.s[0]; fmla %13.4s,v1.4s,v7.s[1]; ldr q5,[%26,#-16]\n\t"\ + "fmla %14.4s,v1.4s,v7.s[2]; fmla %15.4s,v1.4s,v7.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v6.s[0]; fmla %17.4s,v2.4s,v6.s[1]; ldr q1,[%25,#-32]\n\t"\ + "fmla %18.4s,v2.4s,v6.s[2]; fmla %19.4s,v2.4s,v6.s[3]\n\t"\ + "fmla %20.4s,v2.4s,v7.s[0]; fmla %21.4s,v2.4s,v7.s[1]; cmp %w24,#2\n\t"\ + "fmla %22.4s,v2.4s,v7.s[2]; fmla %23.4s,v2.4s,v7.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_TAIL2_A72 \ + "fmla %0.4s,v0.4s,v4.s[0]; fmla %1.4s,v0.4s,v4.s[1]; ldr q2,[%25,#-16]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[2]; fmla %3.4s,v0.4s,v4.s[3]\n\t"\ + "fmla %4.4s,v0.4s,v5.s[0]; fmla %5.4s,v0.4s,v5.s[1]; ldr q6,[%26],#32\n\t"\ + "fmla %6.4s,v0.4s,v5.s[2]; fmla %7.4s,v0.4s,v5.s[3]\n\t"\ + "fmla %8.4s,v1.4s,v4.s[0]; fmla %9.4s,v1.4s,v4.s[1]; ldr q0,[%25],#48\n\t"\ + "fmla %10.4s,v1.4s,v4.s[2]; fmla %11.4s,v1.4s,v4.s[3]\n\t"\ + "fmla %12.4s,v1.4s,v5.s[0]; fmla %13.4s,v1.4s,v5.s[1]; ldr q7,[%26,#-16]\n\t"\ + "fmla %14.4s,v1.4s,v5.s[2]; fmla %15.4s,v1.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v4.s[0]; fmla %17.4s,v2.4s,v4.s[1]; ldr q1,[%25,#-32]\n\t"\ + "fmla %18.4s,v2.4s,v4.s[2]; fmla %19.4s,v2.4s,v4.s[3]\n\t"\ + "fmla %20.4s,v2.4s,v5.s[0]; fmla %21.4s,v2.4s,v5.s[1]\n\t"\ + "fmla %22.4s,v2.4s,v5.s[2]; fmla %23.4s,v2.4s,v5.s[3]\n\t"\ + "fmla %0.4s,v0.4s,v6.s[0]; fmla %1.4s,v0.4s,v6.s[1]; ldr q2,[%25,#-16]\n\t"\ + "fmla %2.4s,v0.4s,v6.s[2]; fmla %3.4s,v0.4s,v6.s[3]\n\t"\ + "fmla %4.4s,v0.4s,v7.s[0]; fmla %5.4s,v0.4s,v7.s[1]\n\t"\ + "fmla %6.4s,v0.4s,v7.s[2]; fmla %7.4s,v0.4s,v7.s[3]\n\t"\ + "fmla %8.4s,v1.4s,v6.s[0]; fmla %9.4s,v1.4s,v6.s[1]\n\t"\ + "fmla %10.4s,v1.4s,v6.s[2]; fmla %11.4s,v1.4s,v6.s[3]\n\t"\ + "fmla %12.4s,v1.4s,v7.s[0]; fmla %13.4s,v1.4s,v7.s[1]\n\t"\ + "fmla %14.4s,v1.4s,v7.s[2]; fmla %15.4s,v1.4s,v7.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v6.s[0]; fmla %17.4s,v2.4s,v6.s[1]\n\t"\ + "fmla %18.4s,v2.4s,v6.s[2]; fmla %19.4s,v2.4s,v6.s[3]\n\t"\ + "fmla %20.4s,v2.4s,v7.s[0]; fmla %21.4s,v2.4s,v7.s[1]\n\t"\ + "fmla %22.4s,v2.4s,v7.s[2]; fmla %23.4s,v2.4s,v7.s[3]\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_TAIL1_A72 \ + "fmla %0.4s,v0.4s,v4.s[0]; fmla %1.4s,v0.4s,v4.s[1]; ldr q2,[%25,#-16]\n\t"\ + "fmla %2.4s,v0.4s,v4.s[2]; fmla %3.4s,v0.4s,v4.s[3]\n\t"\ + "fmla %4.4s,v0.4s,v5.s[0]; fmla %5.4s,v0.4s,v5.s[1]\n\t"\ + "fmla %6.4s,v0.4s,v5.s[2]; fmla %7.4s,v0.4s,v5.s[3]\n\t"\ + "fmla %8.4s,v1.4s,v4.s[0]; fmla %9.4s,v1.4s,v4.s[1]\n\t"\ + "fmla %10.4s,v1.4s,v4.s[2]; fmla %11.4s,v1.4s,v4.s[3]\n\t"\ + "fmla %12.4s,v1.4s,v5.s[0]; fmla %13.4s,v1.4s,v5.s[1]\n\t"\ + "fmla %14.4s,v1.4s,v5.s[2]; fmla %15.4s,v1.4s,v5.s[3]\n\t"\ + "fmla %16.4s,v2.4s,v4.s[0]; fmla %17.4s,v2.4s,v4.s[1]\n\t"\ + "fmla %18.4s,v2.4s,v4.s[2]; fmla %19.4s,v2.4s,v4.s[3]\n\t"\ + "fmla %20.4s,v2.4s,v5.s[0]; fmla %21.4s,v2.4s,v5.s[1]\n\t"\ + "fmla %22.4s,v2.4s,v5.s[2]; fmla %23.4s,v2.4s,v5.s[3]\n\t" + +#define NEON_SGEMM_SAVE_M12N2_UNIT(cq1, cq2, cq3, cq4, cq5, cq6) \ + ct1 = vld1q_f32(c_tmp1);\ + ct2 = vld1q_f32(c_tmp1 + 4);\ + ct3 = vld1q_f32(c_tmp1 + 8);\ + ct4 = vld1q_f32(c_tmp2);\ + ct5 = vld1q_f32(c_tmp2 + 4);\ + ct6 = vld1q_f32(c_tmp2 + 8);\ + cq1 = vfmaq_n_f32(cq1, ct1, beta); cq2 = vfmaq_n_f32(cq2, ct2, beta);\ + cq3 = vfmaq_n_f32(cq3, ct3, beta); cq4 = vfmaq_n_f32(cq4, ct4, beta);\ + cq5 = vfmaq_n_f32(cq5, ct5, beta); cq6 = vfmaq_n_f32(cq6, ct6, beta);\ + vst1q_f32(c_tmp1, cq1);\ + vst1q_f32(c_tmp1 + 4, cq2);\ + vst1q_f32(c_tmp1 + 8, cq3); c_tmp1 += ldc2;\ + vst1q_f32(c_tmp2, cq4);\ + vst1q_f32(c_tmp2 + 4, cq5);\ + vst1q_f32(c_tmp2 + 8, cq6); c_tmp2 += ldc2; + +#define NEON_SGEMM_SAVE_M12N8_ASM1 \ + float *c_tmp1 = c_ptr;\ + float *c_tmp2 = c_ptr + ldc;\ + uint32_t ldc2 = ldc * 2;\ + float32x4_t ct1, ct2, ct3, ct4, ct5, ct6;\ + NEON_SGEMM_SAVE_M12N2_UNIT(cq01, cq09, cq17, cq02, cq10, cq18)\ + NEON_SGEMM_SAVE_M12N2_UNIT(cq03, cq11, cq19, cq04, cq12, cq20)\ + NEON_SGEMM_SAVE_M12N2_UNIT(cq05, cq13, cq21, cq06, cq14, cq22)\ + NEON_SGEMM_SAVE_M12N2_UNIT(cq07, cq15, cq23, cq08, cq16, cq24) + +#define PREF_C_1_LANE(n, mdim) \ + pref_c(c_pref); pref_c(c_pref + mdim - 1); c_pref += ldc; +#define PREF_C(mdim, ndim) \ + MACRO_EXPANSION_##ndim(VOID_BASE, PREF_C_1_LANE, mdim) + +#define NEON_SGEMM_COMPUTE_ASM1(mdim, ndim, cputype) \ + float *c_pref = c_ptr; PREF_C(mdim, ndim)\ + const float *b_ptr = b_head;\ + const float *a_ptr = a_head;\ + uint32_t k_left = K;\ + float32x4_t cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + float32x4_t cq09, cq10, cq11, cq12, cq13, cq14, cq15, cq16;\ + float32x4_t cq17, cq18, cq19, cq20, cq21, cq22, cq23, cq24;\ + __asm__ __volatile__ (\ + "movi %0.16b,#0; movi %1.16b,#0\n\t"\ + "mov %2.16b,%0.16b; mov %3.16b,%1.16b\n\t"\ + "mov %4.16b,%0.16b; mov %5.16b,%1.16b\n\t"\ + "mov %6.16b,%0.16b; mov %7.16b,%1.16b\n\t"\ + "mov %8.16b,%0.16b; mov %9.16b,%1.16b\n\t"\ + "mov %10.16b,%0.16b; mov %11.16b,%1.16b\n\t"\ + "mov %12.16b,%0.16b; mov %13.16b,%1.16b\n\t"\ + "mov %14.16b,%0.16b; mov %15.16b,%1.16b\n\t"\ + "mov %16.16b,%0.16b; mov %17.16b,%1.16b\n\t"\ + "mov %18.16b,%0.16b; mov %19.16b,%1.16b\n\t"\ + "mov %20.16b,%0.16b; mov %21.16b,%1.16b\n\t"\ + "mov %22.16b,%0.16b; mov %23.16b,%1.16b\n\t"\ + "cmp %w24,#0; b.eq 4f\n\t"\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim##_PRELOAD_##cputype\ + "cmp %w24,#2; b.le 2f\n\t"\ + ".balign 16\n\t"\ + "1:\n\t"\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim##_MAIN2_##cputype "b.gt 1b\n\t"\ + "2:\n\t"\ + "cmp %w24,#2; b.ne 3f\n\t"\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim##_TAIL2_##cputype "b 4f\n\t"\ + "3:\n\t"\ + NEON_SGEMM_KERNEL_M##mdim##N##ndim##_TAIL1_##cputype\ + "4:\n\t"\ + :"=w"(cq01),"=w"(cq02),"=w"(cq03),"=w"(cq04),"=w"(cq05),"=w"(cq06),\ + "=w"(cq07),"=w"(cq08),"=w"(cq09),"=w"(cq10),"=w"(cq11),"=w"(cq12),\ + "=w"(cq13),"=w"(cq14),"=w"(cq15),"=w"(cq16),"=w"(cq17),"=w"(cq18),\ + "=w"(cq19),"=w"(cq20),"=w"(cq21),"=w"(cq22),"=w"(cq23),"=w"(cq24),\ + "+r"(k_left),"+r"(a_ptr),"+r"(b_ptr)\ + ::"cc","memory","x0","x1","v0","v1","v2","v3","v4","v5","v6","v7");\ + NEON_SGEMM_SAVE_M##mdim##N##ndim##_ASM1 + +#define NEON_SGEMM_KERNEL_M12N8_HALF_PRELOAD_A35 \ + "ld1r {v0.2s},[%25],#4\n\t"\ + "ldr d4,[%26]; ldr d5,[%26,#8]; ldr d6,[%26,#16]; add %26,%26,#32\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_HALF_MAIN2_A35 \ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %0.2s,v0.2s,v4.2s; fmla %1.2s,v0.2s,v5.2s; fmla %2.2s,v0.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %4.2s,v1.2s,v4.2s; fmla %5.2s,v1.2s,v5.2s; fmla %6.2s,v1.2s,v6.2s\n\t"\ + "ldr d7,[%26,#-8]\n\t"\ + "fmla %8.2s,v2.2s,v4.2s; fmla %9.2s,v2.2s,v5.2s; fmla %10.2s,v2.2s,v6.2s\n\t"\ + "ld1r {v3.2s},[%25],#4\n\t"\ + "fmla %3.2s,v0.2s,v7.2s; fmla %7.2s,v1.2s,v7.2s; fmla %11.2s,v2.2s,v7.2s\n\t"\ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %12.2s,v3.2s,v4.2s; fmla %13.2s,v3.2s,v5.2s; fmla %14.2s,v3.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %16.2s,v1.2s,v4.2s; add %25,%25,#24\n\t"\ + "fmla %17.2s,v1.2s,v5.2s; fmla %18.2s,v1.2s,v6.2s\n\t"\ + "ld1r {v0.2s},[%25],#4\n\t"\ + "fmla %20.2s,v2.2s,v4.2s; fmla %21.2s,v2.2s,v5.2s; fmla %22.2s,v2.2s,v6.2s\n\t"\ + "ldr d4,[%26]; ldr d5,[%26,#8]; ldr d6,[%26,#16]\n\t"\ + "fmla %15.2s,v3.2s,v7.2s; add %26,%26,#64\n\t"\ + "fmla %19.2s,v1.2s,v7.2s\n\t"\ + "fmla %23.2s,v2.2s,v7.2s\n\t"\ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %0.2s,v0.2s,v4.2s; fmla %1.2s,v0.2s,v5.2s; fmla %2.2s,v0.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %4.2s,v1.2s,v4.2s; fmla %5.2s,v1.2s,v5.2s; fmla %6.2s,v1.2s,v6.2s\n\t"\ + "ldr d7,[%26,#-40]\n\t"\ + "fmla %8.2s,v2.2s,v4.2s; fmla %9.2s,v2.2s,v5.2s; fmla %10.2s,v2.2s,v6.2s\n\t"\ + "ld1r {v3.2s},[%25],#4\n\t"\ + "fmla %3.2s,v0.2s,v7.2s; fmla %7.2s,v1.2s,v7.2s; fmla %11.2s,v2.2s,v7.2s\n\t"\ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %12.2s,v3.2s,v4.2s; fmla %13.2s,v3.2s,v5.2s; fmla %14.2s,v3.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %16.2s,v1.2s,v4.2s; add %25,%25,#24\n\t"\ + "fmla %17.2s,v1.2s,v5.2s; fmla %18.2s,v1.2s,v6.2s\n\t"\ + "ld1r {v0.2s},[%25],#4\n\t"\ + "fmla %20.2s,v2.2s,v4.2s; fmla %21.2s,v2.2s,v5.2s; fmla %22.2s,v2.2s,v6.2s\n\t"\ + "ldr d4,[%26,#-32]; ldr d5,[%26,#-24]; ldr d6,[%26,#-16]\n\t"\ + "fmla %15.2s,v3.2s,v7.2s; sub %w24,%w24,#2\n\t"\ + "fmla %19.2s,v1.2s,v7.2s; cmp %w24,#2\n\t"\ + "fmla %23.2s,v2.2s,v7.2s\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_HALF_TAIL2_A35 \ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %0.2s,v0.2s,v4.2s; fmla %1.2s,v0.2s,v5.2s; fmla %2.2s,v0.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %4.2s,v1.2s,v4.2s; fmla %5.2s,v1.2s,v5.2s; fmla %6.2s,v1.2s,v6.2s\n\t"\ + "ldr d7,[%26,#-8]\n\t"\ + "fmla %8.2s,v2.2s,v4.2s; fmla %9.2s,v2.2s,v5.2s; fmla %10.2s,v2.2s,v6.2s\n\t"\ + "ld1r {v3.2s},[%25],#4\n\t"\ + "fmla %3.2s,v0.2s,v7.2s; fmla %7.2s,v1.2s,v7.2s; fmla %11.2s,v2.2s,v7.2s\n\t"\ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %12.2s,v3.2s,v4.2s; fmla %13.2s,v3.2s,v5.2s; fmla %14.2s,v3.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %16.2s,v1.2s,v4.2s; add %25,%25,#24\n\t"\ + "fmla %17.2s,v1.2s,v5.2s; fmla %18.2s,v1.2s,v6.2s\n\t"\ + "ld1r {v0.2s},[%25],#4\n\t"\ + "fmla %20.2s,v2.2s,v4.2s; fmla %21.2s,v2.2s,v5.2s; fmla %22.2s,v2.2s,v6.2s\n\t"\ + "ldr d4,[%26]; ldr d5,[%26,#8]; ldr d6,[%26,#16]\n\t"\ + "fmla %15.2s,v3.2s,v7.2s; add %26,%26,#32\n\t"\ + "fmla %19.2s,v1.2s,v7.2s\n\t"\ + "fmla %23.2s,v2.2s,v7.2s\n\t"\ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %0.2s,v0.2s,v4.2s; fmla %1.2s,v0.2s,v5.2s; fmla %2.2s,v0.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %4.2s,v1.2s,v4.2s; fmla %5.2s,v1.2s,v5.2s; fmla %6.2s,v1.2s,v6.2s\n\t"\ + "ldr d7,[%26,#-8]\n\t"\ + "fmla %8.2s,v2.2s,v4.2s; fmla %9.2s,v2.2s,v5.2s; fmla %10.2s,v2.2s,v6.2s\n\t"\ + "ld1r {v3.2s},[%25],#4\n\t"\ + "fmla %3.2s,v0.2s,v7.2s; fmla %7.2s,v1.2s,v7.2s; fmla %11.2s,v2.2s,v7.2s\n\t"\ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %12.2s,v3.2s,v4.2s; fmla %13.2s,v3.2s,v5.2s; fmla %14.2s,v3.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %16.2s,v1.2s,v4.2s; fmla %17.2s,v1.2s,v5.2s; fmla %18.2s,v1.2s,v6.2s\n\t"\ + "fmla %20.2s,v2.2s,v4.2s; fmla %21.2s,v2.2s,v5.2s; fmla %22.2s,v2.2s,v6.2s\n\t"\ + "fmla %15.2s,v3.2s,v7.2s; add %25,%25,#24\n\t"\ + "fmla %19.2s,v1.2s,v7.2s\n\t"\ + "fmla %23.2s,v2.2s,v7.2s\n\t" + +#define NEON_SGEMM_KERNEL_M12N8_HALF_TAIL1_A35 \ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %0.2s,v0.2s,v4.2s; fmla %1.2s,v0.2s,v5.2s; fmla %2.2s,v0.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %4.2s,v1.2s,v4.2s; fmla %5.2s,v1.2s,v5.2s; fmla %6.2s,v1.2s,v6.2s\n\t"\ + "ldr d7,[%26,#-8]\n\t"\ + "fmla %8.2s,v2.2s,v4.2s; fmla %9.2s,v2.2s,v5.2s; fmla %10.2s,v2.2s,v6.2s\n\t"\ + "ld1r {v3.2s},[%25],#4\n\t"\ + "fmla %3.2s,v0.2s,v7.2s; fmla %7.2s,v1.2s,v7.2s; fmla %11.2s,v2.2s,v7.2s\n\t"\ + "ld1r {v1.2s},[%25],#4\n\t"\ + "fmla %12.2s,v3.2s,v4.2s; fmla %13.2s,v3.2s,v5.2s; fmla %14.2s,v3.2s,v6.2s\n\t"\ + "ld1r {v2.2s},[%25],#4\n\t"\ + "fmla %16.2s,v1.2s,v4.2s; fmla %17.2s,v1.2s,v5.2s; fmla %18.2s,v1.2s,v6.2s\n\t"\ + "fmla %20.2s,v2.2s,v4.2s; fmla %21.2s,v2.2s,v5.2s; fmla %22.2s,v2.2s,v6.2s\n\t"\ + "fmla %15.2s,v3.2s,v7.2s; add %25,%25,#24\n\t"\ + "fmla %19.2s,v1.2s,v7.2s\n\t"\ + "fmla %23.2s,v2.2s,v7.2s\n\t" + +#define NEON_SGEMM_SAVE_M6N2_UNIT_A35(c1, c2, c3, c4, c5, c6) \ + ct1 = vzip1_f32(c1, c2); ct2 = vzip1_f32(c3, c4); ct3 = vzip1_f32(c5, c6);\ + ct4 = vld1_f32(c_tmp), ct5 = vld1_f32(c_tmp + 2); ct6 = vld1_f32(c_tmp + 4);\ + ct1 = vfma_f32(ct1, ct4, beta_d);\ + ct2 = vfma_f32(ct2, ct5, beta_d);\ + ct3 = vfma_f32(ct3, ct6, beta_d);\ + vst1_f32(c_tmp, ct1); vst1_f32(c_tmp + 2, ct2); vst1_f32(c_tmp + 4, ct3);\ + c_tmp += ldc;\ + ct1 = vzip2_f32(c1, c2); ct2 = vzip2_f32(c3, c4); ct3 = vzip2_f32(c5, c6);\ + ct4 = vld1_f32(c_tmp), ct5 = vld1_f32(c_tmp + 2); ct6 = vld1_f32(c_tmp + 4);\ + ct1 = vfma_f32(ct1, ct4, beta_d);\ + ct2 = vfma_f32(ct2, ct5, beta_d);\ + ct3 = vfma_f32(ct3, ct6, beta_d);\ + vst1_f32(c_tmp, ct1); vst1_f32(c_tmp + 2, ct2); vst1_f32(c_tmp + 4, ct3);\ + c_tmp += ldc; + +#define NEON_SGEMM_SAVE_M6N8_A35 \ + NEON_SGEMM_SAVE_M6N2_UNIT_A35(c01, c05, c09, c13, c17, c21)\ + NEON_SGEMM_SAVE_M6N2_UNIT_A35(c02, c06, c10, c14, c18, c22)\ + NEON_SGEMM_SAVE_M6N2_UNIT_A35(c03, c07, c11, c15, c19, c23)\ + NEON_SGEMM_SAVE_M6N2_UNIT_A35(c04, c08, c12, c16, c20, c24) + +#define NEON_SGEMM_SAVE_M8N1_UNIT_A35(c1, c2, c3, c4) \ + ct1 = vld1_f32(c_tmp); ct2 = vld1_f32(c_tmp + 2);\ + ct3 = vld1_f32(c_tmp + 4); ct4 = vld1_f32(c_tmp + 6);\ + c1 = vfma_f32(c1, ct1, beta_d); c2 = vfma_f32(c2, ct2, beta_d);\ + c3 = vfma_f32(c3, ct3, beta_d); c4 = vfma_f32(c4, ct4, beta_d);\ + vst1_f32(c_tmp, c1); vst1_f32(c_tmp + 2, c2);\ + vst1_f32(c_tmp + 4, c3); vst1_f32(c_tmp + 6, c4); c_tmp += ldc; + +#define NEON_SGEMM_SAVE_M8N6_A35 \ + NEON_SGEMM_SAVE_M8N1_UNIT_A35(c01, c02, c03, c04)\ + NEON_SGEMM_SAVE_M8N1_UNIT_A35(c05, c06, c07, c08)\ + NEON_SGEMM_SAVE_M8N1_UNIT_A35(c09, c10, c11, c12)\ + NEON_SGEMM_SAVE_M8N1_UNIT_A35(c13, c14, c15, c16)\ + NEON_SGEMM_SAVE_M8N1_UNIT_A35(c17, c18, c19, c20)\ + NEON_SGEMM_SAVE_M8N1_UNIT_A35(c21, c22, c23, c24) + +#define NEON_SGEMM_KERNEL_M12N8_HALF_A35(a_ptr, b_ptr) \ + k_left = K;\ + __asm__ __volatile__ (\ + "movi %0.8b,#0; movi %1.8b,#0\n\t"\ + "mov %2.8b,%0.8b; mov %3.8b,%1.8b\n\t"\ + "mov %4.8b,%0.8b; mov %5.8b,%1.8b\n\t"\ + "mov %6.8b,%0.8b; mov %7.8b,%1.8b\n\t"\ + "mov %8.8b,%0.8b; mov %9.8b,%1.8b\n\t"\ + "mov %10.8b,%0.8b; mov %11.8b,%1.8b\n\t"\ + "mov %12.8b,%0.8b; mov %13.8b,%1.8b\n\t"\ + "mov %14.8b,%0.8b; mov %15.8b,%1.8b\n\t"\ + "mov %16.8b,%0.8b; mov %17.8b,%1.8b\n\t"\ + "mov %18.8b,%0.8b; mov %19.8b,%1.8b\n\t"\ + "mov %20.8b,%0.8b; mov %21.8b,%1.8b\n\t"\ + "mov %22.8b,%0.8b; mov %23.8b,%1.8b\n\t"\ + "cmp %w24,#0; b.eq 4f\n\t"\ + NEON_SGEMM_KERNEL_M12N8_HALF_PRELOAD_A35\ + "cmp %w24,#2; b.le 2f\n\t"\ + ".balign 16\n\t"\ + "1:\n\t"\ + NEON_SGEMM_KERNEL_M12N8_HALF_MAIN2_A35 "b.gt 1b\n\t"\ + "2:\n\t"\ + "cmp %w24,#2; b.ne 3f\n\t"\ + NEON_SGEMM_KERNEL_M12N8_HALF_TAIL2_A35 "b 4f\n\t"\ + "3:\n\t"\ + NEON_SGEMM_KERNEL_M12N8_HALF_TAIL1_A35\ + "4:\n\t"\ + :"=w"(c01),"=w"(c02),"=w"(c03),"=w"(c04),"=w"(c05),"=w"(c06),\ + "=w"(c07),"=w"(c08),"=w"(c09),"=w"(c10),"=w"(c11),"=w"(c12),\ + "=w"(c13),"=w"(c14),"=w"(c15),"=w"(c16),"=w"(c17),"=w"(c18),\ + "=w"(c19),"=w"(c20),"=w"(c21),"=w"(c22),"=w"(c23),"=w"(c24),\ + "+r"(k_left),"+r"(a_ptr),"+r"(b_ptr)\ + ::"cc","memory","v0","v1","v2","v3","v4","v5","v6","v7"); + +#define NEON_SGEMM_COMPUTE_M8N12_A35 \ + uint32_t k_left;\ + float32x2_t c01, c02, c03, c04, c05, c06, c07, c08;\ + float32x2_t c09, c10, c11, c12, c13, c14, c15, c16;\ + float32x2_t c17, c18, c19, c20, c21, c22, c23, c24;\ + float *c_pref = c_ptr; PREF_C(8, 6)\ + const float *a_ptr = a_head;\ + const float *b_ptr = b_head;\ + NEON_SGEMM_KERNEL_M12N8_HALF_A35(b_ptr, a_ptr)\ + const float32x2_t beta_d = vdup_n_f32(beta);\ + float *c_tmp = c_ptr;\ + float32x2_t ct1, ct2, ct3, ct4;\ + NEON_SGEMM_SAVE_M8N6_A35\ + a_ptr = a_head; b_ptr = b_head + 6;\ + PREF_C(8, 6)\ + NEON_SGEMM_KERNEL_M12N8_HALF_A35(b_ptr, a_ptr)\ + NEON_SGEMM_SAVE_M8N6_A35 + +#define NEON_SGEMM_COMPUTE_M12N8_A35 \ + uint32_t k_left;\ + float32x2_t c01, c02, c03, c04, c05, c06, c07, c08;\ + float32x2_t c09, c10, c11, c12, c13, c14, c15, c16;\ + float32x2_t c17, c18, c19, c20, c21, c22, c23, c24;\ + float *c_pref = c_ptr; PREF_C(6, 8)\ + const float *a_ptr = a_head;\ + const float *b_ptr = b_head;\ + NEON_SGEMM_KERNEL_M12N8_HALF_A35(a_ptr, b_ptr)\ + const float32x2_t beta_d = vdup_n_f32(beta);\ + float *c_tmp = c_ptr;\ + float32x2_t ct1, ct2, ct3, ct4, ct5, ct6;\ + NEON_SGEMM_SAVE_M6N8_A35\ + c_tmp -= 8 * ldc;\ + c_tmp += 6;\ + c_pref = c_ptr + 6; PREF_C(6, 8)\ + b_ptr = b_head; a_ptr = a_head + 6;\ + NEON_SGEMM_KERNEL_M12N8_HALF_A35(a_ptr, b_ptr)\ + NEON_SGEMM_SAVE_M6N8_A35 + +#define CPUID_DETECT_MNK 1000000 + +void sgemm_kernel_lm_m8n12(uint32_t M, uint32_t N, uint32_t K, float beta, + const float * __restrict__ sa, const float * __restrict__ sb, + float * __restrict__ C, uint32_t ldc) { + uint32_t n_left = N; + const float *b_head = sb; + float *c_head = C; + uint32_t acc_mnk = CPUID_DETECT_MNK; + uint8_t cpuid = 0, cputype = 0; + for (; n_left > 11; n_left -= 12) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const float *a_head = sa; + float *c_ptr = c_head; + uint32_t m_left = M; + if (cputype == 53) { + for (; m_left > 7; m_left -= 8) { + NEON_SGEMM_COMPUTE_ASM1(8, 12, A53) + a_head += 8 * K; + c_ptr += 8; + } + } else if (cputype == 55) { + for (; m_left > 7; m_left -= 8) { + NEON_SGEMM_COMPUTE_ASM1(8, 12, A55) + a_head += 8 * K; + c_ptr += 8; + } + } else if (cputype == 35) { + for (; m_left > 7; m_left -= 8) { + NEON_SGEMM_COMPUTE_M8N12_A35 + a_head += 8 * K; + c_ptr += 8; + } + } else { + for (; m_left > 7; m_left -= 8) { + NEON_SGEMM_COMPUTE_ASM1(8, 12, A72) + a_head += 8 * K; + c_ptr += 8; + } + } + MICRO_COMPUTE_LM(4, 12, float, float, float) + b_head += K * 12; + c_head += ldc * 12; + acc_mnk += 12 * K * M; + } + ASSEMBLE_DUALPACK_COMPUTE_LM(8, float, float, float, 8) +} + +void sgemm_kernel_ln_m12n8(uint32_t M, uint32_t N, uint32_t K, float beta, + const float * __restrict__ sa, const float * __restrict__ sb, + float * __restrict__ C, uint32_t ldc) { + uint32_t m_left = M; + const float *a_head = sa; + float *c_head = C; + uint32_t acc_mnk = CPUID_DETECT_MNK; + uint8_t cpuid = 0, cputype = 0; + for (; m_left > 11; m_left -= 12) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const float *b_head = sb; + float *c_ptr = c_head; + uint32_t n_left = N; + if (cputype == 53) { + for (; n_left > 7; n_left -= 8) { + NEON_SGEMM_COMPUTE_ASM1(12, 8, A53) + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } else if (cputype == 55) { + for (; n_left > 7; n_left -= 8) { + NEON_SGEMM_COMPUTE_ASM1(12, 8, A55) + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } else if (cputype == 35) { + for (; n_left > 7; n_left -= 8) { + NEON_SGEMM_COMPUTE_M12N8_A35 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } else { + for (; n_left > 7; n_left -= 8) { + NEON_SGEMM_COMPUTE_ASM1(12, 8, A72) + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } + MICRO_COMPUTE_LN(12, 4, float, float, float) + a_head += K * 12; + c_head += 12; + acc_mnk += 12 * N * K; + } + ASSEMBLE_DUALPACK_COMPUTE_LN(8, float, float, float, 8) +} + diff --git a/src/neon_armv8a/SgemmSkinnyDot.c b/src/neon_armv8a/SgemmSkinnyDot.c new file mode 100644 index 0000000..3c20438 --- /dev/null +++ b/src/neon_armv8a/SgemmSkinnyDot.c @@ -0,0 +1,800 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#define _GNU_SOURCE +#include "arm_neon/ARMCompareAndSwap.h" +#include "arm_neon/ARMCpuType.h" +#include "common/CommonSkinnyDot.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA35.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA53.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA7x.h" +#include +#include + +typedef float sgemm_skinnydot_ascalar; +typedef float sgemm_skinnydot_bscalar; +typedef float sgemm_skinnydot_cscalar; + +static inline void inline_sgemm_arowmajor_bskinny_m4n1(const float *a_ptr1, + const float *b_ptr, float *c_ptr, uint32_t k_inc, uint32_t LDK, + uint32_t LDM, float beta, bool c_rowmajor) { + + const float *a_ptr2 = a_ptr1 + LDK; + const float *a_ptr3 = a_ptr1 + LDK * 2; + const float *a_ptr4 = a_ptr2 + LDK * 2; + + float32x2_t cd1, cd2, cd3, cd4, cd5, cd6, cd7, cd8; + const float *a_pref = a_ptr4 + LDK; + const uint32_t pref_inc = (LDK > k_inc) ? + (LDK - k_inc) * sizeof(float) : 64; + uint32_t k_left = k_inc; + __asm__ __volatile__( + "movz w0,#0; movz w1,#64\n\t" //pref + "movi %[cd1].8b,#0; movi %[cd2].8b,#0\n\t" + "movi %[cd3].8b,#0; movi %[cd4].8b,#0\n\t" + "movi %[cd5].8b,#0; movi %[cd6].8b,#0\n\t" + "movi %[cd7].8b,#0; movi %[cd8].8b,#0\n\t" + "cmp %w[k_left],#4; b.lt 3f\n\t" + "ldr d2,[%[a_ptr1]],#16; ldr d3,[%[a_ptr2]],#16\n\t" + "ldr d4,[%[a_ptr3]],#16; ldr d5,[%[a_ptr4]],#16\n\t" + "ldr d0,[%[b_ptr]],#16\n\t" + "ldr d6,[%[a_ptr1],#-8]; ldr d7,[%[a_ptr2],#-8]\n\t" + "ldr d8,[%[a_ptr3],#-8]; ldr d9,[%[a_ptr4],#-8]\n\t" + "ldr d1,[%[b_ptr],#-8]\n\t" + "cmp %w[k_left],#8; b.lt 2f\n\t" + ".balign 16; 1:\n\t" + "prfm pldl2keep,[%[a_pref]]; add w0,w0,#16\n\t" + "fmla %[cd1].2s,v2.2s,v0.2s; ldr d2,[%[a_ptr1]],#16\n\t" + "cmp w0,%w[k_inc]\n\t" + "fmla %[cd2].2s,v3.2s,v0.2s; ldr d3,[%[a_ptr2]],#16\n\t" + "csel w2,%w[pref_inc],w1,gt\n\t" + "fmla %[cd3].2s,v4.2s,v0.2s; ldr d4,[%[a_ptr3]],#16\n\t" + "fmla %[cd4].2s,v5.2s,v0.2s; ldr d5,[%[a_ptr4]],#16\n\t" + "csel w0,wzr,w0,gt\n\t" + "ldr d0,[%[b_ptr]],#16; sub %w[k_left],%w[k_left],#4\n\t" + "fmla %[cd5].2s,v6.2s,v1.2s; ldr d6,[%[a_ptr1],#-8]\n\t" + "add %[a_pref],%[a_pref],x2\n\t" + "fmla %[cd6].2s,v7.2s,v1.2s; ldr d7,[%[a_ptr2],#-8]\n\t" + "cmp %w[k_left],#8\n\t" + "fmla %[cd7].2s,v8.2s,v1.2s; ldr d8,[%[a_ptr3],#-8]\n\t" + "fmla %[cd8].2s,v9.2s,v1.2s; ldr d9,[%[a_ptr4],#-8]\n\t" + "ldr d1,[%[b_ptr],#-8]; b.ge 1b\n\t" + "2:\n\t" + "fmla %[cd1].2s,v2.2s,v0.2s; fmla %[cd2].2s,v3.2s,v0.2s\n\t" + "fmla %[cd3].2s,v4.2s,v0.2s; fmla %[cd4].2s,v5.2s,v0.2s\n\t" + "sub %w[k_left],%w[k_left],#4\n\t" + "fmla %[cd5].2s,v6.2s,v1.2s; fmla %[cd6].2s,v7.2s,v1.2s\n\t" + "fmla %[cd7].2s,v8.2s,v1.2s; fmla %[cd8].2s,v9.2s,v1.2s\n\t" + "3:\n\t" + :[cd1]"=w"(cd1), [cd2]"=w"(cd2), [cd3]"=w"(cd3), [cd4]"=w"(cd4), + [cd5]"=w"(cd5), [cd6]"=w"(cd6), [cd7]"=w"(cd7), [cd8]"=w"(cd8), + [k_left]"+r"(k_left), [a_pref]"+r"(a_pref), [b_ptr]"+r"(b_ptr), + [a_ptr1]"+r"(a_ptr1), [a_ptr2]"+r"(a_ptr2), + [a_ptr3]"+r"(a_ptr3), [a_ptr4]"+r"(a_ptr4) + :[k_inc]"r"(k_inc), [pref_inc]"r"(pref_inc) + :"cc","memory","x0","x1","x2", + "v0","v1","v2","v3","v4","v5","v6","v7","v8","v9"); + + cd1 = vadd_f32(cd1, cd5); cd2 = vadd_f32(cd2, cd6); + cd3 = vadd_f32(cd3, cd7); cd4 = vadd_f32(cd4, cd8); + float cs1 = vpadds_f32(cd1); + float cs2 = vpadds_f32(cd2); + float cs3 = vpadds_f32(cd3); + float cs4 = vpadds_f32(cd4); + + for (; k_left > 0; k_left--) { + float bs1 = *b_ptr; b_ptr++; + cs1 += (*a_ptr1) * bs1; a_ptr1++; + cs2 += (*a_ptr2) * bs1; a_ptr2++; + cs3 += (*a_ptr3) * bs1; a_ptr3++; + cs4 += (*a_ptr4) * bs1; a_ptr4++; + } + c_ptr[0] = c_ptr[0] * beta + cs1; + c_ptr[1] = c_ptr[1] * beta + cs2; + c_ptr[2] = c_ptr[2] * beta + cs3; + c_ptr[3] = c_ptr[3] * beta + cs4; +} + +static inline void inline_sgemm_arowmajor_bskinny_m1n1(const float *a_ptr, + const float *b_ptr, float *c_ptr, uint32_t k_inc, uint32_t LDK, + uint32_t LDM, float beta, bool c_rowmajor) { + + float cs1; + __asm__ __volatile__( + "movi v16.8b,#0; movi v17.8b,#0\n\t" + "mov v18.8b,v16.8b; mov v19.8b,v17.8b\n\t" + "mov v20.8b,v16.8b; mov v21.8b,v17.8b\n\t" + "mov v22.8b,v16.8b; mov v23.8b,v17.8b\n\t" + "cmp %w[K],#16; b.lt 4f\n\t" + "prfm pldl1keep,[%[a_ptr],#256]\n\t" + "ldr d0,[%[a_ptr]],#64; ldr d8,[%[b_ptr]],#64\n\t" + "ldr d1,[%[a_ptr],#-56]; ldr d9,[%[b_ptr],#-56]\n\t" + "ldr d2,[%[a_ptr],#-48]; ldr d10,[%[b_ptr],#-48]\n\t" + "ldr d3,[%[a_ptr],#-40]; ldr d11,[%[b_ptr],#-40]\n\t" + "ldr d4,[%[a_ptr],#-32]; ldr d12,[%[b_ptr],#-32]\n\t" + "ldr d5,[%[a_ptr],#-24]; ldr d13,[%[b_ptr],#-24]\n\t" + "ldr d6,[%[a_ptr],#-16]; ldr d14,[%[b_ptr],#-16]\n\t" + "ldr d7,[%[a_ptr],#-8]; ldr d15,[%[b_ptr],#-8]\n\t" + "cmp %w[K],#32; b.lt 3f\n\t" + "2:\n\t" + "prfm pldl1keep,[%[a_ptr],#256]\n\t" + "fmla v16.2s,v0.2s,v8.2s; ldr d0,[%[a_ptr]],#64; ldr d8,[%[b_ptr]],#64\n\t" + "fmla v17.2s,v1.2s,v9.2s; ldr d1,[%[a_ptr],#-56]; ldr d9,[%[b_ptr],#-56]\n\t" + "fmla v18.2s,v2.2s,v10.2s; ldr d2,[%[a_ptr],#-48]; ldr d10,[%[b_ptr],#-48]\n\t" + "fmla v19.2s,v3.2s,v11.2s; ldr d3,[%[a_ptr],#-40]; ldr d11,[%[b_ptr],#-40]\n\t" + "sub %w[K],%w[K],#16\n\t" + "fmla v20.2s,v4.2s,v12.2s; ldr d4,[%[a_ptr],#-32]; ldr d12,[%[b_ptr],#-32]\n\t" + "fmla v21.2s,v5.2s,v13.2s; ldr d5,[%[a_ptr],#-24]; ldr d13,[%[b_ptr],#-24]\n\t" + "cmp %w[K],#32\n\t" + "fmla v22.2s,v6.2s,v14.2s; ldr d6,[%[a_ptr],#-16]; ldr d14,[%[b_ptr],#-16]\n\t" + "fmla v23.2s,v7.2s,v15.2s; ldr d7,[%[a_ptr],#-8]; ldr d15,[%[b_ptr],#-8]\n\t" + "b.ge 2b\n\t" + "3:\n\t" + "fmla v16.2s,v0.2s,v8.2s; fmla v17.2s,v1.2s,v9.2s\n\t" + "fmla v18.2s,v2.2s,v10.2s; fmla v19.2s,v3.2s,v11.2s; sub %w[K],%w[K],#16\n\t" + "fmla v20.2s,v4.2s,v12.2s; fmla v21.2s,v5.2s,v13.2s\n\t" + "fmla v22.2s,v6.2s,v14.2s; fmla v23.2s,v7.2s,v15.2s\n\t" + "4:\n\t" + "fadd v16.2s,v16.2s,v20.2s; fadd v17.2s,v17.2s,v21.2s\n\t" + "fadd v18.2s,v18.2s,v22.2s; fadd v19.2s,v19.2s,v23.2s\n\t" + "cmp %w[K],#8; b.lt 5f\n\t" + "ldr d0,[%[a_ptr]],#32; ldr d8,[%[b_ptr]],#32; fmla v16.2s,v0.2s,v8.2s\n\t" + "ldr d1,[%[a_ptr],#-24]; ldr d9,[%[b_ptr],#-24]; fmla v17.2s,v1.2s,v9.2s\n\t" + "sub %w[K],%w[K],#8\n\t" + "ldr d2,[%[a_ptr],#-16]; ldr d10,[%[b_ptr],#-16]; fmla v18.2s,v2.2s,v10.2s\n\t" + "ldr d3,[%[a_ptr],#-8]; ldr d11,[%[b_ptr],#-8]; fmla v19.2s,v3.2s,v11.2s\n\t" + "5:\n\t" + "fadd v16.2s,v16.2s,v18.2s; fadd v17.2s,v17.2s,v19.2s\n\t" + "cmp %w[K],#4; b.lt 6f\n\t" + "ldr d0,[%[a_ptr]],#16; ldr d8,[%[b_ptr]],#16; fmla v16.2s,v0.2s,v8.2s\n\t" + "sub %w[K],%w[K],#4\n\t" + "ldr d1,[%[a_ptr],#-8]; ldr d9,[%[b_ptr],#-8]; fmla v17.2s,v1.2s,v9.2s\n\t" + "6:\n\t" + "fadd v16.2s,v16.2s,v17.2s\n\t" + "cmp %w[K],#2; b.lt 7f\n\t" + "ldr d0,[%[a_ptr]],#8; ldr d8,[%[b_ptr]],#8; fmla v16.2s,v0.2s,v8.2s\n\t" + "sub %w[K],%w[K],#2\n\t" + "7:\n\t" + "faddp %s[cs1],v16.2s\n\t" + "cmp %w[K],#1; b.lt 10f\n\t" + "ldr s0,[%[a_ptr]],#4; ldr s8,[%[b_ptr]],#4; fmla %s[cs1],s0,v8.s[0]\n\t" + "10:\n\t" + :[cs1]"=w"(cs1), [a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr), [K]"+r"(k_inc) + ::"cc","memory","v0","v1","v2","v3","v4","v5", + "v6","v7","v8","v9","v10","v11","v12","v13","v14","v15","v16","v17", + "v18","v19","v20","v21","v22","v23"); + c_ptr[0] = c_ptr[0] * beta + cs1; +} + +/* k_mask = 7 */ +static inline void inline_sgemm_arowmajor_bskinny_m4n2(const float *a_ptr1, + const float *b_ptr, float *c_ptr, uint32_t k_inc, uint32_t LDK, + uint32_t LDM, float beta, bool c_rowmajor) { + + float32x4_t cq1, cq2, cq3, cq4, cq5, cq6, cq7, cq8; + const float *a_ptr2 = a_ptr1 + LDK; + const float *a_ptr3 = a_ptr1 + LDK * 2; + const float *a_ptr4 = a_ptr2 + LDK * 2; + uint32_t k_left = k_inc; + const float *a_pref = a_ptr4 + LDK; + const uint32_t pref_inc = (LDK > k_inc) ? + (LDK - k_inc) * sizeof(float) : 64; + __asm__ __volatile__( + "movz w0,#0; movz w1,#64\n\t" //pref + "movi %[cq1].16b,#0; movi %[cq2].16b,#0\n\t" + "movi %[cq3].16b,#0; movi %[cq4].16b,#0\n\t" + "movi %[cq5].16b,#0; movi %[cq6].16b,#0\n\t" + "movi %[cq7].16b,#0; movi %[cq8].16b,#0\n\t" + "cmp %w[k_left],#4; b.lt 3f\n\t" + "ldr q2,[%[a_ptr1]],#16; ldr q3,[%[a_ptr2]],#16\n\t" + "ldr q4,[%[a_ptr3]],#16; ldr q5,[%[a_ptr4]],#16\n\t" + "ldr q0,[%[b_ptr]]; ldr q1,[%[b_ptr],#16]; add %[b_ptr],%[b_ptr],#32\n\t" + "cmp %w[k_left],#8; b.lt 2f\n\t" + ".balign 16; 1:\n\t" + "prfm pldl2keep,[%[a_pref]]; add w0,w0,#16\n\t" + "fmla %[cq1].4s,v2.4s,v0.4s; fmla %[cq5].4s,v2.4s,v1.4s\n\t" + "ldr q2,[%[a_ptr1]],#16; cmp w0,%w[k_inc]\n\t" + "fmla %[cq2].4s,v3.4s,v0.4s; fmla %[cq6].4s,v3.4s,v1.4s\n\t" + "ldr q3,[%[a_ptr2]],#16; csel w2,%w[pref_inc],w1,gt\n\t" + "sub %w[k_left],%w[k_left],#4\n\t" + "fmla %[cq3].4s,v4.4s,v0.4s; fmla %[cq7].4s,v4.4s,v1.4s\n\t" + "ldr q4,[%[a_ptr3]],#16; csel w0,wzr,w0,gt\n\t" + "cmp %w[k_left],#8\n\t" + "fmla %[cq4].4s,v5.4s,v0.4s; fmla %[cq8].4s,v5.4s,v1.4s\n\t" + "ldr q5,[%[a_ptr4]],#16; add %[a_pref],%[a_pref],x2\n\t" + "ldr q0,[%[b_ptr]]; ldr q1,[%[b_ptr],#16]\n\t" + "add %[b_ptr],%[b_ptr],#32; b.ge 1b\n\t" + "2:\n\t" + "fmla %[cq1].4s,v2.4s,v0.4s; fmla %[cq5].4s,v2.4s,v1.4s\n\t" + "fmla %[cq2].4s,v3.4s,v0.4s; fmla %[cq6].4s,v3.4s,v1.4s\n\t" + "fmla %[cq3].4s,v4.4s,v0.4s; fmla %[cq7].4s,v4.4s,v1.4s\n\t" + "fmla %[cq4].4s,v5.4s,v0.4s; fmla %[cq8].4s,v5.4s,v1.4s\n\t" + "sub %w[k_left],%w[k_left],#4\n\t" + "3:\n\t" + :[cq1]"=w"(cq1), [cq2]"=w"(cq2), [cq3]"=w"(cq3), [cq4]"=w"(cq4), + [cq5]"=w"(cq5), [cq6]"=w"(cq6), [cq7]"=w"(cq7), [cq8]"=w"(cq8), + [k_left]"+r"(k_left), [a_pref]"+r"(a_pref), + [b_ptr]"+r"(b_ptr), [a_ptr1]"+r"(a_ptr1), [a_ptr2]"+r"(a_ptr2), + [a_ptr3]"+r"(a_ptr3), [a_ptr4]"+r"(a_ptr4) + :[k_inc]"r"(k_inc), [pref_inc]"r"(pref_inc) + :"x0","x1","x2","v0","v1","v2","v3","v4","v5","cc","memory"); + + cq1 = vpaddq_f32(cq1, cq5); + cq2 = vpaddq_f32(cq2, cq6); + cq3 = vpaddq_f32(cq3, cq7); + cq4 = vpaddq_f32(cq4, cq8); + + if (k_left >= 2) { + float32x4_t bq1 = vld1q_f32(b_ptr); b_ptr += 4; + float32x2_t ad1 = vld1_f32(a_ptr1); a_ptr1 += 2; + float32x2_t ad2 = vld1_f32(a_ptr2); a_ptr2 += 2; + float32x2_t ad3 = vld1_f32(a_ptr3); a_ptr3 += 2; + float32x2_t ad4 = vld1_f32(a_ptr4); a_ptr4 += 2; + float32x4_t aq1 = vcombine_f32(ad1, ad1); + float32x4_t aq2 = vcombine_f32(ad2, ad2); + float32x4_t aq3 = vcombine_f32(ad3, ad3); + float32x4_t aq4 = vcombine_f32(ad4, ad4); + cq1 = vfmaq_f32(cq1, aq1, bq1); + cq2 = vfmaq_f32(cq2, aq2, bq1); + cq3 = vfmaq_f32(cq3, aq3, bq1); + cq4 = vfmaq_f32(cq4, aq4, bq1); + k_left -= 2; + } + + float32x2_t cd1 = vget_low_f32(vpaddq_f32(cq1, cq1)); + float32x2_t cd2 = vget_low_f32(vpaddq_f32(cq2, cq2)); + float32x2_t cd3 = vget_low_f32(vpaddq_f32(cq3, cq3)); + float32x2_t cd4 = vget_low_f32(vpaddq_f32(cq4, cq4)); + + if (k_left > 0) { + float32x2_t bd1 = vld1_f32(b_ptr); + float32x2_t ad1 = vld1_dup_f32(a_ptr1); + float32x2_t ad2 = vld1_dup_f32(a_ptr2); + float32x2_t ad3 = vld1_dup_f32(a_ptr3); + float32x2_t ad4 = vld1_dup_f32(a_ptr4); + cd1 = vfma_f32(cd1, ad1, bd1); + cd2 = vfma_f32(cd2, ad2, bd1); + cd3 = vfma_f32(cd3, ad3, bd1); + cd4 = vfma_f32(cd4, ad4, bd1); + } + + if (c_rowmajor) { + cd1 = vfma_n_f32(cd1, vld1_f32(c_ptr), beta); + cd2 = vfma_n_f32(cd2, vld1_f32(c_ptr + 2), beta); + cd3 = vfma_n_f32(cd3, vld1_f32(c_ptr + 4), beta); + cd4 = vfma_n_f32(cd4, vld1_f32(c_ptr + 6), beta); + vst1_f32(c_ptr, cd1); + vst1_f32(c_ptr + 2, cd2); + vst1_f32(c_ptr + 4, cd3); + vst1_f32(c_ptr + 6, cd4); + } else { + float32x2_t cd00 = vzip1_f32(cd1, cd2); + float32x2_t cd01 = vzip1_f32(cd3, cd4); + float32x2_t cd10 = vzip2_f32(cd1, cd2); + float32x2_t cd11 = vzip2_f32(cd3, cd4); + float *c_ptr1 = c_ptr; + float *c_ptr2 = c_ptr + LDM; + cd00 = vfma_n_f32(cd00, vld1_f32(c_ptr1), beta); + cd01 = vfma_n_f32(cd01, vld1_f32(c_ptr1 + 2), beta); + cd10 = vfma_n_f32(cd10, vld1_f32(c_ptr2), beta); + cd11 = vfma_n_f32(cd11, vld1_f32(c_ptr2 + 2), beta); + vst1_f32(c_ptr1, cd00); + vst1_f32(c_ptr1 + 2, cd01); + vst1_f32(c_ptr2, cd10); + vst1_f32(c_ptr2 + 2, cd11); + } +} + +static inline void inline_sgemm_arowmajor_bskinny_m1n2(const float *a_ptr, + const float *b_ptr, float *c_ptr, uint32_t k_inc, uint32_t LDK, + uint32_t LDM, float beta, bool c_rowmajor) { + + uint32_t k_left = k_inc; + float cs1, cs2; + __asm__ __volatile__ ( + "movi v8.16b,#0; movi v9.16b,#0\n\t" + "mov v10.16b,v8.16b; mov v11.16b,v9.16b\n\t" + "mov v16.16b,v8.16b; mov v17.16b,v9.16b\n\t" + "mov v18.16b,v8.16b; mov v19.16b,v9.16b\n\t" + "cmp %w[k_left],#16; b.lt 4f\n\t" + "prfm pldl1keep,[%[a_ptr],#256]\n\t" + "ldr q0,[%[a_ptr]]; ldr q1,[%[a_ptr],#16]\n\t" + "ldr q2,[%[a_ptr],#32]; ldr q3,[%[a_ptr],#48]\n\t" + "add %[a_ptr],%[a_ptr],#64\n\t" + "ldr q4,[%[b_ptr]]; ldr q12,[%[b_ptr],#16]\n\t" + "ldr q5,[%[b_ptr],#32]; ldr q13,[%[b_ptr],#48]\n\t" + "ldr q6,[%[b_ptr],#64]; ldr q14,[%[b_ptr],#80]\n\t" + "ldr q7,[%[b_ptr],#96]; ldr q15,[%[b_ptr],#112]\n\t" + "add %[b_ptr],%[b_ptr],#128\n\t" + "cmp %w[k_left],#32; b.lt 3f\n\t" + ".balign 16; 2:\n\t" + "prfm pldl1keep,[%[a_ptr],#256]\n\t" + "fmla v8.4s,v0.4s,v4.4s; ldr q4,[%[b_ptr]]\n\t" + "fmla v10.4s,v0.4s,v12.4s; ldr q12,[%[b_ptr],#16]; ldr q0,[%[a_ptr]],#64\n\t" + "fmla v9.4s,v1.4s,v5.4s; ldr q5,[%[b_ptr],#32]\n\t" + "fmla v11.4s,v1.4s,v13.4s; ldr q13,[%[b_ptr],#48]; ldr q1,[%[a_ptr],#-48]\n\t" + "sub %w[k_left],%w[k_left],#16\n\t" + "fmla v16.4s,v2.4s,v6.4s; ldr q6,[%[b_ptr],#64]\n\t" + "fmla v18.4s,v2.4s,v14.4s; ldr q14,[%[b_ptr],#80]; ldr q2,[%[a_ptr],#-32]\n\t" + "cmp %w[k_left],#32\n\t" + "fmla v17.4s,v3.4s,v7.4s; ldr q7,[%[b_ptr],#96]\n\t" + "fmla v19.4s,v3.4s,v15.4s; ldr q15,[%[b_ptr],#112]; ldr q3,[%[a_ptr],#-16]\n\t" + "add %[b_ptr],%[b_ptr],#128; b.ge 2b\n\t" + "3:\n\t" + "fmla v8.4s,v0.4s,v4.4s; fmla v10.4s,v0.4s,v12.4s\n\t" + "fmla v9.4s,v1.4s,v5.4s; fmla v11.4s,v1.4s,v13.4s\n\t" + "sub %w[k_left],%w[k_left],#16\n\t" + "fmla v16.4s,v2.4s,v6.4s; fmla v18.4s,v2.4s,v14.4s\n\t" + "fmla v17.4s,v3.4s,v7.4s; fmla v19.4s,v3.4s,v15.4s\n\t" + "4:\n\t" + "fadd v8.4s,v8.4s,v16.4s; fadd v9.4s,v9.4s,v17.4s\n\t" + "fadd v10.4s,v10.4s,v18.4s; fadd v11.4s,v11.4s,v19.4s\n\t" + "cmp %w[k_left],#8; b.lt 5f\n\t" + "ldr q0,[%[a_ptr]],#32; ldr q4,[%[b_ptr]]; ldr q12,[%[b_ptr],#16]\n\t" + "fmla v8.4s,v0.4s,v4.4s; fmla v10.4s,v0.4s,v12.4s\n\t" + "sub %w[k_left],%w[k_left],#8\n\t" + "ldr q1,[%[a_ptr],#-16]; ldr q5,[%[b_ptr],#32]; ldr q13,[%[b_ptr],#48]\n\t" + "add %[b_ptr],%[b_ptr],#64\n\t" + "fmla v9.4s,v1.4s,v5.4s; fmla v11.4s,v1.4s,v13.4s\n\t" + "5:\n\t" + "fadd v8.4s,v8.4s,v9.4s; fadd v10.4s,v10.4s,v11.4s\n\t" + "cmp %w[k_left],#4; b.lt 6f\n\t" + "ldr q0,[%[a_ptr]],#16; ldr q4,[%[b_ptr]]; ldr q12,[%[b_ptr],#16]\n\t" + "fmla v8.4s,v0.4s,v4.4s; fmla v10.4s,v0.4s,v12.4s\n\t" + "add %[b_ptr],%[b_ptr],#32; sub %w[k_left],%w[k_left],#4\n\t" + "6:\n\t" + "movi v9.16b,#0; faddp v8.4s,v8.4s,v9.4s; faddp v10.4s,v10.4s,v9.4s\n\t" + "cmp %w[k_left],#2; b.lt 7f\n\t" + "ldr d0,[%[a_ptr]],#8; ldr d4,[%[b_ptr]]; ldr d12,[%[b_ptr],#8]\n\t" + "fmla v8.2s,v0.2s,v4.2s; fmla v10.2s,v0.2s,v12.2s\n\t" + "add %[b_ptr],%[b_ptr],#16; sub %w[k_left],%w[k_left],#2\n\t" + "7:\n\t" + "faddp %s[cs1],v8.2s; faddp %s[cs2],v10.2s\n\t" + "cmp %w[k_left],#1; b.lt 10f\n\t" + "ldr s0,[%[a_ptr]],#4; ldr s4,[%[b_ptr]]; ldr s12,[%[b_ptr],#4]\n\t" + "fmla %s[cs1],s0,v4.s[0]; fmla %s[cs2],s0,v12.s[0]\n\t" + "10:\n\t" + :[cs1]"=w"(cs1), [cs2]"=w"(cs2), + [a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr), [k_left]"+r"(k_left) + ::"cc","memory","v0","v1","v2","v3","v4","v5","v6","v7", + "v8","v9","v10","v11","v12","v13","v14","v15","v16","v17","v18","v19"); + + if (c_rowmajor) { + c_ptr[0] = c_ptr[0] * beta + cs1; + c_ptr[1] = c_ptr[1] * beta + cs2; + } else { + c_ptr[0] = c_ptr[0] * beta + cs1; + c_ptr[LDM] = c_ptr[LDM] * beta + cs2; + } +} + +/* k_mask = 7 */ +static inline void inline_sgemm_arowmajor_bskinny_m4n3(const float *a_ptr1, + const float *b_ptr, float *c_ptr, uint32_t k_inc, uint32_t LDK, + uint32_t LDM, float beta, bool c_rowmajor) { + + const float *a_ptr2 = a_ptr1 + LDK; + const float *a_ptr3 = a_ptr1 + LDK * 2; + const float *a_ptr4 = a_ptr2 + LDK * 2; + uint32_t k_left = k_inc; + uint32_t next_pref = (LDK * 4 >= k_inc) ? + (LDK * 4 - k_inc + 4) * sizeof(float) : 64; + float32x4_t cq1, cq2, cq3; + __asm__ __volatile__( + "movi %[q1].16b,#0; movi %[q2].16b,#0; movi %[q3].16b,#0\n\t" + "movi v10.16b,#0; movi v11.16b,#0; movi v12.16b,#0\n\t" + "movi v13.16b,#0; movi v14.16b,#0; movi v15.16b,#0\n\t" + "movi v16.16b,#0; movi v17.16b,#0; movi v18.16b,#0\n\t" + "cmp %w[k_left],#4; b.lt 4f\n\t" + "ldr q0,[%[a_ptr1]],#16; ldr q1,[%[a_ptr2]],#16\n\t" + "ldr q2,[%[a_ptr3]],#16; ldr q3,[%[a_ptr4]],#16\n\t" + "ldr q4,[%[b_ptr]]; ldr q5,[%[b_ptr],#16]\n\t" + "ldr q6,[%[b_ptr],#32]; add %[b_ptr],%[b_ptr],#48\n\t" + "cmp %w[k_left],#12; b.lt 2f\n\t" + ".balign 16; 1:\n\t" + "fmla %[q1].4s,v0.4s,v4.4s; ldr q7,[%[b_ptr]],#96\n\t" + "fmla %[q2].4s,v0.4s,v5.4s\n\t" + "fmla %[q3].4s,v0.4s,v6.4s; ldr q0,[%[a_ptr1]],#32\n\t" + "fmla v10.4s,v1.4s,v4.4s; ldr q8,[%[b_ptr],#-80]\n\t" + "fmla v11.4s,v1.4s,v5.4s; prfm pldl1keep,[%[a_ptr1],#64]\n\t" + "fmla v12.4s,v1.4s,v6.4s; ldr q1,[%[a_ptr2]],#32\n\t" + "fmla v13.4s,v2.4s,v4.4s; ldr q9,[%[b_ptr],#-64]\n\t" + "fmla v14.4s,v2.4s,v5.4s; prfm pldl1keep,[%[a_ptr2],#64]\n\t" + "fmla v15.4s,v2.4s,v6.4s; ldr q2,[%[a_ptr3]],#32\n\t" + "fmla v16.4s,v3.4s,v4.4s\n\t" + "fmla v17.4s,v3.4s,v5.4s; prfm pldl1keep,[%[a_ptr3],#64]\n\t" + "fmla v18.4s,v3.4s,v6.4s; ldr q3,[%[a_ptr4]],#32\n\t" + "fmla %[q1].4s,v0.4s,v7.4s; ldr q4,[%[b_ptr],#-48]\n\t" + "fmla %[q2].4s,v0.4s,v8.4s; prfm pldl1keep,[%[a_ptr4],#64]\n\t" + "fmla %[q3].4s,v0.4s,v9.4s; ldr q0,[%[a_ptr1],#-16]\n\t" + "fmla v10.4s,v1.4s,v7.4s; ldr q5,[%[b_ptr],#-32]\n\t" + "fmla v11.4s,v1.4s,v8.4s\n\t" + "fmla v12.4s,v1.4s,v9.4s; ldr q1,[%[a_ptr2],#-16]\n\t" + "fmla v13.4s,v2.4s,v7.4s; ldr q6,[%[b_ptr],#-16]\n\t" + "fmla v14.4s,v2.4s,v8.4s; sub %w[k_left],%w[k_left],#8\n\t" + "fmla v15.4s,v2.4s,v9.4s; ldr q2,[%[a_ptr3],#-16]\n\t" + "fmla v16.4s,v3.4s,v7.4s; cmp %w[k_left],#12\n\t" + "fmla v17.4s,v3.4s,v8.4s\n\t" + "fmla v18.4s,v3.4s,v9.4s; ldr q3,[%[a_ptr4],#-16]; b.ge 1b\n\t" + "2:\n\t" + "cmp %w[k_left],#8; b.lt 3f\n\t" + "fmla %[q1].4s,v0.4s,v4.4s; ldr q7,[%[b_ptr]],#48\n\t" + "fmla %[q2].4s,v0.4s,v5.4s\n\t" + "fmla %[q3].4s,v0.4s,v6.4s; ldr q0,[%[a_ptr1]],#16\n\t" + "fmla v10.4s,v1.4s,v4.4s; ldr q8,[%[b_ptr],#-32]\n\t" + "fmla v11.4s,v1.4s,v5.4s\n\t" + "prfm pldl1keep,[%[a_ptr1],%w[next_pref],SXTW #0]\n\t" + "fmla v12.4s,v1.4s,v6.4s; ldr q1,[%[a_ptr2]],#16\n\t" + "fmla v13.4s,v2.4s,v4.4s; ldr q9,[%[b_ptr],#-16]\n\t" + "fmla v14.4s,v2.4s,v5.4s\n\t" + "prfm pldl1keep,[%[a_ptr2],%w[next_pref],SXTW #0]\n\t" + "fmla v15.4s,v2.4s,v6.4s; ldr q2,[%[a_ptr3]],#16\n\t" + "fmla v16.4s,v3.4s,v4.4s\n\t" + "fmla v17.4s,v3.4s,v5.4s\n\t" + "prfm pldl1keep,[%[a_ptr3],%w[next_pref],SXTW #0]\n\t" + "fmla v18.4s,v3.4s,v6.4s; ldr q3,[%[a_ptr4]],#16\n\t" + "fmla %[q1].4s,v0.4s,v7.4s\n\t" + "fmla %[q2].4s,v0.4s,v8.4s\n\t" + "prfm pldl1keep,[%[a_ptr4],%w[next_pref],SXTW #0]\n\t" + "fmla %[q3].4s,v0.4s,v9.4s\n\t" + "fmla v10.4s,v1.4s,v7.4s\n\t" + "fmla v11.4s,v1.4s,v8.4s\n\t" + "fmla v12.4s,v1.4s,v9.4s\n\t" + "fmla v13.4s,v2.4s,v7.4s\n\t" + "fmla v14.4s,v2.4s,v8.4s; sub %w[k_left],%w[k_left],#8\n\t" + "fmla v15.4s,v2.4s,v9.4s\n\t" + "fmla v16.4s,v3.4s,v7.4s\n\t" + "fmla v17.4s,v3.4s,v8.4s\n\t" + "fmla v18.4s,v3.4s,v9.4s; b 4f\n\t" + "3:\n\t" + "fmla %[q1].4s,v0.4s,v4.4s\n\t" + "fmla %[q2].4s,v0.4s,v5.4s\n\t" + "prfm pldl1keep,[%[a_ptr1],%w[next_pref],SXTW #0]\n\t" + "fmla %[q3].4s,v0.4s,v6.4s\n\t" + "fmla v10.4s,v1.4s,v4.4s\n\t" + "prfm pldl1keep,[%[a_ptr2],%w[next_pref],SXTW #0]\n\t" + "fmla v11.4s,v1.4s,v5.4s\n\t" + "fmla v12.4s,v1.4s,v6.4s\n\t" + "prfm pldl1keep,[%[a_ptr3],%w[next_pref],SXTW #0]\n\t" + "fmla v13.4s,v2.4s,v4.4s\n\t" + "fmla v14.4s,v2.4s,v5.4s; sub %w[k_left],%w[k_left],#4\n\t" + "prfm pldl1keep,[%[a_ptr4],%w[next_pref],SXTW #0]\n\t" + "fmla v15.4s,v2.4s,v6.4s\n\t" + "fmla v16.4s,v3.4s,v4.4s\n\t" + "fmla v17.4s,v3.4s,v5.4s\n\t" + "fmla v18.4s,v3.4s,v6.4s\n\t" + "4:\n\t" + "faddp %[q1].4s,%[q1].4s,v10.4s; faddp v13.4s,v13.4s,v16.4s\n\t" + "faddp %[q2].4s,%[q2].4s,v11.4s; faddp v14.4s,v14.4s,v17.4s\n\t" + "faddp %[q3].4s,%[q3].4s,v12.4s; faddp v15.4s,v15.4s,v18.4s\n\t" + "cmp %w[k_left],#2; b.lt 5f\n\t" + "ldr d0,[%[a_ptr1]],#8; ldr d1,[%[a_ptr2]],#8\n\t" + "ldr d2,[%[a_ptr3]],#8; ldr d3,[%[a_ptr4]],#8\n\t" + "ld1r {v4.2d},[%[b_ptr]],#8; ins v0.d[1],v1.d[0]\n\t" + "ld1r {v5.2d},[%[b_ptr]],#8; ins v2.d[1],v3.d[0]\n\t" + "ld1r {v6.2d},[%[b_ptr]],#8; sub %w[k_left],%w[k_left],#2\n\t" + "fmla %[q1].4s,v0.4s,v4.4s\n\t" + "fmla %[q2].4s,v0.4s,v5.4s\n\t" + "fmla %[q3].4s,v0.4s,v6.4s\n\t" + "fmla v13.4s,v2.4s,v4.4s\n\t" + "fmla v14.4s,v2.4s,v5.4s\n\t" + "fmla v15.4s,v2.4s,v6.4s\n\t" + "5:\n\t" + "faddp %[q1].4s,%[q1].4s,v13.4s\n\t" + "faddp %[q2].4s,%[q2].4s,v14.4s\n\t" + "faddp %[q3].4s,%[q3].4s,v15.4s\n\t" + "cmp %w[k_left],#1; b.lt 6f\n\t" + "ldr s0,[%[a_ptr1]],#4; ldr s1,[%[a_ptr2]],#4\n\t" + "ldr s2,[%[a_ptr3]],#4; ldr s3,[%[a_ptr4]],#4\n\t" + "ldr s4,[%[b_ptr]],#4; ins v0.s[1],v1.s[0]\n\t" + "ldr s5,[%[b_ptr]],#4; ins v2.s[1],v3.s[0]\n\t" + "ldr s6,[%[b_ptr]],#4; ins v0.d[1],v2.d[0]\n\t" + "sub %w[k_left],%w[k_left],#1\n\t" + "fmla %[q1].4s,v0.4s,v4.s[0]\n\t" + "fmla %[q2].4s,v0.4s,v5.s[0]\n\t" + "fmla %[q3].4s,v0.4s,v6.s[0]\n\t" + "6:\n\t" + :[q1]"=w"(cq1), [q2]"=w"(cq2), [q3]"=w"(cq3), [k_left]"+r"(k_left), + [a_ptr1]"+r"(a_ptr1), [a_ptr2]"+r"(a_ptr2), [a_ptr3]"+r"(a_ptr3), + [a_ptr4]"+r"(a_ptr4), [b_ptr]"+r"(b_ptr), [next_pref]"+r"(next_pref) + ::"cc","memory","v0","v1","v2","v3","v4","v5","v6","v7","v8","v9", + "v10","v11","v12","v13","v14","v15","v16","v17","v18"); + + if (c_rowmajor) { + float32x4x3_t cqt1 = vld3q_f32(c_ptr); + cqt1.val[0] = vfmaq_n_f32(cq1, cqt1.val[0], beta); + cqt1.val[1] = vfmaq_n_f32(cq2, cqt1.val[1], beta); + cqt1.val[2] = vfmaq_n_f32(cq3, cqt1.val[2], beta); + vst3q_f32(c_ptr, cqt1); + } else { + cq1 = vfmaq_n_f32(cq1, vld1q_f32(c_ptr), beta); + cq2 = vfmaq_n_f32(cq2, vld1q_f32(c_ptr + LDM), beta); + cq3 = vfmaq_n_f32(cq3, vld1q_f32(c_ptr + LDM * 2), beta); + + vst1q_f32(c_ptr, cq1); c_ptr += LDM; + vst1q_f32(c_ptr, cq2); c_ptr += LDM; + vst1q_f32(c_ptr, cq3); + } +} + +static inline void inline_sgemm_arowmajor_bskinny_m1n3(const float *a_ptr, + const float *b_scr, float *c_ptr, uint32_t k_inc, uint32_t LDK, + uint32_t LDM, float beta, bool c_rowmajor) { + + const float *sb_ptr = b_scr; + + float32x4_t cq01, cq02, cq03, cq04, cq05, cq06; + cq01 = cq02 = cq03 = cq04 = cq05 = cq06 = vdupq_n_f32(0.0f); + float32x4_t cq07, cq08, cq09, cq10, cq11, cq12; + cq07 = cq08 = cq09 = cq10 = cq11 = cq12 = vdupq_n_f32(0.0f); + float32x4_t aq1, aq2, bq01, bq02, bq03, bq04, bq05, bq06; + float32x4_t aq3, aq4, bq07, bq08, bq09, bq10, bq11, bq12; + uint32_t k_left = k_inc; + if (k_left > 7) { + aq1 = vld1q_f32(a_ptr); aq2 = vld1q_f32(a_ptr + 4); a_ptr += 8; + bq01 = vld1q_f32(sb_ptr); bq02 = vld1q_f32(sb_ptr + 4); + bq03 = vld1q_f32(sb_ptr + 8); bq04 = vld1q_f32(sb_ptr + 12); + bq05 = vld1q_f32(sb_ptr + 16); bq06 = vld1q_f32(sb_ptr + 20); + sb_ptr += 24; + } + for (; k_left > 23; k_left -= 16) { + aq3 = vld1q_f32(a_ptr); + cq01 = vfmaq_f32(cq01, aq1, bq01); bq07 = vld1q_f32(sb_ptr); + cq02 = vfmaq_f32(cq02, aq1, bq02); bq08 = vld1q_f32(sb_ptr + 4); + cq03 = vfmaq_f32(cq03, aq1, bq03); bq09 = vld1q_f32(sb_ptr + 8); + aq4 = vld1q_f32(a_ptr + 4); + cq04 = vfmaq_f32(cq04, aq2, bq04); bq10 = vld1q_f32(sb_ptr + 12); + cq05 = vfmaq_f32(cq05, aq2, bq05); bq11 = vld1q_f32(sb_ptr + 16); + cq06 = vfmaq_f32(cq06, aq2, bq06); bq12 = vld1q_f32(sb_ptr + 20); + aq1 = vld1q_f32(a_ptr + 8); + cq07 = vfmaq_f32(cq07, aq3, bq07); bq01 = vld1q_f32(sb_ptr + 24); + cq08 = vfmaq_f32(cq08, aq3, bq08); bq02 = vld1q_f32(sb_ptr + 28); + cq09 = vfmaq_f32(cq09, aq3, bq09); bq03 = vld1q_f32(sb_ptr + 32); + aq2 = vld1q_f32(a_ptr + 12); a_ptr += 16; + cq10 = vfmaq_f32(cq10, aq4, bq10); bq04 = vld1q_f32(sb_ptr + 36); + cq11 = vfmaq_f32(cq11, aq4, bq11); bq05 = vld1q_f32(sb_ptr + 40); + cq12 = vfmaq_f32(cq12, aq4, bq12); bq06 = vld1q_f32(sb_ptr + 44); + sb_ptr += 48; + } + if (k_left > 15) { + aq3 = vld1q_f32(a_ptr); + cq01 = vfmaq_f32(cq01, aq1, bq01); bq07 = vld1q_f32(sb_ptr); + cq02 = vfmaq_f32(cq02, aq1, bq02); bq08 = vld1q_f32(sb_ptr + 4); + cq03 = vfmaq_f32(cq03, aq1, bq03); bq09 = vld1q_f32(sb_ptr + 8); + aq4 = vld1q_f32(a_ptr + 4); a_ptr += 8; + cq04 = vfmaq_f32(cq04, aq2, bq04); bq10 = vld1q_f32(sb_ptr + 12); + cq05 = vfmaq_f32(cq05, aq2, bq05); bq11 = vld1q_f32(sb_ptr + 16); + cq06 = vfmaq_f32(cq06, aq2, bq06); bq12 = vld1q_f32(sb_ptr + 20); + cq07 = vfmaq_f32(cq07, aq3, bq07); sb_ptr += 24; + cq08 = vfmaq_f32(cq08, aq3, bq08); k_left -= 16; + cq09 = vfmaq_f32(cq09, aq3, bq09); + cq10 = vfmaq_f32(cq10, aq4, bq10); + cq11 = vfmaq_f32(cq11, aq4, bq11); + cq12 = vfmaq_f32(cq12, aq4, bq12); + } + if (k_left > 7) { + cq01 = vfmaq_f32(cq01, aq1, bq01); k_left -= 8; + cq02 = vfmaq_f32(cq02, aq1, bq02); + cq03 = vfmaq_f32(cq03, aq1, bq03); + cq04 = vfmaq_f32(cq04, aq2, bq04); + cq05 = vfmaq_f32(cq05, aq2, bq05); + cq06 = vfmaq_f32(cq06, aq2, bq06); + } + cq01 = vaddq_f32(cq01, cq07); cq02 = vaddq_f32(cq02, cq08); + cq03 = vaddq_f32(cq03, cq09); cq04 = vaddq_f32(cq04, cq10); + cq05 = vaddq_f32(cq05, cq11); cq06 = vaddq_f32(cq06, cq12); + cq01 = vaddq_f32(cq01, cq04); cq02 = vaddq_f32(cq02, cq05); + cq03 = vaddq_f32(cq03, cq06); + + if (k_left > 3) { + aq1 = vld1q_f32(a_ptr); a_ptr += 4; + bq01 = vld1q_f32(sb_ptr); bq02 = vld1q_f32(sb_ptr + 4); + bq03 = vld1q_f32(sb_ptr + 8); sb_ptr += 12; + cq01 = vfmaq_f32(cq01, aq1, bq01); k_left -= 4; + cq02 = vfmaq_f32(cq02, aq1, bq02); + cq03 = vfmaq_f32(cq03, aq1, bq03); + } + float32x2_t cd1 = vadd_f32(vget_low_f32(cq01), vget_high_f32(cq01)); + float32x2_t cd2 = vadd_f32(vget_low_f32(cq02), vget_high_f32(cq02)); + float32x2_t cd3 = vadd_f32(vget_low_f32(cq03), vget_high_f32(cq03)); + if (k_left > 1) { + float32x2_t ad1 = vld1_f32(a_ptr); a_ptr += 2; + float32x2_t bd1 = vld1_f32(sb_ptr); + float32x2_t bd2 = vld1_f32(sb_ptr + 2); + float32x2_t bd3 = vld1_f32(sb_ptr + 4); sb_ptr += 6; + cd1 = vfma_f32(cd1, ad1, bd1); k_left -= 2; + cd2 = vfma_f32(cd2, ad1, bd2); + cd3 = vfma_f32(cd3, ad1, bd3); + } + float cs1 = vget_lane_f32(cd1, 0) + vget_lane_f32(cd1, 1); + float cs2 = vget_lane_f32(cd2, 0) + vget_lane_f32(cd2, 1); + float cs3 = vget_lane_f32(cd3, 0) + vget_lane_f32(cd3, 1); + if (k_left > 0) { + float as1 = *a_ptr++; + cs1 += as1 * sb_ptr[0]; + cs2 += as1 * sb_ptr[1]; + cs3 += as1 * sb_ptr[2]; + } + + if (c_rowmajor) { + c_ptr[0] = c_ptr[0] * beta + cs1; + c_ptr[1] = c_ptr[1] * beta + cs2; + c_ptr[2] = c_ptr[2] * beta + cs3; + } else { + c_ptr[0] = c_ptr[0] * beta + cs1; + c_ptr[LDM] = c_ptr[LDM] * beta + cs2; + c_ptr[LDM * 2] = c_ptr[LDM * 2] * beta + cs3; + } +} + +#define DEFAULT_SGEMV1_THRESH_K_UNROLL_M 512 +#define DEFAULT_SGEMV1_THRESH_DETECT_CPU 30000 + +static inline bool unroll_test_m4n1(uint32_t M, uint32_t K) { + unsigned char cpu_type = 0, cpu_id = 0; + uint32_t gemv1_thresh_k_unroll_m = DEFAULT_SGEMV1_THRESH_K_UNROLL_M; + if ((uint64_t)M * (uint64_t)K > DEFAULT_SGEMV1_THRESH_DETECT_CPU) { + cpu_id = sched_getcpu(); + cpu_type = blas_arm_get_cpu_type(cpu_id); + /* Based on a number of BLAS tests, + * unrolling M on Cortex-A55 degrades performance in all cases */ + /* Unrolling M on other ARM cores can improve performance when K is small */ + gemv1_thresh_k_unroll_m = cpu_type == 55 ? + 0 : DEFAULT_SGEMV1_THRESH_K_UNROLL_M; + } + return K <= gemv1_thresh_k_unroll_m; +} + +static inline bool unroll_test_m1n1(uint32_t M, uint32_t K) { + return true; +} + +static inline bool unroll_test_m4n2(uint32_t M, uint32_t K) { + return unroll_test_m4n1(M, K); +} + +static inline bool unroll_test_m1n2(uint32_t M, uint32_t K) { + return true; +} + +static inline bool unroll_test_m4n3(uint32_t M, uint32_t K) { + unsigned char cpu_type = 0, cpu_id = 0; + if ((uint64_t)M * (uint64_t)K > DEFAULT_SGEMV1_THRESH_DETECT_CPU) { + cpu_id = sched_getcpu(); + cpu_type = blas_arm_get_cpu_type(cpu_id); + if (cpu_type == 53 || cpu_type == 35) { + return true; + } + return false; + } + return false; +} + +static inline bool unroll_test_m1n3(uint32_t M, uint32_t K) { + return true; +} + +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(sgemm, 1, 7, 5, 32768, float, float, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(sgemm, 2, 7, 5, 32768, float, float, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(sgemm, 3, 7, 5, 32768, float, float, unroll_test) + +#define SGEMM_SKINNY1_FUNC_TEMPLATE(ndim) \ +void sgemm_arowmajor_bskinny_afloat_bfloat_n##ndim(\ + const float *A, const float *B, float *C,\ + uint32_t M, uint32_t K, uint8_t b_c_order, float beta) {\ +\ + unsigned char cpu_type = 0;\ + if ((uint64_t)M * (uint64_t)K * ndim > \ + DEFAULT_SGEMV1_THRESH_DETECT_CPU << 3) {\ + unsigned char cpu_id = sched_getcpu();\ + cpu_type = blas_arm_get_cpu_type(cpu_id);\ + }\ +\ + const uint32_t LDB = (b_c_order & 1) ? ndim : K;\ + const uint32_t LDC = (b_c_order & 2) ? ndim : M;\ + if (cpu_type == 35) {\ + sgemm_skinny1_arowmajor_n##ndim##_a35(A, B, C, M, K, K, LDB, LDC,\ + b_c_order, beta);\ + } else if (cpu_type == 53 || cpu_type == 55) {\ + sgemm_skinny1_arowmajor_n##ndim##_a53(A, B, C, M, K, K, LDB, LDC,\ + b_c_order, beta);\ + } else {\ + sgemm_skinny1_arowmajor_n##ndim##_a7x(A, B, C, M, K, K, LDB, LDC,\ + b_c_order, beta);\ + }\ +}\ +\ +void sgemm_arowmajor_bskinny_afloat_bfloat_n##ndim##_omp(const float *A,\ + const float *B, float *C,\ + uint32_t M, uint32_t K, uint8_t b_c_order,\ + float beta, uint32_t num_threads) {\ +\ + const uint32_t LDC = (b_c_order & 2) ? ndim : M;\ + if (num_threads <= 1) {\ + sgemm_arowmajor_bskinny_afloat_bfloat_n##ndim(A, B, C, M, K,\ + b_c_order, beta);\ + return;\ + }\ +\ + unsigned char cpu_type = 0;\ + if ((uint64_t)M * (uint64_t)K * ndim > \ + DEFAULT_SGEMV1_THRESH_DETECT_CPU << 3) {\ + unsigned char cpu_id = sched_getcpu();\ + cpu_type = blas_arm_get_cpu_type(cpu_id);\ + }\ +\ + const uint32_t LDB = (b_c_order & 1) ? ndim : K;\ + if (cpu_type == 35) {\ + sgemm_skinny1_arowmajor_n##ndim##_a35_omp(A, B, C, M, K, K,\ + LDB, LDC, b_c_order, beta, num_threads);\ + } else if (cpu_type == 53 || cpu_type == 55) {\ + sgemm_skinny1_arowmajor_n##ndim##_a53_omp(A, B, C, M, K, K,\ + LDB, LDC, b_c_order, beta, num_threads);\ + } else {\ + sgemm_skinny1_arowmajor_n##ndim##_a7x_omp(A, B, C, M, K, K,\ + LDB, LDC, b_c_order, beta, num_threads);\ + }\ +} + +SGEMM_SKINNY1_FUNC_TEMPLATE(4) +SGEMM_SKINNY1_FUNC_TEMPLATE(5) +SGEMM_SKINNY1_FUNC_TEMPLATE(6) +SGEMM_SKINNY1_FUNC_TEMPLATE(7) +SGEMM_SKINNY1_FUNC_TEMPLATE(8) +SGEMM_SKINNY1_FUNC_TEMPLATE(9) +SGEMM_SKINNY1_FUNC_TEMPLATE(10) +SGEMM_SKINNY1_FUNC_TEMPLATE(11) +SGEMM_SKINNY1_FUNC_TEMPLATE(12) +SGEMM_SKINNY1_FUNC_TEMPLATE(13) +SGEMM_SKINNY1_FUNC_TEMPLATE(14) +SGEMM_SKINNY1_FUNC_TEMPLATE(15) +SGEMM_SKINNY1_FUNC_TEMPLATE(16) +SGEMM_SKINNY1_FUNC_TEMPLATE(17) +SGEMM_SKINNY1_FUNC_TEMPLATE(18) +SGEMM_SKINNY1_FUNC_TEMPLATE(19) +SGEMM_SKINNY1_FUNC_TEMPLATE(20) +SGEMM_SKINNY1_FUNC_TEMPLATE(21) +SGEMM_SKINNY1_FUNC_TEMPLATE(22) +SGEMM_SKINNY1_FUNC_TEMPLATE(23) +SGEMM_SKINNY1_FUNC_TEMPLATE(24) +SGEMM_SKINNY1_FUNC_TEMPLATE(25) +SGEMM_SKINNY1_FUNC_TEMPLATE(26) +SGEMM_SKINNY1_FUNC_TEMPLATE(27) +SGEMM_SKINNY1_FUNC_TEMPLATE(28) +SGEMM_SKINNY1_FUNC_TEMPLATE(29) +SGEMM_SKINNY1_FUNC_TEMPLATE(30) +SGEMM_SKINNY1_FUNC_TEMPLATE(31) +SGEMM_SKINNY1_FUNC_TEMPLATE(32) +SGEMM_SKINNY1_FUNC_TEMPLATE(33) +SGEMM_SKINNY1_FUNC_TEMPLATE(34) +SGEMM_SKINNY1_FUNC_TEMPLATE(35) +SGEMM_SKINNY1_FUNC_TEMPLATE(36) +SGEMM_SKINNY1_FUNC_TEMPLATE(37) +SGEMM_SKINNY1_FUNC_TEMPLATE(38) +SGEMM_SKINNY1_FUNC_TEMPLATE(39) +SGEMM_SKINNY1_FUNC_TEMPLATE(40) +SGEMM_SKINNY1_FUNC_TEMPLATE(41) +SGEMM_SKINNY1_FUNC_TEMPLATE(42) +SGEMM_SKINNY1_FUNC_TEMPLATE(43) +SGEMM_SKINNY1_FUNC_TEMPLATE(44) +SGEMM_SKINNY1_FUNC_TEMPLATE(45) +SGEMM_SKINNY1_FUNC_TEMPLATE(46) +SGEMM_SKINNY1_FUNC_TEMPLATE(47) +SGEMM_SKINNY1_FUNC_TEMPLATE(48) +SGEMM_SKINNY1_FUNC_TEMPLATE(49) +SGEMM_SKINNY1_FUNC_TEMPLATE(50) + diff --git a/src/neon_armv8a/SgemmSkinnyGer.c b/src/neon_armv8a/SgemmSkinnyGer.c new file mode 100644 index 0000000..0d64809 --- /dev/null +++ b/src/neon_armv8a/SgemmSkinnyGer.c @@ -0,0 +1,283 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonSkinnyGer.h" + +#include + +typedef float sgemm_skinnyger_ascalar; +typedef float sgemm_skinnyger_bscalar; +typedef float sgemm_skinnyger_cscalar; + +typedef float sgemm_skinnyger_avec1; +typedef float sgemm_skinnyger_bvec1; +typedef float sgemm_skinnyger_cvec1; + +typedef float32x2_t sgemm_skinnyger_avec2; +typedef float32x2_t sgemm_skinnyger_bvec2; +typedef float32x2_t sgemm_skinnyger_cvec2; + +typedef float32x4_t sgemm_skinnyger_avec4; +typedef float32x4_t sgemm_skinnyger_bvec4; +typedef float32x4_t sgemm_skinnyger_cvec4; + +typedef float32x4x2_t sgemm_skinnyger_avec8; +typedef float32x4x2_t sgemm_skinnyger_bvec8; +typedef float32x4x2_t sgemm_skinnyger_cvec8; + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 1) { + float32x4x2_t ret; + ret.val[0] = vfmaq_laneq_f32(c_vec.val[0], a_vec.val[0], b_vec, 0); + ret.val[1] = vfmaq_laneq_f32(c_vec.val[1], a_vec.val[1], b_vec, 0); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 2) { + float32x4x2_t ret; + ret.val[0] = vfmaq_laneq_f32(c_vec.val[0], a_vec.val[0], b_vec, 1); + ret.val[1] = vfmaq_laneq_f32(c_vec.val[1], a_vec.val[1], b_vec, 1); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 3) { + float32x4x2_t ret; + ret.val[0] = vfmaq_laneq_f32(c_vec.val[0], a_vec.val[0], b_vec, 2); + ret.val[1] = vfmaq_laneq_f32(c_vec.val[1], a_vec.val[1], b_vec, 2); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 4, 4) { + float32x4x2_t ret; + ret.val[0] = vfmaq_laneq_f32(c_vec.val[0], a_vec.val[0], b_vec, 3); + ret.val[1] = vfmaq_laneq_f32(c_vec.val[1], a_vec.val[1], b_vec, 3); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 2, 1) { + float32x4x2_t ret; + ret.val[0] = vfmaq_lane_f32(c_vec.val[0], a_vec.val[0], b_vec, 0); + ret.val[1] = vfmaq_lane_f32(c_vec.val[1], a_vec.val[1], b_vec, 0); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 2, 2) { + float32x4x2_t ret; + ret.val[0] = vfmaq_lane_f32(c_vec.val[0], a_vec.val[0], b_vec, 1); + ret.val[1] = vfmaq_lane_f32(c_vec.val[1], a_vec.val[1], b_vec, 1); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 8, 1, 1) { + float32x4x2_t ret; + ret.val[0] = vfmaq_n_f32(c_vec.val[0], a_vec.val[0], b_vec); + ret.val[1] = vfmaq_n_f32(c_vec.val[1], a_vec.val[1], b_vec); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 4, 1) { + return vfmaq_laneq_f32(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 4, 2) { + return vfmaq_laneq_f32(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 4, 3) { + return vfmaq_laneq_f32(c_vec, a_vec, b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 4, 4) { + return vfmaq_laneq_f32(c_vec, a_vec, b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 2, 1) { + return vfmaq_lane_f32(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 2, 2) { + return vfmaq_lane_f32(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 4, 1, 1) { + return vfmaq_n_f32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 4, 1) { + return vfma_laneq_f32(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 4, 2) { + return vfma_laneq_f32(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 4, 3) { + return vfma_laneq_f32(c_vec, a_vec, b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 4, 4) { + return vfma_laneq_f32(c_vec, a_vec, b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 2, 1) { + return vfma_lane_f32(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 2, 2) { + return vfma_lane_f32(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 2, 1, 1) { + return vfma_n_f32(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 4, 1) { + return vfmas_laneq_f32(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 4, 2) { + return vfmas_laneq_f32(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 4, 3) { + return vfmas_laneq_f32(c_vec, a_vec, b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 4, 4) { + return vfmas_laneq_f32(c_vec, a_vec, b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 2, 1) { + return vfmas_lane_f32(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 2, 2) { + return vfmas_lane_f32(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(sgemm, 1, 1, 1) { + return a_vec * b_vec + c_vec; +} + +GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 8) { + float32x4x2_t ret; + ret.val[0] = vld1q_f32(a_ptr); + ret.val[1] = vld1q_f32(a_ptr + 4); + __asm__("prfm pldl1keep,[%0,#96]"::"r"(a_ptr):); + return ret; +} + +GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 4) { + __asm__("prfm pldl1keep,[%0,#80]"::"r"(a_ptr):); + return vld1q_f32(a_ptr); +} + +GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 2) { + return vld1_f32(a_ptr); +} + +GEMM_SKINNY_GER_LOADA_UNIT(sgemm, 1) { + return *a_ptr; +} + +GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 8) { + float32x4x2_t ret; + ret.val[0] = vld1q_f32(c_ptr); + ret.val[1] = vld1q_f32(c_ptr + 4); + return ret; +} + +GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 4) { + return vld1q_f32(c_ptr); +} + +GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 2) { + return vld1_f32(c_ptr); +} + +GEMM_SKINNY_GER_LOADC_UNIT(sgemm, 1) { + return *c_ptr; +} + +GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 8) { + vst1q_f32(c_ptr, c_vec.val[0]); + vst1q_f32(c_ptr + 4, c_vec.val[1]); +} + +GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 4) { + vst1q_f32(c_ptr, c_vec); +} + +GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 2) { + vst1_f32(c_ptr, c_vec); +} + +GEMM_SKINNY_GER_STOREC_UNIT(sgemm, 1) { + *c_ptr = c_vec; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(sgemm, 4) { + float32x4_t ret = vdupq_n_f32(0); + float b1 = *b_ptr; b_ptr += ldb; + float b2 = *b_ptr; b_ptr += ldb; + float b3 = *b_ptr; b_ptr += ldb; + float b4 = *b_ptr; + ret = vsetq_lane_f32(b1, ret, 0); + ret = vsetq_lane_f32(b2, ret, 1); + ret = vsetq_lane_f32(b3, ret, 2); + ret = vsetq_lane_f32(b4, ret, 3); + return ret; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(sgemm, 2) { + float32x2_t ret = vdup_n_f32(0); + float b1 = *b_ptr; + float b2 = b_ptr[ldb]; + ret = vset_lane_f32(b1, ret, 0); + ret = vset_lane_f32(b2, ret, 1); + return ret; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(sgemm, 1) { + return *b_ptr; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(sgemm, 4) { + return vld1q_f32(b_ptr); +} + +GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(sgemm, 2) { + return vld1_f32(b_ptr); +} + +GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(sgemm, 1) { + return *b_ptr; +} + +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 1, 7, 15, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 2, 7, 15, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 3, 7, 15, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 4, 7, 15, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 5, 7, 15, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 6, 7, 15, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 7, 7, 15, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 8, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 9, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 10, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 11, 7, 7, 8192, float, float) +GEMM_SKINNY_GER_PARALLEL_FUNC(sgemm, 12, 7, 7, 8192, float, float) + diff --git a/src/neon_armv8a/U8U32DotGemmDriver.c b/src/neon_armv8a/U8U32DotGemmDriver.c new file mode 100644 index 0000000..b08fd5c --- /dev/null +++ b/src/neon_armv8a/U8U32DotGemmDriver.c @@ -0,0 +1,36 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/U8U32DotGemmCopy.h" +#include "neon_armv8a/U8U32DotGemmKernel.h" +#include "neon_armv8a/U8U32DotGemmSkinnyDot.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +#ifdef SCRATCH_K_CORD +#undef SCRATCH_K_CORD +#define SCRATCH_K_CORD(k) ((k) >> 2) +#endif + +#ifdef GEMM_D_K +#undef GEMM_D_K +#define GEMM_D_K 768 +#endif + +GEMM_PARALLEL_FUNC(u8u32dotgemm, uint8_t, uint32_t, uint8_t, uint32_t, uint32_t, + 8, 12, 12, 12, 0, 0) + diff --git a/src/neon_armv8a/U8U32GemmDriver.c b/src/neon_armv8a/U8U32GemmDriver.c new file mode 100644 index 0000000..88db447 --- /dev/null +++ b/src/neon_armv8a/U8U32GemmDriver.c @@ -0,0 +1,48 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/U8U32MlaGemmDriver.h" +#include "neon_armv8a/U8U32DotGemmDriver.h" +#include "arm_neon/ARMCpuType.h" + +int u8u32gemm_serial(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, uint32_t beta_inp) { + + if (blas_arm_get_i8i32_support() == 2) { + return u8u32dotgemm_serial(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp); + } else { + return u8u32mlagemm_serial(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp); + } +} + +int u8u32gemm(int a_rowmajor, int b_rowmajor, + const uint8_t *A, const uint8_t *B, + uint32_t *C, uint32_t M, uint32_t N, uint32_t K, + uint32_t beta_inp, uint32_t num_threads) { + + if (blas_arm_get_i8i32_support() == 2) { + return u8u32dotgemm(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp, num_threads); + } else { + return u8u32mlagemm(a_rowmajor, b_rowmajor, A, B, C, + M, N, K, beta_inp, num_threads); + } +} + diff --git a/src/neon_armv8a/U8U32MlaGemmCopy.c b/src/neon_armv8a/U8U32MlaGemmCopy.c new file mode 100644 index 0000000..a27d2c0 --- /dev/null +++ b/src/neon_armv8a/U8U32MlaGemmCopy.c @@ -0,0 +1,30 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "common/CommonCopy.h" +#include "arm_neon/NeonI8I32MlaGemmCopy.h" + +GENERIC_NCOPY_FUNC(u8u32mlagemm, uint8_t, uint16_t, 8) +GENERIC_NCOPY_FUNC(u8u32mlagemm, uint8_t, uint16_t, 12) + +GENERIC_TCOPY_FUNC(u8u32mlagemm, uint8_t, uint16_t, 8) +GENERIC_TCOPY_FUNC(u8u32mlagemm, uint8_t, uint16_t, 12) + diff --git a/src/neon_armv8a/U8U32MlaGemmDriver.c b/src/neon_armv8a/U8U32MlaGemmDriver.c new file mode 100644 index 0000000..f2a562c --- /dev/null +++ b/src/neon_armv8a/U8U32MlaGemmDriver.c @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/U8U32MlaGemmCopy.h" +#include "neon_armv8a/U8U32MlaGemmKernel.h" +#include "neon_armv8a/U8U32MlaGemmSkinnyGer.h" +#include "neon_armv8a/U8U32MlaGemmSkinnyDot.h" +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonDriver.h" + +GEMM_PARALLEL_FUNC(u8u32mlagemm, uint8_t, uint16_t, uint8_t, uint16_t, uint32_t, + 8, 12, 8, 8, 8, 8) + diff --git a/src/neon_armv8a/U8U32MlaGemmKernel.c b/src/neon_armv8a/U8U32MlaGemmKernel.c new file mode 100644 index 0000000..b9f3cd4 --- /dev/null +++ b/src/neon_armv8a/U8U32MlaGemmKernel.c @@ -0,0 +1,27 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "common/CommonKernel.h" +#include "neon_armv8a/I8I32MlaGemmKernel.h" + +DUALPACK_KERNEL_FUNC_LM(u8u32mlagemm, uint16_t, uint16_t, uint32_t, 8, 12) +DUALPACK_KERNEL_FUNC_LN(u8u32mlagemm, uint16_t, uint16_t, uint32_t, 12, 8) + diff --git a/src/neon_armv8a/U8U32MlaGemmSkinnyDot.c b/src/neon_armv8a/U8U32MlaGemmSkinnyDot.c new file mode 100644 index 0000000..4cb495e --- /dev/null +++ b/src/neon_armv8a/U8U32MlaGemmSkinnyDot.c @@ -0,0 +1,34 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "neon_armv8a/I8I32MlaGemmSkinnyDot.h" +#include "common/CommonSkinnyDot.h" + +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(u8u32mlagemm, 1, 31, 5, 131072, uint8_t, uint8_t, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(u8u32mlagemm, 2, 31, 5, 131072, uint8_t, uint8_t, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(u8u32mlagemm, 3, 31, 5, 131072, uint8_t, uint8_t, unroll_test) + +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 4, 15, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 5, 15, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 6, 15, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 7, 15, 3, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32mlagemm, 8, 15, 3, 131072, uint8_t, uint8_t) \ No newline at end of file diff --git a/src/neon_armv8a/U8U32MlaGemmSkinnyGer.c b/src/neon_armv8a/U8U32MlaGemmSkinnyGer.c new file mode 100644 index 0000000..fc9948c --- /dev/null +++ b/src/neon_armv8a/U8U32MlaGemmSkinnyGer.c @@ -0,0 +1,32 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "arm_neon/NeonI8I32MlaGemmSkinnyGer.h" + +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 1, 5, 29, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 2, 5, 29, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 3, 5, 29, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 4, 5, 29, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 5, 5, 13, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 6, 5, 13, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 7, 5, 13, 8192, uint8_t, uint8_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(u8u32mlagemm, 8, 5, 13, 8192, uint8_t, uint8_t) diff --git a/src/neon_armv8a/extension/HgemmCopy.c b/src/neon_armv8a/extension/HgemmCopy.c new file mode 100644 index 0000000..d31fd58 --- /dev/null +++ b/src/neon_armv8a/extension/HgemmCopy.c @@ -0,0 +1,111 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "common/CommonCopy.h" +#include + +static inline void pref_ab(const float16_t *dat) { + __asm__ ("prfm pldl1keep,[%0,#64]\n\t"::"r"(dat):); +} + +#define NCOPY_NEON_LOOP_K8_UNROLL4(inc, dst_ptr, src1, src2, src3, src4) \ + for (dim1_count = dim1_cache; dim1_count > 7; dim1_count -= 8) {\ + float16x8x4_t t1;\ + t1.val[0] = vld1q_f16(src1); src1 += 8; pref_ab(src1);\ + t1.val[1] = vld1q_f16(src2); src2 += 8; pref_ab(src2);\ + t1.val[2] = vld1q_f16(src3); src3 += 8; pref_ab(src3);\ + t1.val[3] = vld1q_f16(src4); src4 += 8; pref_ab(src4);\ + vst4q_lane_f16(dst_ptr, t1, 0);\ + vst4q_lane_f16(dst_ptr + inc, t1, 1);\ + vst4q_lane_f16(dst_ptr + inc * 2, t1, 2);\ + vst4q_lane_f16(dst_ptr + inc * 3, t1, 3);\ + vst4q_lane_f16(dst_ptr + inc * 4, t1, 4);\ + vst4q_lane_f16(dst_ptr + inc * 5, t1, 5);\ + vst4q_lane_f16(dst_ptr + inc * 6, t1, 6);\ + vst4q_lane_f16(dst_ptr + inc * 7, t1, 7);\ + dst_ptr += inc * 8;\ + } + +#define NCOPY_UNROLL_16 {\ + float16_t *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_NEON_LOOP_K8_UNROLL4(16, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_NEON_LOOP_K8_UNROLL4(16, dst_h1, src5, src6, src7, src8)\ + dst_h1 = dst1 + 8;\ + NCOPY_NEON_LOOP_K8_UNROLL4(16, dst_h1, src9, src10, src11, src12)\ + dst_h1 = dst1 + 12;\ + NCOPY_NEON_LOOP_K8_UNROLL4(16, dst_h1, src13, src14, src15, src16)\ + dst1 = dst_h1 - 12;\ + NCOPY_STD(16)\ +} + +#define NCOPY_UNROLL_8 {\ + float16_t *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_NEON_LOOP_K8_UNROLL4(8, dst_h1, src1, src2, src3, src4)\ + dst_h1 = dst1 + 4;\ + NCOPY_NEON_LOOP_K8_UNROLL4(8, dst_h1, src5, src6, src7, src8)\ + dst1 = dst_h1 - 4;\ + NCOPY_STD(8)\ +} + +#define NCOPY_UNROLL_4 {\ + float16_t *dst_h1 = dst1; uint32_t dim1_cache = dim1_count;\ + NCOPY_NEON_LOOP_K8_UNROLL4(4, dst_h1, src1, src2, src3, src4)\ + dst1 = dst_h1;\ + NCOPY_STD(4)\ +} + +#define NCOPY_UNROLL_2 NCOPY_STD(2) +#define NCOPY_UNROLL_1 NCOPY_STD(1) + +#define NCOPY_float16_t_float16_t(unroll) NCOPY_UNROLL_##unroll + + +#define TCOPY_UNIT_1(src_ptr, dst_ptr, dst_offset) \ + dst_ptr[dst_offset] = *src_ptr; + +#define TCOPY_UNIT_2(src_ptr, dst_ptr, dst_offset) {\ + dst_ptr[dst_offset] = *src_ptr;\ + dst_ptr[dst_offset + 1] = src_ptr[1];\ +} + +#define TCOPY_UNIT_4(src_ptr, dst_ptr, dst_offset) {\ + float16x4_t tmp = vld1_f16(src_ptr); pref_ab(src_ptr + 4);\ + vst1_f16(dst_ptr + dst_offset, tmp);\ +} + +#define TCOPY_UNIT_8(src_ptr, dst_ptr, dst_offset) {\ + float16x8_t tmp = vld1q_f16(src_ptr); pref_ab(src_ptr + 8);\ + vst1q_f16(dst_ptr + dst_offset, tmp);\ +} + +#define TCOPY_UNIT_16(src_ptr, dst_ptr, dst_offset) {\ + float16x8_t tmp1 = vld1q_f16(src_ptr);\ + float16x8_t tmp2 = vld1q_f16(src_ptr + 8); pref_ab(src_ptr + 16);\ + vst1q_f16(dst_ptr + dst_offset, tmp1);\ + vst1q_f16(dst_ptr + dst_offset + 8, tmp2);\ +} + +#define TCOPY_UNIT_float16_t_float16_t(src_ptr, dst_ptr, dst_offset, num_elements) \ + TCOPY_UNIT_##num_elements(src_ptr, dst_ptr, dst_offset) + +GENERIC_NCOPY_FUNC(hgemm, float16_t, float16_t, 8) +GENERIC_NCOPY_FUNC(hgemm, float16_t, float16_t, 16) + +GENERIC_TCOPY_FUNC(hgemm, float16_t, float16_t, 8) +GENERIC_TCOPY_FUNC(hgemm, float16_t, float16_t, 16) + diff --git a/src/neon_armv8a/extension/HgemmKernel.c b/src/neon_armv8a/extension/HgemmKernel.c new file mode 100644 index 0000000..b7cb731 --- /dev/null +++ b/src/neon_armv8a/extension/HgemmKernel.c @@ -0,0 +1,1420 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif +#include "common/CommonKernel.h" +#include "arm_neon/ARMCpuType.h" +#include +#include + +static inline void pref_c(float16_t *dat) { + __asm__ ("prfm pstl1keep,[%0]\n\t"::"r"(dat):); +} + +#define PREF_N1 pref_c(c_pref); c_pref += ldc; +#define PREF_N2 PREF_N1 PREF_N1 +#define PREF_N4 PREF_N2 PREF_N2 +#define PREF_N8 PREF_N4 PREF_N4 +#define PREF_N16 PREF_N8 PREF_N8 + +#define DECLARE_C_8X8 \ + float16x8_t cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08; + +#define DECLARE_C_8X16 DECLARE_C_8X8 \ + float16x8_t cq09, cq10, cq11, cq12, cq13, cq14, cq15, cq16; + +/* fp16-fma kernel for general out-of-order ARM processors */ +/* q0 and q1 for holding data from matrix A */ +/* q2 and q3 for holding data from matrix B */ +#define KERNEL_M8N16_A76 \ + DECLARE_C_8X16\ + float16_t *c_pref = c_ptr + 7; PREF_N16\ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + uint32_t k_left = K;\ + __asm__ __volatile__(\ + "movi %0.16b,#0; movi %1.16b,#0\n\t"\ + "mov %2.16b,%0.16b; mov %3.16b,%1.16b\n\t"\ + "mov %4.16b,%0.16b; mov %5.16b,%1.16b\n\t"\ + "mov %6.16b,%0.16b; mov %7.16b,%1.16b\n\t"\ + "mov %8.16b,%0.16b; mov %9.16b,%1.16b\n\t"\ + "mov %10.16b,%0.16b; mov %11.16b,%1.16b\n\t"\ + "mov %12.16b,%0.16b; mov %13.16b,%1.16b\n\t"\ + "mov %14.16b,%0.16b; mov %15.16b,%1.16b\n\t"\ + "cmp %w16,#0; b.eq 004f\n\t"\ + "ldr q0,[%17],#16; ldr q2,[%18],#16; ldr q3,[%18],#16\n\t"\ + "cmp %w16,#2; b.le 002f\n\t"\ + "001:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "ldr q1,[%17],#32\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v0.8h,v2.h[4]; fmla %5.8h,v0.8h,v2.h[5]\n\t"\ + "prfm pldl1keep,[%17,#128]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[6]; fmla %7.8h,v0.8h,v2.h[7]\n\t"\ + "ldr q2,[%18],#64\n\t"\ + "fmla %8.8h,v0.8h,v3.h[0]; fmla %9.8h,v0.8h,v3.h[1]\n\t"\ + "fmla %10.8h,v0.8h,v3.h[2]; fmla %11.8h,v0.8h,v3.h[3]\n\t"\ + "sub %w16,%w16,#2\n\t"\ + "fmla %12.8h,v0.8h,v3.h[4]; fmla %13.8h,v0.8h,v3.h[5]\n\t"\ + "fmla %14.8h,v0.8h,v3.h[6]; fmla %15.8h,v0.8h,v3.h[7]\n\t"\ + "ldr q3,[%18,#-48]\n\t"\ + "fmla %0.8h,v1.8h,v2.h[0]; fmla %1.8h,v1.8h,v2.h[1]\n\t"\ + "ldr q0,[%17,#-16]\n\t"\ + "fmla %2.8h,v1.8h,v2.h[2]; fmla %3.8h,v1.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v1.8h,v2.h[4]; fmla %5.8h,v1.8h,v2.h[5]\n\t"\ + "cmp %w16,#2\n\t"\ + "fmla %6.8h,v1.8h,v2.h[6]; fmla %7.8h,v1.8h,v2.h[7]\n\t"\ + "ldr q2,[%18,#-32]\n\t"\ + "fmla %8.8h,v1.8h,v3.h[0]; fmla %9.8h,v1.8h,v3.h[1]\n\t"\ + "fmla %10.8h,v1.8h,v3.h[2]; fmla %11.8h,v1.8h,v3.h[3]\n\t"\ + "fmla %12.8h,v1.8h,v3.h[4]; fmla %13.8h,v1.8h,v3.h[5]\n\t"\ + "fmla %14.8h,v1.8h,v3.h[6]; fmla %15.8h,v1.8h,v3.h[7]\n\t"\ + "ldr q3,[%18,#-16]; b.gt 001b\n\t"\ + "002:\n\t"\ + "cmp %w16,#2; b.ne 003f\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "ldr q1,[%17],#16\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v0.8h,v2.h[4]; fmla %5.8h,v0.8h,v2.h[5]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[6]; fmla %7.8h,v0.8h,v2.h[7]\n\t"\ + "ldr q2,[%18],#32\n\t"\ + "fmla %8.8h,v0.8h,v3.h[0]; fmla %9.8h,v0.8h,v3.h[1]\n\t"\ + "fmla %10.8h,v0.8h,v3.h[2]; fmla %11.8h,v0.8h,v3.h[3]\n\t"\ + "sub %w16,%w16,#2\n\t"\ + "fmla %12.8h,v0.8h,v3.h[4]; fmla %13.8h,v0.8h,v3.h[5]\n\t"\ + "fmla %14.8h,v0.8h,v3.h[6]; fmla %15.8h,v0.8h,v3.h[7]\n\t"\ + "ldr q3,[%18,#-16]\n\t"\ + "fmla %0.8h,v1.8h,v2.h[0]; fmla %1.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v1.8h,v2.h[2]; fmla %3.8h,v1.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v1.8h,v2.h[4]; fmla %5.8h,v1.8h,v2.h[5]\n\t"\ + "fmla %6.8h,v1.8h,v2.h[6]; fmla %7.8h,v1.8h,v2.h[7]\n\t"\ + "fmla %8.8h,v1.8h,v3.h[0]; fmla %9.8h,v1.8h,v3.h[1]\n\t"\ + "fmla %10.8h,v1.8h,v3.h[2]; fmla %11.8h,v1.8h,v3.h[3]\n\t"\ + "fmla %12.8h,v1.8h,v3.h[4]; fmla %13.8h,v1.8h,v3.h[5]\n\t"\ + "fmla %14.8h,v1.8h,v3.h[6]; fmla %15.8h,v1.8h,v3.h[7]\n\t"\ + "b 004f\n\t"\ + "003:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v0.8h,v2.h[4]; fmla %5.8h,v0.8h,v2.h[5]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[6]; fmla %7.8h,v0.8h,v2.h[7]\n\t"\ + "fmla %8.8h,v0.8h,v3.h[0]; fmla %9.8h,v0.8h,v3.h[1]\n\t"\ + "fmla %10.8h,v0.8h,v3.h[2]; fmla %11.8h,v0.8h,v3.h[3]\n\t"\ + "sub %w16,%w16,#1\n\t"\ + "fmla %12.8h,v0.8h,v3.h[4]; fmla %13.8h,v0.8h,v3.h[5]\n\t"\ + "fmla %14.8h,v0.8h,v3.h[6]; fmla %15.8h,v0.8h,v3.h[7]\n\t"\ + "004:\n\t"\ + :"=w"(cq01),"=w"(cq02),"=w"(cq03),"=w"(cq04)\ + ,"=w"(cq05),"=w"(cq06),"=w"(cq07),"=w"(cq08)\ + ,"=w"(cq09),"=w"(cq10),"=w"(cq11),"=w"(cq12)\ + ,"=w"(cq13),"=w"(cq14),"=w"(cq15),"=w"(cq16)\ + ,"+r"(k_left),"+r"(a_ptr),"+r"(b_ptr1)\ + ::"cc","memory","v0","v1","v2","v3"); + +#define KERNEL_M16N8_A76 \ + DECLARE_C_8X16\ + float16_t *c_pref = c_ptr + 15; PREF_N8\ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + uint32_t k_left = K;\ + __asm__ __volatile__(\ + "movi %0.16b,#0; movi %1.16b,#0\n\t"\ + "mov %2.16b,%0.16b; mov %3.16b,%1.16b\n\t"\ + "mov %4.16b,%0.16b; mov %5.16b,%1.16b\n\t"\ + "mov %6.16b,%0.16b; mov %7.16b,%1.16b\n\t"\ + "mov %8.16b,%0.16b; mov %9.16b,%1.16b\n\t"\ + "mov %10.16b,%0.16b; mov %11.16b,%1.16b\n\t"\ + "mov %12.16b,%0.16b; mov %13.16b,%1.16b\n\t"\ + "mov %14.16b,%0.16b; mov %15.16b,%1.16b\n\t"\ + "cmp %w16,#0; b.eq 004f\n\t"\ + "ldr q0,[%17],#32; ldr q2,[%18],#16; ldr q1,[%17,#-16]\n\t"\ + "cmp %w16,#2; b.le 002f\n\t"\ + "001:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %2.8h,v0.8h,v2.h[1]\n\t"\ + "ldr q3,[%18],#32\n\t"\ + "fmla %4.8h,v0.8h,v2.h[2]; fmla %6.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v2.h[4]; fmla %10.8h,v0.8h,v2.h[5]\n\t"\ + "prfm pldl1keep,[%18,#128]\n\t"\ + "fmla %12.8h,v0.8h,v2.h[6]; fmla %14.8h,v0.8h,v2.h[7]\n\t"\ + "ldr q0,[%17],#64\n\t"\ + "fmla %1.8h,v1.8h,v2.h[0]; fmla %3.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v2.h[2]; fmla %7.8h,v1.8h,v2.h[3]\n\t"\ + "sub %w16,%w16,#2\n\t"\ + "fmla %9.8h,v1.8h,v2.h[4]; fmla %11.8h,v1.8h,v2.h[5]\n\t"\ + "fmla %13.8h,v1.8h,v2.h[6]; fmla %15.8h,v1.8h,v2.h[7]\n\t"\ + "ldr q1,[%17,#-48]\n\t"\ + "fmla %0.8h,v0.8h,v3.h[0]; fmla %2.8h,v0.8h,v3.h[1]\n\t"\ + "ldr q2,[%18,#-16]\n\t"\ + "fmla %4.8h,v0.8h,v3.h[2]; fmla %6.8h,v0.8h,v3.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v3.h[4]; fmla %10.8h,v0.8h,v3.h[5]\n\t"\ + "cmp %w16,#2\n\t"\ + "fmla %12.8h,v0.8h,v3.h[6]; fmla %14.8h,v0.8h,v3.h[7]\n\t"\ + "ldr q0,[%17,#-32]\n\t"\ + "fmla %1.8h,v1.8h,v3.h[0]; fmla %3.8h,v1.8h,v3.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v3.h[2]; fmla %7.8h,v1.8h,v3.h[3]\n\t"\ + "fmla %9.8h,v1.8h,v3.h[4]; fmla %11.8h,v1.8h,v3.h[5]\n\t"\ + "fmla %13.8h,v1.8h,v3.h[6]; fmla %15.8h,v1.8h,v3.h[7]\n\t"\ + "ldr q1,[%17,#-16]; b.gt 001b\n\t"\ + "002:\n\t"\ + "cmp %w16,#2; b.ne 003f\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %2.8h,v0.8h,v2.h[1]\n\t"\ + "ldr q3,[%18],#16\n\t"\ + "fmla %4.8h,v0.8h,v2.h[2]; fmla %6.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v2.h[4]; fmla %10.8h,v0.8h,v2.h[5]\n\t"\ + "fmla %12.8h,v0.8h,v2.h[6]; fmla %14.8h,v0.8h,v2.h[7]\n\t"\ + "ldr q0,[%17],#32\n\t"\ + "fmla %1.8h,v1.8h,v2.h[0]; fmla %3.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v2.h[2]; fmla %7.8h,v1.8h,v2.h[3]\n\t"\ + "sub %w16,%w16,#2\n\t"\ + "fmla %9.8h,v1.8h,v2.h[4]; fmla %11.8h,v1.8h,v2.h[5]\n\t"\ + "fmla %13.8h,v1.8h,v2.h[6]; fmla %15.8h,v1.8h,v2.h[7]\n\t"\ + "ldr q1,[%17,#-16]\n\t"\ + "fmla %0.8h,v0.8h,v3.h[0]; fmla %2.8h,v0.8h,v3.h[1]\n\t"\ + "fmla %4.8h,v0.8h,v3.h[2]; fmla %6.8h,v0.8h,v3.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v3.h[4]; fmla %10.8h,v0.8h,v3.h[5]\n\t"\ + "fmla %12.8h,v0.8h,v3.h[6]; fmla %14.8h,v0.8h,v3.h[7]\n\t"\ + "fmla %1.8h,v1.8h,v3.h[0]; fmla %3.8h,v1.8h,v3.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v3.h[2]; fmla %7.8h,v1.8h,v3.h[3]\n\t"\ + "fmla %9.8h,v1.8h,v3.h[4]; fmla %11.8h,v1.8h,v3.h[5]\n\t"\ + "fmla %13.8h,v1.8h,v3.h[6]; fmla %15.8h,v1.8h,v3.h[7]\n\t"\ + "b 004f\n\t"\ + "003:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %2.8h,v0.8h,v2.h[1]\n\t"\ + "fmla %4.8h,v0.8h,v2.h[2]; fmla %6.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v2.h[4]; fmla %10.8h,v0.8h,v2.h[5]\n\t"\ + "fmla %12.8h,v0.8h,v2.h[6]; fmla %14.8h,v0.8h,v2.h[7]\n\t"\ + "fmla %1.8h,v1.8h,v2.h[0]; fmla %3.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v2.h[2]; fmla %7.8h,v1.8h,v2.h[3]\n\t"\ + "sub %w16,%w16,#1\n\t"\ + "fmla %9.8h,v1.8h,v2.h[4]; fmla %11.8h,v1.8h,v2.h[5]\n\t"\ + "fmla %13.8h,v1.8h,v2.h[6]; fmla %15.8h,v1.8h,v2.h[7]\n\t"\ + "004:\n\t"\ + :"=w"(cq01),"=w"(cq02),"=w"(cq03),"=w"(cq04)\ + ,"=w"(cq05),"=w"(cq06),"=w"(cq07),"=w"(cq08)\ + ,"=w"(cq09),"=w"(cq10),"=w"(cq11),"=w"(cq12)\ + ,"=w"(cq13),"=w"(cq14),"=w"(cq15),"=w"(cq16)\ + ,"+r"(k_left),"+r"(a_ptr),"+r"(b_ptr1)\ + ::"cc","memory","v0","v1","v2","v3"); + +#define KERNEL_M8N8_A76 \ + DECLARE_C_8X8\ + float16_t *c_pref = c_ptr + 7; PREF_N8\ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + uint32_t k_left = K;\ + __asm__ __volatile__(\ + "movi %0.16b,#0; movi %1.16b,#0\n\t"\ + "mov %2.16b,%0.16b; mov %3.16b,%1.16b\n\t"\ + "mov %4.16b,%0.16b; mov %5.16b,%1.16b\n\t"\ + "mov %6.16b,%0.16b; mov %7.16b,%1.16b\n\t"\ + "cmp %w8,#0; b.eq 104f\n\t"\ + "ldr q0,[%9],#16; ldr q2,[%10],#16\n\t"\ + "cmp %w8,#2; b.le 102f\n\t"\ + "101:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "ldr q1,[%9],#32\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "ldr q3,[%10],#32\n\t"\ + "fmla %4.8h,v0.8h,v2.h[4]; fmla %5.8h,v0.8h,v2.h[5]\n\t"\ + "prfm pldl1keep,[%9,#128]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[6]; fmla %7.8h,v0.8h,v2.h[7]\n\t"\ + "ldr q0,[%9,#-16]\n\t"\ + "fmla %0.8h,v1.8h,v3.h[0]; fmla %1.8h,v1.8h,v3.h[1]\n\t"\ + "ldr q2,[%10,#-16]\n\t"\ + "fmla %2.8h,v1.8h,v3.h[2]; fmla %3.8h,v1.8h,v3.h[3]\n\t"\ + "sub %w8,%w8,#2\n\t"\ + "fmla %4.8h,v1.8h,v3.h[4]; fmla %5.8h,v1.8h,v3.h[5]\n\t"\ + "cmp %w8,#2\n\t"\ + "fmla %6.8h,v1.8h,v3.h[6]; fmla %7.8h,v1.8h,v3.h[7]\n\t"\ + "b.gt 101b\n\t"\ + "102:\n\t"\ + "cmp %w8,#2; b.ne 103f\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "ldr q1,[%9],#16\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "ldr q3,[%10],#16\n\t"\ + "fmla %4.8h,v0.8h,v2.h[4]; fmla %5.8h,v0.8h,v2.h[5]\n\t"\ + "prfm pldl1keep,[%9,#128]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[6]; fmla %7.8h,v0.8h,v2.h[7]\n\t"\ + "fmla %0.8h,v1.8h,v3.h[0]; fmla %1.8h,v1.8h,v3.h[1]\n\t"\ + "fmla %2.8h,v1.8h,v3.h[2]; fmla %3.8h,v1.8h,v3.h[3]\n\t"\ + "sub %w8,%w8,#2\n\t"\ + "fmla %4.8h,v1.8h,v3.h[4]; fmla %5.8h,v1.8h,v3.h[5]\n\t"\ + "fmla %6.8h,v1.8h,v3.h[6]; fmla %7.8h,v1.8h,v3.h[7]\n\t"\ + "b 104f\n\t"\ + "103:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v0.8h,v2.h[4]; fmla %5.8h,v0.8h,v2.h[5]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[6]; fmla %7.8h,v0.8h,v2.h[7]\n\t"\ + "sub %w8,%w8,#1\n\t"\ + "104:\n\t"\ + :"=w"(cq01),"=w"(cq02),"=w"(cq03),"=w"(cq04)\ + ,"=w"(cq05),"=w"(cq06),"=w"(cq07),"=w"(cq08)\ + ,"+r"(k_left),"+r"(a_ptr),"+r"(b_ptr1)\ + ::"cc","memory","v0","v1","v2","v3"); + +/* fp16-fma kernel for A55 specially */ +#define KERNEL_M8N16_A55 \ + DECLARE_C_8X16\ + float16_t *c_pref = c_ptr + 7; PREF_N16\ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + uint32_t k_left = K;\ + __asm__ __volatile__(\ + "movi %0.16b,#0; movi %1.16b,#0\n\t"\ + "mov %2.16b,%0.16b; mov %3.16b,%1.16b\n\t"\ + "mov %4.16b,%0.16b; mov %5.16b,%1.16b\n\t"\ + "mov %6.16b,%0.16b; mov %7.16b,%1.16b\n\t"\ + "mov %8.16b,%0.16b; mov %9.16b,%1.16b\n\t"\ + "mov %10.16b,%0.16b; mov %11.16b,%1.16b\n\t"\ + "mov %12.16b,%0.16b; mov %13.16b,%1.16b\n\t"\ + "mov %14.16b,%0.16b; mov %15.16b,%1.16b\n\t"\ + "cmp %w16,#0; b.eq 004f\n\t"\ + "ldr q0,[%17],#16; ldr d2,[%18],#32; ldr d3,[%18,#-24]\n\t"\ + "ldr d4,[%18,#-16]; ldr d5,[%18,#-8]\n\t"\ + "cmp %w16,#2; b.le 002f\n\t"\ + "001:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "ldr d1,[%17],#32\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "ldr x0,[%17,#-24]\n\t"\ + "fmla %4.8h,v0.8h,v3.h[0]\n\t"\ + "ldr d2,[%18],#64\n\t"\ + "fmla %5.8h,v0.8h,v3.h[1]\n\t"\ + "prfm pldl1keep,[%17,#128]\n\t"\ + "fmla %6.8h,v0.8h,v3.h[2]; fmla %7.8h,v0.8h,v3.h[3]\n\t"\ + "ldr d3,[%18,#-56]\n\t"\ + "fmla %8.8h,v0.8h,v4.h[0]; fmla %9.8h,v0.8h,v4.h[1]\n\t"\ + "fmov v1.d[1],x0\n\t"\ + "fmla %10.8h,v0.8h,v4.h[2]; fmla %11.8h,v0.8h,v4.h[3]\n\t"\ + "sub %w16,%w16,#2\n\t"\ + "fmla %12.8h,v0.8h,v5.h[0]\n\t"\ + "ldr d4,[%18,#-48]\n\t"\ + "fmla %13.8h,v0.8h,v5.h[1]\n\t"\ + "fmla %14.8h,v0.8h,v5.h[2]; fmla %15.8h,v0.8h,v5.h[3]\n\t"\ + "ldr d5,[%18,#-40]\n\t"\ + "fmla %0.8h,v1.8h,v2.h[0]; fmla %1.8h,v1.8h,v2.h[1]\n\t"\ + "ldr d0,[%17,#-16]\n\t"\ + "fmla %2.8h,v1.8h,v2.h[2]; fmla %3.8h,v1.8h,v2.h[3]\n\t"\ + "ldr x0,[%17,#-8]\n\t"\ + "fmla %4.8h,v1.8h,v3.h[0]\n\t"\ + "ldr d2,[%18,#-32]\n\t"\ + "fmla %5.8h,v1.8h,v3.h[1]\n\t"\ + "cmp %w16,#2\n\t"\ + "fmla %6.8h,v1.8h,v3.h[2]; fmla %7.8h,v1.8h,v3.h[3]\n\t"\ + "ldr d3,[%18,#-24]\n\t"\ + "fmla %8.8h,v1.8h,v4.h[0]; fmla %9.8h,v1.8h,v4.h[1]\n\t"\ + "fmla %10.8h,v1.8h,v4.h[2]; fmla %11.8h,v1.8h,v4.h[3]\n\t"\ + "fmov v0.d[1],x0\n\t"\ + "fmla %12.8h,v1.8h,v5.h[0]\n\t"\ + "ldr d4,[%18,#-16]\n\t"\ + "fmla %13.8h,v1.8h,v5.h[1]\n\t"\ + "fmla %14.8h,v1.8h,v5.h[2]; fmla %15.8h,v1.8h,v5.h[3]\n\t"\ + "ldr d5,[%18,#-8]; b.gt 001b\n\t"\ + "002:\n\t"\ + "cmp %w16,#2; b.ne 003f\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "ldr d1,[%17],#16\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "ldr x0,[%17,#-8]\n\t"\ + "fmla %4.8h,v0.8h,v3.h[0]\n\t"\ + "ldr d2,[%18],#32\n\t"\ + "fmla %5.8h,v0.8h,v3.h[1]\n\t"\ + "fmla %6.8h,v0.8h,v3.h[2]; fmla %7.8h,v0.8h,v3.h[3]\n\t"\ + "ldr d3,[%18,#-24]\n\t"\ + "fmla %8.8h,v0.8h,v4.h[0]; fmla %9.8h,v0.8h,v4.h[1]\n\t"\ + "fmla %10.8h,v0.8h,v4.h[2]; fmla %11.8h,v0.8h,v4.h[3]\n\t"\ + "sub %w16,%w16,#2\n\t"\ + "fmla %12.8h,v0.8h,v5.h[0]\n\t"\ + "ldr d4,[%18,#-16]\n\t"\ + "fmla %13.8h,v0.8h,v5.h[1]\n\t"\ + "fmov v1.d[1],x0\n\t"\ + "fmla %14.8h,v0.8h,v5.h[2]; fmla %15.8h,v0.8h,v5.h[3]\n\t"\ + "ldr d5,[%18,#-8]\n\t"\ + "fmla %0.8h,v1.8h,v2.h[0]; fmla %1.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v1.8h,v2.h[2]; fmla %3.8h,v1.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v1.8h,v3.h[0]; fmla %5.8h,v1.8h,v3.h[1]\n\t"\ + "fmla %6.8h,v1.8h,v3.h[2]; fmla %7.8h,v1.8h,v3.h[3]\n\t"\ + "fmla %8.8h,v1.8h,v4.h[0]; fmla %9.8h,v1.8h,v4.h[1]\n\t"\ + "fmla %10.8h,v1.8h,v4.h[2]; fmla %11.8h,v1.8h,v4.h[3]\n\t"\ + "fmla %12.8h,v1.8h,v5.h[0]; fmla %13.8h,v1.8h,v5.h[1]\n\t"\ + "fmla %14.8h,v1.8h,v5.h[2]; fmla %15.8h,v1.8h,v5.h[3]\n\t"\ + "b 004f\n\t"\ + "003:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v0.8h,v3.h[0]; fmla %5.8h,v0.8h,v3.h[1]\n\t"\ + "fmla %6.8h,v0.8h,v3.h[2]; fmla %7.8h,v0.8h,v3.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v4.h[0]; fmla %9.8h,v0.8h,v4.h[1]\n\t"\ + "fmla %10.8h,v0.8h,v4.h[2]; fmla %11.8h,v0.8h,v4.h[3]\n\t"\ + "sub %w16,%w16,#1\n\t"\ + "fmla %12.8h,v0.8h,v5.h[0]; fmla %13.8h,v0.8h,v5.h[1]\n\t"\ + "fmla %14.8h,v0.8h,v5.h[2]; fmla %15.8h,v0.8h,v5.h[3]\n\t"\ + "004:\n\t"\ + :"=w"(cq01),"=w"(cq02),"=w"(cq03),"=w"(cq04)\ + ,"=w"(cq05),"=w"(cq06),"=w"(cq07),"=w"(cq08)\ + ,"=w"(cq09),"=w"(cq10),"=w"(cq11),"=w"(cq12)\ + ,"=w"(cq13),"=w"(cq14),"=w"(cq15),"=w"(cq16)\ + ,"+r"(k_left),"+r"(a_ptr),"+r"(b_ptr1)\ + ::"cc","memory","v0","v1","v2","v3","v4","v5","x0"); + +#define KERNEL_M16N8_A55 \ + DECLARE_C_8X16\ + float16_t *c_pref = c_ptr + 15; PREF_N8\ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + uint32_t k_left = K;\ + __asm__ __volatile__(\ + "movi %0.16b,#0; movi %1.16b,#0\n\t"\ + "mov %2.16b,%0.16b; mov %3.16b,%1.16b\n\t"\ + "mov %4.16b,%0.16b; mov %5.16b,%1.16b\n\t"\ + "mov %6.16b,%0.16b; mov %7.16b,%1.16b\n\t"\ + "mov %8.16b,%0.16b; mov %9.16b,%1.16b\n\t"\ + "mov %10.16b,%0.16b; mov %11.16b,%1.16b\n\t"\ + "mov %12.16b,%0.16b; mov %13.16b,%1.16b\n\t"\ + "mov %14.16b,%0.16b; mov %15.16b,%1.16b\n\t"\ + "cmp %w16,#0; b.eq 004f\n\t"\ + "ldr q0,[%17],#32\n\t"\ + "ldr d2,[%18],#16; ldr d3,[%18,#-8]\n\t"\ + "cmp %w16,#2; b.le 002f\n\t"\ + "001:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[1]; ldr d1,[%17,#-16]\n\t"\ + "fmla %4.8h,v0.8h,v2.h[2]; ldr x0,[%17,#-8]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[3]; prfm pldl2keep,[%18,#128]\n\t"\ + "fmla %8.8h,v0.8h,v3.h[0]; ldr d4,[%18],#32\n\t"\ + "fmla %10.8h,v0.8h,v3.h[1]; fmov v1.d[1],x0\n\t"\ + "fmla %12.8h,v0.8h,v3.h[2]\n\t"\ + "fmla %14.8h,v0.8h,v3.h[3]; ldr d5,[%18,#-24]\n\t"\ + "fmla %1.8h,v1.8h,v2.h[0]; ldr d0,[%17],#64\n\t"\ + "fmla %3.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v2.h[2]; ldr x0,[%17,#-56]\n\t"\ + "fmla %7.8h,v1.8h,v2.h[3]\n\t"\ + "fmla %9.8h,v1.8h,v3.h[0]\n\t"\ + "fmla %11.8h,v1.8h,v3.h[1]; fmov v0.d[1],x0\n\t"\ + "fmla %13.8h,v1.8h,v3.h[2]\n\t"\ + "fmla %15.8h,v1.8h,v3.h[3]\n\t"\ + "fmla %0.8h,v0.8h,v4.h[0]; ldr d1,[%17,#-48]\n\t"\ + "fmla %2.8h,v0.8h,v4.h[1]; ldr x0,[%17,#-40]\n\t"\ + "fmla %4.8h,v0.8h,v4.h[2]; ldr d2,[%18,#-16]\n\t"\ + "fmla %6.8h,v0.8h,v4.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v5.h[0]\n\t"\ + "fmla %10.8h,v0.8h,v5.h[1]; fmov v1.d[1],x0\n\t"\ + "fmla %12.8h,v0.8h,v5.h[2]; ldr d3,[%18,#-8]\n\t"\ + "fmla %14.8h,v0.8h,v5.h[3]\n\t"\ + "fmla %1.8h,v1.8h,v4.h[0]; ldr d0,[%17,#-32]\n\t"\ + "fmla %3.8h,v1.8h,v4.h[1]; ldr x0,[%17,#-24]\n\t"\ + "fmla %5.8h,v1.8h,v4.h[2]\n\t"\ + "fmla %7.8h,v1.8h,v4.h[3]; sub %w16,%w16,#2\n\t"\ + "fmla %9.8h,v1.8h,v5.h[0]\n\t"\ + "fmla %11.8h,v1.8h,v5.h[1]; fmov v0.d[1],x0\n\t"\ + "fmla %13.8h,v1.8h,v5.h[2]; cmp %w16,#2\n\t"\ + "fmla %15.8h,v1.8h,v5.h[3]; b.gt 001b\n\t"\ + "002:\n\t"\ + "cmp %w16,#2; b.ne 003f\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[1]; ldr d1,[%17,#-16]\n\t"\ + "fmla %4.8h,v0.8h,v2.h[2]; ldr x0,[%17,#-8]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v3.h[0]; ldr d4,[%18],#16\n\t"\ + "fmla %10.8h,v0.8h,v3.h[1]; fmov v1.d[1],x0\n\t"\ + "fmla %12.8h,v0.8h,v3.h[2]\n\t"\ + "fmla %14.8h,v0.8h,v3.h[3]; ldr d5,[%18,#-8]\n\t"\ + "fmla %1.8h,v1.8h,v2.h[0]; ldr d0,[%17],#32\n\t"\ + "fmla %3.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v2.h[2]; ldr x0,[%17,#-24]\n\t"\ + "fmla %7.8h,v1.8h,v2.h[3]\n\t"\ + "fmla %9.8h,v1.8h,v3.h[0]\n\t"\ + "fmla %11.8h,v1.8h,v3.h[1]; fmov v0.d[1],x0\n\t"\ + "fmla %13.8h,v1.8h,v3.h[2]\n\t"\ + "fmla %15.8h,v1.8h,v3.h[3]\n\t"\ + "fmla %0.8h,v0.8h,v4.h[0]; ldr d1,[%17,#-16]\n\t"\ + "fmla %2.8h,v0.8h,v4.h[1]; ldr x0,[%17,#-8]\n\t"\ + "fmla %4.8h,v0.8h,v4.h[2]\n\t"\ + "fmla %6.8h,v0.8h,v4.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v5.h[0]\n\t"\ + "fmla %10.8h,v0.8h,v5.h[1]; fmov v1.d[1],x0\n\t"\ + "fmla %12.8h,v0.8h,v5.h[2]\n\t"\ + "fmla %14.8h,v0.8h,v5.h[3]\n\t"\ + "fmla %1.8h,v1.8h,v4.h[0]\n\t"\ + "fmla %3.8h,v1.8h,v4.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v4.h[2]\n\t"\ + "fmla %7.8h,v1.8h,v4.h[3]; sub %w16,%w16,#2\n\t"\ + "fmla %9.8h,v1.8h,v5.h[0]\n\t"\ + "fmla %11.8h,v1.8h,v5.h[1]\n\t"\ + "fmla %13.8h,v1.8h,v5.h[2]\n\t"\ + "fmla %15.8h,v1.8h,v5.h[3]; b 004f\n\t"\ + "003:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[1]; ldr d1,[%17,#-16]\n\t"\ + "fmla %4.8h,v0.8h,v2.h[2]; ldr x0,[%17,#-8]\n\t"\ + "fmla %6.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %8.8h,v0.8h,v3.h[0]\n\t"\ + "fmla %10.8h,v0.8h,v3.h[1]; fmov v1.d[1],x0\n\t"\ + "fmla %12.8h,v0.8h,v3.h[2]\n\t"\ + "fmla %14.8h,v0.8h,v3.h[3]\n\t"\ + "fmla %1.8h,v1.8h,v2.h[0]\n\t"\ + "fmla %3.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %5.8h,v1.8h,v2.h[2]\n\t"\ + "fmla %7.8h,v1.8h,v2.h[3]\n\t"\ + "fmla %9.8h,v1.8h,v3.h[0]\n\t"\ + "fmla %11.8h,v1.8h,v3.h[1]; sub %w16,%w16,#1\n\t"\ + "fmla %13.8h,v1.8h,v3.h[2]\n\t"\ + "fmla %15.8h,v1.8h,v3.h[3]\n\t"\ + "004:\n\t"\ + :"=w"(cq01),"=w"(cq02),"=w"(cq03),"=w"(cq04)\ + ,"=w"(cq05),"=w"(cq06),"=w"(cq07),"=w"(cq08)\ + ,"=w"(cq09),"=w"(cq10),"=w"(cq11),"=w"(cq12)\ + ,"=w"(cq13),"=w"(cq14),"=w"(cq15),"=w"(cq16)\ + ,"+r"(k_left),"+r"(a_ptr),"+r"(b_ptr1)\ + ::"cc","memory","v0","v1","v2","v3","v4","v5","x0"); + +#define KERNEL_M8N8_A55 \ + DECLARE_C_8X8\ + float16_t *c_pref = c_ptr + 7; PREF_N8\ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + uint32_t k_left = K;\ + __asm__ __volatile__(\ + "movi %0.16b,#0; movi %1.16b,#0\n\t"\ + "mov %2.16b,%0.16b; mov %3.16b,%1.16b\n\t"\ + "mov %4.16b,%0.16b; mov %5.16b,%1.16b\n\t"\ + "mov %6.16b,%0.16b; mov %7.16b,%1.16b\n\t"\ + "cmp %w8,#0; b.eq 104f\n\t"\ + "ldr q0,[%9],#16; ldr d2,[%10],#16; ldr d3,[%10,#-8]\n\t"\ + "cmp %w8,#2; b.le 102f\n\t"\ + "101:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; ldr d1,[%9],#32\n\t"\ + "fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; ldr x0,[%9,#-24]\n\t"\ + "fmla %3.8h,v0.8h,v2.h[3]; prfm pldl1keep,[%9,#128]\n\t"\ + "fmla %4.8h,v0.8h,v3.h[0]; ldr d2,[%10],#32\n\t"\ + "fmla %5.8h,v0.8h,v3.h[1]; fmov v1.d[1],x0\n\t"\ + "fmla %6.8h,v0.8h,v3.h[2]\n\t"\ + "fmla %7.8h,v0.8h,v3.h[3]; ldr d3,[%10,#-24]\n\t"\ + "fmla %0.8h,v1.8h,v2.h[0]; ldr d0,[%9,#-16]\n\t"\ + "fmla %1.8h,v1.8h,v2.h[1]; ldr x0,[%9,#-8]\n\t"\ + "fmla %2.8h,v1.8h,v2.h[2]\n\t"\ + "fmla %3.8h,v1.8h,v2.h[3]; ldr d2,[%10,#-16]\n\t"\ + "fmla %4.8h,v1.8h,v3.h[0]; fmov v0.d[1],x0\n\t"\ + "fmla %5.8h,v1.8h,v3.h[1]; sub %w8,%w8,#2\n\t"\ + "fmla %6.8h,v1.8h,v3.h[2]; cmp %w8,#2\n\t"\ + "fmla %7.8h,v1.8h,v3.h[3]; ldr d3,[%10,#-8]\n\t"\ + "b.gt 101b\n\t"\ + "102:\n\t"\ + "cmp %w8,#2; b.ne 103f\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; ldr d1,[%9],#16\n\t"\ + "fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; ldr x0,[%9,#-8]\n\t"\ + "fmla %3.8h,v0.8h,v2.h[3]; prfm pldl1keep,[%9,#128]\n\t"\ + "fmla %4.8h,v0.8h,v3.h[0]; ldr d2,[%10],#16\n\t"\ + "fmla %5.8h,v0.8h,v3.h[1]; fmov v1.d[1],x0\n\t"\ + "fmla %6.8h,v0.8h,v3.h[2]\n\t"\ + "fmla %7.8h,v0.8h,v3.h[3]; ldr d3,[%10,#-8]\n\t"\ + "fmla %0.8h,v1.8h,v2.h[0]\n\t"\ + "fmla %1.8h,v1.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v1.8h,v2.h[2]\n\t"\ + "fmla %3.8h,v1.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v1.8h,v3.h[0]\n\t"\ + "fmla %5.8h,v1.8h,v3.h[1]; sub %w8,%w8,#2\n\t"\ + "fmla %6.8h,v1.8h,v3.h[2]\n\t"\ + "fmla %7.8h,v1.8h,v3.h[3]\n\t"\ + "b 104f\n\t"\ + "103:\n\t"\ + "fmla %0.8h,v0.8h,v2.h[0]; fmla %1.8h,v0.8h,v2.h[1]\n\t"\ + "fmla %2.8h,v0.8h,v2.h[2]; fmla %3.8h,v0.8h,v2.h[3]\n\t"\ + "fmla %4.8h,v0.8h,v3.h[0]; fmla %5.8h,v0.8h,v3.h[1]\n\t"\ + "fmla %6.8h,v0.8h,v3.h[2]; fmla %7.8h,v0.8h,v3.h[3]\n\t"\ + "sub %w8,%w8,#1\n\t"\ + "104:\n\t"\ + :"=w"(cq01),"=w"(cq02),"=w"(cq03),"=w"(cq04)\ + ,"=w"(cq05),"=w"(cq06),"=w"(cq07),"=w"(cq08)\ + ,"+r"(k_left),"+r"(a_ptr),"+r"(b_ptr1)\ + ::"cc","memory","v0","v1","v2","v3","x0"); + +#define KERNEL_M8N4_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x8_t cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f16(0.0f);\ + cq05 = cq06 = cq07 = cq08 = vdupq_n_f16(0.0f);\ + float16x8_t aq01, aq02, bq01;\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + aq01 = vld1q_f16(a_ptr);\ + aq02 = vld1q_f16(a_ptr + 8); a_ptr += 16;\ + bq01 = vld1q_f16(b_ptr1); b_ptr1 += 8;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cq01 = vfmaq_laneq_f16(cq01, aq01, bq01, 0);\ + cq02 = vfmaq_laneq_f16(cq02, aq01, bq01, 1);\ + cq03 = vfmaq_laneq_f16(cq03, aq01, bq01, 2);\ + cq04 = vfmaq_laneq_f16(cq04, aq01, bq01, 3);\ + aq01 = vld1q_f16(a_ptr);\ + cq05 = vfmaq_laneq_f16(cq05, aq02, bq01, 4);\ + cq06 = vfmaq_laneq_f16(cq06, aq02, bq01, 5);\ + cq07 = vfmaq_laneq_f16(cq07, aq02, bq01, 6);\ + cq08 = vfmaq_laneq_f16(cq08, aq02, bq01, 7);\ + aq02 = vld1q_f16(a_ptr + 8); a_ptr += 16;\ + bq01 = vld1q_f16(b_ptr1); b_ptr1 += 8;\ + }\ + if (k_left > 1) {\ + cq01 = vfmaq_laneq_f16(cq01, aq01, bq01, 0);\ + cq02 = vfmaq_laneq_f16(cq02, aq01, bq01, 1);\ + cq03 = vfmaq_laneq_f16(cq03, aq01, bq01, 2);\ + cq04 = vfmaq_laneq_f16(cq04, aq01, bq01, 3);\ + cq05 = vfmaq_laneq_f16(cq05, aq02, bq01, 4);\ + cq06 = vfmaq_laneq_f16(cq06, aq02, bq01, 5);\ + cq07 = vfmaq_laneq_f16(cq07, aq02, bq01, 6);\ + cq08 = vfmaq_laneq_f16(cq08, aq02, bq01, 7);\ + k_left -= 2;\ + }\ + cq01 = vaddq_f16(cq01, cq05);\ + cq02 = vaddq_f16(cq02, cq06);\ + cq03 = vaddq_f16(cq03, cq07);\ + cq04 = vaddq_f16(cq04, cq08);\ + if (k_left > 0) {\ + float16x4_t bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + aq01 = vld1q_f16(a_ptr); a_ptr += 8;\ + cq01 = vfmaq_lane_f16(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f16(cq02, aq01, bd01, 1);\ + cq03 = vfmaq_lane_f16(cq03, aq01, bd01, 2);\ + cq04 = vfmaq_lane_f16(cq04, aq01, bd01, 3);\ + } + +#define KERNEL_M8N2_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x8_t cq01, cq02, cq03, cq04;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f16(0.0f);\ + float16x8_t aq01, aq02; float16x4_t bd01;\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + aq01 = vld1q_f16(a_ptr);\ + aq02 = vld1q_f16(a_ptr + 8); a_ptr += 16;\ + bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cq01 = vfmaq_lane_f16(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f16(cq02, aq01, bd01, 1);\ + aq01 = vld1q_f16(a_ptr);\ + cq03 = vfmaq_lane_f16(cq03, aq02, bd01, 2);\ + cq04 = vfmaq_lane_f16(cq04, aq02, bd01, 3);\ + aq02 = vld1q_f16(a_ptr + 8); a_ptr += 16;\ + bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + }\ + if (k_left > 1) {\ + cq01 = vfmaq_lane_f16(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f16(cq02, aq01, bd01, 1);\ + cq03 = vfmaq_lane_f16(cq03, aq02, bd01, 2);\ + cq04 = vfmaq_lane_f16(cq04, aq02, bd01, 3);\ + k_left -= 2;\ + }\ + cq01 = vaddq_f16(cq01, cq03);\ + cq02 = vaddq_f16(cq02, cq04);\ + if (k_left > 0) {\ + aq01 = vld1q_f16(a_ptr); a_ptr += 8;\ + float16_t bs1 = b_ptr1[0];\ + float16_t bs2 = b_ptr1[1]; b_ptr1 += 2;\ + cq01 = vfmaq_n_f16(cq01, aq01, bs1);\ + cq02 = vfmaq_n_f16(cq02, aq01, bs2);\ + } + +#define KERNEL_M8N1_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x8_t cq01, cq02, cq03, cq04;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f16(0.0f);\ + float16x8_t aq01, aq02, aq03, aq04;\ + float16x4_t bd01;\ + uint32_t k_left = K;\ + if (k_left > 3) {\ + aq01 = vld1q_f16(a_ptr);\ + aq02 = vld1q_f16(a_ptr + 8);\ + aq03 = vld1q_f16(a_ptr + 16);\ + aq04 = vld1q_f16(a_ptr + 24); a_ptr += 32;\ + bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + }\ + for (; k_left > 7; k_left -= 4) {\ + cq01 = vfmaq_lane_f16(cq01, aq01, bd01, 0);\ + aq01 = vld1q_f16(a_ptr);\ + cq02 = vfmaq_lane_f16(cq02, aq02, bd01, 1);\ + aq02 = vld1q_f16(a_ptr + 8);\ + cq03 = vfmaq_lane_f16(cq03, aq03, bd01, 2);\ + aq03 = vld1q_f16(a_ptr + 16);\ + cq04 = vfmaq_lane_f16(cq04, aq04, bd01, 3);\ + aq04 = vld1q_f16(a_ptr + 24); a_ptr += 32;\ + bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + }\ + if (k_left > 3) {\ + cq01 = vfmaq_lane_f16(cq01, aq01, bd01, 0);\ + cq02 = vfmaq_lane_f16(cq02, aq02, bd01, 1);\ + cq03 = vfmaq_lane_f16(cq03, aq03, bd01, 2);\ + cq04 = vfmaq_lane_f16(cq04, aq04, bd01, 3);\ + k_left -= 4;\ + }\ + cq01 = vaddq_f16(cq01, cq02);\ + cq03 = vaddq_f16(cq03, cq04);\ + cq01 = vaddq_f16(cq01, cq03);\ + for (; k_left > 0; k_left--) {\ + aq01 = vld1q_f16(a_ptr); a_ptr += 8;\ + float16_t bs1 = *b_ptr1; b_ptr1++;\ + cq01 = vfmaq_n_f16(cq01, aq01, bs1);\ + } + +#define KERNEL_M8N4 KERNEL_M8N4_UNIT(a_head, b_head) +#define KERNEL_M8N2 KERNEL_M8N2_UNIT(a_head, b_head) +#define KERNEL_M8N1 KERNEL_M8N1_UNIT(a_head, b_head) +#define KERNEL_M4N8 KERNEL_M8N4_UNIT(b_head, a_head) +#define KERNEL_M2N8 KERNEL_M8N2_UNIT(b_head, a_head) +#define KERNEL_M1N8 KERNEL_M8N1_UNIT(b_head, a_head) + +#define KERNEL_M4N16_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x8_t cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f16(0.0f);\ + cq05 = cq06 = cq07 = cq08 = vdupq_n_f16(0.0f);\ + float16x8_t aq01, bq01, bq02, bq03, bq04;\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + aq01 = vld1q_f16(a_ptr); a_ptr += 8;\ + bq01 = vld1q_f16(b_ptr1);\ + bq02 = vld1q_f16(b_ptr1 + 8);\ + bq03 = vld1q_f16(b_ptr1 + 16);\ + bq04 = vld1q_f16(b_ptr1 + 24); b_ptr1 += 32;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cq01 = vfmaq_laneq_f16(cq01, bq01, aq01, 0);\ + cq03 = vfmaq_laneq_f16(cq03, bq01, aq01, 1);\ + cq05 = vfmaq_laneq_f16(cq05, bq01, aq01, 2);\ + cq07 = vfmaq_laneq_f16(cq07, bq01, aq01, 3);\ + bq01 = vld1q_f16(b_ptr1);\ + cq02 = vfmaq_laneq_f16(cq02, bq02, aq01, 0);\ + cq04 = vfmaq_laneq_f16(cq04, bq02, aq01, 1);\ + cq06 = vfmaq_laneq_f16(cq06, bq02, aq01, 2);\ + cq08 = vfmaq_laneq_f16(cq08, bq02, aq01, 3);\ + bq02 = vld1q_f16(b_ptr1 + 8);\ + cq01 = vfmaq_laneq_f16(cq01, bq03, aq01, 4);\ + cq03 = vfmaq_laneq_f16(cq03, bq03, aq01, 5);\ + cq05 = vfmaq_laneq_f16(cq05, bq03, aq01, 6);\ + cq07 = vfmaq_laneq_f16(cq07, bq03, aq01, 7);\ + bq03 = vld1q_f16(b_ptr1 + 16);\ + cq02 = vfmaq_laneq_f16(cq02, bq04, aq01, 4);\ + cq04 = vfmaq_laneq_f16(cq04, bq04, aq01, 5);\ + cq06 = vfmaq_laneq_f16(cq06, bq04, aq01, 6);\ + cq08 = vfmaq_laneq_f16(cq08, bq04, aq01, 7);\ + bq04 = vld1q_f16(b_ptr1 + 24); b_ptr1 += 32;\ + aq01 = vld1q_f16(a_ptr); a_ptr += 8;\ + }\ + if (k_left > 1) {\ + cq01 = vfmaq_laneq_f16(cq01, bq01, aq01, 0);\ + cq03 = vfmaq_laneq_f16(cq03, bq01, aq01, 1);\ + cq05 = vfmaq_laneq_f16(cq05, bq01, aq01, 2);\ + cq07 = vfmaq_laneq_f16(cq07, bq01, aq01, 3);\ + cq02 = vfmaq_laneq_f16(cq02, bq02, aq01, 0);\ + cq04 = vfmaq_laneq_f16(cq04, bq02, aq01, 1);\ + cq06 = vfmaq_laneq_f16(cq06, bq02, aq01, 2);\ + cq08 = vfmaq_laneq_f16(cq08, bq02, aq01, 3);\ + cq01 = vfmaq_laneq_f16(cq01, bq03, aq01, 4);\ + cq03 = vfmaq_laneq_f16(cq03, bq03, aq01, 5);\ + cq05 = vfmaq_laneq_f16(cq05, bq03, aq01, 6);\ + cq07 = vfmaq_laneq_f16(cq07, bq03, aq01, 7);\ + cq02 = vfmaq_laneq_f16(cq02, bq04, aq01, 4);\ + cq04 = vfmaq_laneq_f16(cq04, bq04, aq01, 5);\ + cq06 = vfmaq_laneq_f16(cq06, bq04, aq01, 6);\ + cq08 = vfmaq_laneq_f16(cq08, bq04, aq01, 7);\ + k_left -= 2;\ + }\ + if (k_left > 0) {\ + float16x4_t ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + bq01 = vld1q_f16(b_ptr1);\ + bq02 = vld1q_f16(b_ptr1 + 8); b_ptr1 += 16;\ + cq01 = vfmaq_lane_f16(cq01, bq01, ad01, 0);\ + cq03 = vfmaq_lane_f16(cq03, bq01, ad01, 1);\ + cq05 = vfmaq_lane_f16(cq05, bq01, ad01, 2);\ + cq07 = vfmaq_lane_f16(cq07, bq01, ad01, 3);\ + cq02 = vfmaq_lane_f16(cq02, bq02, ad01, 0);\ + cq04 = vfmaq_lane_f16(cq04, bq02, ad01, 1);\ + cq06 = vfmaq_lane_f16(cq06, bq02, ad01, 2);\ + cq08 = vfmaq_lane_f16(cq08, bq02, ad01, 3);\ + } + +#define KERNEL_M2N16_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x8_t cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f16(0.0f);\ + cq05 = cq06 = cq07 = cq08 = vdupq_n_f16(0.0f);\ + float16x8_t bq01, bq02, bq03, bq04;\ + float16x4_t ad01;\ + uint32_t k_left = K;\ + if (k_left > 1) {\ + ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + bq01 = vld1q_f16(b_ptr1);\ + bq02 = vld1q_f16(b_ptr1 + 8);\ + bq03 = vld1q_f16(b_ptr1 + 16);\ + bq04 = vld1q_f16(b_ptr1 + 24); b_ptr1 += 32;\ + }\ + for (; k_left > 3; k_left -= 2) {\ + cq01 = vfmaq_lane_f16(cq01, bq01, ad01, 0);\ + cq03 = vfmaq_lane_f16(cq03, bq01, ad01, 1);\ + bq01 = vld1q_f16(b_ptr1);\ + cq02 = vfmaq_lane_f16(cq02, bq02, ad01, 0);\ + cq04 = vfmaq_lane_f16(cq04, bq02, ad01, 1);\ + bq02 = vld1q_f16(b_ptr1 + 8);\ + cq05 = vfmaq_lane_f16(cq05, bq03, ad01, 2);\ + cq07 = vfmaq_lane_f16(cq07, bq03, ad01, 3);\ + bq03 = vld1q_f16(b_ptr1 + 16);\ + cq06 = vfmaq_lane_f16(cq06, bq04, ad01, 2);\ + cq08 = vfmaq_lane_f16(cq08, bq04, ad01, 3);\ + bq04 = vld1q_f16(b_ptr1 + 24); b_ptr1 += 32;\ + ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + }\ + if (k_left > 1) {\ + cq01 = vfmaq_lane_f16(cq01, bq01, ad01, 0);\ + cq03 = vfmaq_lane_f16(cq03, bq01, ad01, 1);\ + cq05 = vfmaq_lane_f16(cq05, bq03, ad01, 2);\ + cq07 = vfmaq_lane_f16(cq07, bq03, ad01, 3);\ + cq02 = vfmaq_lane_f16(cq02, bq02, ad01, 0);\ + cq04 = vfmaq_lane_f16(cq04, bq02, ad01, 1);\ + cq06 = vfmaq_lane_f16(cq06, bq04, ad01, 2);\ + cq08 = vfmaq_lane_f16(cq08, bq04, ad01, 3);\ + k_left -= 2;\ + }\ + cq01 = vaddq_f16(cq01, cq05);\ + cq02 = vaddq_f16(cq02, cq06);\ + cq03 = vaddq_f16(cq03, cq07);\ + cq04 = vaddq_f16(cq04, cq08);\ + if (k_left > 0) {\ + bq01 = vld1q_f16(b_ptr1);\ + bq02 = vld1q_f16(b_ptr1 + 8); b_ptr1 += 16;\ + float16_t as1 = a_ptr[0];\ + float16_t as2 = a_ptr[1]; a_ptr += 2;\ + cq01 = vfmaq_n_f16(cq01, bq01, as1);\ + cq02 = vfmaq_n_f16(cq02, bq02, as1);\ + cq03 = vfmaq_n_f16(cq03, bq01, as2);\ + cq04 = vfmaq_n_f16(cq04, bq02, as2);\ + } + +#define KERNEL_M1N16_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x8_t cq01, cq02, cq03, cq04, cq05, cq06, cq07, cq08;\ + cq01 = cq02 = cq03 = cq04 = vdupq_n_f16(0.0f);\ + cq05 = cq06 = cq07 = cq08 = vdupq_n_f16(0.0f);\ + float16x8_t bq01, bq02, bq03, bq04, bq05, bq06, bq07, bq08;\ + float16x4_t ad01;\ + uint32_t k_left = K;\ + if (k_left > 3) {\ + ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + bq01 = vld1q_f16(b_ptr1);\ + bq02 = vld1q_f16(b_ptr1 + 8);\ + bq03 = vld1q_f16(b_ptr1 + 16);\ + bq04 = vld1q_f16(b_ptr1 + 24);\ + bq05 = vld1q_f16(b_ptr1 + 32);\ + bq06 = vld1q_f16(b_ptr1 + 40);\ + bq07 = vld1q_f16(b_ptr1 + 48);\ + bq08 = vld1q_f16(b_ptr1 + 56); b_ptr1 += 64;\ + }\ + for (; k_left > 7; k_left -= 4) {\ + cq01 = vfmaq_lane_f16(cq01, bq01, ad01, 0);\ + bq01 = vld1q_f16(b_ptr1);\ + cq02 = vfmaq_lane_f16(cq02, bq02, ad01, 0);\ + bq02 = vld1q_f16(b_ptr1 + 8);\ + cq03 = vfmaq_lane_f16(cq03, bq03, ad01, 1);\ + bq03 = vld1q_f16(b_ptr1 + 16);\ + cq04 = vfmaq_lane_f16(cq04, bq04, ad01, 1);\ + bq04 = vld1q_f16(b_ptr1 + 24);\ + cq05 = vfmaq_lane_f16(cq05, bq05, ad01, 2);\ + bq05 = vld1q_f16(b_ptr1 + 32);\ + cq06 = vfmaq_lane_f16(cq06, bq06, ad01, 2);\ + bq06 = vld1q_f16(b_ptr1 + 40);\ + cq07 = vfmaq_lane_f16(cq07, bq07, ad01, 3);\ + bq07 = vld1q_f16(b_ptr1 + 48);\ + cq08 = vfmaq_lane_f16(cq08, bq08, ad01, 3);\ + bq08 = vld1q_f16(b_ptr1 + 56); b_ptr1 += 64;\ + ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + }\ + if (k_left > 3) {\ + cq01 = vfmaq_lane_f16(cq01, bq01, ad01, 0);\ + cq03 = vfmaq_lane_f16(cq03, bq03, ad01, 1);\ + cq05 = vfmaq_lane_f16(cq05, bq05, ad01, 2);\ + cq07 = vfmaq_lane_f16(cq07, bq07, ad01, 3);\ + cq02 = vfmaq_lane_f16(cq02, bq02, ad01, 0);\ + cq04 = vfmaq_lane_f16(cq04, bq04, ad01, 1);\ + cq06 = vfmaq_lane_f16(cq06, bq06, ad01, 2);\ + cq08 = vfmaq_lane_f16(cq08, bq08, ad01, 3);\ + k_left -= 4;\ + }\ + cq01 = vaddq_f16(cq01, cq03);\ + cq05 = vaddq_f16(cq05, cq07);\ + cq02 = vaddq_f16(cq02, cq04);\ + cq06 = vaddq_f16(cq06, cq08);\ + cq01 = vaddq_f16(cq01, cq05);\ + cq02 = vaddq_f16(cq02, cq06);\ + for (; k_left > 0; k_left--) {\ + float16_t as1 = *a_ptr; a_ptr++;\ + bq01 = vld1q_f16(b_ptr1);\ + bq02 = vld1q_f16(b_ptr1 + 8); b_ptr1 += 16;\ + cq01 = vfmaq_n_f16(cq01, bq01, as1);\ + cq02 = vfmaq_n_f16(cq02, bq02, as1);\ + } + +#define KERNEL_M4N16 KERNEL_M4N16_UNIT(a_head, b_head) +#define KERNEL_M2N16 KERNEL_M2N16_UNIT(a_head, b_head) +#define KERNEL_M1N16 KERNEL_M1N16_UNIT(a_head, b_head) +#define KERNEL_M16N4 KERNEL_M4N16_UNIT(b_head, a_head) +#define KERNEL_M16N2 KERNEL_M2N16_UNIT(b_head, a_head) +#define KERNEL_M16N1 KERNEL_M1N16_UNIT(b_head, a_head) + +#define KERNEL_M4N4 \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x4_t cd01, cd02, cd03, cd04;\ + cd01 = cd02 = cd03 = cd04 = vdup_n_f16(0.0f);\ + float16x4_t ad01, bd01;\ + uint32_t k_left = K;\ + for (; k_left > 0; k_left--) {\ + ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + cd01 = vfma_lane_f16(cd01, ad01, bd01, 0);\ + cd02 = vfma_lane_f16(cd02, ad01, bd01, 1);\ + cd03 = vfma_lane_f16(cd03, ad01, bd01, 2);\ + cd04 = vfma_lane_f16(cd04, ad01, bd01, 3);\ + } + +#define KERNEL_M4N2_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x4_t cd01, cd02, cd03, cd04;\ + cd01 = cd02 = cd03 = cd04 = vdup_n_f16(0.0f);\ + float16x4_t ad01, ad02, bd01;\ + uint32_t k_left = K;\ + for (; k_left > 1; k_left -= 2) {\ + ad01 = vld1_f16(a_ptr);\ + ad02 = vld1_f16(a_ptr + 4); a_ptr += 8;\ + bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + cd01 = vfma_lane_f16(cd01, ad01, bd01, 0);\ + cd02 = vfma_lane_f16(cd02, ad01, bd01, 1);\ + cd03 = vfma_lane_f16(cd03, ad02, bd01, 2);\ + cd04 = vfma_lane_f16(cd04, ad02, bd01, 3);\ + }\ + cd01 = vadd_f16(cd01, cd03);\ + cd02 = vadd_f16(cd02, cd04);\ + if (k_left > 0) {\ + ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + float16_t bs1 = b_ptr1[0];\ + float16_t bs2 = b_ptr1[1]; b_ptr1 += 2;\ + cd01 = vfma_n_f16(cd01, ad01, bs1);\ + cd02 = vfma_n_f16(cd02, ad01, bs2);\ + } + +#define KERNEL_M4N1_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x4_t cd01, cd02, cd03, cd04;\ + cd01 = cd02 = cd03 = cd04 = vdup_n_f16(0.0f);\ + float16x4_t ad01, ad02, ad03, ad04, bd01;\ + uint32_t k_left = K;\ + for (; k_left > 3; k_left -= 4) {\ + ad01 = vld1_f16(a_ptr);\ + ad02 = vld1_f16(a_ptr + 4);\ + ad03 = vld1_f16(a_ptr + 8);\ + ad04 = vld1_f16(a_ptr + 12); a_ptr += 16;\ + bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + cd01 = vfma_lane_f16(cd01, ad01, bd01, 0);\ + cd02 = vfma_lane_f16(cd02, ad02, bd01, 1);\ + cd03 = vfma_lane_f16(cd03, ad03, bd01, 2);\ + cd04 = vfma_lane_f16(cd04, ad04, bd01, 3);\ + }\ + cd01 = vadd_f16(cd01, cd03);\ + cd02 = vadd_f16(cd02, cd04);\ + cd01 = vadd_f16(cd01, cd02);\ + for (; k_left > 0; k_left--) {\ + ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + float16_t bs1 = *b_ptr1; b_ptr1++;\ + cd01 = vfma_n_f16(cd01, ad01, bs1);\ + } + +#define KERNEL_M4N2 KERNEL_M4N2_UNIT(a_head, b_head) +#define KERNEL_M4N1 KERNEL_M4N1_UNIT(a_head, b_head) +#define KERNEL_M2N4 KERNEL_M4N2_UNIT(b_head, a_head) +#define KERNEL_M1N4 KERNEL_M4N1_UNIT(b_head, a_head) + +#define KERNEL_M2N2 \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16_t cs1, cs2, cs3, cs4;\ + cs1 = cs2 = cs3 = cs4 = 0.0f;\ + float16_t as1, as2, bs1, bs2;\ + uint32_t k_left = K;\ + for (; k_left > 0; k_left--) {\ + as1 = a_ptr[0]; as2 = a_ptr[1]; a_ptr += 2;\ + bs1 = b_ptr1[0]; bs2 = b_ptr1[1]; b_ptr1 += 2;\ + cs1 += as1 * bs1;\ + cs2 += as2 * bs1;\ + cs3 += as1 * bs2;\ + cs4 += as2 * bs2;\ + } + +#define KERNEL_M2N1_UNIT(a_head, b_head) \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16_t cs1, cs2; cs1 = cs2 = 0.0f;\ + float16_t as1, as2, bs1;\ + uint32_t k_left = K;\ + for (; k_left > 0; k_left--) {\ + as1 = a_ptr[0]; as2 = a_ptr[1]; a_ptr += 2;\ + bs1 = b_ptr1[0]; b_ptr1++;\ + cs1 += as1 * bs1;\ + cs2 += as2 * bs1;\ + } + +#define KERNEL_M2N1 KERNEL_M2N1_UNIT(a_head, b_head) +#define KERNEL_M1N2 KERNEL_M2N1_UNIT(b_head, a_head) + +#define KERNEL_M1N1 \ + const float16_t *a_ptr = a_head;\ + const float16_t *b_ptr1 = b_head;\ + float16x4_t cd01 = vdup_n_f16(0.0f);\ + float16x4_t ad01, bd01;\ + uint32_t k_left = K;\ + for (; k_left > 3; k_left -= 4) {\ + ad01 = vld1_f16(a_ptr); a_ptr += 4;\ + bd01 = vld1_f16(b_ptr1); b_ptr1 += 4;\ + cd01 = vfma_f16(cd01, ad01, bd01);\ + }\ + float16_t cs1 = vget_lane_f16(cd01, 0) + vget_lane_f16(cd01, 1) + \ + vget_lane_f16(cd01, 2) + vget_lane_f16(cd01, 3);\ + for (; k_left > 0; k_left--) {\ + cs1 += (*a_ptr) * (*b_ptr1); a_ptr++; b_ptr1++;\ + } + + +#define SAVE_M1N8_UNIT(cq01, c_tmp) {\ + float16_t cs1 = vgetq_lane_f16(cq01, 0);\ + float16_t cs2 = vgetq_lane_f16(cq01, 1);\ + float16_t cs3 = vgetq_lane_f16(cq01, 2);\ + float16_t cs4 = vgetq_lane_f16(cq01, 3);\ + float16_t cs5 = vgetq_lane_f16(cq01, 4);\ + float16_t cs6 = vgetq_lane_f16(cq01, 5);\ + float16_t cs7 = vgetq_lane_f16(cq01, 6);\ + float16_t cs8 = vgetq_lane_f16(cq01, 7);\ + *c_tmp = *c_tmp * beta + cs1; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs2; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs3; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs4; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs5; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs6; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs7; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs8; c_tmp += ldc;\ +} + +#define SAVE_M2N8_UNIT(cq01, cq02, c_tmp) {\ + float16x8x2_t cqd1;\ + cqd1.val[0] = vdupq_n_f16(0.0f);\ + cqd1.val[1] = vdupq_n_f16(0.0f);\ + cqd1 = vld2q_lane_f16(c_tmp, cqd1, 0); c_tmp += ldc;\ + cqd1 = vld2q_lane_f16(c_tmp, cqd1, 1); c_tmp += ldc;\ + cqd1 = vld2q_lane_f16(c_tmp, cqd1, 2); c_tmp += ldc;\ + cqd1 = vld2q_lane_f16(c_tmp, cqd1, 3); c_tmp += ldc;\ + cqd1 = vld2q_lane_f16(c_tmp, cqd1, 4); c_tmp += ldc;\ + cqd1 = vld2q_lane_f16(c_tmp, cqd1, 5); c_tmp += ldc;\ + cqd1 = vld2q_lane_f16(c_tmp, cqd1, 6); c_tmp += ldc;\ + cqd1 = vld2q_lane_f16(c_tmp, cqd1, 7); c_tmp -= ldc * 7;\ + cqd1.val[0] = vfmaq_n_f16(cq01, cqd1.val[0], beta);\ + cqd1.val[1] = vfmaq_n_f16(cq02, cqd1.val[1], beta);\ + vst2q_lane_f16(c_tmp, cqd1, 0); c_tmp += ldc;\ + vst2q_lane_f16(c_tmp, cqd1, 1); c_tmp += ldc;\ + vst2q_lane_f16(c_tmp, cqd1, 2); c_tmp += ldc;\ + vst2q_lane_f16(c_tmp, cqd1, 3); c_tmp += ldc;\ + vst2q_lane_f16(c_tmp, cqd1, 4); c_tmp += ldc;\ + vst2q_lane_f16(c_tmp, cqd1, 5); c_tmp += ldc;\ + vst2q_lane_f16(c_tmp, cqd1, 6); c_tmp += ldc;\ + vst2q_lane_f16(c_tmp, cqd1, 7); c_tmp += ldc;\ +} + +#define SAVE_M4N8_UNIT(cq01, cq02, cq03, cq04, c_tmp) {\ + float16x8x4_t cqq1;\ + cqq1.val[0] = vdupq_n_f16(0.0f);\ + cqq1.val[1] = vdupq_n_f16(0.0f);\ + cqq1.val[2] = vdupq_n_f16(0.0f);\ + cqq1.val[3] = vdupq_n_f16(0.0f);\ + cqq1 = vld4q_lane_f16(c_tmp, cqq1, 0); c_tmp += ldc;\ + cqq1 = vld4q_lane_f16(c_tmp, cqq1, 1); c_tmp += ldc;\ + cqq1 = vld4q_lane_f16(c_tmp, cqq1, 2); c_tmp += ldc;\ + cqq1 = vld4q_lane_f16(c_tmp, cqq1, 3); c_tmp += ldc;\ + cqq1 = vld4q_lane_f16(c_tmp, cqq1, 4); c_tmp += ldc;\ + cqq1 = vld4q_lane_f16(c_tmp, cqq1, 5); c_tmp += ldc;\ + cqq1 = vld4q_lane_f16(c_tmp, cqq1, 6); c_tmp += ldc;\ + cqq1 = vld4q_lane_f16(c_tmp, cqq1, 7); c_tmp -= ldc * 7;\ + cqq1.val[0] = vfmaq_n_f16(cq01, cqq1.val[0], beta);\ + cqq1.val[1] = vfmaq_n_f16(cq02, cqq1.val[1], beta);\ + cqq1.val[2] = vfmaq_n_f16(cq03, cqq1.val[2], beta);\ + cqq1.val[3] = vfmaq_n_f16(cq04, cqq1.val[3], beta);\ + vst4q_lane_f16(c_tmp, cqq1, 0); c_tmp += ldc;\ + vst4q_lane_f16(c_tmp, cqq1, 1); c_tmp += ldc;\ + vst4q_lane_f16(c_tmp, cqq1, 2); c_tmp += ldc;\ + vst4q_lane_f16(c_tmp, cqq1, 3); c_tmp += ldc;\ + vst4q_lane_f16(c_tmp, cqq1, 4); c_tmp += ldc;\ + vst4q_lane_f16(c_tmp, cqq1, 5); c_tmp += ldc;\ + vst4q_lane_f16(c_tmp, cqq1, 6); c_tmp += ldc;\ + vst4q_lane_f16(c_tmp, cqq1, 7); c_tmp += ldc;\ +} + +#define SAVE_M2N4_UNIT(cd01, cd02, c_tmp) {\ + float16x4x2_t cdd1;\ + cdd1.val[0] = vdup_n_f16(0.0f);\ + cdd1.val[1] = vdup_n_f16(0.0f);\ + cdd1 = vld2_lane_f16(c_tmp, cdd1, 0); c_tmp += ldc;\ + cdd1 = vld2_lane_f16(c_tmp, cdd1, 1); c_tmp += ldc;\ + cdd1 = vld2_lane_f16(c_tmp, cdd1, 2); c_tmp += ldc;\ + cdd1 = vld2_lane_f16(c_tmp, cdd1, 3); c_tmp -= ldc * 3;\ + cdd1.val[0] = vfma_n_f16(cd01, cdd1.val[0], beta);\ + cdd1.val[1] = vfma_n_f16(cd02, cdd1.val[1], beta);\ + vst2_lane_f16(c_tmp, cdd1, 0); c_tmp += ldc;\ + vst2_lane_f16(c_tmp, cdd1, 1); c_tmp += ldc;\ + vst2_lane_f16(c_tmp, cdd1, 2); c_tmp += ldc;\ + vst2_lane_f16(c_tmp, cdd1, 3); c_tmp += ldc;\ +} + +#define SAVE_M1N4_UNIT(cd01, c_tmp) {\ + float16_t cs1 = vget_lane_f16(cd01, 0);\ + float16_t cs2 = vget_lane_f16(cd01, 1);\ + float16_t cs3 = vget_lane_f16(cd01, 2);\ + float16_t cs4 = vget_lane_f16(cd01, 3);\ + *c_tmp = *c_tmp * beta + cs1; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs2; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs3; c_tmp += ldc;\ + *c_tmp = *c_tmp * beta + cs4; c_tmp += ldc;\ +} + +#define SAVE_M16N2_UNIT(cq01, cq02, cq03, cq04, c_tmp) \ + cq01 = vfmaq_n_f16(cq01, vld1q_f16(c_tmp), beta);\ + cq02 = vfmaq_n_f16(cq02, vld1q_f16(c_tmp + 8), beta);\ + cq03 = vfmaq_n_f16(cq03, vld1q_f16(c_tmp + ldc), beta);\ + cq04 = vfmaq_n_f16(cq04, vld1q_f16(c_tmp + ldc + 8), beta);\ + vst1q_f16(c_tmp, cq01); vst1q_f16(c_tmp + 8, cq02);\ + vst1q_f16(c_tmp + ldc, cq03); vst1q_f16(c_tmp + ldc + 8, cq04);\ + c_tmp += ldc * 2; + +#define SAVE_M8N2_UNIT(cq01, cq02, c_tmp) \ + cq01 = vfmaq_n_f16(cq01, vld1q_f16(c_tmp), beta);\ + cq02 = vfmaq_n_f16(cq02, vld1q_f16(c_tmp + ldc), beta);\ + vst1q_f16(c_tmp, cq01);\ + vst1q_f16(c_tmp + ldc, cq02); c_tmp += ldc * 2; + +#define SAVE_M4N2_UNIT(cd01, cd02, c_tmp) \ + cd01 = vfma_n_f16(cd01, vld1_f16(c_tmp), beta);\ + cd02 = vfma_n_f16(cd02, vld1_f16(c_tmp + ldc), beta);\ + vst1_f16(c_tmp, cd01);\ + vst1_f16(c_tmp + ldc, cd02); c_tmp += ldc * 2; + +#define SAVE_M8N16 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M8N2_UNIT(cq01, cq02, c_tmp)\ + SAVE_M8N2_UNIT(cq03, cq04, c_tmp)\ + SAVE_M8N2_UNIT(cq05, cq06, c_tmp)\ + SAVE_M8N2_UNIT(cq07, cq08, c_tmp)\ + SAVE_M8N2_UNIT(cq09, cq10, c_tmp)\ + SAVE_M8N2_UNIT(cq11, cq12, c_tmp)\ + SAVE_M8N2_UNIT(cq13, cq14, c_tmp)\ + SAVE_M8N2_UNIT(cq15, cq16, c_tmp) + +#define SAVE_M4N16 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M4N8_UNIT(cq01, cq03, cq05, cq07, c_tmp)\ + SAVE_M4N8_UNIT(cq02, cq04, cq06, cq08, c_tmp) + +#define SAVE_M2N16 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M2N8_UNIT(cq01, cq03, c_tmp)\ + SAVE_M2N8_UNIT(cq02, cq04, c_tmp) + +#define SAVE_M1N16 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M1N8_UNIT(cq01, c_tmp)\ + SAVE_M1N8_UNIT(cq02, c_tmp) + +#define SAVE_M16N8 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M16N2_UNIT(cq01, cq02, cq03, cq04, c_tmp)\ + SAVE_M16N2_UNIT(cq05, cq06, cq07, cq08, c_tmp)\ + SAVE_M16N2_UNIT(cq09, cq10, cq11, cq12, c_tmp)\ + SAVE_M16N2_UNIT(cq13, cq14, cq15, cq16, c_tmp) + +#define SAVE_M8N8 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M8N2_UNIT(cq01, cq02, c_tmp)\ + SAVE_M8N2_UNIT(cq03, cq04, c_tmp)\ + SAVE_M8N2_UNIT(cq05, cq06, c_tmp)\ + SAVE_M8N2_UNIT(cq07, cq08, c_tmp) + +#define SAVE_M4N8 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M4N8_UNIT(cq01, cq02, cq03, cq04, c_tmp) + +#define SAVE_M2N8 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M2N8_UNIT(cq01, cq02, c_tmp) + +#define SAVE_M1N8 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M1N8_UNIT(cq01, c_tmp) + +#define SAVE_M16N4 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M16N2_UNIT(cq01, cq02, cq03, cq04, c_tmp)\ + SAVE_M16N2_UNIT(cq05, cq06, cq07, cq08, c_tmp) + +#define SAVE_M8N4 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M8N2_UNIT(cq01, cq02, c_tmp)\ + SAVE_M8N2_UNIT(cq03, cq04, c_tmp) + +#define SAVE_M4N4 \ + float16_t *c_tmp = c_ptr;\ + SAVE_M4N2_UNIT(cd01, cd02, c_tmp)\ + SAVE_M4N2_UNIT(cd03, cd04, c_tmp) + +#define SAVE_M2N4 \ + float16_t *c_tmp = c_ptr; SAVE_M2N4_UNIT(cd01, cd02, c_tmp) + +#define SAVE_M1N4 \ + float16_t *c_tmp = c_ptr; SAVE_M1N4_UNIT(cd01, c_tmp) + +#define SAVE_M16N2 \ + float16_t *c_tmp = c_ptr; SAVE_M16N2_UNIT(cq01, cq02, cq03, cq04, c_tmp) + +#define SAVE_M8N2 \ + float16_t *c_tmp = c_ptr; SAVE_M8N2_UNIT(cq01, cq02, c_tmp) + +#define SAVE_M4N2 \ + float16_t *c_tmp = c_ptr; SAVE_M4N2_UNIT(cd01, cd02, c_tmp) + +#define SAVE_M2N2 \ + c_ptr[0] = c_ptr[0] * beta + cs1;\ + c_ptr[1] = c_ptr[1] * beta + cs2;\ + c_ptr[ldc] = c_ptr[ldc] * beta + cs3;\ + c_ptr[ldc + 1] = c_ptr[ldc + 1] * beta + cs4;\ + +#define SAVE_M1N2 \ + c_ptr[0] = c_ptr[0] * beta + cs1;\ + c_ptr[ldc] = c_ptr[ldc] * beta + cs2;\ + +#define SAVE_M16N1 \ + cq01 = vfmaq_n_f16(cq01, vld1q_f16(c_ptr), beta);\ + cq02 = vfmaq_n_f16(cq02, vld1q_f16(c_ptr + 8), beta);\ + vst1q_f16(c_ptr, cq01); vst1q_f16(c_ptr + 8, cq02); + +#define SAVE_M8N1 \ + cq01 = vfmaq_n_f16(cq01, vld1q_f16(c_ptr), beta);\ + vst1q_f16(c_ptr, cq01); + +#define SAVE_M4N1 \ + cd01 = vfma_n_f16(cd01, vld1_f16(c_ptr), beta);\ + vst1_f16(c_ptr, cd01); + +#define SAVE_M2N1 \ + c_ptr[0] = c_ptr[0] * beta + cs1;\ + c_ptr[1] = c_ptr[1] * beta + cs2;\ + +#define SAVE_M1N1 \ + c_ptr[0] = c_ptr[0] * beta + cs1; + +#define NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(mdim, ndim) \ +static inline void\ + inline_dualpack_gemm_afloat16_t_bfloat16_t_cfloat16_t_m##mdim##_n##ndim(\ + const float16_t *a_head, const float16_t *b_head, float16_t *c_ptr,\ + uint32_t K, float16_t beta, uint32_t ldc) {\ + KERNEL_M##mdim##N##ndim\ + SAVE_M##mdim##N##ndim\ +} + +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 1) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 2) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 1) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 2) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 4) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 4) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 1) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 2) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 4) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 8) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 8) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 8) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 1) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 2) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(8, 4) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(1, 16) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(2, 16) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(4, 16) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(16, 1) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(16, 2) +NEON_HGEMM_INLINE_DUALPACK_UNIT_FUNC(16, 4) + +#define CPUID_DETECT_MNK 1000000 + +void hgemm_kernel_lm_m8n16(uint32_t M, uint32_t N, uint32_t K, float16_t beta, + const float16_t * __restrict__ sa, const float16_t * __restrict__ sb, + float16_t * __restrict__ C, uint32_t ldc) { + + uint32_t n_left = N; + const float16_t *b_head = sb; + float16_t *c_head = C; + uint32_t acc_mnk = CPUID_DETECT_MNK; + uint8_t cpuid = 0, cputype = 0; + + for (; n_left > 15; n_left -= 16) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const float16_t *a_head = sa; + float16_t *c_ptr = c_head; + uint32_t m_left = M; + if (cputype == 55) { + for (; m_left > 7; m_left -= 8) { + KERNEL_M8N16_A55 + SAVE_M8N16 + a_head += 8 * K; + c_ptr += 8; + } + } else { + for (; m_left > 7; m_left -= 8) { + KERNEL_M8N16_A76 + SAVE_M8N16 + a_head += 8 * K; + c_ptr += 8; + } + } + MICRO_COMPUTE_LM(4, 16, float16_t, float16_t, float16_t) + b_head += K * 16; + c_head += ldc * 16; + acc_mnk += 16 * K * M; + } + + for (; n_left > 7; n_left -= 8) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const float16_t *a_head = sa; + float16_t *c_ptr = c_head; + uint32_t m_left = M; + if (cputype == 55) { + for (; m_left > 7; m_left -= 8) { + KERNEL_M8N8_A55 + SAVE_M8N8 + a_head += 8 * K; + c_ptr += 8; + } + } else { + for (; m_left > 7; m_left -= 8) { + KERNEL_M8N8_A76 + SAVE_M8N8 + a_head += 8 * K; + c_ptr += 8; + } + } + MICRO_COMPUTE_LM(4, 8, float16_t, float16_t, float16_t) + b_head += K * 8; + c_head += ldc * 8; + acc_mnk += 8 * K * M; + } + + ASSEMBLE_DUALPACK_COMPUTE_LM(4, float16_t, float16_t, float16_t, 8) +} + +void hgemm_kernel_ln_m16n8(uint32_t M, uint32_t N, uint32_t K, float16_t beta, + const float16_t * __restrict__ sa, const float16_t * __restrict__ sb, + float16_t * __restrict__ C, uint32_t ldc) { + + uint32_t m_left = M; + const float16_t *a_head = sa; + float16_t *c_head = C; + uint32_t acc_mnk = CPUID_DETECT_MNK; + uint8_t cpuid = 0, cputype = 0; + for (; m_left > 15; m_left -= 16) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const float16_t *b_head = sb; + float16_t *c_ptr = c_head; + uint32_t n_left = N; + if (cputype == 55) { + for (; n_left > 7; n_left -= 8) { + KERNEL_M16N8_A55 + SAVE_M16N8 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } else { + for (; n_left > 7; n_left -= 8) { + KERNEL_M16N8_A76 + SAVE_M16N8 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } + MICRO_COMPUTE_LN(16, 4, float16_t, float16_t, float16_t) + a_head += K * 16; + c_head += 16; + acc_mnk += 16 * N * K; + } + + for (; m_left > 7; m_left -= 8) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const float16_t *b_head = sb; + float16_t *c_ptr = c_head; + uint32_t n_left = N; + if (cputype == 55) { + for (; n_left > 7; n_left -= 8) { + KERNEL_M8N8_A55 + SAVE_M8N8 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } else { + for (; n_left > 7; n_left -= 8) { + KERNEL_M8N8_A76 + SAVE_M8N8 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } + MICRO_COMPUTE_LN(8, 4, float16_t, float16_t, float16_t) + a_head += K * 8; + c_head += 8; + acc_mnk += 8 * N * K; + } + + ASSEMBLE_DUALPACK_COMPUTE_LN(4, float16_t, float16_t, float16_t, 8) +} + diff --git a/src/neon_armv8a/extension/HgemmSkinnyDot.c b/src/neon_armv8a/extension/HgemmSkinnyDot.c new file mode 100644 index 0000000..0e46b23 --- /dev/null +++ b/src/neon_armv8a/extension/HgemmSkinnyDot.c @@ -0,0 +1,350 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonSkinnyDot.h" +#include + +static inline void inline_hgemm_arowmajor_bskinny_m1n1( + const float16_t *a_ptr, const float16_t *b_ptr, float16_t *c_ptr, + uint32_t k_left, uint32_t LDK, uint32_t LDM, + float16_t beta, bool c_rowmajor) { + + float16x8_t cq1; + __asm__ __volatile__ ( + "movi %[cq1].16b,#0; movi v0.16b,#0\n\t" + "mov v1.16b,%[cq1].16b; mov v2.16b,v0.16b\n\t" + "cmp %w[k_left],#32; b.lt 3f\n\t" + "ldr q3,[%[a_ptr]],#64; ldr q7,[%[b_ptr]],#64\n\t" + "ldr q4,[%[a_ptr],#-48]; ldr q8,[%[b_ptr],#-48]\n\t" + "ldr q5,[%[a_ptr],#-32]; ldr q9,[%[b_ptr],#-32]\n\t" + "ldr q6,[%[a_ptr],#-16]; ldr q10,[%[b_ptr],#-16]\n\t" + "cmp %w[k_left],#64; b.lt 2f\n\t" + ".balign 16; 1:\n\t" + "fmla %[cq1].8h,v3.8h,v7.8h; ldr q3,[%[a_ptr]],#64\n\t" + "ldr q7,[%[b_ptr]],#64; sub %w[k_left],%w[k_left],#32\n\t" + "fmla v0.8h,v4.8h,v8.8h; ldr q4,[%[a_ptr],#-48]\n\t" + "ldr q8,[%[b_ptr],#-48]; cmp %w[k_left],#64\n\t" + "fmla v1.8h,v5.8h,v9.8h; ldr q5,[%[a_ptr],#-32]\n\t" + "ldr q9,[%[b_ptr],#-32]\n\t" + "fmla v2.8h,v6.8h,v10.8h; ldr q6,[%[a_ptr],#-16]\n\t" + "ldr q10,[%[b_ptr],#-16]; b.ge 1b\n\t" + "2:\n\t" + "fmla %[cq1].8h,v3.8h,v7.8h; sub %w[k_left],%w[k_left],#32\n\t" + "fmla v0.8h,v4.8h,v8.8h\n\t" + "fmla v1.8h,v5.8h,v9.8h\n\t" + "fmla v2.8h,v6.8h,v10.8h\n\t" + "3:\n\t" + "cmp %w[k_left],#16; fadd %[cq1].8h,%[cq1].8h,v1.8h\n\t" + "fadd v0.8h,v0.8h,v2.8h; b.lt 4f\n\t" + "ldr q3,[%[a_ptr]],#32; ldr q7,[%[b_ptr]],#32\n\t" + "ldr q4,[%[a_ptr],#-16]; ldr q8,[%[b_ptr],#-16]\n\t" + "sub %w[k_left],%w[k_left],#16\n\t" + "fmla %[cq1].8h,v3.8h,v7.8h; fmla v0.8h,v4.8h,v8.8h\n\t" + "4:\n\t" + "cmp %w[k_left],#8; fadd %[cq1].8h,%[cq1].8h,v0.8h; b.lt 5f\n\t" + "ldr q3,[%[a_ptr]],#16; ldr q7,[%[b_ptr]],#16\n\t" + "sub %w[k_left],%w[k_left],#8; fmla %[cq1].8h,v3.8h,v7.8h\n\t" + "5:\n\t" + :[cq1]"=w"(cq1), [k_left]"+r"(k_left), + [a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr) + ::"cc","memory","v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10"); + + float16x4_t cd1 = vget_low_f16(vpaddq_f16(cq1, cq1)); + if (k_left > 3) { + float16x4_t ad1 = vld1_f16(a_ptr); a_ptr += 4; + float16x4_t bd1 = vld1_f16(b_ptr); b_ptr += 4; + cd1 = vfma_f16(cd1, ad1, bd1); k_left -= 4; + } + + float16_t cs1 = vget_lane_f16(cd1, 0) + vget_lane_f16(cd1, 1) + + vget_lane_f16(cd1, 2) + vget_lane_f16(cd1, 3); + for (; k_left > 0; k_left--) { + float16_t as1 = *a_ptr; a_ptr++; + float16_t bs1 = *b_ptr; b_ptr++; + cs1 += as1 * bs1; + } + + *c_ptr = c_ptr[0] * beta + cs1; +} + +/* k_mask = 15 */ +static inline void inline_hgemm_arowmajor_bskinny_m1n2( + const float16_t *a_ptr, const float16_t *b_ptr, float16_t *c_ptr, + uint32_t k_left, uint32_t LDK, uint32_t LDM, + float16_t beta, bool c_rowmajor) { + + float16x8_t cq1, cq2; + __asm__ __volatile__ ( + "movi %[cq1].16b,#0; movi %[cq2].16b,#0\n\t" + "mov v0.16b,%[cq1].16b; mov v1.16b,%[cq2].16b\n\t" + "cmp %w[k_left],#16; b.lt 3f\n\t" + "ldr q2,[%[a_ptr]],#32; ldr q4,[%[b_ptr]],#64; ldr q6,[%[b_ptr],#-48]\n\t" + "ldr q3,[%[a_ptr],#-16]; ldr q5,[%[b_ptr],#-32]; ldr q7,[%[b_ptr],#-16]\n\t" + "cmp %w[k_left],#32; b.lt 2f\n\t" + "1:\n\t" + "fmla %[cq1].8h,v2.8h,v4.8h; ldr q4,[%[b_ptr]],#64\n\t" + "sub %w[k_left],%w[k_left],#16\n\t" + "fmla %[cq2].8h,v2.8h,v6.8h; ldr q6,[%[b_ptr],#-48]\n\t" + "ldr q2,[%[a_ptr]],#32\n\t" + "fmla v0.8h,v3.8h,v5.8h; ldr q5,[%[b_ptr],#-32]\n\t" + "cmp %w[k_left],#32\n\t" + "fmla v1.8h,v3.8h,v7.8h; ldr q7,[%[b_ptr],#-16]\n\t" + "ldr q3,[%[a_ptr],#-16]\n\t" + "b.ge 1b\n\t" + "2:\n\t" + "fmla %[cq1].8h,v2.8h,v4.8h; sub %w[k_left],%w[k_left],#16\n\t" + "fmla %[cq2].8h,v2.8h,v6.8h\n\t" + "fmla v0.8h,v3.8h,v5.8h\n\t" + "fmla v1.8h,v3.8h,v7.8h\n\t" + "3:\n\t" + "cmp %w[k_left],#8; fadd %[cq1].8h,%[cq1].8h,v0.8h\n\t" + "fadd %[cq2].8h,%[cq2].8h,v1.8h; b.lt 4f\n\t" + "ldr q2,[%[a_ptr]],#16; ldr q4,[%[b_ptr]],#32; ldr q6,[%[b_ptr],#-16]\n\t" + "sub %w[k_left],%w[k_left],#8\n\t" + "fmla %[cq1].8h,v2.8h,v4.8h; fmla %[cq2].8h,v2.8h,v6.8h\n\t" + "4:\n\t" + :[cq1]"=w"(cq1), [cq2]"=w"(cq2), [k_left]"+r"(k_left), + [a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr) + ::"cc","memory","v0","v1","v2","v3","v4","v5","v6","v7"); + + cq1 = vpaddq_f16(cq1, cq2); + if (k_left > 3) { + float16x4_t ad1 = vld1_f16(a_ptr); a_ptr += 4; + float16x8_t aq1 = vcombine_f16(ad1, ad1); + float16x8_t bq1 = vld1q_f16(b_ptr); b_ptr += 8; + cq1 = vfmaq_f16(cq1, aq1, bq1); k_left -= 4; + } + + const float16x8_t cz1 = vdupq_n_f16(0); + float16x4_t cd1 = vget_low_f16(vpaddq_f16(cq1, cz1)); + if (k_left > 1) { + float16x4_t ad1; + __asm__("ld1r {%0.2s},[%1],#4":"=w"(ad1),"+r"(a_ptr)::"memory"); + float16x4_t bd1 = vld1_f16(b_ptr); b_ptr += 4; + cd1 = vfma_f16(cd1, ad1, bd1); k_left -= 2; + } + + cd1 = vpadd_f16(cd1, vget_low_f16(cz1)); + if (k_left > 0) { + float16x4_t ad1, bd1; + __asm__("ld1r {%0.4h},[%1],#2":"=w"(ad1),"+r"(a_ptr)::"memory"); + __asm__("ldr %s0,[%1],#4":"=w"(bd1),"+r"(b_ptr)::"memory"); + cd1 = vfma_f16(cd1, ad1, bd1); + } + + if (c_rowmajor) { + c_ptr[0] = c_ptr[0] * beta + vget_lane_f16(cd1, 0); + c_ptr[1] = c_ptr[1] * beta + vget_lane_f16(cd1, 1); + } else { + c_ptr[0] = c_ptr[0] * beta + vget_lane_f16(cd1, 0); + c_ptr[LDM] = c_ptr[LDM] * beta + vget_lane_f16(cd1, 1); + } +} + +/* k_mask = 13 */ +static inline void inline_hgemm_arowmajor_bskinny_m1n3( + const float16_t *a_ptr, const float16_t *b_ptr, float16_t *c_ptr, + uint32_t k_left, uint32_t LDK, uint32_t LDM, + float16_t beta, bool c_rowmajor) { + + float16x8_t cq1, cq2, cq3; + __asm__ __volatile__ ( + "movi %[cq1].16b,#0; movi %[cq2].16b,#0; movi %[cq3].16b,#0\n\t" + "mov v0.16b,%[cq1].16b; mov v1.16b,%[cq2].16b; mov v2.16b,%[cq3].16b\n\t" + "cmp %w[k_left],#16; b.lt 3f\n\t" + "ldr q3,[%[a_ptr]],#32; ldr q5,[%[b_ptr]],#96\n\t" + "ldr q7,[%[b_ptr],#-80]; ldr q9,[%[b_ptr],#-64]\n\t" + "ldr q4,[%[a_ptr],#-16]; ldr q6,[%[b_ptr],#-48]\n\t" + "ldr q8,[%[b_ptr],#-32]; ldr q10,[%[b_ptr],#-16]\n\t" + "cmp %w[k_left],#32; b.lt 2f\n\t" + "1:\n\t" + "fmla %[cq1].8h,v3.8h,v5.8h; ldr q5,[%[b_ptr]],#96\n\t" + "sub %w[k_left],%w[k_left],#16\n\t" + "fmla %[cq2].8h,v3.8h,v7.8h; ldr q7,[%[b_ptr],#-80]\n\t" + "fmla %[cq3].8h,v3.8h,v9.8h; ldr q9,[%[b_ptr],#-64]\n\t" + "ldr q3,[%[a_ptr]],#32\n\t" + "fmla v0.8h,v4.8h,v6.8h; ldr q6,[%[b_ptr],#-48]\n\t" + "cmp %w[k_left],#32\n\t" + "fmla v1.8h,v4.8h,v8.8h; ldr q8,[%[b_ptr],#-32]\n\t" + "fmla v2.8h,v4.8h,v10.8h; ldr q10,[%[b_ptr],#-16]\n\t" + "ldr q4,[%[a_ptr],#-16]\n\t" + "b.ge 1b\n\t" + "2:\n\t" + "fmla %[cq1].8h,v3.8h,v5.8h; sub %w[k_left],%w[k_left],#16\n\t" + "fmla %[cq2].8h,v3.8h,v7.8h\n\t" + "fmla %[cq3].8h,v3.8h,v9.8h\n\t" + "fmla v0.8h,v4.8h,v6.8h\n\t" + "fmla v1.8h,v4.8h,v8.8h\n\t" + "fmla v2.8h,v4.8h,v10.8h\n\t" + "3:\n\t" + "cmp %w[k_left],#8\n\t" + "fadd %[cq1].8h,%[cq1].8h,v0.8h\n\t" + "fadd %[cq2].8h,%[cq2].8h,v1.8h\n\t" + "fadd %[cq3].8h,%[cq3].8h,v2.8h; b.lt 4f\n\t" + "ldr q3,[%[a_ptr]],#16; ldr q5,[%[b_ptr]],#48\n\t" + "ldr q7,[%[b_ptr],#-32]; ldr q9,[%[b_ptr],#-16]\n\t" + "sub %w[k_left],%w[k_left],#8\n\t" + "fmla %[cq1].8h,v3.8h,v5.8h\n\t" + "fmla %[cq2].8h,v3.8h,v7.8h\n\t" + "fmla %[cq3].8h,v3.8h,v9.8h\n\t" + "4:\n\t" + :[cq1]"=w"(cq1), [cq2]"=w"(cq2), [cq3]"=w"(cq3), + [k_left]"+r"(k_left), [a_ptr]"+r"(a_ptr), [b_ptr]"+r"(b_ptr) + ::"cc","memory","v0","v1","v2","v3","v4","v5","v6","v7","v8","v9","v10"); + + float16x4_t cd1 = vadd_f16(vget_low_f16(cq1), vget_high_f16(cq1)); + float16x4_t cd2 = vadd_f16(vget_low_f16(cq2), vget_high_f16(cq2)); + float16x4_t cd3 = vadd_f16(vget_low_f16(cq3), vget_high_f16(cq3)); + if (k_left > 3) { + float16x4_t ad1 = vld1_f16(a_ptr); a_ptr += 4; + float16x4_t bd1 = vld1_f16(b_ptr); + float16x4_t bd2 = vld1_f16(b_ptr + 4); + float16x4_t bd3 = vld1_f16(b_ptr + 8); b_ptr += 12; + cd1 = vfma_f16(cd1, ad1, bd1); + cd2 = vfma_f16(cd2, ad1, bd2); + cd3 = vfma_f16(cd3, ad1, bd3); k_left -= 4; + } + + float16_t cs1 = vget_lane_f16(cd1, 0) + vget_lane_f16(cd1, 1) + + vget_lane_f16(cd1, 2) + vget_lane_f16(cd1, 3); + float16_t cs2 = vget_lane_f16(cd2, 0) + vget_lane_f16(cd2, 1) + + vget_lane_f16(cd2, 2) + vget_lane_f16(cd2, 3); + float16_t cs3 = vget_lane_f16(cd3, 0) + vget_lane_f16(cd3, 1) + + vget_lane_f16(cd3, 2) + vget_lane_f16(cd3, 3); + for (; k_left > 0; k_left--) { + float16_t as1 = *a_ptr; a_ptr++; + cs1 += as1 * b_ptr[0]; + cs2 += as1 * b_ptr[1]; + cs3 += as1 * b_ptr[2]; b_ptr += 3; + } + + if (c_rowmajor) { + c_ptr[0] = c_ptr[0] * beta + cs1; + c_ptr[1] = c_ptr[1] * beta + cs2; + c_ptr[2] = c_ptr[2] * beta + cs3; + } else { + c_ptr[0] = c_ptr[0] * beta + cs1; + c_ptr[LDM] = c_ptr[LDM] * beta + cs2; + c_ptr[LDM * 2] = c_ptr[LDM * 2] * beta + cs3; + } +} + +typedef float16_t hgemm_skinnydot_ascalar; +typedef float16_t hgemm_skinnydot_bscalar; +typedef float16_t hgemm_skinnydot_cscalar; + +static inline bool unroll_test_m1n1(uint32_t M, uint32_t K) { + return true; +} + +static inline bool unroll_test_m1n2(uint32_t M, uint32_t K) { + return true; +} + +static inline bool unroll_test_m1n3(uint32_t M, uint32_t K) { + return true; +} + +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(hgemm, 1, 13, 1, 65536, float16_t, float16_t, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(hgemm, 2, 15, 1, 65536, float16_t, float16_t, unroll_test) +GEMM_SKINNY_DOT_PARALLEL_FUNC_NOINCINLINE(hgemm, 3, 13, 1, 65536, float16_t, float16_t, unroll_test) + +typedef float16_t hgemm_skinnydot_avec1; +typedef float16_t hgemm_skinnydot_bvec1; +typedef float16_t hgemm_skinnydot_cvec1; + +typedef float16x4_t hgemm_skinnydot_avec4; +typedef float16x4_t hgemm_skinnydot_bvec4; +typedef float16x4_t hgemm_skinnydot_cvec4; + +typedef float16x8_t hgemm_skinnydot_avec8; +typedef float16x8_t hgemm_skinnydot_bvec8; +typedef float16x8_t hgemm_skinnydot_cvec8; + +GEMM_SKINNY_DOT_CALC_UNIT(hgemm, 8) { + return vfmaq_f16(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_DOT_CALC_UNIT(hgemm, 4) { + return vfma_f16(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_DOT_CALC_UNIT(hgemm, 1) { + return c_vec + a_vec * b_vec; +} + +GEMM_SKINNY_DOT_LOADA_UNIT(hgemm, 8) { + __asm__("prfm pldl1keep,[%0,#80]"::"r"(a_ptr):); + return vld1q_f16(a_ptr); +} + +GEMM_SKINNY_DOT_LOADA_UNIT(hgemm, 4) { + __asm__("prfm pldl1keep,[%0,#72]"::"r"(a_ptr):); + return vld1_f16(a_ptr); +} + +GEMM_SKINNY_DOT_LOADA_UNIT(hgemm, 1) { + return *a_ptr; +} + +GEMM_SKINNY_DOT_LOADB_UNIT(hgemm, 8) { + return vld1q_f16(b_ptr); +} + +GEMM_SKINNY_DOT_LOADB_UNIT(hgemm, 4) { + return vld1_f16(b_ptr); +} + +GEMM_SKINNY_DOT_LOADB_UNIT(hgemm, 1) { + return *b_ptr; +} + +GEMM_SKINNY_DOT_REDUC_UNIT(hgemm, 8, 4) { + return vget_low_f16(vpaddq_f16(c_vec, c_vec)); +} + +GEMM_SKINNY_DOT_REDUC_UNIT(hgemm, 4, 1) { + float cs1 = vget_lane_f16(c_vec, 0); + float cs2 = vget_lane_f16(c_vec, 1); + float cs3 = vget_lane_f16(c_vec, 2); + float cs4 = vget_lane_f16(c_vec, 3); + cs1 += cs2; cs3 += cs4; + return cs1 + cs3; +} + +GEMM_SKINNY_DOT_INITC_UNIT(hgemm, 8) { + return vdupq_n_f16(0); +} + +GEMM_SKINNY_DOT_INITC_UNIT(hgemm, 4) { + return vdup_n_f16(0); +} + +GEMM_SKINNY_DOT_INITC_UNIT(hgemm, 1) { + return 0; +} + +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 4, 13, 7, 65536, float16_t, float16_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 5, 13, 7, 65536, float16_t, float16_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 6, 13, 7, 65536, float16_t, float16_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 7, 13, 3, 65536, float16_t, float16_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 8, 13, 3, 65536, float16_t, float16_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 9, 13, 3, 65536, float16_t, float16_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 10, 13, 3, 65536, float16_t, float16_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 11, 13, 3, 65536, float16_t, float16_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(hgemm, 12, 13, 3, 65536, float16_t, float16_t) diff --git a/src/neon_armv8a/extension/HgemmSkinnyGer.c b/src/neon_armv8a/extension/HgemmSkinnyGer.c new file mode 100644 index 0000000..ec47b8e --- /dev/null +++ b/src/neon_armv8a/extension/HgemmSkinnyGer.c @@ -0,0 +1,232 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonSkinnyGer.h" + +#include + +typedef float16_t hgemm_skinnyger_ascalar; +typedef float16_t hgemm_skinnyger_bscalar; +typedef float16_t hgemm_skinnyger_cscalar; + +typedef float16_t hgemm_skinnyger_avec1; +typedef float16_t hgemm_skinnyger_bvec1; +typedef float16_t hgemm_skinnyger_cvec1; + +typedef float16x4_t hgemm_skinnyger_avec4; +typedef float16x4_t hgemm_skinnyger_bvec4; +typedef float16x4_t hgemm_skinnyger_cvec4; + +typedef float16x8_t hgemm_skinnyger_avec8; +typedef float16x8_t hgemm_skinnyger_bvec8; +typedef float16x8_t hgemm_skinnyger_cvec8; + +typedef float16x8x2_t hgemm_skinnyger_avec16; +typedef float16x8x2_t hgemm_skinnyger_bvec16; +typedef float16x8x2_t hgemm_skinnyger_cvec16; + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 16, 4, 1) { + float16x8x2_t ret; + ret.val[0] = vfmaq_lane_f16(c_vec.val[0], a_vec.val[0], b_vec, 0); + ret.val[1] = vfmaq_lane_f16(c_vec.val[1], a_vec.val[1], b_vec, 0); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 16, 4, 2) { + float16x8x2_t ret; + ret.val[0] = vfmaq_lane_f16(c_vec.val[0], a_vec.val[0], b_vec, 1); + ret.val[1] = vfmaq_lane_f16(c_vec.val[1], a_vec.val[1], b_vec, 1); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 16, 4, 3) { + float16x8x2_t ret; + ret.val[0] = vfmaq_lane_f16(c_vec.val[0], a_vec.val[0], b_vec, 2); + ret.val[1] = vfmaq_lane_f16(c_vec.val[1], a_vec.val[1], b_vec, 2); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 16, 4, 4) { + float16x8x2_t ret; + ret.val[0] = vfmaq_lane_f16(c_vec.val[0], a_vec.val[0], b_vec, 3); + ret.val[1] = vfmaq_lane_f16(c_vec.val[1], a_vec.val[1], b_vec, 3); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 16, 1, 1) { + float16x8x2_t ret; + ret.val[0] = vfmaq_n_f16(c_vec.val[0], a_vec.val[0], b_vec); + ret.val[1] = vfmaq_n_f16(c_vec.val[1], a_vec.val[1], b_vec); + return ret; +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 8, 4, 1) { + return vfmaq_lane_f16(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 8, 4, 2) { + return vfmaq_lane_f16(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 8, 4, 3) { + return vfmaq_lane_f16(c_vec, a_vec, b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 8, 4, 4) { + return vfmaq_lane_f16(c_vec, a_vec, b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 8, 1, 1) { + return vfmaq_n_f16(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 4, 4, 1) { + return vfma_lane_f16(c_vec, a_vec, b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 4, 4, 2) { + return vfma_lane_f16(c_vec, a_vec, b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 4, 4, 3) { + return vfma_lane_f16(c_vec, a_vec, b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 4, 4, 4) { + return vfma_lane_f16(c_vec, a_vec, b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 4, 1, 1) { + return vfma_n_f16(c_vec, a_vec, b_vec); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 1, 4, 1) { + return c_vec + a_vec * vget_lane_f16(b_vec, 0); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 1, 4, 2) { + return c_vec + a_vec * vget_lane_f16(b_vec, 1); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 1, 4, 3) { + return c_vec + a_vec * vget_lane_f16(b_vec, 2); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 1, 4, 4) { + return c_vec + a_vec * vget_lane_f16(b_vec, 3); +} + +GEMM_SKINNY_GER_CALC_UNIT(hgemm, 1, 1, 1) { + return c_vec + a_vec * b_vec; +} + +GEMM_SKINNY_GER_LOADA_UNIT(hgemm, 16) { + float16x8x2_t ret; + ret.val[0] = vld1q_f16(a_ptr); + ret.val[1] = vld1q_f16(a_ptr + 8); + __asm__("prfm pldl1keep,[%0,#96]"::"r"(a_ptr):); + return ret; +} + +GEMM_SKINNY_GER_LOADA_UNIT(hgemm, 8) { + __asm__("prfm pldl1keep,[%0,#80]"::"r"(a_ptr):); + return vld1q_f16(a_ptr); +} + +GEMM_SKINNY_GER_LOADA_UNIT(hgemm, 4) { + return vld1_f16(a_ptr); +} + +GEMM_SKINNY_GER_LOADA_UNIT(hgemm, 1) { + return *a_ptr; +} + +GEMM_SKINNY_GER_LOADC_UNIT(hgemm, 16) { + float16x8x2_t ret; + ret.val[0] = vld1q_f16(c_ptr); + ret.val[1] = vld1q_f16(c_ptr + 8); + return ret; +} + +GEMM_SKINNY_GER_LOADC_UNIT(hgemm, 8) { + return vld1q_f16(c_ptr); +} + +GEMM_SKINNY_GER_LOADC_UNIT(hgemm, 4) { + return vld1_f16(c_ptr); +} + +GEMM_SKINNY_GER_LOADC_UNIT(hgemm, 1) { + return *c_ptr; +} + +GEMM_SKINNY_GER_STOREC_UNIT(hgemm, 16) { + vst1q_f16(c_ptr, c_vec.val[0]); + vst1q_f16(c_ptr + 8, c_vec.val[1]); +} + +GEMM_SKINNY_GER_STOREC_UNIT(hgemm, 8) { + vst1q_f16(c_ptr, c_vec); +} + +GEMM_SKINNY_GER_STOREC_UNIT(hgemm, 4) { + vst1_f16(c_ptr, c_vec); +} + +GEMM_SKINNY_GER_STOREC_UNIT(hgemm, 1) { + *c_ptr = c_vec; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(hgemm, 4) { + float16x4_t ret = vdup_n_f16(0); + float16_t b1 = *b_ptr; b_ptr += ldb; + float16_t b2 = *b_ptr; b_ptr += ldb; + float16_t b3 = *b_ptr; b_ptr += ldb; + float16_t b4 = *b_ptr; + ret = vset_lane_f16(b1, ret, 0); + ret = vset_lane_f16(b2, ret, 1); + ret = vset_lane_f16(b3, ret, 2); + ret = vset_lane_f16(b4, ret, 3); + return ret; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BROWMAJOR(hgemm, 1) { + return *b_ptr; +} + +GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(hgemm, 4) { + return vld1_f16(b_ptr); +} + +GEMM_SKINNY_GER_LOADB_UNIT_BCOLMAJOR(hgemm, 1) { + return *b_ptr; +} + +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 1, 5, 29, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 2, 5, 29, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 3, 5, 29, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 4, 5, 29, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 5, 5, 29, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 6, 5, 29, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 7, 5, 29, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 8, 5, 13, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 9, 5, 13, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 10, 5, 13, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 11, 5, 13, 16384, float16_t, float16_t) +GEMM_SKINNY_GER_PARALLEL_FUNC(hgemm, 12, 5, 13, 16384, float16_t, float16_t) + diff --git a/src/neon_armv8a/extension/S8S32DotGemmCopy.c b/src/neon_armv8a/extension/S8S32DotGemmCopy.c new file mode 100644 index 0000000..3cb9665 --- /dev/null +++ b/src/neon_armv8a/extension/S8S32DotGemmCopy.c @@ -0,0 +1,30 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "common/CommonCopy.h" +#include "neon_armv8a/I8I32DotGemmCopy.h" + +GENERIC_NCOPY_FUNC(s8s32dotgemm, int8_t, int32_t, 8) +GENERIC_NCOPY_FUNC(s8s32dotgemm, int8_t, int32_t, 12) + +TCOPY_FUNC_TEMPLATE(s8s32dotgemm_int8_t_int32_t_tcopy_unroll, 8) +TCOPY_FUNC_TEMPLATE(s8s32dotgemm_int8_t_int32_t_tcopy_unroll, 12) + diff --git a/src/neon_armv8a/extension/S8S32DotGemmKernel.c b/src/neon_armv8a/extension/S8S32DotGemmKernel.c new file mode 100644 index 0000000..409ca4f --- /dev/null +++ b/src/neon_armv8a/extension/S8S32DotGemmKernel.c @@ -0,0 +1,116 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif +#include "common/CommonKernel.h" +#include "arm_neon/ARMCpuType.h" +#include +#include "neon_armv8a/I8I32DotGemmKernel.h" + +#define CPUID_DETECT_MNK 1000000 + +void s8s32dotgemm_kernel_lm_m8n12(uint32_t M, uint32_t N, uint32_t K, int32_t beta, + const int32_t * __restrict__ sa, const int32_t * __restrict__ sb, + int32_t * __restrict__ C, uint32_t ldc) { + + uint32_t n_left = N; + const int32_t *b_head = sb; + int32_t *c_head = C; + uint32_t acc_mnk = CPUID_DETECT_MNK; + uint8_t cpuid = 0, cputype = 0; + + for (; n_left > 11; n_left -= 12) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const int32_t *a_head = sa; + int32_t *c_ptr = c_head; + uint32_t m_left = M; + if (cputype == 55) { + for (; m_left > 7; m_left -= 8) { + KERNEL_M8N12_TEMPLATE(A55) + SAVE_M8N12 + a_head += 8 * K; + c_ptr += 8; + } + } else { + for (; m_left > 7; m_left -= 8) { + KERNEL_M8N12_TEMPLATE(A76) + SAVE_M8N12 + a_head += 8 * K; + c_ptr += 8; + } + } + MICRO_COMPUTE_LM(4, 12, int32_t, int32_t, int32_t) + b_head += K * 12; + c_head += ldc * 12; + acc_mnk += 12 * K * M; + } + + ASSEMBLE_DUALPACK_COMPUTE_LM(8, int32_t, int32_t, int32_t, 8) +} + +void s8s32dotgemm_kernel_ln_m12n8(uint32_t M, uint32_t N, uint32_t K, int32_t beta, + const int32_t * __restrict__ sa, const int32_t * __restrict__ sb, + int32_t * __restrict__ C, uint32_t ldc) { + + uint32_t m_left = M; + const int32_t *a_head = sa; + int32_t *c_head = C; + uint32_t acc_mnk = CPUID_DETECT_MNK; + uint8_t cpuid = 0, cputype = 0; + for (; m_left > 11; m_left -= 12) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const int32_t *b_head = sb; + int32_t *c_ptr = c_head; + uint32_t n_left = N; + if (cputype == 55) { + for (; n_left > 7; n_left -= 8) { + KERNEL_M12N8_TEMPLATE(A55) + SAVE_M12N8 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } else { + for (; n_left > 7; n_left -= 8) { + KERNEL_M12N8_TEMPLATE(A76) + SAVE_M12N8 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } + MICRO_COMPUTE_LN(12, 4, int32_t, int32_t, int32_t) + a_head += K * 12; + c_head += 12; + acc_mnk += 12 * N * K; + } + + ASSEMBLE_DUALPACK_COMPUTE_LN(8, int32_t, int32_t, int32_t, 8) +} + diff --git a/src/neon_armv8a/extension/S8S32DotGemmSkinnyDot.c b/src/neon_armv8a/extension/S8S32DotGemmSkinnyDot.c new file mode 100644 index 0000000..c2fab56 --- /dev/null +++ b/src/neon_armv8a/extension/S8S32DotGemmSkinnyDot.c @@ -0,0 +1,37 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifdef GEMM_UNSIGNED_INT +#undef GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonSkinnyDot.h" +#include "arm_neon/NeonI8I32DotGemmSkinnyDot.h" + +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 1, 29, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 2, 29, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 3, 29, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 4, 29, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 5, 29, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 6, 29, 7, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 7, 29, 3, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 8, 29, 3, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 9, 29, 3, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 10, 29, 3, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 11, 29, 3, 131072, int8_t, int8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(s8s32dotgemm, 12, 29, 3, 131072, int8_t, int8_t) diff --git a/src/neon_armv8a/extension/U8U32DotGemmCopy.c b/src/neon_armv8a/extension/U8U32DotGemmCopy.c new file mode 100644 index 0000000..16f42ca --- /dev/null +++ b/src/neon_armv8a/extension/U8U32DotGemmCopy.c @@ -0,0 +1,30 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "common/CommonCopy.h" +#include "neon_armv8a/I8I32DotGemmCopy.h" + +GENERIC_NCOPY_FUNC(u8u32dotgemm, uint8_t, uint32_t, 8) +GENERIC_NCOPY_FUNC(u8u32dotgemm, uint8_t, uint32_t, 12) + +TCOPY_FUNC_TEMPLATE(u8u32dotgemm_uint8_t_uint32_t_tcopy_unroll, 8) +TCOPY_FUNC_TEMPLATE(u8u32dotgemm_uint8_t_uint32_t_tcopy_unroll, 12) + diff --git a/src/neon_armv8a/extension/U8U32DotGemmKernel.c b/src/neon_armv8a/extension/U8U32DotGemmKernel.c new file mode 100644 index 0000000..ca2b01c --- /dev/null +++ b/src/neon_armv8a/extension/U8U32DotGemmKernel.c @@ -0,0 +1,116 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#ifndef _GNU_SOURCE +#define _GNU_SOURCE +#endif +#include "common/CommonKernel.h" +#include "arm_neon/ARMCpuType.h" +#include +#include "neon_armv8a/I8I32DotGemmKernel.h" + +#define CPUID_DETECT_MNK 1000000 + +void u8u32dotgemm_kernel_lm_m8n12(uint32_t M, uint32_t N, uint32_t K, uint32_t beta, + const uint32_t * __restrict__ sa, const uint32_t * __restrict__ sb, + uint32_t * __restrict__ C, uint32_t ldc) { + + uint32_t n_left = N; + const uint32_t *b_head = sb; + uint32_t *c_head = C; + uint32_t acc_mnk = CPUID_DETECT_MNK; + uint8_t cpuid = 0, cputype = 0; + + for (; n_left > 11; n_left -= 12) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const uint32_t *a_head = sa; + uint32_t *c_ptr = c_head; + uint32_t m_left = M; + if (cputype == 55) { + for (; m_left > 7; m_left -= 8) { + KERNEL_M8N12_TEMPLATE(A55) + SAVE_M8N12 + a_head += 8 * K; + c_ptr += 8; + } + } else { + for (; m_left > 7; m_left -= 8) { + KERNEL_M8N12_TEMPLATE(A76) + SAVE_M8N12 + a_head += 8 * K; + c_ptr += 8; + } + } + MICRO_COMPUTE_LM(4, 12, uint32_t, uint32_t, uint32_t) + b_head += K * 12; + c_head += ldc * 12; + acc_mnk += 12 * K * M; + } + + ASSEMBLE_DUALPACK_COMPUTE_LM(8, uint32_t, uint32_t, uint32_t, 8) +} + +void u8u32dotgemm_kernel_ln_m12n8(uint32_t M, uint32_t N, uint32_t K, uint32_t beta, + const uint32_t * __restrict__ sa, const uint32_t * __restrict__ sb, + uint32_t * __restrict__ C, uint32_t ldc) { + + uint32_t m_left = M; + const uint32_t *a_head = sa; + uint32_t *c_head = C; + uint32_t acc_mnk = CPUID_DETECT_MNK; + uint8_t cpuid = 0, cputype = 0; + for (; m_left > 11; m_left -= 12) { + if (acc_mnk >= CPUID_DETECT_MNK) { + cpuid = sched_getcpu(); + cputype = blas_arm_get_cpu_type(cpuid); + acc_mnk = 0; + } + const uint32_t *b_head = sb; + uint32_t *c_ptr = c_head; + uint32_t n_left = N; + if (cputype == 55) { + for (; n_left > 7; n_left -= 8) { + KERNEL_M12N8_TEMPLATE(A55) + SAVE_M12N8 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } else { + for (; n_left > 7; n_left -= 8) { + KERNEL_M12N8_TEMPLATE(A76) + SAVE_M12N8 + b_head += 8 * K; + c_ptr += 8 * ldc; + } + } + MICRO_COMPUTE_LN(12, 4, uint32_t, uint32_t, uint32_t) + a_head += K * 12; + c_head += 12; + acc_mnk += 12 * N * K; + } + + ASSEMBLE_DUALPACK_COMPUTE_LN(8, uint32_t, uint32_t, uint32_t, 8) +} + diff --git a/src/neon_armv8a/extension/U8U32DotGemmSkinnyDot.c b/src/neon_armv8a/extension/U8U32DotGemmSkinnyDot.c new file mode 100644 index 0000000..c11849e --- /dev/null +++ b/src/neon_armv8a/extension/U8U32DotGemmSkinnyDot.c @@ -0,0 +1,37 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#ifndef GEMM_UNSIGNED_INT +#define GEMM_UNSIGNED_INT +#endif + +#include "arm_neon/ARMCompareAndSwap.h" +#include "common/CommonSkinnyDot.h" +#include "arm_neon/NeonI8I32DotGemmSkinnyDot.h" + +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 1, 29, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 2, 29, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 3, 29, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 4, 29, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 5, 29, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 6, 29, 7, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 7, 29, 3, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 8, 29, 3, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 9, 29, 3, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 10, 29, 3, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 11, 29, 3, 131072, uint8_t, uint8_t) +GEMM_SKINNY_DOT_PARALLEL_FUNC(u8u32dotgemm, 12, 29, 3, 131072, uint8_t, uint8_t) diff --git a/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA35.c b/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA35.c new file mode 100644 index 0000000..e8295e6 --- /dev/null +++ b/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA35.c @@ -0,0 +1,72 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA35.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotDriver.h" + +DRIVER_PURE_PACK(a35, 4, 10240, 3, 4) +DRIVER_PURE_PACK(a35, 5, 8192, 3, 4) +DRIVER_PURE_PACK(a35, 6, 8192, 3, 4) +DRIVER_PURE_PACK(a35, 7, 6144, 3, 3) +DRIVER_PURE_PACK(a35, 8, 6144, 3, 3) +DRIVER_PURE_PACK(a35, 9, 5120, 4, 4) +DRIVER_PURE_PACK(a35, 10, 5120, 0, 4) +DRIVER_PURE_PACK(a35, 11, 4096, 4, 4) +DRIVER_PURE_PACK(a35, 12, 4096, 0, 4) +DRIVER_PURE_PACK(a35, 13, 3584, 4, 3) +DRIVER_PURE_PACK(a35, 14, 3584, 0, 3) +DRIVER_PURE_PACK(a35, 15, 3072, 4, 3) +DRIVER_PURE_PACK(a35, 16, 3072, 0, 3) +DRIVER_PURE_PACK(a35, 17, 2560, 4, 3) +DRIVER_PURE_PACK(a35, 18, 2560, 4, 3) + +DRIVER_MIX2_PACK(a35, 19, 2560, 0, 4, 10, 9, 4) +DRIVER_MIX2_PACK(a35, 20, 2560, 4, 4, 11, 9, 4) +DRIVER_MIX2_PACK(a35, 21, 2048, 0, 4, 12, 9, 4) +DRIVER_MIX2_PACK(a35, 22, 2048, 0, 3, 14, 8, 3) +DRIVER_MIX2_PACK(a35, 23, 2048, 4, 3, 15, 8, 3) +DRIVER_MIX2_PACK(a35, 24, 2048, 0, 3, 16, 8, 3) +DRIVER_MIX2_PACK(a35, 25, 2048, 4, 3, 17, 8, 3) +DRIVER_MIX2_PACK(a35, 26, 1536, 4, 3, 18, 8, 3) +DRIVER_MIX2_PACK(a35, 27, 1536, 0, 4, 14, 13, 3) +DRIVER_MIX2_PACK(a35, 28, 1536, 4, 4, 15, 13, 3) +DRIVER_MIX2_PACK(a35, 29, 1536, 0, 4, 16, 13, 3) +DRIVER_MIX2_PACK(a35, 30, 1536, 4, 4, 17, 13, 3) +DRIVER_MIX2_PACK(a35, 31, 1536, 4, 4, 18, 13, 3) +DRIVER_MIX2_PACK(a35, 32, 1536, 4, 0, 18, 14, 3) +DRIVER_MIX2_PACK(a35, 33, 1536, 4, 4, 18, 15, 3) +DRIVER_MIX2_PACK(a35, 34, 1280, 4, 0, 18, 16, 3) +DRIVER_MIX2_PACK(a35, 35, 1280, 4, 4, 18, 17, 3) +DRIVER_MIX2_PACK(a35, 36, 1280, 4, 4, 18, 18, 3) + +DRIVER_MIX3_PACK(a35, 37, 1280, 0, 4, 3, 16, 13, 8, 3) +DRIVER_MIX3_PACK(a35, 38, 1280, 4, 4, 3, 17, 13, 8, 3) +DRIVER_MIX3_PACK(a35, 39, 1280, 4, 4, 3, 18, 13, 8, 3) +DRIVER_MIX3_PACK(a35, 40, 1280, 4, 0, 3, 18, 14, 8, 3) +DRIVER_MIX3_PACK(a35, 41, 1024, 4, 4, 3, 18, 15, 8, 3) +DRIVER_MIX3_PACK(a35, 42, 1024, 4, 0, 3, 18, 16, 8, 3) +DRIVER_MIX3_PACK(a35, 43, 1024, 4, 4, 3, 18, 17, 8, 3) +DRIVER_MIX3_PACK(a35, 44, 1024, 4, 4, 3, 18, 18, 8, 3) +DRIVER_MIX3_PACK(a35, 45, 1024, 4, 0, 4, 18, 14, 13, 3) +DRIVER_MIX3_PACK(a35, 46, 1024, 4, 0, 0, 18, 14, 14, 3) +DRIVER_MIX3_PACK(a35, 47, 1024, 4, 4, 0, 18, 15, 14, 3) +DRIVER_MIX3_PACK(a35, 48, 1024, 4, 4, 4, 18, 15, 15, 3) +DRIVER_MIX3_PACK(a35, 49, 1024, 4, 0, 4, 18, 16, 15, 3) +DRIVER_MIX3_PACK(a35, 50, 1024, 4, 0, 0, 18, 16, 16, 3) + + diff --git a/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA53.c b/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA53.c new file mode 100644 index 0000000..8587cd1 --- /dev/null +++ b/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA53.c @@ -0,0 +1,71 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA53.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotDriver.h" + +DRIVER_PURE_PACK(a53, 4, 10240, 1, 4) +DRIVER_PURE_PACK(a53, 5, 8192, 1, 4) +DRIVER_PURE_PACK(a53, 6, 8192, 1, 4) +DRIVER_PURE_PACK(a53, 7, 6144, 1, 4) +DRIVER_PURE_PACK(a53, 8, 6144, 1, 4) +DRIVER_PURE_PACK(a53, 9, 5120, 1, 4) +DRIVER_PURE_PACK(a53, 10, 5120, 1, 4) +DRIVER_PURE_PACK(a53, 11, 4096, 1, 4) +DRIVER_PURE_PACK(a53, 12, 4096, 1, 4) +DRIVER_PURE_PACK(a53, 13, 3584, 1, 4) +DRIVER_PURE_PACK(a53, 14, 3584, 1, 4) +DRIVER_PURE_PACK(a53, 15, 3072, 2, 4) +DRIVER_PURE_PACK(a53, 16, 3072, 2, 4) +DRIVER_PURE_PACK(a53, 17, 2048, 2, 4) +DRIVER_PURE_PACK(a53, 18, 2048, 2, 4) +DRIVER_PURE_PACK(a53, 19, 2048, 2, 4) +DRIVER_PURE_PACK(a53, 20, 2048, 2, 4) +DRIVER_PURE_PACK(a53, 21, 2048, 2, 4) +DRIVER_PURE_PACK(a53, 22, 2048, 2, 4) +DRIVER_PURE_PACK(a53, 23, 2048, 0, 4) +DRIVER_PURE_PACK(a53, 24, 2048, 0, 4) +DRIVER_PURE_PACK(a53, 25, 1536, 0, 4) +DRIVER_PURE_PACK(a53, 26, 1536, 0, 4) + +DRIVER_MIX2_PACK(a53, 27, 1536, 2, 1, 15, 12, 4) +DRIVER_MIX2_PACK(a53, 28, 1536, 2, 1, 16, 12, 4) +DRIVER_MIX2_PACK(a53, 29, 1536, 2, 1, 17, 12, 4) +DRIVER_MIX2_PACK(a53, 30, 1536, 2, 1, 18, 12, 4) +DRIVER_MIX2_PACK(a53, 31, 1536, 2, 1, 19, 12, 4) +DRIVER_MIX2_PACK(a53, 32, 1536, 2, 1, 20, 12, 4) +DRIVER_MIX2_PACK(a53, 33, 1536, 2, 1, 21, 12, 4) +DRIVER_MIX2_PACK(a53, 34, 1280, 2, 1, 22, 12, 4) +DRIVER_MIX2_PACK(a53, 35, 1280, 0, 1, 23, 12, 4) +DRIVER_MIX2_PACK(a53, 36, 1280, 0, 1, 24, 12, 4) +DRIVER_MIX2_PACK(a53, 37, 1280, 0, 1, 25, 12, 4) +DRIVER_MIX2_PACK(a53, 38, 1280, 0, 1, 26, 12, 4) +DRIVER_MIX2_PACK(a53, 39, 1280, 0, 2, 24, 15, 4) +DRIVER_MIX2_PACK(a53, 40, 1280, 0, 2, 24, 16, 4) +DRIVER_MIX2_PACK(a53, 41, 1024, 0, 2, 24, 17, 4) +DRIVER_MIX2_PACK(a53, 42, 1024, 0, 2, 24, 18, 4) +DRIVER_MIX2_PACK(a53, 43, 1024, 0, 2, 24, 19, 4) +DRIVER_MIX2_PACK(a53, 44, 1024, 0, 2, 24, 20, 4) +DRIVER_MIX2_PACK(a53, 45, 1024, 0, 2, 24, 21, 4) +DRIVER_MIX2_PACK(a53, 46, 1024, 0, 2, 24, 22, 4) +DRIVER_MIX2_PACK(a53, 47, 1024, 0, 0, 24, 23, 4) +DRIVER_MIX2_PACK(a53, 48, 1024, 0, 0, 24, 24, 4) +DRIVER_MIX2_PACK(a53, 49, 1024, 0, 0, 25, 24, 4) +DRIVER_MIX2_PACK(a53, 50, 1024, 0, 0, 26, 24, 4) + + diff --git a/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA7x.c b/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA7x.c new file mode 100644 index 0000000..b1299ca --- /dev/null +++ b/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotA7x.c @@ -0,0 +1,71 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotKernelA7x.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.h" +#include "neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotDriver.h" + +DRIVER_PURE_PACK(a7x, 4, 10240, 1, 4) +DRIVER_PURE_PACK(a7x, 5, 8192, 1, 4) +DRIVER_PURE_PACK(a7x, 6, 8192, 1, 4) +DRIVER_PURE_PACK(a7x, 7, 6144, 1, 4) +DRIVER_PURE_PACK(a7x, 8, 6144, 1, 4) +DRIVER_PURE_PACK(a7x, 9, 5120, 1, 4) +DRIVER_PURE_PACK(a7x, 10, 5120, 1, 4) +DRIVER_PURE_PACK(a7x, 11, 4096, 1, 4) +DRIVER_PURE_PACK(a7x, 12, 4096, 1, 4) +DRIVER_PURE_PACK(a7x, 13, 3584, 1, 3) +DRIVER_PURE_PACK(a7x, 14, 3584, 1, 3) +DRIVER_PURE_PACK(a7x, 15, 3072, 1, 3) +DRIVER_PURE_PACK(a7x, 16, 3072, 1, 3) +DRIVER_PURE_PACK(a7x, 17, 2560, 1, 3) +DRIVER_PURE_PACK(a7x, 18, 2560, 1, 3) +DRIVER_PURE_PACK(a7x, 19, 2560, 1, 3) +DRIVER_PURE_PACK(a7x, 20, 2560, 1, 3) +DRIVER_PURE_PACK(a7x, 21, 2048, 1, 3) +DRIVER_PURE_PACK(a7x, 22, 2048, 1, 3) +DRIVER_PURE_PACK(a7x, 23, 2048, 1, 3) +DRIVER_PURE_PACK(a7x, 24, 2048, 1, 3) +DRIVER_PURE_PACK(a7x, 25, 2048, 1, 3) +DRIVER_PURE_PACK(a7x, 26, 1536, 1, 3) + +DRIVER_MIX2_PACK(a7x, 27, 1536, 1, 1, 14, 13, 3) +DRIVER_MIX2_PACK(a7x, 28, 1536, 1, 1, 15, 13, 3) +DRIVER_MIX2_PACK(a7x, 29, 1536, 1, 1, 16, 13, 3) +DRIVER_MIX2_PACK(a7x, 30, 1536, 1, 1, 17, 13, 3) +DRIVER_MIX2_PACK(a7x, 31, 1536, 1, 1, 18, 13, 3) +DRIVER_MIX2_PACK(a7x, 32, 1536, 1, 1, 19, 13, 3) +DRIVER_MIX2_PACK(a7x, 33, 1536, 1, 1, 20, 13, 3) +DRIVER_MIX2_PACK(a7x, 34, 1280, 1, 1, 21, 13, 3) +DRIVER_MIX2_PACK(a7x, 35, 1280, 1, 1, 22, 13, 3) +DRIVER_MIX2_PACK(a7x, 36, 1280, 1, 1, 23, 13, 3) +DRIVER_MIX2_PACK(a7x, 37, 1280, 1, 1, 24, 13, 3) +DRIVER_MIX2_PACK(a7x, 38, 1280, 1, 1, 25, 13, 3) +DRIVER_MIX2_PACK(a7x, 39, 1280, 1, 1, 26, 13, 3) +DRIVER_MIX2_PACK(a7x, 40, 1280, 1, 1, 26, 14, 3) +DRIVER_MIX2_PACK(a7x, 41, 1024, 1, 1, 26, 15, 3) +DRIVER_MIX2_PACK(a7x, 42, 1024, 1, 1, 26, 16, 3) +DRIVER_MIX2_PACK(a7x, 43, 1024, 1, 1, 26, 17, 3) +DRIVER_MIX2_PACK(a7x, 44, 1024, 1, 1, 26, 18, 3) +DRIVER_MIX2_PACK(a7x, 45, 1024, 1, 1, 26, 19, 3) +DRIVER_MIX2_PACK(a7x, 46, 1024, 1, 1, 26, 20, 3) +DRIVER_MIX2_PACK(a7x, 47, 1024, 1, 1, 26, 21, 3) +DRIVER_MIX2_PACK(a7x, 48, 1024, 1, 1, 26, 22, 3) +DRIVER_MIX2_PACK(a7x, 49, 1024, 1, 1, 26, 23, 3) +DRIVER_MIX2_PACK(a7x, 50, 1024, 1, 1, 26, 24, 3) + + diff --git a/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.c b/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.c new file mode 100644 index 0000000..0d2769b --- /dev/null +++ b/src/neon_armv8a/sgemm_skinny_dot_kernel/SgemmSkinnyDotCopy.c @@ -0,0 +1,547 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include +#include + +static inline void pref_b(const float *src) { + __asm__("prfm pldl1keep,[%0,#64]\n\t"::"r"(src):); +} + +static inline void pack_rm_from_cm_4col(float * __restrict__ b_wt, + const float * __restrict__ b_rd, uint32_t K, uint32_t LDB, + uint32_t N, uint32_t ninc1_2, + uint32_t ninc1_4, uint32_t ninc2_4, uint32_t ninc3_4) { + + const float *b_l1 = b_rd; + const float *b_l2 = b_rd + LDB; + const float *b_l3 = b_rd + LDB * 2; + const float *b_l4 = b_rd + LDB * 3; + float *b_w1 = b_wt; + + uint32_t k_left = K; + + for (; k_left > 3; k_left -= 4) { + float32x4x4_t tmp; + tmp.val[0] = vld1q_f32(b_l1); b_l1 += 4; pref_b(b_l1); + tmp.val[1] = vld1q_f32(b_l2); b_l2 += 4; pref_b(b_l2); + tmp.val[2] = vld1q_f32(b_l3); b_l3 += 4; pref_b(b_l3); + tmp.val[3] = vld1q_f32(b_l4); b_l4 += 4; pref_b(b_l4); + vst4q_lane_f32(b_w1, tmp, 0); + vst4q_lane_f32(b_w1 + ninc1_4, tmp, 1); + vst4q_lane_f32(b_w1 + ninc2_4, tmp, 2); + vst4q_lane_f32(b_w1 + ninc3_4, tmp, 3); + b_w1 += N * 4; + } + if (k_left > 1) { + float32x2x4_t tmp; + tmp.val[0] = vld1_f32(b_l1); b_l1 += 2; + tmp.val[1] = vld1_f32(b_l2); b_l2 += 2; + tmp.val[2] = vld1_f32(b_l3); b_l3 += 2; + tmp.val[3] = vld1_f32(b_l4); b_l4 += 2; + vst4_lane_f32(b_w1, tmp, 0); + vst4_lane_f32(b_w1 + ninc1_2, tmp, 1); + b_w1 += N * 2; + k_left -= 2; + } + if (k_left > 0) { + b_w1[0] = *b_l1; + b_w1[1] = *b_l2; + b_w1[2] = *b_l3; + b_w1[3] = *b_l4; + } +} + +void pack_0_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + const float *b_rd = B; + uint32_t n_left = N; + for (; n_left > 3; n_left -= 4) { + pack_rm_from_cm_4col(b_scr + N - n_left, b_rd, K, LDB, N, + N, N, N * 2, N * 3); + b_rd += 4 * LDB; + } + float *b_wt = b_scr + N - n_left; + if (n_left == 3) { + const float *b_rd2 = b_rd + LDB; + const float *b_rd3 = b_rd + LDB * 2; + uint32_t k_left = K; + for (; k_left > 3; k_left -= 4) { + float32x4x3_t tmp; + tmp.val[0] = vld1q_f32(b_rd); b_rd += 4; pref_b(b_rd); + tmp.val[1] = vld1q_f32(b_rd2); b_rd2 += 4; pref_b(b_rd2); + tmp.val[2] = vld1q_f32(b_rd3); b_rd3 += 4; pref_b(b_rd3); + vst3q_lane_f32(b_wt, tmp, 0); b_wt += N; + vst3q_lane_f32(b_wt, tmp, 1); b_wt += N; + vst3q_lane_f32(b_wt, tmp, 2); b_wt += N; + vst3q_lane_f32(b_wt, tmp, 3); b_wt += N; + } + if (k_left > 1) { + float32x2x3_t tmp; + tmp.val[0] = vld1_f32(b_rd); b_rd += 2; + tmp.val[1] = vld1_f32(b_rd2); b_rd2 += 2; + tmp.val[2] = vld1_f32(b_rd3); b_rd3 += 2; + vst3_lane_f32(b_wt, tmp, 0); b_wt += N; + vst3_lane_f32(b_wt, tmp, 1); b_wt += N; + k_left -= 2; + } + if (k_left > 0) { + b_wt[0] = *b_rd; b_wt[1] = *b_rd2; b_wt[2] = *b_rd3; + } + } else if (n_left == 2) { + const float *b_rd2 = b_rd + LDB; + uint32_t k_left = K; + for (; k_left > 3; k_left -= 4) { + float32x4x2_t tmp; + tmp.val[0] = vld1q_f32(b_rd); b_rd += 4; pref_b(b_rd); + tmp.val[1] = vld1q_f32(b_rd2); b_rd2 += 4; pref_b(b_rd2); + vst2q_lane_f32(b_wt, tmp, 0); b_wt += N; + vst2q_lane_f32(b_wt, tmp, 1); b_wt += N; + vst2q_lane_f32(b_wt, tmp, 2); b_wt += N; + vst2q_lane_f32(b_wt, tmp, 3); b_wt += N; + } + if (k_left > 1) { + float32x2x2_t tmp; + tmp.val[0] = vld1_f32(b_rd); b_rd += 2; + tmp.val[1] = vld1_f32(b_rd2); b_rd2 += 2; + vst2_lane_f32(b_wt, tmp, 0); b_wt += N; + vst2_lane_f32(b_wt, tmp, 1); b_wt += N; + k_left -= 2; + } + if (k_left > 0) { + b_wt[0] = *b_rd; b_wt[1] = *b_rd2; + } + } else if (n_left == 1) { + for (uint32_t k_pos = 0; k_pos < K; ++k_pos) { + *b_wt = b_rd[k_pos]; b_wt += N; + } + } +} + +void pack_1_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + const float *b_rd = B; + const uint32_t n_4z = N & 0xFFFFFFFC; + uint32_t n_left = N; + for (; n_left > 3; n_left -= 4) { + pack_rm_from_cm_4col(b_scr + N - n_left, b_rd, K, LDB, N, + N, n_4z, n_4z * 2, n_4z * 3); + b_rd += 4 * LDB; + } + float *b_wt = b_scr + (N - n_left) * 4; + if (n_left == 3) { + const float *b_rd2 = b_rd + LDB; + const float *b_rd3 = b_rd + LDB * 2; + uint32_t k_left = K; + for (; k_left > 3; k_left -= 4) { + float32x4_t tmp1 = vld1q_f32(b_rd); b_rd += 4; pref_b(b_rd); + float32x4_t tmp2 = vld1q_f32(b_rd2); b_rd2 += 4; pref_b(b_rd2); + float32x4_t tmp3 = vld1q_f32(b_rd3); b_rd3 += 4; pref_b(b_rd3); + vst1q_f32(b_wt, tmp1); vst1q_f32(b_wt + 4, tmp2); + vst1q_f32(b_wt + 8, tmp3); b_wt += N * 4; + } + b_wt -= (N - n_left) * 3; + for (; k_left > 0; k_left--) { + b_wt[0] = *b_rd++; b_wt[1] = *b_rd2++; b_wt[2] = *b_rd3++; + b_wt += N; + } + } else if (n_left == 2) { + const float *b_rd2 = b_rd + LDB; + uint32_t k_left = K; + for (; k_left > 3; k_left -= 4) { + float32x4_t tmp1 = vld1q_f32(b_rd); b_rd += 4; pref_b(b_rd); + float32x4_t tmp2 = vld1q_f32(b_rd2); b_rd2 += 4; pref_b(b_rd2); + vst1q_f32(b_wt, tmp1); vst1q_f32(b_wt + 4, tmp2); + b_wt += N * 4; + } + b_wt -= (N - n_left) * 3; + for (; k_left > 0; k_left--) { + b_wt[0] = *b_rd++; b_wt[1] = *b_rd2++; + b_wt += N; + } + } else if (n_left == 1) { + uint32_t k_left = K; + for (; k_left > 3; k_left -= 4) { + float32x4_t tmp1 = vld1q_f32(b_rd); b_rd += 4; pref_b(b_rd); + vst1q_f32(b_wt, tmp1); b_wt += N * 4; + } + b_wt -= (N - n_left) * 3; + for (; k_left > 0; k_left--) { + b_wt[0] = *b_rd++; + b_wt += N; + } + } +} + +void pack_2_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + const float *b_rd = B; + const uint32_t n_2z = N & 0xFFFFFFFC; + uint32_t n_left = N; + for (; n_left > 3; n_left -= 4) { + pack_rm_from_cm_4col(b_scr + N - n_left, b_rd, K, LDB, N, + n_2z, n_2z, N * 2, N * 2 + n_2z); + b_rd += 4 * LDB; + } + float *b_wt = b_scr + (N - n_left) * 2; + if (n_left == 3) { + const float *b_rd2 = b_rd + LDB; + const float *b_rd3 = b_rd + LDB * 2; + uint32_t k_left = K; + for (; k_left > 1; k_left -= 2) { + float32x2_t tmp1 = vld1_f32(b_rd); b_rd += 2; + float32x2_t tmp2 = vld1_f32(b_rd2); b_rd2 += 2; + float32x2_t tmp3 = vld1_f32(b_rd3); b_rd3 += 2; + vst1_f32(b_wt, tmp1); vst1_f32(b_wt + 2, tmp2); + vst1_f32(b_wt + 4, tmp3); b_wt += N * 2; + } + b_wt -= N - n_left; + if (k_left > 0) { + b_wt[0] = *b_rd; b_wt[1] = *b_rd2; b_wt[2] = *b_rd3; + } + } else if (n_left == 2) { + const float *b_rd2 = b_rd + LDB; + uint32_t k_left = K; + for (; k_left > 1; k_left -= 2) { + float32x2_t tmp1 = vld1_f32(b_rd); b_rd += 2; + float32x2_t tmp2 = vld1_f32(b_rd2); b_rd2 += 2; + vst1_f32(b_wt, tmp1); vst1_f32(b_wt + 2, tmp2); + b_wt += N * 2; + } + b_wt -= N - n_left; + if (k_left > 0) { + b_wt[0] = *b_rd; b_wt[1] = *b_rd2; + } + } else if (n_left == 1) { + uint32_t k_left = K; + for (; k_left > 1; k_left -= 2) { + float32x2_t tmp1 = vld1_f32(b_rd); b_rd += 2; + vst1_f32(b_wt, tmp1); + b_wt += N * 2; + } + b_wt -= N - n_left; + if (k_left > 0) { + b_wt[0] = *b_rd; + } + } +} + +void pack_3_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + const float *b_rd = B; + uint32_t n_left = N; + for (; n_left > 3; n_left -= 4) { + const float *b_rd1 = b_rd; + const float *b_rd2 = b_rd + LDB; + const float *b_rd3 = b_rd + LDB * 2; + const float *b_rd4 = b_rd2 + LDB * 2; + float *b_wt = b_scr + (N - n_left) * 2; + b_rd += LDB * 4; + uint32_t k_left = K; + for (; k_left > 3; k_left -= 4) { + float32x4_t t1 = vld1q_f32(b_rd1); b_rd1 += 4; pref_b(b_rd1); + float32x4_t t2 = vld1q_f32(b_rd2); b_rd2 += 4; pref_b(b_rd2); + float32x4_t t3 = vld1q_f32(b_rd3); b_rd3 += 4; pref_b(b_rd3); + float32x4_t t4 = vld1q_f32(b_rd4); b_rd4 += 4; pref_b(b_rd4); + vst1_f32(b_wt, vget_low_f32(t1)); + vst1_f32(b_wt + 2, vget_low_f32(t2)); + vst1_f32(b_wt + 4, vget_low_f32(t3)); + vst1_f32(b_wt + 6, vget_low_f32(t4)); b_wt += 2 * N; + vst1_f32(b_wt, vget_high_f32(t1)); + vst1_f32(b_wt + 2, vget_high_f32(t2)); + vst1_f32(b_wt + 4, vget_high_f32(t3)); + vst1_f32(b_wt + 6, vget_high_f32(t4)); b_wt += 2 * N; + } + if (k_left > 1) { + float32x2_t t1 = vld1_f32(b_rd1); b_rd1 += 2; + float32x2_t t2 = vld1_f32(b_rd2); b_rd2 += 2; + float32x2_t t3 = vld1_f32(b_rd3); b_rd3 += 2; + float32x2_t t4 = vld1_f32(b_rd4); b_rd4 += 2; + vst1_f32(b_wt, t1); vst1_f32(b_wt + 2, t2); + vst1_f32(b_wt + 4, t3); vst1_f32(b_wt + 6, t4); b_wt += 2 * N; + k_left -= 2; + } + b_wt -= N - n_left; + if (k_left > 0) { + b_wt[0] = *b_rd1; b_wt[1] = *b_rd2; + b_wt[2] = *b_rd3; b_wt[3] = *b_rd4; + } + } + if (n_left > 1) { + const float *b_rd1 = b_rd; + const float *b_rd2 = b_rd + LDB; + float *b_wt = b_scr + (N - n_left) * 2; + b_rd += LDB * 2; + uint32_t k_left = K; + for (; k_left > 1; k_left -= 2) { + float32x2_t t1 = vld1_f32(b_rd1); b_rd1 += 2; + float32x2_t t2 = vld1_f32(b_rd2); b_rd2 += 2; + vst1_f32(b_wt, t1); vst1_f32(b_wt + 2, t2); b_wt += 2 * N; + } + b_wt -= N - n_left; + if (k_left > 0) { + b_wt[0] = *b_rd1; b_wt[1] = *b_rd2; + } + n_left -= 2; + } + if (n_left > 0) { + float *b_wt = b_scr + (N - n_left) * 2; + uint32_t k_left = K; + for (; k_left > 1; k_left -= 2) { + float32x2_t t1 = vld1_f32(b_rd); b_rd += 2; + vst1_f32(b_wt, t1); b_wt += 2 * N; + } + b_wt -= N - n_left; + if (k_left > 0) { + b_wt[0] = *b_rd; + } + } +} + +void pack_4_from_cm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + const float *b_rd = B; + const uint32_t n_2z = (N << 1) - (N & 0xFFFFFFFE); + uint32_t n_left = N; + for (; n_left > 3; n_left -= 4) { + const float *b_rd1 = b_rd; + const float *b_rd2 = b_rd + LDB; + const float *b_rd3 = b_rd + LDB * 2; + const float *b_rd4 = b_rd2 + LDB * 2; + float *b_wt = b_scr + N - n_left; + b_rd += LDB * 4; + uint32_t k_left = K; + for (; k_left > 3; k_left -= 4) { + float32x4x4_t tmp; + tmp.val[0] = vld1q_f32(b_rd1); b_rd1 += 4; pref_b(b_rd1); + tmp.val[1] = vld1q_f32(b_rd2); b_rd2 += 4; pref_b(b_rd2); + tmp.val[2] = vld1q_f32(b_rd3); b_rd3 += 4; pref_b(b_rd3); + tmp.val[3] = vld1q_f32(b_rd4); b_rd4 += 4; pref_b(b_rd4); + tmp.val[1] = vrev64q_f32(tmp.val[1]); + tmp.val[3] = vrev64q_f32(tmp.val[3]); + vst4q_lane_f32(b_wt, tmp, 0); + vst4q_lane_f32(b_wt + n_2z, tmp, 1); b_wt += 2 * N; + vst4q_lane_f32(b_wt, tmp, 2); + vst4q_lane_f32(b_wt + n_2z, tmp, 3); b_wt += 2 * N; + } + if (k_left > 1) { + float32x2_t t1 = vld1_f32(b_rd1); b_rd1 += 2; + float32x2_t t2 = vld1_f32(b_rd2); b_rd2 += 2; + float32x2_t t3 = vld1_f32(b_rd3); b_rd3 += 2; + float32x2_t t4 = vld1_f32(b_rd4); b_rd4 += 2; + t2 = vrev64_f32(t2); t4 = vrev64_f32(t4); + float32x2_t d1 = vtrn1_f32(t1, t2); + float32x2_t d2 = vtrn1_f32(t3, t4); + float32x2_t d3 = vtrn2_f32(t1, t2); + float32x2_t d4 = vtrn2_f32(t3, t4); + vst1_f32(b_wt, d1); vst1_f32(b_wt + 2, d2); + vst1_f32(b_wt + n_2z, d3); vst1_f32(b_wt + n_2z + 2, d4); + b_wt += 2 * N; k_left -= 2; + } + if (k_left > 0) { + b_wt[0] = *b_rd1; b_wt[1] = *b_rd2; + b_wt[2] = *b_rd3; b_wt[3] = *b_rd4; + } + } + if (n_left > 1) { + const float *b_rd1 = b_rd; + const float *b_rd2 = b_rd + LDB; + float *b_wt = b_scr + N - n_left; + b_rd += LDB * 2; + uint32_t k_left = K; + for (; k_left > 1; k_left -= 2) { + float32x2_t t1 = vld1_f32(b_rd1); b_rd1 += 2; + float32x2_t t2 = vld1_f32(b_rd2); b_rd2 += 2; + t2 = vrev64_f32(t2); + float32x2_t d1 = vtrn1_f32(t1, t2); + float32x2_t d2 = vtrn2_f32(t1, t2); + vst1_f32(b_wt, d1); vst1_f32(b_wt + n_2z, d2); + b_wt += 2 * N; + } + if (k_left > 0) { + b_wt[0] = *b_rd1; b_wt[1] = *b_rd2; + } + n_left -= 2; + } + if (n_left > 0) { + float *b_wt = b_scr + N - n_left; + uint32_t k_left = K; + for (; k_left > 1; k_left -= 2) { + float32x2_t t1 = vld1_f32(b_rd); b_rd += 2; + vst1_f32(b_wt, t1); b_wt += 2 * N; + } + if (k_left > 0) { + b_wt[0] = *b_rd; + } + } +} + +void pack_0_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + uint32_t k_left = K; + const float *b_rd = B; + float *b_wt = b_scr; + for (; k_left > 0; k_left--) { + const float *b_rd1 = b_rd; b_rd += LDB; + uint32_t n_left = N; + for (; n_left > 3; n_left -= 4) { + float32x4_t t1 = vld1q_f32(b_rd1); b_rd1 += 4; + vst1q_f32(b_wt, t1); b_wt += 4; + } + if (n_left > 1) { + float32x2_t t1 = vld1_f32(b_rd1); b_rd1 += 2; + vst1_f32(b_wt, t1); b_wt += 2; + n_left -= 2; + } + if (n_left > 0) { + *b_wt = *b_rd1; b_wt++; + } + } +} + +void pack_2_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + uint32_t k_left = K; + const uint32_t n_4z = N & 0xFFFFFFFC; + const float *b_rd = B; + float *b_wt = b_scr; + for (; k_left > 1; k_left -= 2) { + const float *b_rd1 = b_rd; + const float *b_rd2 = b_rd + LDB; + b_rd += LDB * 2; + float *b_wt1 = b_wt; + float *b_wt2 = b_wt + n_4z; + b_wt += N * 2; + uint32_t n_left = N; + for (; n_left > 3; n_left -= 4) { + float32x4_t t1 = vld1q_f32(b_rd1); b_rd1 += 4; + float32x4_t t2 = vld1q_f32(b_rd2); b_rd2 += 4; + vst1q_f32(b_wt1, t1); b_wt1 += 4; + vst1q_f32(b_wt2, t2); b_wt2 += 4; + } + for (; n_left > 0; n_left--) { + b_wt2[0] = *b_rd1++; + b_wt2[1] = *b_rd2++; + b_wt2 += 2; + } + } + pack_0_from_rm(b_wt, b_rd, LDB, k_left, N); +} + +void pack_1_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + uint32_t k_left = K; + const float *b_rd = B; + float *b_wt = b_scr; + const uint32_t n_4z = N & 0xFFFFFFFC; + for (; k_left > 3; k_left -= 4) { + const float *b_rd1 = b_rd; + const float *b_rd2 = b_rd + LDB; + const float *b_rd3 = b_rd + LDB * 2; + const float *b_rd4 = b_rd2 + LDB * 2; + b_rd += LDB * 4; + float *b_wt1 = b_wt; + float *b_wt2 = b_wt + n_4z; + float *b_wt3 = b_wt + n_4z * 2; + float *b_wt4 = b_wt2 + n_4z * 2; + b_wt += N * 4; + uint32_t n_left = N; + for (; n_left > 3; n_left -= 4) { + float32x4_t t1 = vld1q_f32(b_rd1); b_rd1 += 4; + float32x4_t t2 = vld1q_f32(b_rd2); b_rd2 += 4; + float32x4_t t3 = vld1q_f32(b_rd3); b_rd3 += 4; + float32x4_t t4 = vld1q_f32(b_rd4); b_rd4 += 4; + vst1q_f32(b_wt1, t1); b_wt1 += 4; + vst1q_f32(b_wt2, t2); b_wt2 += 4; + vst1q_f32(b_wt3, t3); b_wt3 += 4; + vst1q_f32(b_wt4, t4); b_wt4 += 4; + } + for (; n_left > 0; n_left--) { + b_wt4[0] = *b_rd1++; b_wt4[1] = *b_rd2++; + b_wt4[2] = *b_rd3++; b_wt4[3] = *b_rd4++; b_wt4 += 4; + } + } + pack_0_from_rm(b_wt, b_rd, LDB, k_left, N); +} + +void pack_3_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + uint32_t k_left = K; + const float *b_rd = B; + float *b_wt = b_scr; + for (; k_left > 1; k_left -= 2) { + const float *b_rd1 = b_rd; + const float *b_rd2 = b_rd + LDB; + b_rd += LDB * 2; + float *b_wt1 = b_wt; + b_wt += N * 2; + uint32_t n_left = N; + for (; n_left > 1; n_left -= 2) { + float32x2_t t1 = vld1_f32(b_rd1); b_rd1 += 2; + float32x2_t t2 = vld1_f32(b_rd2); b_rd2 += 2; + float32x2_t d1 = vzip1_f32(t1, t2); + float32x2_t d2 = vzip2_f32(t1, t2); + vst1_f32(b_wt1, d1); + vst1_f32(b_wt1 + 2, d2); b_wt1 += 4; + } + if (n_left > 0) { + b_wt1[0] = *b_rd1; b_wt1[1] = *b_rd2; + } + } + pack_0_from_rm(b_wt, b_rd, LDB, k_left, N); +} + +void pack_4_from_rm(float * __restrict__ b_scr, + const float * __restrict__ B, uint32_t LDB, uint32_t K, uint32_t N) { + + uint32_t k_left = K; + const float *b_rd = B; + float *b_wt = b_scr; + const uint32_t n_2z = (N << 1) - (N & 0xFFFFFFFE); + for (; k_left > 1; k_left -= 2) { + const float *b_rd1 = b_rd; + const float *b_rd2 = b_rd + LDB; + b_rd += LDB * 2; + float *b_wt1 = b_wt; + float *b_wt2 = b_wt + n_2z; + b_wt += N * 2; + uint32_t n_left = N; + for (; n_left > 1; n_left -= 2) { + float32x2_t t1 = vld1_f32(b_rd1); b_rd1 += 2; + float32x2_t t2 = vld1_f32(b_rd2); b_rd2 += 2; + t2 = vrev64_f32(t2); + float32x2_t d1 = vzip1_f32(t1, t2); + float32x2_t d2 = vzip2_f32(t2, t1); + vst1_f32(b_wt1, d1); b_wt1 += 2; + vst1_f32(b_wt2, d2); b_wt2 += 2; + } + if (n_left > 0) { + b_wt1[0] = *b_rd1; b_wt1[1] = *b_rd2; + } + } + pack_0_from_rm(b_wt, b_rd, LDB, k_left, N); +} + diff --git a/test/TestBias.c b/test/TestBias.c new file mode 100644 index 0000000..10517a1 --- /dev/null +++ b/test/TestBias.c @@ -0,0 +1,281 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#if __aarch64__ +#include "neon_armv8a/Bias.h" +#else +#include "neon_armv7a/Bias.h" +#endif + +#include +#include +#include +#include +#include + +#define TEST_BIAS(type) \ +static void test_bias_##type(uint32_t dim1, uint32_t dim2, uint8_t status) {\ + bool dim0_bias = status & 1;\ + bool dim1_bias = status & 2;\ + bool dim2_bias = status & 4;\ + printf("Test info for bias:\n");\ + printf("data type = "#type"\n");\ + printf("dim1 = %u, dim2 = %u\n", dim1, dim2);\ + printf("dim0_bias = %d\n", dim0_bias ? 1 : 0);\ + printf("dim1_bias = %d\n", dim1_bias ? 1 : 0);\ + printf("dim2_bias = %d\n", dim2_bias ? 1 : 0);\ +\ + const uint64_t size_dat = (dim1 + 4) * (dim2 + 4);\ + const uint32_t num_iters = 40000000 / size_dat;\ + if (num_iters <= 2) {\ + printf("Problem size too large.\n");\ + return;\ + }\ + type * const ref = (type *)malloc(sizeof(type) * size_dat);\ + type * const dat = (type *)malloc(sizeof(type) *\ + size_dat * num_iters);\ + type * const bias_dim1 = dim1_bias ? (type *)malloc(sizeof(type) *\ + (dim1 + 4) * num_iters) : NULL;\ + type * const bias_dim2 = dim2_bias ? (type *)malloc(sizeof(type) *\ + (dim2 + 4) * num_iters) : NULL;\ +\ + srand(time(NULL));\ + for (uint64_t pos = 0; pos < size_dat; ++pos) {\ + ref[pos] = rand() % 256;\ + }\ + for (uint32_t pos = 0; pos < num_iters; ++pos) {\ + memcpy(dat + pos * size_dat, ref, size_dat * sizeof(type));\ + }\ + if (dim1_bias) {\ + for (uint32_t pos = 0; pos < dim1 + 4; ++pos) {\ + bias_dim1[pos] = rand() % 256;\ + }\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + memcpy(bias_dim1 + pos * (dim1 + 4), bias_dim1, (dim1 + 4) * sizeof(type));\ + }\ + }\ + if (dim2_bias) {\ + for (uint32_t pos = 0; pos < dim2 + 4; ++pos) {\ + bias_dim2[pos] = rand() % 256;\ + }\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + memcpy(bias_dim2 + pos * (dim2 + 4), bias_dim2, (dim2 + 4) * sizeof(type));\ + }\ + }\ +\ + const type bias_v0 = dim0_bias ? (rand() % 256 + 1) : 0;\ + if (dim0_bias) {\ + for (uint32_t pos = 0; pos < dim1 * dim2; ++pos) ref[pos] += bias_v0;\ + }\ +\ + if (dim1_bias) {\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + type *curr = ref + dim2_pos * dim1;\ + for (uint32_t dim1_pos = 0; dim1_pos < dim1; ++dim1_pos) {\ + curr[dim1_pos] += (type)2.0 * bias_dim1[dim1_pos];\ + }\ + }\ + }\ +\ + if (dim2_bias) {\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + const type bias = (type)3.0 * bias_dim2[dim2_pos];\ + type *curr = ref + dim2_pos * dim1;\ + for (uint32_t dim1_pos = 0; dim1_pos < dim1; ++dim1_pos) {\ + curr[dim1_pos] += bias;\ + }\ + }\ + }\ +\ + bias_##type(dat, bias_v0, bias_dim1, 2.0, bias_dim2, 3.0, dim1, dim2);\ + double max_diff = 0.0;\ + for (uint64_t pos = 0; pos < size_dat; ++pos) {\ + double tmp = (double)dat[pos] - (double)ref[pos];\ + if (tmp < 0) tmp *= -1.0;\ + if (tmp > max_diff) max_diff = tmp;\ + }\ + printf("Max diff. between calc. and ref.: %.2e\n", max_diff);\ +\ + struct timespec st, et;\ + clock_gettime(CLOCK_MONOTONIC, &st);\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + bias_##type(dat + pos * size_dat, bias_v0,\ + bias_dim1 ? bias_dim1 + pos * dim1 : NULL, 2.0,\ + bias_dim2 ? bias_dim2 + pos * dim2 : NULL, 3.0,\ + dim1, dim2);\ + }\ + clock_gettime(CLOCK_MONOTONIC, &et);\ + double nsec = (double)(et.tv_nsec - st.tv_nsec) + 1.0e9 * (double)\ + (et.tv_sec - st.tv_sec);\ + printf("Avg. perf.: %.2e G elements per second.\n", (double)dim1 * \ + (double)dim2 * (double)(num_iters - 1) / nsec);\ +\ + free(ref);\ + free(dat);\ + free(bias_dim1);\ + free(bias_dim2);\ +} + +#define TEST_SUM(signint, sumfunc) \ +void test_sum_##signint##8to32(uint32_t dim1, uint32_t dim2,\ + uint32_t status) {\ +\ + printf("Test info for sum:\n");\ + printf("data type = "#signint"8 -> "#signint"32\n");\ + printf("dim1 = %u, dim2 = %u\n", dim1, dim2);\ + if (status) {\ + printf("sum along dim1 direction, output length = dim2\n");\ + } else {\ + printf("sum along dim2 direction, output length = dim1\n");\ + }\ +\ + const uint64_t size_dat = (dim1 + 4) * (dim2 + 4);\ + const uint32_t num_iters = 40000000 / size_dat;\ + if (num_iters <= 2) {\ + printf("Problem size too large.\n");\ + return;\ + }\ + signint##8_t * const dat = (signint##8_t *)malloc(size_dat * num_iters);\ +\ + const uint32_t size_out = status ? (dim2 + 4) : (dim1 + 4);\ + signint##32_t * const ref = (signint##32_t *)malloc(size_out * 4);\ + signint##32_t * const tst = (signint##32_t *)malloc(size_out * num_iters * 4);\ +\ + srand(time(NULL));\ + for (uint64_t pos = 0; pos < size_dat; ++pos) {\ + dat[pos] = rand();\ + }\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + memcpy(dat + pos * size_dat, dat, size_dat);\ + }\ +\ + if (status) {\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + const signint##8_t *src = dat + dim2_pos * dim1;\ + signint##32_t sum = 0;\ + for (uint32_t dim1_pos = 0; dim1_pos < dim1; ++dim1_pos) {\ + sum += src[dim1_pos];\ + }\ + ref[dim2_pos] = sum;\ + }\ + for (uint32_t dim2_pos = dim2; dim2_pos < size_out; ++dim2_pos) {\ + ref[dim2_pos] = tst[dim2_pos] = rand();\ + }\ + } else {\ + for (uint32_t dim1_pos = 0; dim1_pos < dim1; ++dim1_pos) {\ + ref[dim1_pos] = 0;\ + }\ + for (uint32_t dim1_pos = dim1; dim1_pos < size_out; dim1_pos++) {\ + ref[dim1_pos] = tst[dim1_pos] = rand();\ + }\ + for (uint32_t dim2_pos = 0; dim2_pos < dim2; ++dim2_pos) {\ + const signint##8_t *src = dat + dim2_pos * dim1;\ + for (uint32_t dim1_pos = 0; dim1_pos < dim1; ++dim1_pos) {\ + ref[dim1_pos] += src[dim1_pos];\ + }\ + }\ + }\ +\ + sumfunc(dat, tst, dim1, dim2, status);\ + int consistent = 1;\ + for (uint32_t pos = 0; pos < size_out; ++pos) {\ + if (consistent != 0 && ref[pos] != tst[pos]) {\ + consistent = 0;\ + printf("elements at pos %u are unequal.\n", pos);\ + }\ + }\ + if (consistent != 0) {\ + printf("all elements are equal between ref. and tst.\n");\ + struct timespec st, et;\ + clock_gettime(CLOCK_MONOTONIC, &st);\ + for (uint32_t pos = 1; pos < num_iters; ++pos) {\ + sumfunc(dat + pos * size_dat, tst + pos * size_out,\ + dim1, dim2, status);\ + }\ + clock_gettime(CLOCK_MONOTONIC, &et);\ + double nsec = (double)(et.tv_nsec - st.tv_nsec) + 1.0e9 * \ + (double)(et.tv_sec - st.tv_sec);\ + printf("Avg. Perf.: %.2e G elements read per second.\n",\ + (double)dim1 * (double)(dim2) * (double)(num_iters - 1) / nsec);\ + }\ + free(dat);\ + free(ref);\ + free(tst);\ +} + + +TEST_BIAS(float) + +TEST_BIAS(int32_t) + +TEST_SUM(uint, u8u32_sum) + +/************************************************************************ + * cmd usage of the test program for bias functions + * + * dim1: the length of the first dimension of the matrix, + * equals to number of columns for row-major matrices, + * equals to number of rows for column-major matrices. + * dim2: the length of the second dimension of the matrix, + * equals to number of rows for row-major matrices, + * equals to number of columns for column-major matrices. + * bias_status: a number indicating which function to test. + * 0 - 7 for bias function: + * 0: no bias is performed. + * 1: only scalar bias is applied. the bias is identical + * to each element. + * 2: bias only along the first dimension, the size of bias + * vector equals dim1. for row-major matrix, this means + * elem(col_id, row_id) += bias(col_id); + * 3: scalar & first-dimension bias operations. + * 4: bias only along the second dimension, the size of bias + * vector equals dim2. for row-major matrix, this means + * elem(col_id, row_id) += bias(row_id); + * 5: scalar & second-dimension bias operations. + * 6: first-dimension & second-dimension bias operations. + * 7: scalar & first-dimension & second-dimension bias operations. + * 8 - 9 for sum function: + * 8: sum along dim2. for a row-major matrix, the sum of elements + * in each column is calculated. + * 9: sum along dim1. for a row-major matrix, the sum of elements + * in each row is calculated. + * data_type: a string indicating the data type of bias + * float: fp32 + * int32: int32_t + ************************************************************************/ + +int main(int argc, char **argv) { + + const uint32_t dim1 = (argc > 1) ? atoi(argv[1]) : 63; + const uint32_t dim2 = (argc > 2) ? atoi(argv[2]) : 143; + const uint8_t bias_status = (argc > 3) ? atoi(argv[3]) : 7; + const char * const data_type = (argc > 4) ? argv[4] : "float"; + + if (bias_status > 7) { + test_sum_uint8to32(dim1, dim2, bias_status & 1); + return 0; + } + + if (data_type[0] == 'f' || data_type[0] == 'F') { + test_bias_float(dim1, dim2, bias_status); + return 0; + } + + test_bias_int32_t(dim1, dim2, bias_status); + return 0; +} + diff --git a/test/TestCompilerOpenMP.c b/test/TestCompilerOpenMP.c new file mode 100644 index 0000000..b248218 --- /dev/null +++ b/test/TestCompilerOpenMP.c @@ -0,0 +1,12 @@ +#include +#include + +int main() { + int *id = (int*)malloc(omp_get_max_threads() * sizeof(int)); +#pragma omp parallel + { + id[omp_get_thread_num()] = omp_get_thread_num(); + } + free(id); + return 0; +} diff --git a/test/TestGemm.c b/test/TestGemm.c new file mode 100644 index 0000000..19d6c9e --- /dev/null +++ b/test/TestGemm.c @@ -0,0 +1,98 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "common/CommonTest.h" +#include "Gemm.h" + +STD_TEST(sgemm, float, float, float, (RAND_MAX >> 2), (RAND_MAX >> 2)) + +STD_TEST(s8s32gemm, int8_t, int8_t, int32_t, 64, 1) + +STD_TEST(u8u32gemm, uint8_t, uint8_t, uint32_t, -64, 1) + +#if __aarch64__ +STD_TEST(hgemm, float16_t, float16_t, float16_t, 6, 32) +#endif + +/************************************************************************* + * cmd usage of test program for GEMM functions + * \ + * + * GEMM operation: C[MxN] = A[MxK] B[KxN] + beta * C[MxN] + * Parameters: + * M: the number of rows in matrix A. + * N: the number of columns in matrix B. + * K: the number of columns in matrix A. + * transAB: a number indicating the storage order of source matrices: + * 0: A column-major, B column-major + * 1: A row-major, B column-major + * 2: A column-major, B row-major + * 3: A row-major, B row-major + * num_threads: number of threads for GEMM. + * gemm_type: a string indicating the type of GEMM: + * sgemm: fp32 GEMM + * hgemm: fp16 GEMM + * u8u32: uint8 * uint8 -> uint32 GEMM + * s8s32: int8 * int8 -> int32 GEMM + * beta: the scaling factor applied to matrix C prior to GEMM operation, + * C = AB + beta * C. + *************************************************************************/ +int main(int argc, char **argv) { + uint32_t M = 383; + uint32_t N = 479; + uint32_t K = 319; + uint8_t transAB = 0; + uint32_t num_threads = 0; + const char *gemm_type = "sgemm"; + double beta = 0.5; + if (argc > 1) M = atoi(argv[1]); + if (argc > 2) N = atoi(argv[2]); + if (argc > 3) K = atoi(argv[3]); + if (argc > 4) transAB = atoi(argv[4]); + if (argc > 5) num_threads = atoi(argv[5]); + if (argc > 6) gemm_type = argv[6]; + if (argc > 7) beta = atof(argv[7]); + printf("Test info: M = %u, N = %u, K = %u\n", M, N, K); + printf("Test info: a_rowmajor = %d, b_rowmajor = %d\n", + transAB & 1, (transAB & 2) >> 1); + printf("Test info: num_threads = %u, beta = %.2e\n", num_threads, beta); + +#if __aarch64__ + if (strcmp(gemm_type, "hgemm") == 0) { + printf("Test info: gemmtype = hgemm.\n"); + std_test_hgemm(hgemm, M, N, K, transAB, beta, num_threads); + return 0; + } +#endif + + if (strcmp(gemm_type, "u8u32") == 0) { + printf("Test info: gemmtype = u8u32gemm.\n"); + std_test_u8u32gemm(u8u32gemm, M, N, K, transAB, beta, num_threads); + return 0; + } + + if (strcmp(gemm_type, "s8s32") == 0) { + printf("Test info: gemmtype = s8s32gemm.\n"); + std_test_s8s32gemm(s8s32gemm, M, N, K, transAB, beta, num_threads); + return 0; + } + + printf("Test info: gemmtype = sgemm.\n"); + std_test_sgemm(sgemm, M, N, K, transAB, beta, num_threads); + return 0; +} + diff --git a/test/TestQuant.c b/test/TestQuant.c new file mode 100644 index 0000000..a67f808 --- /dev/null +++ b/test/TestQuant.c @@ -0,0 +1,95 @@ +/*****************************************************************************/ +/* Copyright YouDao, Inc. */ +/* */ +/* Licensed under the Apache License, Version 2.0 (the "License"); */ +/* you may not use this file except in compliance with the License. */ +/* You may obtain a copy of the License at */ +/* */ +/* http://www.apache.org/licenses/LICENSE-2.0 */ +/* */ +/* Unless required by applicable law or agreed to in writing, software */ +/* distributed under the License is distributed on an "AS IS" BASIS, */ +/* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ +/* See the License for the specific language governing permissions and */ +/* limitations under the License. */ +/*****************************************************************************/ + + +#include "common/CommonTest.h" +#include "Quant.h" + +TEST_QUANT_UNSYM(32, 8) + +TEST_QUANT_SYM(32, 8) + +TEST_QUANT_UNSYM(32, 16) + +TEST_QUANT_SYM(32, 16) + +TEST_DEQUANT_SYM(32, 32) + +TEST_REQUANT_UNSYM(float, 32, 8) + +TEST_REQUANT_SYM(float, 32, 8) + +TEST_REQUANT_UNSYM(float, 32, 16) + +TEST_REQUANT_SYM(float, 32, 16) + +TEST_REQUANT_UNSYM(float, 16, 8) + +TEST_REQUANT_SYM(float, 16, 8) + +int main(int argc, char **argv) { + + uint32_t size = argc > 1 ? atoi(argv[1]) : 4; + const char * const type = argc > 2 ? argv[2] : "qu"; + + if (type[0] == 'q') { + if (type[1] == 'u') { + if (type[2] == '1') { + test_quant_asym_f32_u16(size); + } else { + test_quant_asym_f32_u8(size); + } + } else if (type[1] == 's') { + if (type[2] == '1') { + test_quant_sym_f32_s16(size); + } else { + test_quant_sym_f32_s8(size); + } + } + } else if (type[0] == 'd') { + test_dequant_sym_f32_s32(size); + } else if (type[0] == 'r') { + if (type[1] == 'u') { + int32_t max = argc > 3 ? atoi(argv[3]) : 20000000; + int32_t min = argc > 4 ? atoi(argv[4]) : -10000000; + float org_scale = argc > 5 ? atof(argv[5]) : 2.0; + if (type[2] == '1') { + test_requant_int16_t_float_uint8_t(size, + (int16_t)(min & 0xFFFF), (int16_t)(max & 0xFFFF), org_scale); + } else { + if (type[2] == '3' && type[3] == '2' && type[4] == '1') { + test_requant_int32_t_float_uint16_t(size, min, max, org_scale); + } else { + test_requant_int32_t_float_uint8_t(size, min, max, org_scale); + } + } + } else if (type[1] == 's') { + uint32_t max_abs = argc > 3 ? atoi(argv[3]) : 2000000; + float org_scale = argc > 4 ? atof(argv[4]) : 2.0; + if (type[2] == '1') { + test_requant_int16_t_float_int8_t(size, + (uint16_t)(max_abs & 0xFFFF), org_scale); + } else { + if (type[2] == '3' && type[3] == '2' && type[4] == '1') { + test_requant_int32_t_float_int16_t(size, max_abs, org_scale); + } else { + test_requant_int32_t_float_int8_t(size, max_abs, org_scale); + } + } + } + } + return 0; +}