Skip to content

Commit

Permalink
Optimize multiplication with maddubs
Browse files Browse the repository at this point in the history
  • Loading branch information
Andrew Evstyukhin committed Nov 18, 2022
1 parent 1b84011 commit 4278993
Showing 1 changed file with 71 additions and 26 deletions.
97 changes: 71 additions & 26 deletions src/SnippetLevelsBufferHalf.h
Original file line number Diff line number Diff line change
Expand Up @@ -229,16 +229,15 @@ class LevelsBufferHalf final
{
const __m512i wtop = _mm512_broadcastw_epi16(mtop);

const __m128i mshift = _mm_cvtsi32_si128(pL << 2);
const __m256i vmask = _mm256_set1_epi8(0xF);
const __m512i wshift = _mm512_broadcastw_epi16(_mm_cvtsi32_si128(pL << 2));
const __m512i wmask = _mm512_set1_epi8(0xF);

__m512i wzeroes = _mm512_setzero_si512();

int c = cH;
for (;;)
{
__m512i wsum = _mm512_setzero_si512();
__m512i wsum1 = _mm512_setzero_si512();

int k = static_cast<int>(count);
const uint8_t** p = values;
Expand All @@ -254,23 +253,17 @@ class LevelsBufferHalf final
__m256i vdelta = _mm256_load_si256(p0);
__m256i vdelta1 = _mm256_load_si256(p1);

vdelta = _mm256_and_si256(_mm256_srl_epi16(vdelta, mshift), vmask);
vdelta1 = _mm256_and_si256(_mm256_srl_epi16(vdelta1, mshift), vmask);
__m512i wdelta = _mm512_or_epi64(_mm512_cvtepu8_epi16(vdelta), _mm512_slli_epi16(_mm512_cvtepu8_epi16(vdelta1), 8));

__m512i wadd = _mm512_cvtepu8_epi16(vdelta);
__m512i wadd1 = _mm512_cvtepu8_epi16(vdelta1);
wdelta = _mm512_and_epi32(_mm512_srlv_epi16(wdelta, wshift), wmask);

wadd = _mm512_mullo_epi16(wadd, wadd);
wadd1 = _mm512_mullo_epi16(wadd1, wadd1);
__m512i wadd = _mm512_maddubs_epi16(wdelta, wdelta);

wsum = _mm512_add_epi16(wsum, wadd);
wsum1 = _mm512_add_epi16(wsum1, wadd1);

p += 2;
}

wsum = _mm512_add_epi16(wsum, wsum1);

if (k & 1)
{
auto value = p[0];
Expand All @@ -279,7 +272,7 @@ class LevelsBufferHalf final

__m256i vdelta = _mm256_load_si256(p0);

vdelta = _mm256_and_si256(_mm256_srl_epi16(vdelta, mshift), vmask);
vdelta = _mm256_and_si256(_mm256_srlv_epi16(vdelta, _mm512_castsi512_si256(wshift)), _mm512_castsi512_si256(wmask));

__m512i wadd = _mm512_cvtepu8_epi16(vdelta);

Expand Down Expand Up @@ -347,8 +340,8 @@ class LevelsBufferHalf final
{
const __m256i vtop = _mm256_broadcastw_epi16(mtop);

const __m128i mshift = _mm_cvtsi32_si128(pL << 2);
const __m128i mmask = _mm_set1_epi8(0xF);
const __m256i vshift = _mm256_broadcastd_epi32(_mm_cvtsi32_si128(pL << 2));
const __m256i vmask = _mm256_set1_epi8(0xF);

__m256i vzeroes = _mm256_setzero_si256();

Expand All @@ -357,14 +350,40 @@ class LevelsBufferHalf final
{
__m256i vsum = _mm256_setzero_si256();

for (size_t i = 0; i < count; i++)
int k = static_cast<int>(count);
const uint8_t** p = values;

while ((k -= 2) >= 0)
{
auto value = values[i];
auto value = p[0];
auto value1 = p[1];

const __m128i* p = (const __m128i*)&value[c >> 1];
const __m128i* p0 = (const __m128i*)&value[c >> 1];
const __m128i* p1 = (const __m128i*)&value1[c >> 1];

__m128i mdelta = _mm_load_si128(p);
mdelta = _mm_and_si128(_mm_srl_epi16(mdelta, mshift), mmask);
__m128i mdelta = _mm_load_si128(p0);
__m128i mdelta1 = _mm_load_si128(p1);

__m256i vdelta = _mm256_or_si256(_mm256_cvtepu8_epi16(mdelta), _mm256_slli_epi16(_mm256_cvtepu8_epi16(mdelta1), 8));

vdelta = _mm256_and_si256(_mm256_srlv_epi32(vdelta, vshift), vmask);

__m256i vadd = _mm256_maddubs_epi16(vdelta, vdelta);

vsum = _mm256_add_epi16(vsum, vadd);

p += 2;
}

if (k & 1)
{
auto value = p[0];

const __m128i* p0 = (const __m128i*)&value[c >> 1];

__m128i mdelta = _mm_load_si128(p0);

mdelta = _mm_and_si128(_mm_srlv_epi32(mdelta, _mm256_castsi256_si128(vshift)), _mm256_castsi256_si128(vmask));

__m256i vadd = _mm256_cvtepu8_epi16(mdelta);

Expand Down Expand Up @@ -435,18 +454,44 @@ class LevelsBufferHalf final
{
__m128i msum = _mm_setzero_si128();

for (size_t i = 0; i < count; i++)
int k = static_cast<int>(count);
const uint8_t** p = values;

while ((k -= 2) >= 0)
{
auto value = values[i];
auto value = p[0];
auto value1 = p[1];

const __m128i* p = (const __m128i*)&value[c >> 1];
const __m128i* p0 = (const __m128i*)&value[c >> 1];
const __m128i* p1 = (const __m128i*)&value1[c >> 1];

__m128i mdelta = _mm_loadl_epi64(p0);
__m128i mdelta1 = _mm_loadl_epi64(p1);

mdelta = _mm_unpacklo_epi8(mdelta, mdelta1);

__m128i mdelta = _mm_loadl_epi64(p);
mdelta = _mm_and_si128(_mm_srl_epi16(mdelta, mshift), mmask);

__m128i madd = _mm_cvtepu8_epi16(mdelta);
__m128i madd = _mm_maddubs_epi16(mdelta, mdelta);

msum = _mm_add_epi16(msum, madd);

p += 2;
}

if (k & 1)
{
auto value = p[0];

const __m128i* p0 = (const __m128i*)&value[c >> 1];

__m128i mdelta = _mm_loadl_epi64(p0);

mdelta = _mm_cvtepu8_epi16(mdelta);

mdelta = _mm_and_si128(_mm_srl_epi16(mdelta, mshift), mmask);

madd = _mm_mullo_epi16(madd, madd);
__m128i madd = _mm_mullo_epi16(mdelta, mdelta);

msum = _mm_add_epi16(msum, madd);
}
Expand Down

0 comments on commit 4278993

Please sign in to comment.