diff --git a/src/math/matrix.c b/src/math/matrix.c index 364a0349507e..773e85730cc7 100644 --- a/src/math/matrix.c +++ b/src/math/matrix.c @@ -94,28 +94,27 @@ int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, int16_t *x = a->data; int16_t *y = b->data; int16_t *z = c->data; - int64_t p; + int32_t prod; int i; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows * a->columns; i++) { + /* Compute the total number of elements in the matrices */ + const int total_elements = a->rows * a->columns; + /* Compute the required bit shift based on the fractional part of each matrix */ + const int shift = a->fractions + b->fractions - c->fractions - 1; + + /* Perform multiplication with or without adjusting the fractional bits */ + if (shift == -1) { + /* Direct multiplication when no adjustment for fractional bits is needed */ + for (i = 0; i < total_elements; i++, x++, y++, z++) *z = *x * *y; - x++; - y++; - z++; + } else { + /* Multiplication with rounding to account for the fractional bits */ + for (i = 0; i < total_elements; i++, x++, y++, z++) { + /* Multiply elements as int32_t */ + prod = (int32_t)(*x) * *y; + /* Adjust and round the result */ + *z = (int16_t)(((prod >> shift) + 1) >> 1); } - - return 0; - } - - for (i = 0; i < a->rows * a->columns; i++) { - p = (int32_t)(*x) * *y; - *z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - x++; - y++; - z++; } return 0;