Skip to content

Commit

Permalink
[software] Remove load of che inputs from inner loop
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Jul 5, 2024
1 parent f3f9212 commit 491a050
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 64 deletions.
172 changes: 116 additions & 56 deletions software/kernels/baremetal/mempool_chest_f16.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
out[i][j + 3] = a[i] / c[j + 3]*/

#ifdef __XDIVSQRT
#define DIV_LOOP(ab, ab_n, i) \
#define DIV_LOOP(ab, ab_n) \
{ \
re0 = 0; \
re1 = 0; \
Expand All @@ -30,10 +30,6 @@
D1 = 0; \
D2 = 0; \
D3 = 0; \
cd0 = *(uint32_t *)&pPilotTX_itr[2U * j]; \
cd1 = *(uint32_t *)&pPilotTX_itr[2U * (j + 1)]; \
cd2 = *(uint32_t *)&pPilotTX_itr[2U * (j + 2)]; \
cd3 = *(uint32_t *)&pPilotTX_itr[2U * (j + 3)]; \
asm volatile("vfdotpex.s.h %[D0], %[cd0], %[cd0];" \
"vfdotpex.s.h %[D1], %[cd1], %[cd1];" \
"vfdotpex.s.h %[D2], %[cd2], %[cd2];" \
Expand Down Expand Up @@ -65,13 +61,9 @@
: [cd0] "r"(cd0), [cd1] "r"(cd1), [cd2] "r"(cd2), \
[cd3] "r"(cd3), [x] "r"(ab), [y] "r"(ab_n) \
:); \
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = re0; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = re1; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = re2; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = re3; \
}
#else
#define DIV_LOOP(ab, ab_n, i) \
#define DIV_LOOP(ab, ab_n) \
{ \
re0 = 0; \
re1 = 0; \
Expand All @@ -85,10 +77,6 @@
D1 = 0; \
D2 = 0; \
D3 = 0; \
cd0 = *(uint32_t *)&pPilotTX_itr[2U * j]; \
cd1 = *(uint32_t *)&pPilotTX_itr[2U * (j + 1)]; \
cd2 = *(uint32_t *)&pPilotTX_itr[2U * (j + 2)]; \
cd3 = *(uint32_t *)&pPilotTX_itr[2U * (j + 3)]; \
asm volatile("vfdotpex.s.h %[D0], %[cd0], %[cd0];" \
"vfdotpex.s.h %[D1], %[cd1], %[cd1];" \
"vfdotpex.s.h %[D2], %[cd2], %[cd2];" \
Expand Down Expand Up @@ -126,10 +114,6 @@
: [cd0] "r"(cd0), [cd1] "r"(cd1), [cd2] "r"(cd2), \
[cd3] "r"(cd3), [x] "r"(ab), [y] "r"(ab_n) \
:); \
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = re0; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = re1; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = re2; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = re3; \
}
#endif

Expand All @@ -140,7 +124,7 @@
out[i][j + 2] = a[i] * c[j + 2]
out[i][j + 3] = a[i] * c[j + 3]*/

#define MUL_LOOP(ab, ab_n, i) \
#define MUL_LOOP(ab, ab_n) \
{ \
re0 = 0; \
re1 = 0; \
Expand All @@ -150,10 +134,6 @@
im1 = 0; \
im2 = 0; \
im3 = 0; \
cd0 = *(uint32_t *)&pPilotTX_itr[2U * j]; \
cd1 = *(uint32_t *)&pPilotTX_itr[2U * (j + 1)]; \
cd2 = *(uint32_t *)&pPilotTX_itr[2U * (j + 2)]; \
cd3 = *(uint32_t *)&pPilotTX_itr[2U * (j + 3)]; \
asm volatile("vfdotpex.s.h %[re0], %[x], %[cd0];" \
"vfdotpex.s.h %[re1], %[x], %[cd1];" \
"vfdotpex.s.h %[re2], %[x], %[cd2];" \
Expand All @@ -178,22 +158,14 @@
[im2] "+&r"(im2), [im3] "+&r"(im3) \
: [cd0] "r"(cd0), [cd1] "r"(cd1), [cd2] "r"(cd2), [cd3] "r"(cd3) \
:); \
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = re0; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = re1; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = re2; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = re3; \
}

#define CMUL_LOOP(ab, i) \
#define CMUL_LOOP(ab) \
{ \
sum0 = 0; \
sum1 = 0; \
sum2 = 0; \
sum3 = 0; \
cd0 = *(uint32_t *)&pPilotTX_itr[2U * j]; \
cd1 = *(uint32_t *)&pPilotTX_itr[2U * (j + 1)]; \
cd2 = *(uint32_t *)&pPilotTX_itr[2U * (j + 2)]; \
cd3 = *(uint32_t *)&pPilotTX_itr[2U * (j + 3)]; \
asm volatile("fcdotpex.s.h %[sum0], %[x], %[cd0];" \
"fcdotpex.s.h %[sum1], %[x], %[cd1];" \
"fcdotpex.s.h %[sum2], %[x], %[cd2];" \
Expand All @@ -203,10 +175,6 @@
: [cd0] "r"(cd0), [cd1] "r"(cd1), [cd2] "r"(cd2), \
[cd3] "r"(cd3), [x] "r"(ab) \
:); \
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = sum0; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = sum1; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = sum2; \
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = sum3; \
}

