From 3427edd0f98ec83e9722f5e0380b688566c67104 Mon Sep 17 00:00:00 2001 From: Shriram Shastry Date: Mon, 29 Apr 2024 17:36:55 +0530 Subject: [PATCH] Math: Optimise 16-bit matrix multiplication functions. Improve mat_multiply and mat_multiply_elementwise for 16-bit signed integers by refactoring operations and simplifying handling of Q0 data. Both functions now use incremented pointers in loop expressions for better legibility and potential compiler optimisation. Changes: - Improved x pointer increments in mat_multiply. - Integrating Q0 conditional logic into the main flow. - Clean up the mat_multiply_elementwise loop structure. - Ensure precision with appropriate fractional bit shifts. These modifications aim to improve accuracy and efficiency in fixed-point arithmetic operations. Signed-off-by: Shriram Shastry --- src/math/matrix.c | 101 +++++++++++++++++++++------------------------- 1 file changed, 46 insertions(+), 55 deletions(-) diff --git a/src/math/matrix.c b/src/math/matrix.c index e4c4c1ecbfff..30149f4675fc 100644 --- a/src/math/matrix.c +++ b/src/math/matrix.c @@ -3,6 +3,7 @@ // Copyright(c) 2022 Intel Corporation. All rights reserved. // // Author: Seppo Ingalsuo +// Shriram Shastry #include #include @@ -10,86 +11,76 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) { - int64_t s; - int16_t *x; - int16_t *y; - int16_t *z = c->data; - int i, j, k; - int y_inc = b->columns; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; - if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns) return -EINVAL; - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows; i++) { - for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; - for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; - } - *z = (int16_t)s; /* For Q16.0 */ - z++; - } - } - return 0; - } + int64_t acc; + int16_t *x, *y, *z = c->data; + int i, j, k; + /* Increment for pointer y to jump to the next row in matrix B */ + int y_inc = b->columns; + const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; + const int shift = shift_minus_one + 1; + + /* Calculate offset for rounding. The offset is only needed when shift is + * non-negative (> 0). Since shift = shift_minus_one + 1, the check for non-negative + * value is (shift_minus_one >= -1) + */ + const int64_t offset = (shift_minus_one >= -1) ? (1LL << shift_minus_one) : 0; for (i = 0; i < a->rows; i++) { for (j = 0; j < b->columns; j++) { - s = 0; + acc = 0; + /* Position pointer x at the beginning of the ith row in matrix A */ x = a->data + a->columns * i; + /* Position pointer y at the beginning of the jth column in matrix B */ y = b->data + j; for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; + /* Multiply elements and add to sum */ + acc += (int64_t)(*x++) * (*y); + /* Move pointer y to the next element in the column */ y += y_inc; } - *z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - z++; + /* If shift == 0, then shift_minus_one == -1, which means the data is Q16.0 + * Otherwise, add the offset before the shift to round up if necessary + */ + *z++ = (shift == 0) ? (int16_t)acc : (int16_t)((acc + offset) >> shift); } } + return 0; } int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) -{ int64_t p; +{ + /* Validate dimensions match for elementwise multiplication */ + if (a->columns != b->columns || a->rows != b->rows) + return -EINVAL; + + const int total_elements = a->rows * a->columns; + /* Calculate shift for result Q format */ + const int shift = a->fractions + b->fractions - c->fractions; + /* Offset for rounding */ + const int64_t offset = (shift >= 0) ? (1LL << (shift - 1)) : 0; + int16_t *x = a->data; int16_t *y = b->data; int16_t *z = c->data; + int64_t acc; int i; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; - if (a->columns != b->columns || b->columns != c->columns || - a->rows != b->rows || b->rows != c->rows) { - return -EINVAL; - } - - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows * a->columns; i++) { - *z = *x * *y; - x++; - y++; - z++; + /* When no shifting is required (shift -1 indicated), simply multiply */ + if (shift == -1) { + for (i = 0; i < total_elements; i++) + z[i] = x[i] * y[i]; /* Direct elementwise multiplication */ + } else { + /* General case with shifting and offset for rounding */ + for (i = 0; i < total_elements; i++) { + acc = (int64_t)x[i] * (int64_t)y[i]; /* Cast to int64_t to avoid overflow */ + z[i] = (int16_t)((acc + offset) >> shift); /* Apply shift and offset */ } - - return 0; - } - - for (i = 0; i < a->rows * a->columns; i++) { - p = (int32_t)(*x) * *y; - *z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - x++; - y++; - z++; } return 0;