diff --git a/src/knn.c b/src/knn.c index 5d120fe..eb3e916 100644 --- a/src/knn.c +++ b/src/knn.c @@ -34,6 +34,7 @@ Nob_String_View deflate_sv(Arena *arena, Nob_String_View sv) typedef struct { size_t klass; Nob_String_View text; + float c; } Sample; typedef struct { @@ -74,10 +75,9 @@ typedef struct { size_t capacity; } NCDs; -float ncd(Arena *arena, Nob_String_View a, Nob_String_View b, float cb) +float ncd(Arena *arena, Nob_String_View a, float ca, Nob_String_View b, float cb) { Nob_String_View ab = nob_sv_from_cstr(arena_sprintf(arena, SV_Fmt" "SV_Fmt, SV_Arg(a), SV_Arg(b))); - float ca = deflate_sv(arena, a).count; float cab = deflate_sv(arena, ab).count; float mn = ca; if (mn > cb) mn = cb; float mx = ca; if (mx < cb) mx = cb; @@ -102,13 +102,28 @@ typedef struct { Arena arena; } Klassify_State; +void *precompute_train_c_thread(void *params) +{ + Klassify_State *state = params; + + float ca; + for (size_t i = 0; i < state->train_count; ++i) { + ca = deflate_sv(&state->arena, state->train[i].text).count; + state->train[i].c = ca; + } + + return NULL; +} + + void *klassify_thread(void *params) { Klassify_State *state = params; float cb = deflate_sv(&state->arena, state->text).count; for (size_t i = 0; i < state->train_count; ++i) { - float distance = ncd(&state->arena, state->train[i].text, state->text, cb); + float distance = ncd(&state->arena, state->train[i].text, state->train[i].c, + state->text, cb); arena_reset(&state->arena); nob_da_append(&state->ncds, ((NCD) { .distance = distance, @@ -147,6 +162,22 @@ void klass_predictor_init(Klass_Predictor *kp, Samples train_samples) memset(kp->states, 0, kp->nprocs*sizeof(Klassify_State)); } +void klass_predictor_precompute_train_c(Klass_Predictor *kp) +{ + for (size_t i = 0; i < kp->nprocs; ++i) { + kp->states[i].train = kp->train_samples.items + i*kp->chunk_size; + kp->states[i].train_count = kp->chunk_size; + if (i == kp->nprocs - 1) kp->states[i].train_count += kp->chunk_rem; + arena_reset(&kp->states[i].arena); + if (pthread_create(&kp->threads[i], NULL, precompute_train_c_thread, + &kp->states[i]) != 0) { + nob_log(NOB_ERROR, "Could not create thread"); + exit(1); + } + } +} + + size_t klass_predictor_predict(Klass_Predictor *kp, Nob_String_View text, size_t k) { for (size_t i = 0; i < kp->nprocs; ++i) { @@ -231,6 +262,12 @@ int main(int argc, char **argv) Klass_Predictor kp = {0}; klass_predictor_init(&kp, train_samples); + nob_log(NOB_INFO, "Pre-computing training lengths"); + double begin = clock_get_secs(); + klass_predictor_precompute_train_c(&kp); + double end = clock_get_secs(); + nob_log(NOB_INFO, "Elapsed Time: %.3lf secs", end - begin); + if (argc <= 0) { interactive_mode(&kp); } else {