Skip to content

Commit

Permalink
[MNN:Bugfix] Fix bug of no sse poolgrad
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaying committed Oct 19, 2023
1 parent a8c1f1a commit b019bef
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
32 changes: 16 additions & 16 deletions source/backend/cpu/BinaryUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,14 @@ struct BinaryBitwiseXor {
}
};

template<typename Func, typename V, int pack, typename U>
template<typename Func, typename V, int pack, typename U, typename Tout>
void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int needBroadcastIndex) {
Func compute;
const int sizeDivUnit = elementSize / pack;
const int remainCount = elementSize - sizeDivUnit * pack;
auto src0 = (const U*)(inputRaw0);
auto src1 = (const U*)(inputRaw1);
auto dst = (U*)outputRaw;
auto dst = (Tout*)outputRaw;

if (-1 == needBroadcastIndex) {
if (sizeDivUnit > 0) {
Expand All @@ -210,7 +210,7 @@ void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, i
if (remainCount > 0) {
U tempSrc0[pack];
U tempSrc1[pack];
U tempDst[pack];
Tout tempDst[pack];
::memcpy(tempSrc0, src0, remainCount * sizeof(U));
::memcpy(tempSrc1, src1, remainCount * sizeof(U));
V a = V::load(tempSrc0);
Expand All @@ -233,7 +233,7 @@ void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, i
}
if (remainCount > 0) {
U tempSrc1[pack];
U tempDst[pack];
Tout tempDst[pack];
::memcpy(tempSrc1, src1, remainCount * sizeof(U));
V b = V::load(tempSrc1);
V::save(tempDst, compute(a, b));
Expand All @@ -254,7 +254,7 @@ void executeVec(void* outputRaw, const void* inputRaw0, const void* inputRaw1, i
}
if (remainCount > 0) {
U tempSrc0[pack];
U tempDst[pack];
Tout tempDst[pack];
::memcpy(tempSrc0, src0, remainCount * sizeof(U));
V a = V::load(tempSrc0);
V::save(tempDst, compute(a, b));
Expand Down Expand Up @@ -415,27 +415,27 @@ template<typename V, int pack, typename U>
MNNBinaryExecute selectVector(int type) {
switch (type) {
case BinaryOpOperation_ADD:
return executeVec<VecBinaryAdd<V>, V, pack, U>;
return executeVec<VecBinaryAdd<V>, V, pack, U, U>;
case BinaryOpOperation_SUB:
return executeVec<VecBinarySub<V>, V, pack, U>;
return executeVec<VecBinarySub<V>, V, pack, U, U>;
case BinaryOpOperation_MUL:
return executeVec<VecBinaryMul<V>, V, pack, U>;
return executeVec<VecBinaryMul<V>, V, pack, U, U>;
case BinaryOpOperation_MINIMUM:
return executeVec<VecBinaryMin<V>, V, pack, U>;
return executeVec<VecBinaryMin<V>, V, pack, U, U>;
case BinaryOpOperation_MAXIMUM:
return executeVec<VecBinaryMax<V>, V, pack, U>;
return executeVec<VecBinaryMax<V>, V, pack, U, U>;
case BinaryOpOperation_SquaredDifference:
return executeVec<VecBinarySqd<V>, V, pack, U>;
return executeVec<VecBinarySqd<V>, V, pack, U, U>;
case BinaryOpOperation_LESS:
return executeVec<VecBinaryLess<V>, V, pack, U>;
return executeVec<VecBinaryLess<V>, V, pack, U, int32_t>;
case BinaryOpOperation_LESS_EQUAL:
return executeVec<VecBinaryLessEqual<V>, V, pack, U>;
return executeVec<VecBinaryLessEqual<V>, V, pack, U, int32_t>;
case BinaryOpOperation_GREATER:
return executeVec<VecBinaryGreater<V>, V, pack, U>;
return executeVec<VecBinaryGreater<V>, V, pack, U, int32_t>;
case BinaryOpOperation_GREATER_EQUAL:
return executeVec<VecBinaryGreaterEqual<V>, V, pack, U>;
return executeVec<VecBinaryGreaterEqual<V>, V, pack, U, int32_t>;
case BinaryOpOperation_EQUAL:
return executeVec<VecBinaryEqual<V>, V, pack, U>;
return executeVec<VecBinaryEqual<V>, V, pack, U, int32_t>;
}
return nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions source/backend/cpu/x86_x64/avx/Vec8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ struct Vec8 {
static void save(float* addr, const VecType& v) {
_mm256_storeu_ps(addr, v.value);
}
static void save(int32_t* addr, const VecType& v) {
_mm256_storeu_ps((float*)addr, v.value);
}
static VecType max(const VecType& v1, const VecType& v2) {
VecType dst = { _mm256_max_ps(v1.value, v2.value) };
return dst;
Expand Down
6 changes: 6 additions & 0 deletions source/math/Vec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,9 @@ struct Vec<float, 4> {
static void save(float* addr, const VecTypeInt32& v) {
vst1q_f32(addr, reinterpret_cast<float32x4_t>(v.value));
}
static void save(int32_t* addr, const VecType& v) {
vst1q_s32(addr, reinterpret_cast<int32x4_t>(v.value));
}
static VecType max(const VecType& v1, const VecType& v2) {
VecType dst = { vmaxq_f32(v1.value, v2.value) };
return dst;
Expand Down Expand Up @@ -763,6 +766,9 @@ struct Vec<float, 4> {
static void save(float* addr, const VecTypeInt32& v) {
_mm_storeu_ps(addr, _mm_castsi128_ps(v.value));
}
static void save(int32_t* addr, const VecType& v) {
_mm_storeu_si128((__m128i*)addr, _mm_castps_si128(v.value));
}
static VecType max(const VecType& v1, const VecType& v2) {
VecType dst = { _mm_max_ps(v1.value, v2.value) };
return dst;
Expand Down

0 comments on commit b019bef

Please sign in to comment.