Skip to content

Commit

Permalink
[software] Fix f16 matmul
Browse files Browse the repository at this point in the history
  • Loading branch information
mbertuletti committed Jan 8, 2024
1 parent d384b79 commit aa0f848
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 154 deletions.
4 changes: 2 additions & 2 deletions software/apps/matmul_f16/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ int main() {
mempool_start_benchmark();
matmul_2x2_single_f16(matrix_a, matrix_b, matrix_c, matrix_M, matrix_N,
matrix_P);
mempool_barrier(num_cores);
mempool_stop_benchmark();
}
mempool_barrier(num_cores);
#endif

#if defined(PARALLEL)
Expand All @@ -58,7 +58,7 @@ int main() {
mempool_stop_benchmark();
#endif

mempool_check_f16(matrix_c, C, matrix_M * matrix_P, 0.1f, 0);
mempool_check_f16(matrix_c, C, matrix_M * matrix_P, 10.0f, 0);
mempool_barrier(num_cores);
return 0;
}
2 changes: 1 addition & 1 deletion software/apps/matmul_f32/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// Licensed under the Apache License, Version 2.0, see LICENSE for details.
// SPDX-License-Identifier: Apache-2.0

// Author: Samuel Riedel, ETH Zurich
// Author: Marco Bertuletti, ETH Zurich

#include <stdint.h>
#include <string.h>
Expand Down
2 changes: 1 addition & 1 deletion software/runtime/data/data_matmul_f16.h.tpl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
i = 0
out += '\n'
for a in array:
out += '(__fp16){}f, '.format(a)
out += '(__fp16){:.4f}f, '.format(a)
i += 1
if i % 8 == 0:
out += '\n'
Expand Down
4 changes: 2 additions & 2 deletions software/runtime/data/data_matmul_f16.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ def main():
matrix_P = args.dim_p

# Create sparse matrix
A = np.random.rand(matrix_M, matrix_N)
B = np.random.rand(matrix_N, matrix_P)
A = np.random.rand(matrix_M, matrix_N).astype(np.float16)
B = np.random.rand(matrix_N, matrix_P).astype(np.float16)
C = np.matmul(A, B)