#define SHUFFLE_A \
Expand Down Expand Up @@ -262,10 +230,30 @@ void mempool_chest_f16s_unrolled4(__fp16 *pH, __fp16 *pPilotRX,
ab3 = *(uint32_t *)&pPilotRX_itr[2U * (i + 3)];
SHUFFLE_A;
for (uint32_t j = 0; j < nTX; j += 4) {
DIV_LOOP(ab0, ab_n0, i);
DIV_LOOP(ab1, ab_n1, i + 1);
DIV_LOOP(ab2, ab_n2, i + 2);
DIV_LOOP(ab3, ab_n3, i + 3);
cd0 = *(uint32_t *)&pPilotTX_itr[2U * j];
cd1 = *(uint32_t *)&pPilotTX_itr[2U * (j + 1)];
cd2 = *(uint32_t *)&pPilotTX_itr[2U * (j + 2)];
cd3 = *(uint32_t *)&pPilotTX_itr[2U * (j + 3)];
DIV_LOOP(ab0, ab_n0);
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = re3;
DIV_LOOP(ab1, ab_n1);
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 3)]) = re3;
DIV_LOOP(ab2, ab_n2);
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 3)]) = re3;
DIV_LOOP(ab3, ab_n3);
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 3)]) = re3;
}
}
}
Expand Down Expand Up @@ -320,21 +308,73 @@ void mempool_chest_f16p_unrolled4(__fp16 *pH, __fp16 *pPilotRX,
#endif

for (uint32_t j = 0; j < nTX; j += 4) {
cd0 = *(uint32_t *)&pPilotTX_itr[2U * j];
cd1 = *(uint32_t *)&pPilotTX_itr[2U * (j + 1)];
cd2 = *(uint32_t *)&pPilotTX_itr[2U * (j + 2)];
cd3 = *(uint32_t *)&pPilotTX_itr[2U * (j + 3)];
#if (defined(__CDOTP) && defined(__MUL))
CMUL_LOOP(ab0, i);
CMUL_LOOP(ab1, i + 1);
CMUL_LOOP(ab2, i + 2);
CMUL_LOOP(ab3, i + 3);
CMUL_LOOP(ab0);
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = sum0;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = sum1;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = sum2;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = sum3;
CMUL_LOOP(ab1);
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j)]) = sum0;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 1)]) = sum1;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 2)]) = sum2;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 3)]) = sum3;
CMUL_LOOP(ab2);
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j)]) = sum0;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 1)]) = sum1;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 2)]) = sum2;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 3)]) = sum3;
CMUL_LOOP(ab3);
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j)]) = sum0;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 1)]) = sum1;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 2)]) = sum2;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 3)]) = sum3;
#elif (!defined(__CDOTP) && defined(__MUL))
MUL_LOOP(ab0, ab_n0, i);
MUL_LOOP(ab1, ab_n1, i + 1);
MUL_LOOP(ab2, ab_n2, i + 2);
MUL_LOOP(ab3, ab_n3, i + 3);
MUL_LOOP(ab0, ab_n0);
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = re3;
MUL_LOOP(ab1, ab_n1);
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 3)]) = re3;
MUL_LOOP(ab2, ab_n2);
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 3)]) = re3;
MUL_LOOP(ab3, ab_n3);
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 3)]) = re3;
#else
DIV_LOOP(ab0, ab_n0, i)
DIV_LOOP(ab1, ab_n1, i + 1)
DIV_LOOP(ab2, ab_n2, i + 2)
DIV_LOOP(ab3, ab_n3, i + 3)
DIV_LOOP(ab0, ab_n0);
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = re3;
DIV_LOOP(ab1, ab_n1);
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 3)]) = re3;
DIV_LOOP(ab2, ab_n2);
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 3)]) = re3;
DIV_LOOP(ab3, ab_n3);
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j)]) = re0;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 1)]) = re1;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 2)]) = re2;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 3)]) = re3;
#endif
}
}
Expand Down Expand Up @@ -371,10 +411,30 @@ void mempool_chest_f16p_unrolled4_local(__fp16 *volatile pH,
ab2 = *(uint32_t *)&pPilotRX_itr[2U * (i + 2)];
ab3 = *(uint32_t *)&pPilotRX_itr[2U * (i + 3)];
for (j = 0; j < nTX; j += 4) {
CMUL_LOOP(ab0, i);
CMUL_LOOP(ab1, i + 1);
CMUL_LOOP(ab2, i + 2);
CMUL_LOOP(ab3, i + 3);
cd0 = *(uint32_t *)&pPilotTX_itr[2U * j];
cd1 = *(uint32_t *)&pPilotTX_itr[2U * (j + 1)];
cd2 = *(uint32_t *)&pPilotTX_itr[2U * (j + 2)];
cd3 = *(uint32_t *)&pPilotTX_itr[2U * (j + 3)];
CMUL_LOOP(ab0);
*((uint32_t *)&pH_itr[2 * (i * nTX + j)]) = sum0;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 1)]) = sum1;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 2)]) = sum2;
*((uint32_t *)&pH_itr[2 * (i * nTX + j + 3)]) = sum3;
CMUL_LOOP(ab1);
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j)]) = sum0;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 1)]) = sum1;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 2)]) = sum2;
*((uint32_t *)&pH_itr[2 * ((i + 1) * nTX + j + 3)]) = sum3;
CMUL_LOOP(ab2);
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j)]) = sum0;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 1)]) = sum1;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 2)]) = sum2;
*((uint32_t *)&pH_itr[2 * ((i + 2) * nTX + j + 3)]) = sum3;
CMUL_LOOP(ab3);
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j)]) = sum0;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 1)]) = sum1;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 2)]) = sum2;
*((uint32_t *)&pH_itr[2 * ((i + 3) * nTX + j + 3)]) = sum3;
}
}
mempool_barrier(nPE);
Expand Down
20 changes: 12 additions & 8 deletions software/kernels/baremetal/mempool_chest_q16.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@

