Skip to content

Commit

Permalink
+add AMX-BF16 of class SynetConvolution16bNchwGemm.
Browse files Browse the repository at this point in the history
  • Loading branch information
ermig1979 committed Jul 25, 2024
1 parent ae93d8c commit c569d2c
Show file tree
Hide file tree
Showing 11 changed files with 442 additions and 14 deletions.
2 changes: 1 addition & 1 deletion docs/2024.html
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ <h5>New features</h5>
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function SynetRelu16b.</li>
<li>API of SynetAdd16b framework.</li>
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of class SynetAdd16bUniform.</li>
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of class SynetConvolution16bNchwGemm.</li>
<li>Base implementation, SSE4.1, AVX2, AVX-512BW optimizations, AMX-BF16 of class SynetConvolution16bNchwGemm.</li>
</ul>
<h5>Improving</h5>
<ul>
Expand Down
1 change: 1 addition & 0 deletions prj/vs2019/AmxBf16.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16b.cpp" />
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16bNchwGemm.cpp" />
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16bNhwcDirect.cpp" />
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16bNhwcGemm.cpp" />
<ClCompile Include="..\..\src\Simd\SimdAmxBf16BFloat16.cpp" />
Expand Down
3 changes: 3 additions & 0 deletions prj/vs2019/AmxBf16.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,8 @@
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetInnerProduct16b.cpp">
<Filter>AmxBf16</Filter>
</ClCompile>
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16bNchwGemm.cpp">
<Filter>AmxBf16</Filter>
</ClCompile>
</ItemGroup>
</Project>
1 change: 1 addition & 0 deletions prj/vs2022/AmxBf16.vcxproj
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16b.cpp" />
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16bNchwGemm.cpp" />
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16bNhwcDirect.cpp" />
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16bNhwcGemm.cpp" />
<ClCompile Include="..\..\src\Simd\SimdAmxBf16BFloat16.cpp" />
Expand Down
3 changes: 3 additions & 0 deletions prj/vs2022/AmxBf16.vcxproj.filters
Original file line number Diff line number Diff line change
Expand Up @@ -245,5 +245,8 @@
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetInnerProduct16b.cpp">
<Filter>AmxBf16</Filter>
</ClCompile>
<ClCompile Include="..\..\src\Simd\SimdAmxBf16SynetConvolution16bNchwGemm.cpp">
<Filter>AmxBf16</Filter>
</ClCompile>
</ItemGroup>
</Project>
2 changes: 1 addition & 1 deletion src/Simd/SimdAmxBf16SynetConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ namespace Simd
if (SynetConvolution16bNhwcGemm::Preferable(param))
return new AmxBf16::SynetConvolution16bNhwcGemm(param);
if (Base::SynetConvolution16bNchwGemm::Preferable(param))
return new Avx512bw::SynetConvolution16bNchwGemm(param);
return new AmxBf16::SynetConvolution16bNchwGemm(param);
return new Base::SynetConvolution16bGemm(param);
}
}
Expand Down
347 changes: 347 additions & 0 deletions src/Simd/SimdAmxBf16SynetConvolution16bNchwGemm.cpp

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions src/Simd/SimdSynetConvolution16b.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,14 @@ namespace Simd
virtual String Ext() const { return "AmxBf16"; }
};

class SynetConvolution16bNchwGemm : public Avx512bw::SynetConvolution16bNchwGemm
{
public:
SynetConvolution16bNchwGemm(const ConvParam& p);

virtual String Ext() const { return "AmxBf16"; }
};

//-------------------------------------------------------------------------------------------------

