From 4a0201ebd0ba61918c67fa8d4f0f27f81c13bc60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=81=E8=A1=8C?= Date: Sun, 25 Jun 2023 16:39:19 +0800 Subject: [PATCH] Bugfix of AVX512 compile. --- source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp | 4 ++-- source/backend/cpu/x86_x64/avx512/GemmCommon.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp b/source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp index d277e1789..b7e776da3 100644 --- a/source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/avx512/FunctionSummary.hpp @@ -36,8 +36,8 @@ extern "C" { void _AVX512_MNNPackForMatMul_B(float* dest, const float* source, size_t h, size_t l, bool transpose); void _AVX512_MNNPackC8ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); -void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); -void _AVX512_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); +void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); +void _AVX512_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void _AVX512_MNNGetSparseMatMulPackMode(int* eP, int *lP, int* hP); void _AVX512_MNNPackedSparseMatMulEpx8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, diff --git a/source/backend/cpu/x86_x64/avx512/GemmCommon.cpp b/source/backend/cpu/x86_x64/avx512/GemmCommon.cpp index cf9120904..0911129fc 100644 --- a/source/backend/cpu/x86_x64/avx512/GemmCommon.cpp +++ b/source/backend/cpu/x86_x64/avx512/GemmCommon.cpp @@ -1131,7 +1131,7 @@ static void _AVX512_MNNPackednMatMulRemainCommon(float* C, const float* A, const } } -void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias) { +void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { #ifdef MNN_X86_USE_ASM if (nullptr == postParameters) { _AVX512_MNNGemmFloatUnit48x8(C, A, B, parameter); @@ -1147,7 +1147,7 @@ void _AVX512_MNNPackedMatMul(float* C, const float* A, const float* B, const siz } //#define MNN_X86_DEBUG -void _AVX512_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias) { +void _AVX512_MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) { #ifdef MNN_X86_DEBUG static std::set gSize; if (gSize.find(eSize) == gSize.end()) {