A = np.reshape(A, (matrix_M * matrix_N), order='C').astype(np.float16)
Expand Down
235 changes: 87 additions & 148 deletions software/runtime/kernel/matmul_f16.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ void matmul_2x2_single_f16(__fp16 const *__restrict__ A,
uint32_t M, uint32_t N, uint32_t P) {
for (uint32_t i = 0; i < M; i += 2) {
for (uint32_t j = 0; j < P; j += 2) {
__fp16 c00 = 0.0f;
__fp16 c01 = 0.0f;
__fp16 c10 = 0.0f;
__fp16 c11 = 0.0f;
__fp16 c00 = (__fp16)0.0f;
__fp16 c01 = (__fp16)0.0f;
__fp16 c10 = (__fp16)0.0f;
__fp16 c11 = (__fp16)0.0f;
for (uint32_t k = 0; k < N; k += 2) {
// Explicitly load the values first to help with scheduling
__fp16 val_a00 = A[(i + 0) * N + k + 0];
Expand All @@ -32,58 +32,28 @@ void matmul_2x2_single_f16(__fp16 const *__restrict__ A,
__fp16 val_b01 = B[(k + 0) * P + j + 1];
__fp16 val_b10 = B[(k + 1) * P + j + 0];
__fp16 val_b11 = B[(k + 1) * P + j + 1];
__fp16 mul00 = 0.0f;
__fp16 mul01 = 0.0f;
__fp16 mul10 = 0.0f;
__fp16 mul11 = 0.0f;
asm volatile(
"fmadd.h %[c00], %[val_a00], %[val_b00], %[c00];"
"fmadd.h %[c01], %[val_a00], %[val_b01], %[c01];"
"fmadd.h %[c10], %[val_a10], %[val_b00], %[c10];"
"fmadd.h %[c11], %[val_a10], %[val_b01], %[c11];"
"fmadd.h %[c00], %[val_a01], %[val_b10], %[c00];"
"fmadd.h %[c01], %[val_a01], %[val_b11], %[c01];"
"fmadd.h %[c10], %[val_a11], %[val_b10], %[c10];"
"fmadd.h %[c11], %[val_a11], %[val_b11], %[c11];"
: [c00] "+&r"(c00), [c01] "+&r"(c01), [c10] "+&r"(c10),
[c11] "+&r"(c11), [mul00] "+&r"(mul00), [mul01] "+&r"(mul01),
[mul10] "+&r"(mul10), [mul11] "+&r"(mul11)
: [val_a00] "r"(val_a00), [val_a01] "r"(val_a01),
[val_a10] "r"(val_a10), [val_a11] "r"(val_a11),
[val_b00] "r"(val_b00), [val_b01] "r"(val_b01),
[val_b10] "r"(val_b10), [val_b11] "r"(val_b11));
// asm volatile(
// "fmul.h %[mul00], %[val_a00], %[val_b00];"
// "fmul.h %[mul01], %[val_a01], %[val_b10];"
// "fmul.h %[mul10], %[val_a00], %[val_b01];"
// "fmul.h %[mul11], %[val_a01], %[val_b11];"
// "fadd.h %[c00], %[c00], %[mul00];"
// "fadd.h %[c00], %[c00], %[mul01];"
// "fadd.h %[c01], %[c01], %[mul10];"
// "fadd.h %[c01], %[c01], %[mul11];"
// "fmul.h %[mul00], %[val_a10], %[val_b00];"
// "fmul.h %[mul01], %[val_a11], %[val_b10];"
// "fmul.h %[mul10], %[val_a10], %[val_b01];"
// "fmul.h %[mul11], %[val_a11], %[val_b11];"
// "fadd.h %[c10], %[c10], %[mul00];"
// "fadd.h %[c10], %[c10], %[mul01];"
// "fadd.h %[c11], %[c11], %[mul10];"
// "fadd.h %[c11], %[c11], %[mul11];"
// : [c00] "+&r"(c00), [c01] "+&r"(c01), [c10] "+&r"(c10),
// [c11] "+&r"(c11),
// [mul00] "+&r"(mul00), [mul01] "+&r"(mul01), [mul10]
// "+&r"(mul10), [mul11] "+&r"(mul11)
// : [val_a00] "r"(val_a00), [val_a01] "r"(val_a01), [val_a10]
// "r"(val_a10), [val_a11] "r"(val_a11),
// [val_b00] "r"(val_b00), [val_b01] "r"(val_b01), [val_b10]
// "r"(val_b10), [val_b11] "r"(val_b11));
asm volatile("fmadd.h %[c00], %[val_a00], %[val_b00], %[c00];"
"fmadd.h %[c00], %[val_a01], %[val_b10], %[c00];"
"fmadd.h %[c01], %[val_a00], %[val_b01], %[c01];"
"fmadd.h %[c01], %[val_a01], %[val_b11], %[c01];"
"fmadd.h %[c10], %[val_a10], %[val_b00], %[c10];"
"fmadd.h %[c10], %[val_a11], %[val_b10], %[c10];"
"fmadd.h %[c11], %[val_a10], %[val_b01], %[c11];"
"fmadd.h %[c11], %[val_a11], %[val_b11], %[c11];"
: [c00] "+&r"(c00), [c01] "+&r"(c01), [c10] "+&r"(c10),
[c11] "+&r"(c11)
: [val_a00] "r"(val_a00), [val_a01] "r"(val_a01),
[val_a10] "r"(val_a10), [val_a11] "r"(val_a11),
[val_b00] "r"(val_b00), [val_b01] "r"(val_b01),
[val_b10] "r"(val_b10), [val_b11] "r"(val_b11));
}
C[(i + 0) * P + j + 0] = c00;
C[(i + 0) * P + j + 1] = c01;
C[(i + 1) * P + j + 0] = c10;
C[(i + 1) * P + j + 1] = c11;
}
}
return;
}

void matmul_2x2_parallel_f16(__fp16 const *__restrict__ A,
Expand All @@ -96,10 +66,10 @@ void matmul_2x2_parallel_f16(__fp16 const *__restrict__ A,
uint32_t const c_end = (P / c) * ((id % c) + 1);
for (uint32_t i = 2 * (id / c); i < M; i += 2 * (numThreads / c)) {
for (uint32_t j = c_start; j < c_end; j += 2) {
__fp16 c00 = 0.0f;
__fp16 c01 = 0.0f;
__fp16 c10 = 0.0f;
__fp16 c11 = 0.0f;
__fp16 c00 = (__fp16)0.0f;
__fp16 c01 = (__fp16)0.0f;
__fp16 c10 = (__fp16)0.0f;
__fp16 c11 = (__fp16)0.0f;
for (uint32_t k = 0; k < N; k += 2) {
// Explicitly load the values first to help with scheduling
__fp16 val_a00 = A[(i + 0) * N + k + 0];
Expand All @@ -110,51 +80,20 @@ void matmul_2x2_parallel_f16(__fp16 const *__restrict__ A,
__fp16 val_b01 = B[(k + 0) * P + j + 1];
__fp16 val_b10 = B[(k + 1) * P + j + 0];
__fp16 val_b11 = B[(k + 1) * P + j + 1];
__fp16 mul00 = 0.0f;
__fp16 mul01 = 0.0f;
__fp16 mul10 = 0.0f;
__fp16 mul11 = 0.0f;
asm volatile(
"fmadd.h %[c00], %[val_a00], %[val_b00], %[c00];"
"fmadd.h %[c01], %[val_a00], %[val_b01], %[c01];"
"fmadd.h %[c10], %[val_a10], %[val_b00], %[c10];"
"fmadd.h %[c11], %[val_a10], %[val_b01], %[c11];"
"fmadd.h %[c00], %[val_a01], %[val_b10], %[c00];"
"fmadd.h %[c01], %[val_a01], %[val_b11], %[c01];"
"fmadd.h %[c10], %[val_a11], %[val_b10], %[c10];"
"fmadd.h %[c11], %[val_a11], %[val_b11], %[c11];"
: [c00] "+&r"(c00), [c01] "+&r"(c01), [c10] "+&r"(c10),
[c11] "+&r"(c11), [mul00] "+&r"(mul00), [mul01] "+&r"(mul01),
[mul10] "+&r"(mul10), [mul11] "+&r"(mul11)
: [val_a00] "r"(val_a00), [val_a01] "r"(val_a01),
[val_a10] "r"(val_a10), [val_a11] "r"(val_a11),
[val_b00] "r"(val_b00), [val_b01] "r"(val_b01),
[val_b10] "r"(val_b10), [val_b11] "r"(val_b11));
// asm volatile(
// "fmul.h %[mul00], %[val_a00], %[val_b00];"
// "fmul.h %[mul01], %[val_a01], %[val_b10];"
// "fmul.h %[mul10], %[val_a00], %[val_b01];"
// "fmul.h %[mul11], %[val_a01], %[val_b11];"
// "fadd.h %[c00], %[c00], %[mul00];"
// "fadd.h %[c00], %[c00], %[mul01];"
// "fadd.h %[c01], %[c01], %[mul10];"
// "fadd.h %[c01], %[c01], %[mul11];"
// "fmul.h %[mul00], %[val_a10], %[val_b00];"
// "fmul.h %[mul01], %[val_a11], %[val_b10];"
// "fmul.h %[mul10], %[val_a10], %[val_b01];"
// "fmul.h %[mul11], %[val_a11], %[val_b11];"
// "fadd.h %[c10], %[c10], %[mul00];"
// "fadd.h %[c10], %[c10], %[mul01];"
// "fadd.h %[c11], %[c11], %[mul10];"
// "fadd.h %[c11], %[c11], %[mul11];"
// : [c00] "+&r"(c00), [c01] "+&r"(c01), [c10] "+&r"(c10),
// [c11] "+&r"(c11),
// [mul00] "+&r"(mul00), [mul01] "+&r"(mul01), [mul10]
// "+&r"(mul10), [mul11] "+&r"(mul11)
// : [val_a00] "r"(val_a00), [val_a01] "r"(val_a01), [val_a10]
// "r"(val_a10), [val_a11] "r"(val_a11),
// [val_b00] "r"(val_b00), [val_b01] "r"(val_b01), [val_b10]
// "r"(val_b10), [val_b11] "r"(val_b11));
asm volatile("fmadd.h %[c00], %[val_a00], %[val_b00], %[c00];"
"fmadd.h %[c01], %[val_a00], %[val_b01], %[c01];"
"fmadd.h %[c10], %[val_a10], %[val_b00], %[c10];"
"fmadd.h %[c11], %[val_a10], %[val_b01], %[c11];"
"fmadd.h %[c00], %[val_a01], %[val_b10], %[c00];"
"fmadd.h %[c01], %[val_a01], %[val_b11], %[c01];"
"fmadd.h %[c10], %[val_a11], %[val_b10], %[c10];"
"fmadd.h %[c11], %[val_a11], %[val_b11], %[c11];"
: [c00] "+&r"(c00), [c01] "+&r"(c01), [c10] "+&r"(c10),
[c11] "+&r"(c11)
: [val_a00] "r"(val_a00), [val_a01] "r"(val_a01),
[val_a10] "r"(val_a10), [val_a11] "r"(val_a11),
[val_b00] "r"(val_b00), [val_b01] "r"(val_b01),
[val_b10] "r"(val_b10), [val_b11] "r"(val_b11));
}
C[(i + 0) * P + j + 0] = c00;
C[(i + 0) * P + j + 1] = c01;
Expand All @@ -174,31 +113,31 @@ void matmul_4x2_parallel_f16(const __fp16 *__restrict__ pSrcA,
uint32_t j = 0; // loop counter for N
uint32_t k = 0; // loop counter for P

for (k = core_id; k < P / 2; k += numThreads) {
for (i = 0; i < M / 4; i++) {
__fp16 c00 = 0.0f;
__fp16 c01 = 0.0f;
__fp16 c10 = 0.0f;
__fp16 c11 = 0.0f;
__fp16 c20 = 0.0f;
__fp16 c21 = 0.0f;
__fp16 c30 = 0.0f;
__fp16 c31 = 0.0f;
for (j = 0; j < N / 2; j++) {

__fp16 a00 = pSrcA[(i * 4) * N + (j * 2)];
__fp16 a01 = pSrcA[(i * 4) * N + (j * 2) + 1];
__fp16 a10 = pSrcA[(i * 4 + 1) * N + (j * 2)];
__fp16 a11 = pSrcA[(i * 4 + 1) * N + (j * 2) + 1];
__fp16 a20 = pSrcA[(i * 4 + 2) * N + (j * 2)];
__fp16 a21 = pSrcA[(i * 4 + 2) * N + (j * 2) + 1];
__fp16 a30 = pSrcA[(i * 4 + 3) * N + (j * 2)];
__fp16 a31 = pSrcA[(i * 4 + 3) * N + (j * 2) + 1];

__fp16 b00 = pSrcB[(j * 2) * P + (k * 2)];
__fp16 b01 = pSrcB[(j * 2) * P + (k * 2) + 1];
__fp16 b10 = pSrcB[(j * 2 + 1) * P + (k * 2)];
__fp16 b11 = pSrcB[(j * 2 + 1) * P + (k * 2) + 1];
for (k = 2 * core_id; k < P; k += 2 * numThreads) {
for (i = 0; i < M; i += 4) {
__fp16 c00 = (__fp16)0.0f;
__fp16 c01 = (__fp16)0.0f;
__fp16 c10 = (__fp16)0.0f;
__fp16 c11 = (__fp16)0.0f;
__fp16 c20 = (__fp16)0.0f;
__fp16 c21 = (__fp16)0.0f;
__fp16 c30 = (__fp16)0.0f;
__fp16 c31 = (__fp16)0.0f;
for (j = 0; j < N; j += 2) {

__fp16 a00 = pSrcA[i * N + j];
__fp16 a01 = pSrcA[i * N + j + 1];
__fp16 a10 = pSrcA[(i + 1) * N + j];
__fp16 a11 = pSrcA[(i + 1) * N + j + 1];
__fp16 a20 = pSrcA[(i + 2) * N + j];
__fp16 a21 = pSrcA[(i + 2) * N + j + 1];
__fp16 a30 = pSrcA[(i + 3) * N + j];
__fp16 a31 = pSrcA[(i + 3) * N + j + 1];

__fp16 b00 = pSrcB[j * P + k];
__fp16 b01 = pSrcB[j * P + k + 1];
__fp16 b10 = pSrcB[(j + 1) * P + k];
__fp16 b11 = pSrcB[(j + 1) * P + k + 1];

asm volatile(

Expand All @@ -219,22 +158,22 @@ void matmul_4x2_parallel_f16(const __fp16 *__restrict__ pSrcA,
"fmadd.h %[c31], %[a30], %[b01], %[c31];"
"fmadd.h %[c31], %[a31], %[b11], %[c31];"

: [c00] "=r"(c00), [c01] "=r"(c01), [c10] "=r"(c10),
[c11] "=r"(c11), [c20] "=r"(c20), [c21] "=r"(c21),
[c30] "=r"(c30), [c31] "=r"(c31)
: [c00] "+&r"(c00), [c01] "+&r"(c01), [c10] "+&r"(c10),
[c11] "+&r"(c11), [c20] "+&r"(c20), [c21] "+&r"(c21),
[c30] "+&r"(c30), [c31] "+&r"(c31)
: [b00] "r"(b00), [b01] "r"(b01), [b10] "r"(b10), [b11] "r"(b11),
[a00] "r"(a00), [a01] "r"(a01), [a10] "r"(a10), [a11] "r"(a11),
[a20] "r"(a20), [a21] "r"(a21), [a30] "r"(a30), [a31] "r"(a31)
:);
}
pDstC[(i * 4) * P + (k * 2)] = c00;
pDstC[(i * 4) * P + (k * 2 + 1)] = c01;
pDstC[(i * 4 + 1) * P + (k * 2)] = c10;
pDstC[(i * 4 + 1) * P + (k * 2 + 1)] = c11;
pDstC[(i * 4 + 2) * P + (k * 2)] = c20;
pDstC[(i * 4 + 2) * P + (k * 2 + 1)] = c21;
pDstC[(i * 4 + 3) * P + (k * 2)] = c30;
pDstC[(i * 4 + 3) * P + (k * 2 + 1)] = c31;
pDstC[i * P + k] = c00;
pDstC[i * P + k + 1] = c01;
pDstC[(i + 1) * P + k] = c10;
pDstC[(i + 1) * P + k + 1] = c11;
pDstC[(i + 2) * P + k] = c20;
pDstC[(i + 2) * P + k + 1] = c21;
pDstC[(i + 3) * P + k] = c30;
pDstC[(i + 3) * P + k + 1] = c31;
}
}
}
Expand All @@ -249,8 +188,8 @@ void matmul_4x2_parallel_f16vec(const __fp16 *__restrict__ pSrcA,
uint32_t j = 0; // loop counter for N
uint32_t k = 0; // loop counter for P

for (k = core_id; k < P / 2; k += numThreads) {
for (i = 0; i < M / 4; i++) {
for (k = core_id * 2; k < P; k += numThreads * 2) {
for (i = 0; i < M; i += 4) {
float volatile sum00 = 0.0f;
float volatile sum01 = 0.0f;
float volatile sum10 = 0.0f;
Expand All @@ -259,14 +198,14 @@ void matmul_4x2_parallel_f16vec(const __fp16 *__restrict__ pSrcA,
float volatile sum21 = 0.0f;
float volatile sum30 = 0.0f;
float volatile sum31 = 0.0f;
for (j = 0; j < N / 2; j++) {

v2h aVec0 = *(v2h *)&(pSrcA[(i * 4) * N + (j * 2)]);
v2h aVec1 = *(v2h *)&(pSrcA[(i * 4 + 1) * N + (j * 2)]);
v2h aVec2 = *(v2h *)&(pSrcA[(i * 4 + 2) * N + (j * 2)]);
v2h aVec3 = *(v2h *)&(pSrcA[(i * 4 + 3) * N + (j * 2)]);
v2h bVecTemp0 = *(v2h *)&(pSrcB[(j * 2) * P + (k * 2)]);
v2h bVecTemp1 = *(v2h *)&(pSrcB[(j * 2 + 1) * P + (k * 2)]);
for (j = 0; j < N; j += 2) {

v2h aVec0 = *(v2h *)&(pSrcA[i * N + j]);
v2h aVec1 = *(v2h *)&(pSrcA[(i + 1) * N + j]);
v2h aVec2 = *(v2h *)&(pSrcA[(i + 2) * N + j]);
v2h aVec3 = *(v2h *)&(pSrcA[(i + 3) * N + j]);
v2h bVecTemp0 = *(v2h *)&(pSrcB[j * P + k]);
v2h bVecTemp1 = *(v2h *)&(pSrcB[(j + 1) * P + k]);
v2h bVec0, bVec1;
unsigned TempH, TempL;
asm volatile(
Expand Down Expand Up @@ -304,10 +243,10 @@ void matmul_4x2_parallel_f16vec(const __fp16 *__restrict__ pSrcA,
[sum11] "r"(sum11), [sum20] "r"(sum20), [sum21] "r"(sum21),
[sum30] "r"(sum30), [sum31] "r"(sum31)
:);
(*(v2h *)&pDstC[(i * 4) * P + (k * 2)]) = res0;
(*(v2h *)&pDstC[(i * 4 + 1) * P + (k * 2)]) = res1;
(*(v2h *)&pDstC[(i * 4 + 2) * P + (k * 2)]) = res2;
(*(v2h *)&pDstC[(i * 4 + 3) * P + (k * 2)]) = res3;
(*(v2h *)&pDstC[i * P + k]) = res0;
(*(v2h *)&pDstC[(i + 1) * P + k]) = res1;
(*(v2h *)&pDstC[(i + 2) * P + k]) = res2;
(*(v2h *)&pDstC[(i + 3) * P + k]) = res3;
}
}
}

0 comments on commit aa0f848

Please sign in to comment.