Skip to content

Commit

Permalink
[MNN:Bugfix] Fix bug for avx512 compile
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaying committed Oct 20, 2023
1 parent b019bef commit 476083a
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 5 deletions.
4 changes: 2 additions & 2 deletions source/backend/cpu/arm/arm32/MNNBinarySqdInt8.S
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,15 @@ L4Loop:
vqmovn.s32 d8, q10
vqmovn.s32 d9, q11
vdup.8 q12, r3
vdup.8 q15, r11
vdup.8 q1, r11

vaddw.s8 q3, q3, d4
vaddw.s8 q4, q4, d4

vqmovn.s16 d12, q3
vqmovn.s16 d13, q4
vmax.s8 q6, q6, q12
vmin.s8 q6, q6, q15
vmin.s8 q6, q6, q1
cmp r6, #4
vst1.32 {q6}, [r0]!
bge L4Loop
Expand Down
2 changes: 1 addition & 1 deletion source/backend/cpu/x86_x64/avx/Vec8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct Vec8 {
return dst;
}
VecType operator>(const VecType& lr) {
__m256 mask = _mm256_cmp_ps(lr.value, value, 0x01);
__m256 mask = _mm256_cmp_ps(value, lr.value, 14);
VecType dst = { _mm256_and_ps(one, mask) } ;
return dst;
}
Expand Down
2 changes: 1 addition & 1 deletion source/backend/cpu/x86_x64/avx512/PackedFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void _AVX512_MNNConvRunForLineDepthwise(float* dst, const float* src, const floa
}

static MNNBinaryExecute _AVX512_MNNSelectBinaryFunctionForFloat(int opType) {
auto vecF = MNN::selectVector<Vec16, 16>(opType);
auto vecF = MNN::selectVector<Vec16, 16, float>(opType);
if (nullptr != vecF) {
return vecF;
}
Expand Down
38 changes: 38 additions & 0 deletions source/backend/cpu/x86_x64/avx512/Vec16.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,41 @@ struct Vec16 {
return value[i];
#endif
}
VecType operator==(const VecType& lr) const {
__m512 one = _mm512_set1_ps(1.0f);
__m512 zero = _mm512_set1_ps(0.0f);
__mmask16 mask = _mm512_cmp_ps_mask(value, lr.value, 0);
VecType dst = { _mm512_mask_blend_ps(mask, zero, one) } ;
return dst;
}
VecType operator>(const VecType& lr) {
__m512 one = _mm512_set1_ps(1.0f);
__m512 zero = _mm512_set1_ps(0.0f);
__mmask16 mask = _mm512_cmp_ps_mask(value, lr.value, 14);
VecType dst = { _mm512_mask_blend_ps(mask, zero, one) } ;
return dst;
}
VecType operator>=(const VecType& lr) {
__m512 one = _mm512_set1_ps(1.0f);
__m512 zero = _mm512_set1_ps(0.0f);
__mmask16 mask = _mm512_cmp_ps_mask(value, lr.value, 13);
VecType dst = { _mm512_mask_blend_ps(mask, zero, one) } ;
return dst;
}
VecType operator<(const VecType& lr) {
__m512 one = _mm512_set1_ps(1.0f);
__m512 zero = _mm512_set1_ps(0.0f);
__mmask16 mask = _mm512_cmp_ps_mask(value, lr.value, 0x01);
VecType dst = { _mm512_mask_blend_ps(mask, zero, one) } ;
return dst;
}
VecType operator<=(const VecType& lr) {
__m512 one = _mm512_set1_ps(1.0f);
__m512 zero = _mm512_set1_ps(0.0f);
__mmask16 mask = _mm512_cmp_ps_mask(value, lr.value, 0x02);
VecType dst = { _mm512_mask_blend_ps(mask, zero, one) } ;
return dst;
}
static VecType load(const float* addr) {
VecType v = { _mm512_loadu_ps(addr) };
return v;
Expand All @@ -270,6 +305,9 @@ struct Vec16 {
static void save(float* addr, const VecType& v) {
_mm512_storeu_ps(addr, v.value);
}
static void save(int32_t* addr, const VecType& v) {
_mm512_storeu_ps((float*)addr, v.value);
}
static VecType max(const VecType& v1, const VecType& v2) {
VecType dst = { _mm512_max_ps(v1.value, v2.value) };
return dst;
Expand Down
2 changes: 1 addition & 1 deletion test/op/BinaryOPTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class BinaryTestCommon : public MNNTestCase {
for (int i = 0; i < shape_y.size(); ++i) {
size_y *= shape_y[i];
}
for (int i = 0; i < shape_y.size(); ++i) {
for (int i = 0; i < shape_out.size(); ++i) {
size_out *= shape_out[i];
}
if (format == NC4HW4 && data_x.size() > size_x) {
Expand Down

0 comments on commit 476083a

Please sign in to comment.