Skip to content

Commit

Permalink
Merge pull request #24 from rozukke/fix/softmax-performance
Browse files Browse the repository at this point in the history
Improve softmax perf 4x
  • Loading branch information
nhatdongdang authored Jul 4, 2024
2 parents 7bc5f27 + fae646b commit 74dba0f
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 22 deletions.
8 changes: 4 additions & 4 deletions src/main.c
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ u8 infer(vector* input) {
// Somewhat experimental, minumum number of alligned_alloc without breaking things.
// This code fucking sucks but its fast so uhhhh
u8 infer_reuse_layers_thread(vector* input, matrix** weights, vector** biases) {
vector* outputs[NUM_LAYERS];
vector* outputs[2];
outputs[0] = new_vec_aligned(98);
outputs[1] = new_vec_aligned(65);

Expand Down Expand Up @@ -208,8 +208,8 @@ int main(int argc, char* argv[]) {
for (int i = 0; i < input_count; i++) {
// printf("Thread %d: Processing input %d\n", omp_get_thread_num(), i);

vector* input = new_vec_aligned(TSIZE_ALGN_BYTES / sizeof(f32));
memcpy(input->data, (f32*)&tensors[TSIZE_ALGN_BYTES / sizeof(f32) * i], TSIZE_ALGN_BYTES);
vector* input = new_vec_aligned(TENSOR_SIZE);
memcpy(input->data, (f32*)&tensors[TSIZE_ALGN_BYTES / sizeof(f32) * i], TENSOR_SIZE * sizeof(f32));

#pragma omp for
for (int j = 0; j < iter_per_in - 1; j++) {
Expand All @@ -223,7 +223,7 @@ int main(int argc, char* argv[]) {
}

free(results_local);
printf("Thread %d: %d\n", omp_get_thread_num(), force);
printf("Thread %2d: %d\n", omp_get_thread_num(), force);
}

// Output for csv
Expand Down
35 changes: 17 additions & 18 deletions src/matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void sgemv_t_tuned(const float* weights, const float* inputs, float* __restrict_
}

// TODO: SIMD tuned versions if these are a noticeable impact
void vector_add_inplace(int len, const f32* src, f32* dest) {
void vector_add_inplace(int len, const f32* src, f32* __restrict__ dest) {
for (int i = 0; i < len; i++) {
dest[i] += src[i];
}
Expand All @@ -77,10 +77,9 @@ void relu_inplace(f32* dest, int len) {
}

// Hacky but fast and accurate for existing inputs
static double fastexp(double x) {
i64 tmp = (i64)(1512775 * x + 1072632447);
tmp <<= 32;
double result;
static inline float fastexp(float x) {
int tmp = (int)(1512775 * x + 1072632447);
float result;
memcpy(&result, &tmp, sizeof(result));
return result;
}
Expand All @@ -95,6 +94,19 @@ void softmax_inplace(f32* dest, int len) {
}
}

// Get result from output layer
u8 argmax(f32* in, int len) {
int idx = 0;
float res = in[0];
for (int i = 0; i < len; i++) {
if (res < in[i]) {
res = in[i];
idx = i;
}
}
return idx;
}

void transpose_mat_inplace(matrix* in) {
int cols_before = in->cols;
int rows_before = in->rows;
Expand All @@ -118,16 +130,3 @@ void transpose_mat_inplace(matrix* in) {
in->cols = pad_w_width;
in->rows = cols_before;
}

// Get result from output layer
u8 argmax(f32* in, int len) {
int idx = 0;
float res = in[0];
for (int i = 0; i < len; i++) {
if (res < in[i]) {
res = in[i];
idx = i;
}
}
return idx;
}

0 comments on commit 74dba0f

Please sign in to comment.