#define DIV_LOOP(ab, ab_n) \
{ \
cd0 = *(v2s *)&pPilotTX_itr[2U * j]; \
cd1 = *(v2s *)&pPilotTX_itr[2U * (j + 1)]; \
cd2 = *(v2s *)&pPilotTX_itr[2U * (j + 2)]; \
cd3 = *(v2s *)&pPilotTX_itr[2U * (j + 3)]; \
D0 = (1 << 16U) / __DOTP2(cd0, cd0); \
D1 = (1 << 16U) / __DOTP2(cd1, cd1); \
D2 = (1 << 16U) / __DOTP2(cd2, cd2); \
Expand Down Expand Up @@ -54,10 +50,6 @@

#define MUL_LOOP(ab, ab_n) \
{ \
cd0 = *(v2s *)&pPilotTX_itr[2U * j]; \
cd1 = *(v2s *)&pPilotTX_itr[2U * (j + 1)]; \
cd2 = *(v2s *)&pPilotTX_itr[2U * (j + 2)]; \
cd3 = *(v2s *)&pPilotTX_itr[2U * (j + 3)]; \
re0 = __DOTP2(ab_n, cd0); \
re1 = __DOTP2(ab_n, cd1); \
re2 = __DOTP2(ab_n, cd2); \
Expand Down Expand Up @@ -173,6 +165,10 @@ void mempool_chest_q16s_unrolled4(int16_t *pH, int16_t *pPilotRX,
ab3 = *(v2s *)&pPilotRX_itr[2U * (i + 3)];
SHUFFLE_A;
for (j = 0; j < nTX; j += 4) {
cd0 = *(v2s *)&pPilotTX_itr[2U * j];
cd1 = *(v2s *)&pPilotTX_itr[2U * (j + 1)];
cd2 = *(v2s *)&pPilotTX_itr[2U * (j + 2)];
cd3 = *(v2s *)&pPilotTX_itr[2U * (j + 3)];
#ifdef __MUL
MUL_LOOP(ab0, ab_n0);
*((v2s *)&pH_itr[2 * (i * nTX + j)]) = (v2s)re0;
Expand Down Expand Up @@ -261,6 +257,10 @@ void mempool_chest_q16p_unrolled4(int16_t *volatile pH,
ab3 = *(v2s *)&pPilotRX_itr[2U * (i + 3)];
SHUFFLE_A;
for (uint32_t j = 0; j < nTX; j += 4) {
cd0 = *(v2s *)&pPilotTX_itr[2U * j];
cd1 = *(v2s *)&pPilotTX_itr[2U * (j + 1)];
cd2 = *(v2s *)&pPilotTX_itr[2U * (j + 2)];
cd3 = *(v2s *)&pPilotTX_itr[2U * (j + 3)];
#ifdef __MUL
MUL_LOOP(ab0, ab_n0);
*((v2s *)&pH_itr[2 * (i * nTX + j)]) = (v2s)re0;
Expand Down Expand Up @@ -343,6 +343,10 @@ void mempool_chest_q16p_unrolled4_local(int16_t *volatile pH,
ab3 = *(v2s *)&pPilotRX_itr[2U * (i + 3)];
SHUFFLE_A;
for (j = 0; j < nTX; j += 4) {
cd0 = *(v2s *)&pPilotTX_itr[2U * j];
cd1 = *(v2s *)&pPilotTX_itr[2U * (j + 1)];
cd2 = *(v2s *)&pPilotTX_itr[2U * (j + 2)];
cd3 = *(v2s *)&pPilotTX_itr[2U * (j + 3)];
MUL_LOOP(ab0, ab_n0);
*((v2s *)&pH_itr[2 * (i * nTX + j)]) = (v2s)re0;
*((v2s *)&pH_itr[2 * (i * nTX + j + 1)]) = (v2s)re1;
Expand Down

0 comments on commit 491a050

Please sign in to comment.