From 1481b5fb1e1cdab38b5946a2d230b9eab2580449 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Sun, 19 May 2024 12:01:03 -0400 Subject: [PATCH] chore: Add fix_utf8 and remove_leading_trailing_nonalpha functions --- src/cleanstream-filter.cpp | 4 +- src/whisper-utils/whisper-processing.cpp | 138 ++++------------------- src/whisper-utils/whisper-utils.cpp | 110 ++++++++++++++++++ src/whisper-utils/whisper-utils.h | 2 + 4 files changed, 137 insertions(+), 117 deletions(-) diff --git a/src/cleanstream-filter.cpp b/src/cleanstream-filter.cpp index 5cd9159..964973e 100644 --- a/src/cleanstream-filter.cpp +++ b/src/cleanstream-filter.cpp @@ -394,7 +394,7 @@ void cleanstream_defaults(obs_data_t *s) obs_data_set_default_string(s, "initial_prompt", ""); obs_data_set_default_int(s, "n_threads", 4); obs_data_set_default_int(s, "n_max_text_ctx", 16384); - obs_data_set_default_bool(s, "no_context", false); + obs_data_set_default_bool(s, "no_context", true); obs_data_set_default_bool(s, "single_segment", true); obs_data_set_default_bool(s, "print_special", false); obs_data_set_default_bool(s, "print_progress", false); @@ -409,7 +409,7 @@ void cleanstream_defaults(obs_data_t *s) obs_data_set_default_bool(s, "speed_up", false); obs_data_set_default_bool(s, "suppress_blank", true); obs_data_set_default_bool(s, "suppress_non_speech_tokens", true); - obs_data_set_default_double(s, "temperature", 0.5); + obs_data_set_default_double(s, "temperature", 0.1); obs_data_set_default_double(s, "max_initial_ts", 1.0); obs_data_set_default_double(s, "length_penalty", -1.0); } diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 3322f43..50211ee 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -18,110 +18,7 @@ #include #endif #include "model-utils/model-downloader.h" - -#define VAD_THOLD 0.0001f -#define FREQ_THOLD 100.0f - -void high_pass_filter(float *pcmf32, size_t pcm32f_size, float cutoff, uint32_t sample_rate) -{ - const float rc = 1.0f / (2.0f * (float)M_PI * cutoff); - const float dt = 1.0f / (float)sample_rate; - const float alpha = dt / (rc + dt); - - float y = pcmf32[0]; - - for (size_t i = 1; i < pcm32f_size; i++) { - y = alpha * (y + pcmf32[i] - pcmf32[i - 1]); - pcmf32[i] = y; - } -} - -// VAD (voice activity detection), return true if speech detected -bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float vad_thold, - float freq_thold, bool verbose) -{ - const uint64_t n_samples = pcm32f_size; - - if (freq_thold > 0.0f) { - high_pass_filter(pcmf32, pcm32f_size, freq_thold, sample_rate); - } - - float energy_all = 0.0f; - - for (uint64_t i = 0; i < n_samples; i++) { - energy_all += fabsf(pcmf32[i]); - } - - energy_all /= (float)n_samples; - - if (verbose) { - obs_log(LOG_INFO, "%s: energy_all: %f, vad_thold: %f, freq_thold: %f", __func__, - energy_all, vad_thold, freq_thold); - } - - if (energy_all < vad_thold) { - return false; - } - - return true; -} - -float avg_energy_in_window(const float *pcmf32, size_t window_i, uint64_t n_samples_window) -{ - float energy_in_window = 0.0f; - for (uint64_t j = 0; j < n_samples_window; j++) { - energy_in_window += fabsf(pcmf32[window_i + j]); - } - energy_in_window /= (float)n_samples_window; - - return energy_in_window; -} - -float max_energy_in_window(const float *pcmf32, size_t window_i, uint64_t n_samples_window) -{ - float energy_in_window = 0.0f; - for (uint64_t j = 0; j < n_samples_window; j++) { - energy_in_window = std::max(energy_in_window, fabsf(pcmf32[window_i + j])); - } - - return energy_in_window; -} - -// Find a word boundary -size_t word_boundary_simple(const float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, - float thold, bool verbose) -{ - // scan the buffer with a window of 50ms - const uint64_t n_samples_window = (sample_rate * 50) / 1000; - - float first_window_energy = avg_energy_in_window(pcmf32, 0, n_samples_window); - float last_window_energy = - avg_energy_in_window(pcmf32, pcm32f_size - n_samples_window, n_samples_window); - float max_energy_in_middle = - max_energy_in_window(pcmf32, n_samples_window, pcm32f_size - n_samples_window); - - if (verbose) { - obs_log(LOG_INFO, - "%s: first_window_energy: %f, last_window_energy: %f, max_energy_in_middle: %f", - __func__, first_window_energy, last_window_energy, max_energy_in_middle); - // print avg energy in all windows in sample - for (uint64_t i = 0; i < pcm32f_size - n_samples_window; i += n_samples_window) { - obs_log(LOG_INFO, "%s: avg energy_in_window %llu: %f", __func__, i, - avg_energy_in_window(pcmf32, i, n_samples_window)); - } - } - - const float max_energy_thold = max_energy_in_middle * thold; - if (first_window_energy < max_energy_thold && last_window_energy < max_energy_thold) { - if (verbose) { - obs_log(LOG_INFO, "%s: word boundary found between %llu and %llu", __func__, - n_samples_window, pcm32f_size - n_samples_window); - } - return n_samples_window; - } - - return 0; -} +#include "whisper-utils.h" struct whisper_context *init_whisper_context(const std::string &model_path_in, struct cleanstream_data *gf) @@ -267,21 +164,32 @@ int run_whisper_inference(struct cleanstream_data *gf, const float *pcm32f_data, } sentence_p /= (float)n_tokens; - // convert text to lowercase - std::string text_lower(text); - std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower); - // trim whitespace (use lambda) - text_lower.erase(std::find_if(text_lower.rbegin(), text_lower.rend(), - [](unsigned char ch) { return !std::isspace(ch); }) - .base(), - text_lower.end()); + std::string text_preproc = text; + + if (text_preproc.empty()) { + return DETECTION_RESULT_SILENCE; + } + + // if language is en convert text to lowercase + if (strcmp(gf->whisper_params.language, "en") == 0) { + std::string text_lower; + std::transform(text_preproc.begin(), text_preproc.end(), text_lower.begin(), + ::tolower); + text_preproc = text_lower; + // remove leading and trailing non-alphanumeric characters + text_preproc = remove_leading_trailing_nonalpha(text_preproc); + } else { + // fix UTF8 encoding + std::string text_fixed = fix_utf8(text); + text_preproc = text_fixed; + } if (gf->log_words) { obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(), - to_timestamp(t1).c_str(), sentence_p, text_lower.c_str()); + to_timestamp(t1).c_str(), sentence_p, text_preproc.c_str()); } - if (text_lower.empty()) { + if (text_preproc.empty()) { return DETECTION_RESULT_SILENCE; } @@ -289,7 +197,7 @@ int run_whisper_inference(struct cleanstream_data *gf, const float *pcm32f_data, try { if (gf->detect_regex != nullptr && strlen(gf->detect_regex) > 0) { std::regex filler_regex(gf->detect_regex); - if (std::regex_search(text_lower, filler_regex, + if (std::regex_search(text_preproc, filler_regex, std::regex_constants::match_any)) { return DETECTION_RESULT_BEEP; } diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index 181136c..2a5808b 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -4,7 +4,13 @@ #include "whisper-processing.h" #include + #include +#include +#include +#include +#include +#include void update_whisper_model(struct cleanstream_data *gf, obs_data_t *s) { @@ -140,3 +146,107 @@ void start_whisper_thread_with_path(struct cleanstream_data *gf, const std::stri std::thread new_whisper_thread(whisper_loop, gf); gf->whisper_thread.swap(new_whisper_thread); } + +#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) +#define is_trail_byte(c) (((c)&0xc0) == 0x80) + +inline int lead_byte_length(const uint8_t c) +{ + if ((c & 0xe0) == 0xc0) { + return 2; + } else if ((c & 0xf0) == 0xe0) { + return 3; + } else if ((c & 0xf8) == 0xf0) { + return 4; + } else { + return 1; + } +} + +inline bool is_valid_lead_byte(const uint8_t *c) +{ + const int length = lead_byte_length(c[0]); + if (length == 1) { + return true; + } + if (length == 2 && is_trail_byte(c[1])) { + return true; + } + if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) { + return true; + } + if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) { + return true; + } + return false; +} + +/* +* Fix UTF8 encoding issues on Windows. +*/ +std::string fix_utf8(const std::string &str) +{ +#ifdef _WIN32 + // Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs + // 0xf?, and 0xc? becomes 0xe?, so we need to fix it. + std::stringstream ss; + uint8_t *c_str = (uint8_t *)str.c_str(); + for (size_t i = 0; i < str.size(); ++i) { + if (is_lead_byte(c_str[i])) { + // this is a unicode leading byte + // if the next char is 0xff - it's a bug char, replace it with 0x9f + if (c_str[i + 1] == 0xff) { + c_str[i + 1] = 0x9f; + } + if (!is_valid_lead_byte(c_str + i)) { + // This is a bug lead byte, because it's length 3 and the i+2 byte is also + // a lead byte + c_str[i] = c_str[i] - 0x20; + } + } else { + if (c_str[i] >= 0xf8) { + // this may be a malformed lead byte. + // lets see if it becomes a valid lead byte if we "fix" it + uint8_t buf_[4]; + buf_[0] = c_str[i] - 0x20; + buf_[1] = c_str[i + 1]; + buf_[2] = c_str[i + 2]; + buf_[3] = c_str[i + 3]; + if (is_valid_lead_byte(buf_)) { + // this is a malformed lead byte, fix it + c_str[i] = c_str[i] - 0x20; + } + } + } + } + + return std::string((char *)c_str); +#else + return str; +#endif +} + +/* +* Remove leading and trailing non-alphabetic characters from a string. +* This function is used to remove leading and trailing spaces, newlines, tabs or punctuation. +* @param str: the string to remove leading and trailing non-alphabetic characters from. +* @return: the string with leading and trailing non-alphabetic characters removed. +*/ +std::string remove_leading_trailing_nonalpha(const std::string &str) +{ + std::string str_copy = str; + // remove trailing spaces, newlines, tabs or punctuation + auto last_non_space = + std::find_if(str_copy.rbegin(), str_copy.rend(), [](unsigned char ch) { + return !std::isspace(ch) || !std::ispunct(ch); + }).base(); + str_copy.erase(last_non_space, str_copy.end()); + // remove leading spaces, newlines, tabs or punctuation + auto first_non_space = std::find_if(str_copy.begin(), str_copy.end(), + [](unsigned char ch) { + return !std::isspace(ch) || !std::ispunct(ch); + }) + + 1; + str_copy.erase(str_copy.begin(), first_non_space); + return str_copy; +} diff --git a/src/whisper-utils/whisper-utils.h b/src/whisper-utils/whisper-utils.h index 9ec80a8..ead20e6 100644 --- a/src/whisper-utils/whisper-utils.h +++ b/src/whisper-utils/whisper-utils.h @@ -10,5 +10,7 @@ void update_whisper_model(struct cleanstream_data *gf, obs_data_t *s); void shutdown_whisper_thread(struct cleanstream_data *gf); void start_whisper_thread_with_path(struct cleanstream_data *gf, const std::string &path); +std::string fix_utf8(const std::string &str); +std::string remove_leading_trailing_nonalpha(const std::string &str); #endif /* WHISPER_UTILS_H */