Skip to content

Commit

Permalink
Math: Optimise 16-bit matrix multiplication functions.
Browse files Browse the repository at this point in the history
Improve mat_multiply and mat_multiply_elementwise for 16-bit signed
integers by refactoring operations and simplifying handling of Q0 data.
Both functions now use incremented pointers in loop expressions for
better legibility and potential compiler optimisation.

Changes: - Improved x pointer increments in mat_multiply.
- Integrating Q0 conditional logic into the main flow.
- Clean up the mat_multiply_elementwise loop structure.
- Ensure precision with appropriate fractional bit shifts.

These modifications aim to improve accuracy and efficiency in
fixed-point arithmetic operations.

Signed-off-by: Shriram Shastry <malladi.sastry@intel.com>
  • Loading branch information
ShriramShastry committed Apr 29, 2024
1 parent 07b762e commit b3fa07c
Showing 1 changed file with 26 additions and 54 deletions.
80 changes: 26 additions & 54 deletions src/math/matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Copyright(c) 2022 Intel Corporation. All rights reserved.
//
// Author: Seppo Ingalsuo <seppo.ingalsuo@linux.intel.com>
// Shriram Shastry <malladi.sastry@linux.intel.com>

#include <sof/math/matrix.h>
#include <errno.h>
Expand All @@ -11,47 +12,24 @@
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
{
int64_t s;
int16_t *x;
int16_t *y;
int16_t *z = c->data;
int16_t *x, *y, *z = c->data;
int i, j, k;
int y_inc = b->columns;
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
const int shift = a->fractions + b->fractions - c->fractions - 1;

if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
return -EINVAL;

/* If all data is Q0 */
if (shift_minus_one == -1) {
for (i = 0; i < a->rows; i++) {
for (j = 0; j < b->columns; j++) {
s = 0;
x = a->data + a->columns * i;
y = b->data + j;
for (k = 0; k < b->rows; k++) {
s += (int32_t)(*x) * (*y);
x++;
y += y_inc;
}
*z = (int16_t)s; /* For Q16.0 */
z++;
}
}

return 0;
}

for (i = 0; i < a->rows; i++) {
for (j = 0; j < b->columns; j++) {
s = 0;
x = a->data + a->columns * i;
y = b->data + j;
s = 0; x = a->data + a->columns * i; y = b->data + j;
for (k = 0; k < b->rows; k++) {
s += (int32_t)(*x) * (*y);
x++;
y += y_inc;
s += (int32_t)(*x++) * (*y); y += y_inc;
}
*z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
if (shift == -1)
*z = (int16_t)s;
else
*z = (int16_t)(((s >> shift) + 1) >> 1);
z++;
}
}
Expand All @@ -60,36 +38,30 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_

int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
struct mat_matrix_16b *c)
{ int64_t p;
{
if (a->columns != b->columns || a->rows != b->rows)
return -EINVAL;

const int total_elements = a->rows * a->columns;
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;

int16_t *x = a->data;
int16_t *y = b->data;
int16_t *z = c->data;
int64_t p;
int i;
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;

if (a->columns != b->columns || b->columns != c->columns ||
a->rows != b->rows || b->rows != c->rows) {
return -EINVAL;
}

/* If all data is Q0 */
if (shift_minus_one == -1) {
for (i = 0; i < a->rows * a->columns; i++) {
*z = *x * *y;
x++;
y++;
z++;
}
// If all data is Q0
for (i = 0; i < total_elements; i++)
z[i] = x[i] * y[i];

return 0;
}

for (i = 0; i < a->rows * a->columns; i++) {
p = (int32_t)(*x) * *y;
*z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
x++;
y++;
z++;
} else {
// General case
for (i = 0; i < total_elements; i++) {
p = (int32_t)x[i] * y[i];
z[i] = (int16_t)(((p >> shift_minus_one) + 1) >> 1); // Shift to Qx.y
}
}

return 0;
Expand Down

0 comments on commit b3fa07c

Please sign in to comment.