From b3fa07cd874b43bc3619e890eb8ccfc5d06a2e0c Mon Sep 17 00:00:00 2001 From: Shriram Shastry Date: Mon, 29 Apr 2024 17:36:55 +0530 Subject: [PATCH] Math: Optimise 16-bit matrix multiplication functions. Improve mat_multiply and mat_multiply_elementwise for 16-bit signed integers by refactoring operations and simplifying handling of Q0 data. Both functions now use incremented pointers in loop expressions for better legibility and potential compiler optimisation. Changes: - Improved x pointer increments in mat_multiply. - Integrating Q0 conditional logic into the main flow. - Clean up the mat_multiply_elementwise loop structure. - Ensure precision with appropriate fractional bit shifts. These modifications aim to improve accuracy and efficiency in fixed-point arithmetic operations. Signed-off-by: Shriram Shastry --- src/math/matrix.c | 80 +++++++++++++++-------------------------------- 1 file changed, 26 insertions(+), 54 deletions(-) diff --git a/src/math/matrix.c b/src/math/matrix.c index e4c4c1ecbfff..2cdcdc5f02b0 100644 --- a/src/math/matrix.c +++ b/src/math/matrix.c @@ -3,6 +3,7 @@ // Copyright(c) 2022 Intel Corporation. All rights reserved. // // Author: Seppo Ingalsuo +// Shriram Shastry #include #include @@ -11,47 +12,24 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) { int64_t s; - int16_t *x; - int16_t *y; - int16_t *z = c->data; + int16_t *x, *y, *z = c->data; int i, j, k; int y_inc = b->columns; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; + const int shift = a->fractions + b->fractions - c->fractions - 1; if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns) return -EINVAL; - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows; i++) { - for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; - for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; - } - *z = (int16_t)s; /* For Q16.0 */ - z++; - } - } - - return 0; - } - for (i = 0; i < a->rows; i++) { for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; + s = 0; x = a->data + a->columns * i; y = b->data + j; for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; + s += (int32_t)(*x++) * (*y); y += y_inc; } - *z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ + if (shift == -1) + *z = (int16_t)s; + else + *z = (int16_t)(((s >> shift) + 1) >> 1); z++; } } @@ -60,36 +38,30 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_ int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) -{ int64_t p; +{ + if (a->columns != b->columns || a->rows != b->rows) + return -EINVAL; + + const int total_elements = a->rows * a->columns; + const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; + int16_t *x = a->data; int16_t *y = b->data; int16_t *z = c->data; + int64_t p; int i; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; - - if (a->columns != b->columns || b->columns != c->columns || - a->rows != b->rows || b->rows != c->rows) { - return -EINVAL; - } - /* If all data is Q0 */ if (shift_minus_one == -1) { - for (i = 0; i < a->rows * a->columns; i++) { - *z = *x * *y; - x++; - y++; - z++; - } + // If all data is Q0 + for (i = 0; i < total_elements; i++) + z[i] = x[i] * y[i]; - return 0; - } - - for (i = 0; i < a->rows * a->columns; i++) { - p = (int32_t)(*x) * *y; - *z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - x++; - y++; - z++; + } else { + // General case + for (i = 0; i < total_elements; i++) { + p = (int32_t)x[i] * y[i]; + z[i] = (int16_t)(((p >> shift_minus_one) + 1) >> 1); // Shift to Qx.y + } } return 0;