void* SynetConvolution16bInit(size_t batch, const SimdConvolutionParameters* conv, SimdSynetCompatibilityType compatibility);
Expand Down
68 changes: 66 additions & 2 deletions src/Simd/SimdSynetConvolution16bCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ namespace Simd
template<SimdConvolutionActivationType type, int index> static SIMD_INLINE void Save(uint8_t* ptr, float* buf, __m512 value, const __m512* bias, const __m512* params, __mmask16 tail = __mmask16(-1));
template<SimdConvolutionActivationType type, int index> static SIMD_INLINE void Apply(uint8_t* ptr, float* buf, const __m512* bias, const __m512* params, __mmask16 tail = __mmask16(-1));
template<SimdConvolutionActivationType type> static SIMD_INLINE void Postprocess(const float* src, const float* bias, const float* params, size_t offset, uint8_t* dst, __mmask16 tail = __mmask16(-1));
template<SimdConvolutionActivationType type, int index> static SIMD_INLINE void Apply(uint8_t* ptr, float* buf, const float* bias, const float* params, size_t offset, __mmask16 tail = __mmask16(-1));
};

template <> struct Term16b<Term16bLast16b>
Expand All @@ -910,7 +911,7 @@ namespace Simd
__m512 f32 = Activate<type>(_mm512_add_ps(value, bias[index]), params, index);
_mm256_mask_storeu_epi16((uint16_t*)ptr + index * F, tail, (__m256i)_mm512_cvtneps_pbh(f32));
_mm_prefetch((const char*)(ptr + index * DF), _MM_HINT_NTA);
_mm_prefetch((const char*)(buf + index * A), _MM_HINT_NTA);
_mm_prefetch((const char*)(buf + index * F), _MM_HINT_NTA);
}

template<SimdConvolutionActivationType type> static SIMD_INLINE void Postprocess(const float* src, const float* bias, const float* params, size_t offset, uint8_t* dst, __mmask16 tail = __mmask16(-1))
Expand All @@ -920,6 +921,15 @@ namespace Simd
//_mm_prefetch((const char*)(src + offset), _MM_HINT_NTA);
//_mm_prefetch((const char*)(dst + offset * 2), _MM_HINT_NTA);
}

template<SimdConvolutionActivationType type, int index> static SIMD_INLINE void Apply(uint8_t* ptr, float* buf, const float* bias, const float* params, size_t offset, __mmask16 tail = __mmask16(-1))
{
__m512 value = _mm512_maskz_loadu_ps(tail, buf + index * F);
__m512 f32 = ActivateNchw<type>(_mm512_add_ps(value, _mm512_set1_ps(bias[offset])), params, offset);
_mm256_mask_storeu_epi16((uint16_t*)ptr + index * F, tail, (__m256i)_mm512_cvtneps_pbh(f32));
//_mm_prefetch((const char*)(ptr + index * DF), _MM_HINT_NTA);
//_mm_prefetch((const char*)(buf + index * A), _MM_HINT_NTA);
}
};

template <> struct Term16b<Term16bLast32f>
Expand All @@ -939,7 +949,7 @@ namespace Simd
__m512 value = _mm512_maskz_loadu_ps(tail, buf + index * F);
_mm512_mask_storeu_ps((float*)ptr + index * F, tail, Activate<type>(_mm512_add_ps(value, bias[index]), params, index));
_mm_prefetch((const char*)(ptr + index * A), _MM_HINT_NTA);
_mm_prefetch((const char*)(buf + index * A), _MM_HINT_NTA);
_mm_prefetch((const char*)(buf + index * F), _MM_HINT_NTA);
}

template<SimdConvolutionActivationType type> static SIMD_INLINE void Postprocess(const float* src, const float* bias, const float* params, size_t offset, uint8_t* dst, __mmask16 tail = __mmask16(-1))
Expand All @@ -949,6 +959,15 @@ namespace Simd
//_mm_prefetch((const char*)(src + offset), _MM_HINT_NTA);
//_mm_prefetch((const char*)(dst + offset * 4), _MM_HINT_NTA);
}

