From b019bef591d1d6ae8aadaaf7f44f90bda5a7a0b5 Mon Sep 17 00:00:00 2001 From: xiaying Date: Thu, 19 Oct 2023 10:12:55 +0800 Subject: [PATCH] [MNN:Bugfix] Fix bug of no sse poolgrad --- source/backend/cpu/BinaryUtils.hpp | 32 ++++++++++++------------- source/backend/cpu/x86_x64/avx/Vec8.hpp | 3 +++ source/math/Vec.hpp | 6 +++++ 3 files changed, 25 insertions(+), 16 deletions(-) diff --git a/source/backend/cpu/BinaryUtils.hpp b/source/backend/cpu/BinaryUtils.hpp index be05589ad..368e46b41 100644 --- a/source/backend/cpu/BinaryUtils.hpp +++ b/source/backend/cpu/BinaryUtils.hpp @@ -187,14 +187,14 @@ struct BinaryBitwiseXor { } }; -template +template void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int needBroadcastIndex) { Func compute; const int sizeDivUnit = elementSize / pack; const int remainCount = elementSize - sizeDivUnit * pack; auto src0 = (const U*)(inputRaw0); auto src1 = (const U*)(inputRaw1); - auto dst = (U*)outputRaw; + auto dst = (Tout*)outputRaw; if (-1 == needBroadcastIndex) { if (sizeDivUnit > 0) { @@ -210,7 +210,7 @@ void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, i if (remainCount > 0) { U tempSrc0[pack]; U tempSrc1[pack]; - U tempDst[pack]; + Tout tempDst[pack]; ::memcpy(tempSrc0, src0, remainCount * sizeof(U)); ::memcpy(tempSrc1, src1, remainCount * sizeof(U)); V a = V::load(tempSrc0); @@ -233,7 +233,7 @@ void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, i } if (remainCount > 0) { U tempSrc1[pack]; - U tempDst[pack]; + Tout tempDst[pack]; ::memcpy(tempSrc1, src1, remainCount * sizeof(U)); V b = V::load(tempSrc1); V::save(tempDst, compute(a, b)); @@ -254,7 +254,7 @@ void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, i } if (remainCount > 0) { U tempSrc0[pack]; - U tempDst[pack]; + Tout tempDst[pack]; ::memcpy(tempSrc0, src0, remainCount * sizeof(U)); V a = V::load(tempSrc0); V::save(tempDst, compute(a, b)); @@ -415,27 +415,27 @@ template MNNBinaryExecute selectVector(int type) { switch (type) { case BinaryOpOperation_ADD: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, U>; case BinaryOpOperation_SUB: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, U>; case BinaryOpOperation_MUL: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, U>; case BinaryOpOperation_MINIMUM: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, U>; case BinaryOpOperation_MAXIMUM: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, U>; case BinaryOpOperation_SquaredDifference: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, U>; case BinaryOpOperation_LESS: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, int32_t>; case BinaryOpOperation_LESS_EQUAL: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, int32_t>; case BinaryOpOperation_GREATER: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, int32_t>; case BinaryOpOperation_GREATER_EQUAL: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, int32_t>; case BinaryOpOperation_EQUAL: - return executeVec, V, pack, U>; + return executeVec, V, pack, U, int32_t>; } return nullptr; } diff --git a/source/backend/cpu/x86_x64/avx/Vec8.hpp b/source/backend/cpu/x86_x64/avx/Vec8.hpp index 501768e69..d5a2a7b5c 100644 --- a/source/backend/cpu/x86_x64/avx/Vec8.hpp +++ b/source/backend/cpu/x86_x64/avx/Vec8.hpp @@ -168,6 +168,9 @@ struct Vec8 { static void save(float* addr, const VecType& v) { _mm256_storeu_ps(addr, v.value); } + static void save(int32_t* addr, const VecType& v) { + _mm256_storeu_ps((float*)addr, v.value); + } static VecType max(const VecType& v1, const VecType& v2) { VecType dst = { _mm256_max_ps(v1.value, v2.value) }; return dst; diff --git a/source/math/Vec.hpp b/source/math/Vec.hpp index fc4d8e0d4..7ca82319b 100644 --- a/source/math/Vec.hpp +++ b/source/math/Vec.hpp @@ -413,6 +413,9 @@ struct Vec { static void save(float* addr, const VecTypeInt32& v) { vst1q_f32(addr, reinterpret_cast(v.value)); } + static void save(int32_t* addr, const VecType& v) { + vst1q_s32(addr, reinterpret_cast(v.value)); + } static VecType max(const VecType& v1, const VecType& v2) { VecType dst = { vmaxq_f32(v1.value, v2.value) }; return dst; @@ -763,6 +766,9 @@ struct Vec { static void save(float* addr, const VecTypeInt32& v) { _mm_storeu_ps(addr, _mm_castsi128_ps(v.value)); } + static void save(int32_t* addr, const VecType& v) { + _mm_storeu_si128((__m128i*)addr, _mm_castps_si128(v.value)); + } static VecType max(const VecType& v1, const VecType& v2) { VecType dst = { _mm_max_ps(v1.value, v2.value) }; return dst;