-
Notifications
You must be signed in to change notification settings - Fork 1
/
gemm.cve
49 lines (41 loc) · 1.21 KB
/
gemm.cve
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#define UNROLL_FACTOR 9
#include <ve_kernel.h>
__ve_kernel__
void
gemm_op_d(
const double * A,
const double * B,
double * C)
{
const int ve_lvl = 256;
const int pg_off = __pg_ix * 256 * 256;
/* re-computing these adresses everytime is expensive! */
const double * my_A = A + pg_off;
const double * my_B = B + pg_off;
double * my_C = C + pg_off;
__vr accs[UNROLL_FACTOR];
for(int i = 0; i < 256; i += UNROLL_FACTOR)
{
#pragma unroll (UNROLL_FACTOR)
for(int i = 0; i < UNROLL_FACTOR; ++i)
accs[i] = _vel_pvbrd_vsl(0, ve_lvl);
for(int ii = 0; ii < 256; ++ii)
{
__vr row_B = _vel_vld_vssl(8, my_B + ii * 256, ve_lvl);
#pragma unroll (UNROLL_FACTOR)
for(int j = 0; j < UNROLL_FACTOR; ++j)
{
if(i + j < 256) accs[j] = _vel_vfmadd_vvsvl(accs[j],
my_A[(i + j) * 256 + ii], row_B, ve_lvl);
}
}
#pragma unroll (UNROLL_FACTOR)
for(int j = 0; j < UNROLL_FACTOR; ++j)
{
if(i + j < 256)
{
_vel_vst_vssl(accs[j], 8, my_C + (i + j) * 256, ve_lvl);
}
}
}
}