template<SimdConvolutionActivationType type, int index> static SIMD_INLINE void Apply(uint8_t* ptr, float* buf, const float* bias, const float* params, size_t offset, __mmask16 tail = __mmask16(-1))
{
__m512 value = _mm512_maskz_loadu_ps(tail, buf + index * F);
__m512 f32 = ActivateNchw<type>(_mm512_add_ps(value, _mm512_set1_ps(bias[offset])), params, offset);
_mm512_mask_storeu_ps((float*)ptr + index * F, tail, f32);
//_mm_prefetch((const char*)(ptr + index * DF), _MM_HINT_NTA);
//_mm_prefetch((const char*)(buf + index * A), _MM_HINT_NTA);
}
};

template <> struct Term16b<Term16bInterim>
Expand All @@ -970,6 +989,10 @@ namespace Simd
template<SimdConvolutionActivationType type> static SIMD_INLINE void Postprocess(const float* src, const float* bias, const float* params, size_t offset, uint8_t* dst, __mmask16 tail = __mmask16(-1))
{
}

template<SimdConvolutionActivationType type, int index> static SIMD_INLINE void Apply(uint8_t* ptr, float* buf, const float* bias, const float* params, size_t offset, __mmask16 tail = __mmask16(-1))
{
}
};

//-------------------------------------------------------------------------------------------------
Expand All @@ -996,6 +1019,8 @@ namespace Simd
Term16b<term>::template Save<type, 1>(ptr, buf, val1, bias, params, tail);
}

//-------------------------------------------------------------------------------------------------

template<Term16bType term, SimdConvolutionActivationType type> SIMD_INLINE void Apply1(uint8_t* ptr, float* buf, const __m512* bias, const __m512* params, __mmask16 tail = __mmask16(-1))
{
Term16b<term>::template Apply<type, 0>(ptr, buf, bias, params, tail);
Expand Down Expand Up @@ -1031,10 +1056,49 @@ namespace Simd
Apply2<term, type>(ptr + 7 * dP, buf + 7 * dB, bias, params, tail);
}

//-------------------------------------------------------------------------------------------------

template<Term16bType term, SimdConvolutionActivationType type> SIMD_INLINE void Postprocess(const float* sum, const float* bias, const float* params, size_t offset, uint8_t* dst, __mmask16 tail = __mmask16(-1))
{
Term16b<term>::template Postprocess<type>(sum, bias, params, offset, dst, tail);
}

//-------------------------------------------------------------------------------------------------

template<Term16bType term, SimdConvolutionActivationType type> SIMD_INLINE void Apply1(uint8_t* ptr, float* buf, const float* bias, const float* params, size_t offset, __mmask16 tail = __mmask16(-1))
{
Term16b<term>::template Apply<type, 0>(ptr, buf, bias, params, offset, tail);
}

template<Term16bType term, SimdConvolutionActivationType type> SIMD_INLINE void Apply1x8(uint8_t* ptr, int dP, float* buf, int dB, const float* bias, const float* params, size_t offset, __mmask16 tail = __mmask16(-1))
{
Apply1<term, type>(ptr + 0 * dP, buf + 0 * dB, bias, params, offset + 0, tail);
Apply1<term, type>(ptr + 1 * dP, buf + 1 * dB, bias, params, offset + 1, tail);
Apply1<term, type>(ptr + 2 * dP, buf + 2 * dB, bias, params, offset + 2, tail);
Apply1<term, type>(ptr + 3 * dP, buf + 3 * dB, bias, params, offset + 3, tail);
Apply1<term, type>(ptr + 4 * dP, buf + 4 * dB, bias, params, offset + 4, tail);
Apply1<term, type>(ptr + 5 * dP, buf + 5 * dB, bias, params, offset + 5, tail);
Apply1<term, type>(ptr + 6 * dP, buf + 6 * dB, bias, params, offset + 6, tail);
Apply1<term, type>(ptr + 7 * dP, buf + 7 * dB, bias, params, offset + 7, tail);
}

template<Term16bType term, SimdConvolutionActivationType type> SIMD_INLINE void Apply2(uint8_t* ptr, float* buf, const float* bias, const float* params, size_t offset, __mmask16 tail = __mmask16(-1))
{
Term16b<term>::template Apply<type, 0>(ptr, buf, bias, params, offset);
Term16b<term>::template Apply<type, 1>(ptr, buf, bias, params, offset, tail);
}

