diff --git a/src/knn.c b/src/knn.c index 5d120fe..83aca1d 100644 --- a/src/knn.c +++ b/src/knn.c @@ -11,26 +11,32 @@ #define K 2 +static __thread z_stream defstream = {0}; +static __thread bool defstream_init = false; + // Stolen from https://gist.github.com/arq5x/5315739 Nob_String_View deflate_sv(Arena *arena, Nob_String_View sv) { - size_t output_size = sv.count*2; - void *output = arena_alloc(arena, output_size); - - z_stream defstream = {0}; defstream.avail_in = (uInt)sv.count; defstream.next_in = (Bytef *)sv.data; + + int ret = defstream_init ? deflateReset(&defstream) : deflateInit(&defstream, Z_BEST_COMPRESSION); + defstream_init = true; + assert(ret == Z_OK && "Failed to initialize zlib deflate stream"); + size_t output_size = (size_t)deflateBound(&defstream, defstream.avail_in); + void *output = arena_alloc(arena, output_size); defstream.avail_out = (uInt)output_size; defstream.next_out = (Bytef *)output; - deflateInit(&defstream, Z_BEST_COMPRESSION); - int result = deflate(&defstream, Z_FINISH); - assert(result == Z_STREAM_END && "Probably not enough output buffer was allocated"); - deflateEnd(&defstream); + assert(deflate(&defstream, Z_FINISH) == Z_STREAM_END && "Probably `avail_in` is zero"); return nob_sv_from_parts(output, defstream.total_out); } +void deflate_end() { + deflateEnd(&defstream); +} + typedef struct { size_t klass; Nob_String_View text; @@ -115,6 +121,7 @@ void *klassify_thread(void *params) .klass = state->train[i].klass, })); } + deflate_end(); return NULL; }