template<Term16bType term, SimdConvolutionActivationType type> SIMD_INLINE void Apply2x8(uint8_t* ptr, int dP, float* buf, int dB, const float* bias, const float* params, size_t offset, __mmask16 tail = __mmask16(-1))
{
Apply2<term, type>(ptr + 0 * dP, buf + 0 * dB, bias, params, offset + 0, tail);
Apply2<term, type>(ptr + 1 * dP, buf + 1 * dB, bias, params, offset + 1, tail);
Apply2<term, type>(ptr + 2 * dP, buf + 2 * dB, bias, params, offset + 2, tail);
Apply2<term, type>(ptr + 3 * dP, buf + 3 * dB, bias, params, offset + 3, tail);
Apply2<term, type>(ptr + 4 * dP, buf + 4 * dB, bias, params, offset + 4, tail);
Apply2<term, type>(ptr + 5 * dP, buf + 5 * dB, bias, params, offset + 5, tail);
Apply2<term, type>(ptr + 6 * dP, buf + 6 * dB, bias, params, offset + 6, tail);
Apply2<term, type>(ptr + 7 * dP, buf + 7 * dB, bias, params, offset + 7, tail);
}
}
#endif
}
Expand Down
19 changes: 10 additions & 9 deletions src/Test/TestSynetConvolution16b.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,18 @@ namespace Test
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 16, 320, 320, 32, _2, _1, _1, _0, _1, 1, aRe, tT, b16, f32), c, f1, f2);
#endif
#if 1
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 44, 44, 256, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 44, 44, 256, _1, _1, _1, _0, _0, 1, aPr, tF, f32, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 44, 44, 256, _1, _1, _1, _0, _0, 1, aPr, tF, b16, f32), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 44, 44, 256, _1, _1, _1, _0, _0, 1, aPr, tT, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 48, 48, 256, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 48, 48, 256, _1, _1, _1, _0, _0, 1, aPr, tF, f32, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 48, 48, 256, _1, _1, _1, _0, _0, 1, aPr, tF, b16, f32), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 256, 48, 48, 256, _1, _1, _1, _0, _0, 1, aPr, tT, b16, b16), c, f1, f2);
#endif
#if 0
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 2255, 55, 55, 155, _1, _1, _1, _0, _0, 1, aId, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 15, 16, 55, _1, _1, _1, _0, _0, 1, aId, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 56, 15, 15, 56, _1, _1, _1, _0, _0, 1, aId, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 15, 15, 56, _1, _1, _1, _0, _0, 1, aId, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 15, 15, 55, _1, _1, _1, _0, _0, 1, aId, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 2255, 55, 55, 155, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 16, 16, 55, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 15, 16, 55, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 56, 15, 15, 56, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 15, 15, 56, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
result = result && SynetConvolution16bForwardAutoTest(eps, Param(1, 55, 15, 15, 55, _1, _1, _1, _0, _0, 1, aPr, tF, b16, b16), c, f1, f2);
#endif

#else
Expand Down
2 changes: 1 addition & 1 deletion src/Test/TestSynetConvolution32f.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ namespace Test
result = result && SynetConvolution32fForwardAutoTest(eps, Param(2, 192, 5, 5, 256, _3, _1, _1, _0, _0, 1, a, t), c, f1, f2);
#endif
#if 1
result = result && SynetConvolution32fForwardAutoTest(eps, Param(1, 256, 44, 44, 256, _1, _1, _1, _0, _0, 1, a, t), c, f1, f2);
result = result && SynetConvolution32fForwardAutoTest(eps, Param(1, 256, 48, 48, 256, _1, _1, _1, _0, _0, 1, a, t), c, f1, f2);
#endif
#else
result = result && SynetConvolution32fForwardAutoTest(eps, Param(1, 256, 44, 44, 256, _1, _1, _1, _0, _0, 1, a, t), c, f1, f2);
Expand Down

0 comments on commit c569d2c

Please sign in to comment.