From 1481b5fb1e1cdab38b5946a2d230b9eab2580449 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Sun, 19 May 2024 12:01:03 -0400 Subject: [PATCH 1/3] chore: Add fix_utf8 and remove_leading_trailing_nonalpha functions --- src/cleanstream-filter.cpp | 4 +- src/whisper-utils/whisper-processing.cpp | 138 ++++------------------- src/whisper-utils/whisper-utils.cpp | 110 ++++++++++++++++++ src/whisper-utils/whisper-utils.h | 2 + 4 files changed, 137 insertions(+), 117 deletions(-) diff --git a/src/cleanstream-filter.cpp b/src/cleanstream-filter.cpp index 5cd9159..964973e 100644 --- a/src/cleanstream-filter.cpp +++ b/src/cleanstream-filter.cpp @@ -394,7 +394,7 @@ void cleanstream_defaults(obs_data_t *s) obs_data_set_default_string(s, "initial_prompt", ""); obs_data_set_default_int(s, "n_threads", 4); obs_data_set_default_int(s, "n_max_text_ctx", 16384); - obs_data_set_default_bool(s, "no_context", false); + obs_data_set_default_bool(s, "no_context", true); obs_data_set_default_bool(s, "single_segment", true); obs_data_set_default_bool(s, "print_special", false); obs_data_set_default_bool(s, "print_progress", false); @@ -409,7 +409,7 @@ void cleanstream_defaults(obs_data_t *s) obs_data_set_default_bool(s, "speed_up", false); obs_data_set_default_bool(s, "suppress_blank", true); obs_data_set_default_bool(s, "suppress_non_speech_tokens", true); - obs_data_set_default_double(s, "temperature", 0.5); + obs_data_set_default_double(s, "temperature", 0.1); obs_data_set_default_double(s, "max_initial_ts", 1.0); obs_data_set_default_double(s, "length_penalty", -1.0); } diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 3322f43..50211ee 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -18,110 +18,7 @@ #include #endif #include "model-utils/model-downloader.h" - -#define VAD_THOLD 0.0001f -#define FREQ_THOLD 100.0f - -void high_pass_filter(float *pcmf32, size_t pcm32f_size, float cutoff, uint32_t sample_rate) -{ - const float rc = 1.0f / (2.0f * (float)M_PI * cutoff); - const float dt = 1.0f / (float)sample_rate; - const float alpha = dt / (rc + dt); - - float y = pcmf32[0]; - - for (size_t i = 1; i < pcm32f_size; i++) { - y = alpha * (y + pcmf32[i] - pcmf32[i - 1]); - pcmf32[i] = y; - } -} - -// VAD (voice activity detection), return true if speech detected -bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float vad_thold, - float freq_thold, bool verbose) -{ - const uint64_t n_samples = pcm32f_size; - - if (freq_thold > 0.0f) { - high_pass_filter(pcmf32, pcm32f_size, freq_thold, sample_rate); - } - - float energy_all = 0.0f; - - for (uint64_t i = 0; i < n_samples; i++) { - energy_all += fabsf(pcmf32[i]); - } - - energy_all /= (float)n_samples; - - if (verbose) { - obs_log(LOG_INFO, "%s: energy_all: %f, vad_thold: %f, freq_thold: %f", __func__, - energy_all, vad_thold, freq_thold); - } - - if (energy_all < vad_thold) { - return false; - } - - return true; -} - -float avg_energy_in_window(const float *pcmf32, size_t window_i, uint64_t n_samples_window) -{ - float energy_in_window = 0.0f; - for (uint64_t j = 0; j < n_samples_window; j++) { - energy_in_window += fabsf(pcmf32[window_i + j]); - } - energy_in_window /= (float)n_samples_window; - - return energy_in_window; -} - -float max_energy_in_window(const float *pcmf32, size_t window_i, uint64_t n_samples_window) -{ - float energy_in_window = 0.0f; - for (uint64_t j = 0; j < n_samples_window; j++) { - energy_in_window = std::max(energy_in_window, fabsf(pcmf32[window_i + j])); - } - - return energy_in_window; -} - -// Find a word boundary -size_t word_boundary_simple(const float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, - float thold, bool verbose) -{ - // scan the buffer with a window of 50ms - const uint64_t n_samples_window = (sample_rate * 50) / 1000; - - float first_window_energy = avg_energy_in_window(pcmf32, 0, n_samples_window); - float last_window_energy = - avg_energy_in_window(pcmf32, pcm32f_size - n_samples_window, n_samples_window); - float max_energy_in_middle = - max_energy_in_window(pcmf32, n_samples_window, pcm32f_size - n_samples_window); - - if (verbose) { - obs_log(LOG_INFO, - "%s: first_window_energy: %f, last_window_energy: %f, max_energy_in_middle: %f", - __func__, first_window_energy, last_window_energy, max_energy_in_middle); - // print avg energy in all windows in sample - for (uint64_t i = 0; i < pcm32f_size - n_samples_window; i += n_samples_window) { - obs_log(LOG_INFO, "%s: avg energy_in_window %llu: %f", __func__, i, - avg_energy_in_window(pcmf32, i, n_samples_window)); - } - } - - const float max_energy_thold = max_energy_in_middle * thold; - if (first_window_energy < max_energy_thold && last_window_energy < max_energy_thold) { - if (verbose) { - obs_log(LOG_INFO, "%s: word boundary found between %llu and %llu", __func__, - n_samples_window, pcm32f_size - n_samples_window); - } - return n_samples_window; - } - - return 0; -} +#include "whisper-utils.h" struct whisper_context *init_whisper_context(const std::string &model_path_in, struct cleanstream_data *gf) @@ -267,21 +164,32 @@ int run_whisper_inference(struct cleanstream_data *gf, const float *pcm32f_data, } sentence_p /= (float)n_tokens; - // convert text to lowercase - std::string text_lower(text); - std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower); - // trim whitespace (use lambda) - text_lower.erase(std::find_if(text_lower.rbegin(), text_lower.rend(), - [](unsigned char ch) { return !std::isspace(ch); }) - .base(), - text_lower.end()); + std::string text_preproc = text; + + if (text_preproc.empty()) { + return DETECTION_RESULT_SILENCE; + } + + // if language is en convert text to lowercase + if (strcmp(gf->whisper_params.language, "en") == 0) { + std::string text_lower; + std::transform(text_preproc.begin(), text_preproc.end(), text_lower.begin(), + ::tolower); + text_preproc = text_lower; + // remove leading and trailing non-alphanumeric characters + text_preproc = remove_leading_trailing_nonalpha(text_preproc); + } else { + // fix UTF8 encoding + std::string text_fixed = fix_utf8(text); + text_preproc = text_fixed; + } if (gf->log_words) { obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(), - to_timestamp(t1).c_str(), sentence_p, text_lower.c_str()); + to_timestamp(t1).c_str(), sentence_p, text_preproc.c_str()); } - if (text_lower.empty()) { + if (text_preproc.empty()) { return DETECTION_RESULT_SILENCE; } @@ -289,7 +197,7 @@ int run_whisper_inference(struct cleanstream_data *gf, const float *pcm32f_data, try { if (gf->detect_regex != nullptr && strlen(gf->detect_regex) > 0) { std::regex filler_regex(gf->detect_regex); - if (std::regex_search(text_lower, filler_regex, + if (std::regex_search(text_preproc, filler_regex, std::regex_constants::match_any)) { return DETECTION_RESULT_BEEP; } diff --git a/src/whisper-utils/whisper-utils.cpp b/src/whisper-utils/whisper-utils.cpp index 181136c..2a5808b 100644 --- a/src/whisper-utils/whisper-utils.cpp +++ b/src/whisper-utils/whisper-utils.cpp @@ -4,7 +4,13 @@ #include "whisper-processing.h" #include + #include +#include +#include +#include +#include +#include void update_whisper_model(struct cleanstream_data *gf, obs_data_t *s) { @@ -140,3 +146,107 @@ void start_whisper_thread_with_path(struct cleanstream_data *gf, const std::stri std::thread new_whisper_thread(whisper_loop, gf); gf->whisper_thread.swap(new_whisper_thread); } + +#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0) +#define is_trail_byte(c) (((c)&0xc0) == 0x80) + +inline int lead_byte_length(const uint8_t c) +{ + if ((c & 0xe0) == 0xc0) { + return 2; + } else if ((c & 0xf0) == 0xe0) { + return 3; + } else if ((c & 0xf8) == 0xf0) { + return 4; + } else { + return 1; + } +} + +inline bool is_valid_lead_byte(const uint8_t *c) +{ + const int length = lead_byte_length(c[0]); + if (length == 1) { + return true; + } + if (length == 2 && is_trail_byte(c[1])) { + return true; + } + if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) { + return true; + } + if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) { + return true; + } + return false; +} + +/* +* Fix UTF8 encoding issues on Windows. +*/ +std::string fix_utf8(const std::string &str) +{ +#ifdef _WIN32 + // Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs + // 0xf?, and 0xc? becomes 0xe?, so we need to fix it. + std::stringstream ss; + uint8_t *c_str = (uint8_t *)str.c_str(); + for (size_t i = 0; i < str.size(); ++i) { + if (is_lead_byte(c_str[i])) { + // this is a unicode leading byte + // if the next char is 0xff - it's a bug char, replace it with 0x9f + if (c_str[i + 1] == 0xff) { + c_str[i + 1] = 0x9f; + } + if (!is_valid_lead_byte(c_str + i)) { + // This is a bug lead byte, because it's length 3 and the i+2 byte is also + // a lead byte + c_str[i] = c_str[i] - 0x20; + } + } else { + if (c_str[i] >= 0xf8) { + // this may be a malformed lead byte. + // lets see if it becomes a valid lead byte if we "fix" it + uint8_t buf_[4]; + buf_[0] = c_str[i] - 0x20; + buf_[1] = c_str[i + 1]; + buf_[2] = c_str[i + 2]; + buf_[3] = c_str[i + 3]; + if (is_valid_lead_byte(buf_)) { + // this is a malformed lead byte, fix it + c_str[i] = c_str[i] - 0x20; + } + } + } + } + + return std::string((char *)c_str); +#else + return str; +#endif +} + +/* +* Remove leading and trailing non-alphabetic characters from a string. +* This function is used to remove leading and trailing spaces, newlines, tabs or punctuation. +* @param str: the string to remove leading and trailing non-alphabetic characters from. +* @return: the string with leading and trailing non-alphabetic characters removed. +*/ +std::string remove_leading_trailing_nonalpha(const std::string &str) +{ + std::string str_copy = str; + // remove trailing spaces, newlines, tabs or punctuation + auto last_non_space = + std::find_if(str_copy.rbegin(), str_copy.rend(), [](unsigned char ch) { + return !std::isspace(ch) || !std::ispunct(ch); + }).base(); + str_copy.erase(last_non_space, str_copy.end()); + // remove leading spaces, newlines, tabs or punctuation + auto first_non_space = std::find_if(str_copy.begin(), str_copy.end(), + [](unsigned char ch) { + return !std::isspace(ch) || !std::ispunct(ch); + }) + + 1; + str_copy.erase(str_copy.begin(), first_non_space); + return str_copy; +} diff --git a/src/whisper-utils/whisper-utils.h b/src/whisper-utils/whisper-utils.h index 9ec80a8..ead20e6 100644 --- a/src/whisper-utils/whisper-utils.h +++ b/src/whisper-utils/whisper-utils.h @@ -10,5 +10,7 @@ void update_whisper_model(struct cleanstream_data *gf, obs_data_t *s); void shutdown_whisper_thread(struct cleanstream_data *gf); void start_whisper_thread_with_path(struct cleanstream_data *gf, const std::string &path); +std::string fix_utf8(const std::string &str); +std::string remove_leading_trailing_nonalpha(const std::string &str); #endif /* WHISPER_UTILS_H */ From 69586b0b7c3843e605c5570eb7f5ca3d27f98a9f Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Sun, 19 May 2024 22:54:47 -0400 Subject: [PATCH 2/3] Refactor cleanstream-filter-data.h and clean up cleanstream-filter.cpp --- src/cleanstream-filter-data.h | 1 - src/cleanstream-filter.cpp | 33 ++--- src/whisper-utils/whisper-language.h | 198 +++++++++++++-------------- 3 files changed, 116 insertions(+), 116 deletions(-) diff --git a/src/cleanstream-filter-data.h b/src/cleanstream-filter-data.h index d7d1928..72e0c68 100644 --- a/src/cleanstream-filter-data.h +++ b/src/cleanstream-filter-data.h @@ -73,7 +73,6 @@ struct cleanstream_data { std::map audioFileCache; size_t audioFilePointer = 0; - float filler_p_threshold; bool vad_enabled; int log_level; const char *detect_regex; diff --git a/src/cleanstream-filter.cpp b/src/cleanstream-filter.cpp index 964973e..fd97995 100644 --- a/src/cleanstream-filter.cpp +++ b/src/cleanstream-filter.cpp @@ -228,7 +228,6 @@ void cleanstream_update(void *data, obs_data_t *s) gf->detect_regex = obs_data_get_string(s, "detect_regex"); gf->replace_sound = obs_data_get_int(s, "replace_sound"); - gf->filler_p_threshold = (float)obs_data_get_double(s, "filler_p_threshold"); gf->log_level = (int)obs_data_get_int(s, "log_level"); gf->vad_enabled = obs_data_get_bool(s, "vad_enabled"); gf->log_words = obs_data_get_bool(s, "log_words"); @@ -382,7 +381,6 @@ void cleanstream_defaults(obs_data_t *s) "(fuck)|(shit)|(bitch)|(cunt)|(pussy)|(dick)|(asshole)|(whore)|(cock)|(nigger)|(nigga)|(prick)"); obs_data_set_default_int(s, "replace_sound", REPLACE_SOUNDS_SILENCE); obs_data_set_default_bool(s, "advanced_settings", false); - obs_data_set_default_double(s, "filler_p_threshold", 0.75); obs_data_set_default_bool(s, "vad_enabled", true); obs_data_set_default_int(s, "log_level", LOG_DEBUG); obs_data_set_default_bool(s, "log_words", false); @@ -495,6 +493,22 @@ obs_properties_t *cleanstream_properties(void *data) } } + // Add language selector + obs_property_t *whisper_language_select_list = + obs_properties_add_list(ppts, "whisper_language_select", "Language", + OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING); + // get a sorted list of available languages + std::vector whisper_available_lang_keys; + for (auto const &pair : whisper_available_lang) { + whisper_available_lang_keys.push_back(pair.first); + } + std::sort(whisper_available_lang_keys.begin(), whisper_available_lang_keys.end()); + // iterate over all available languages in whisper_available_lang map + for (const std::string &key : whisper_available_lang_keys) { + obs_property_list_add_string(whisper_language_select_list, + whisper_available_lang.at(key).c_str(), key.c_str()); + } + // Add advanced settings checkbox obs_property_t *advanced_settings_prop = obs_properties_add_bool(ppts, "advanced_settings", MT_("advanced_settings")); @@ -505,16 +519,13 @@ obs_properties_t *cleanstream_properties(void *data) // If advanced settings is enabled, show the advanced settings group const bool show_hide = obs_data_get_bool(settings, "advanced_settings"); for (const std::string &prop_name : - {"whisper_params_group", "log_words", "filler_p_threshold", "vad_enabled", - "log_level"}) { + {"whisper_params_group", "log_words", "vad_enabled", "log_level"}) { obs_property_set_visible(obs_properties_get(props, prop_name.c_str()), show_hide); } return true; }); - obs_properties_add_float_slider(ppts, "filler_p_threshold", MT_("filler_p_threshold"), 0.0f, - 1.0f, 0.05f); obs_properties_add_bool(ppts, "vad_enabled", MT_("vad_enabled")); obs_property_t *list = obs_properties_add_list(ppts, "log_level", MT_("log_level"), OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); @@ -527,16 +538,6 @@ obs_properties_t *cleanstream_properties(void *data) obs_properties_add_group(ppts, "whisper_params_group", MT_("Whisper_Parameters"), OBS_GROUP_NORMAL, whisper_params_group); - // Add language selector - obs_property_t *whisper_language_select_list = - obs_properties_add_list(whisper_params_group, "whisper_language_select", "Language", - OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_STRING); - // iterate over all available languages in whisper_available_lang map - for (auto const &pair : whisper_available_lang) { - obs_property_list_add_string(whisper_language_select_list, pair.second.c_str(), - pair.first.c_str()); - } - obs_property_t *whisper_sampling_method_list = obs_properties_add_list( whisper_params_group, "whisper_sampling_method", MT_("whisper_sampling_method"), OBS_COMBO_TYPE_LIST, OBS_COMBO_FORMAT_INT); diff --git a/src/whisper-utils/whisper-language.h b/src/whisper-utils/whisper-language.h index 2d69c65..6882305 100644 --- a/src/whisper-utils/whisper-language.h +++ b/src/whisper-utils/whisper-language.h @@ -7,399 +7,399 @@ static const std::map whisper_available_lang = { { "en", - "english", + "English", }, { "zh", - "chinese", + "Chinese", }, { "de", - "german", + "German", }, { "es", - "spanish", + "Spanish", }, { "ru", - "russian", + "Russian", }, { "ko", - "korean", + "Korean", }, { "fr", - "french", + "French", }, { "ja", - "japanese", + "Japanese", }, { "pt", - "portuguese", + "Portuguese", }, { "tr", - "turkish", + "Turkish", }, { "pl", - "polish", + "Polish", }, { "ca", - "catalan", + "Catalan", }, { "nl", - "dutch", + "Dutch", }, { "ar", - "arabic", + "Arabic", }, { "sv", - "swedish", + "Swedish", }, { "it", - "italian", + "Italian", }, { "id", - "indonesian", + "Indonesian", }, { "hi", - "hindi", + "Hindi", }, { "fi", - "finnish", + "Finnish", }, { "vi", - "vietnamese", + "Vietnamese", }, { "he", - "hebrew", + "Hebrew", }, { "uk", - "ukrainian", + "Ukrainian", }, { "el", - "greek", + "Greek", }, { "ms", - "malay", + "Malay", }, { "cs", - "czech", + "Czech", }, { "ro", - "romanian", + "Romanian", }, { "da", - "danish", + "Danish", }, { "hu", - "hungarian", + "Hungarian", }, { "ta", - "tamil", + "Tamil", }, { "no", - "norwegian", + "Norwegian", }, { "th", - "thai", + "Thai", }, { "ur", - "urdu", + "Urdu", }, { "hr", - "croatian", + "Croatian", }, { "bg", - "bulgarian", + "Bulgarian", }, { "lt", - "lithuanian", + "Lithuanian", }, { "la", - "latin", + "Latin", }, { "mi", - "maori", + "Maori", }, { "ml", - "malayalam", + "Malayalam", }, { "cy", - "welsh", + "Welsh", }, { "sk", - "slovak", + "Slovak", }, { "te", - "telugu", + "Telugu", }, { "fa", - "persian", + "Persian", }, { "lv", - "latvian", + "Latvian", }, { "bn", - "bengali", + "Bengali", }, { "sr", - "serbian", + "Serbian", }, { "az", - "azerbaijani", + "Azerbaijani", }, { "sl", - "slovenian", + "Slovenian", }, { "kn", - "kannada", + "Kannada", }, { "et", - "estonian", + "Estonian", }, { "mk", - "macedonian", + "Macedonian", }, { "br", - "breton", + "Breton", }, { "eu", - "basque", + "Basque", }, { "is", - "icelandic", + "Icelandic", }, { "hy", - "armenian", + "Armenian", }, { "ne", - "nepali", + "Nepali", }, { "mn", - "mongolian", + "Mongolian", }, { "bs", - "bosnian", + "Bosnian", }, { "kk", - "kazakh", + "Kazakh", }, { "sq", - "albanian", + "Albanian", }, { "sw", - "swahili", + "Swahili", }, { "gl", - "galician", + "Galician", }, { "mr", - "marathi", + "Marathi", }, { "pa", - "punjabi", + "Punjabi", }, { "si", - "sinhala", + "Sinhala", }, { "km", - "khmer", + "Khmer", }, { "sn", - "shona", + "Shona", }, { "yo", - "yoruba", + "Yoruba", }, { "so", - "somali", + "Somali", }, { "af", - "afrikaans", + "Afrikaans", }, { "oc", - "occitan", + "Occitan", }, { "ka", - "georgian", + "Georgian", }, { "be", - "belarusian", + "Belarusian", }, { "tg", - "tajik", + "Tajik", }, { "sd", - "sindhi", + "Sindhi", }, { "gu", - "gujarati", + "Gujarati", }, { "am", - "amharic", + "Amharic", }, { "yi", - "yiddish", + "Yiddish", }, { "lo", - "lao", + "Lao", }, { "uz", - "uzbek", + "Uzbek", }, { "fo", - "faroese", + "Faroese", }, { "ht", - "haitian", + "Haitian", }, { "ps", - "pashto", + "Pashto", }, { "tk", - "turkmen", + "Turkmen", }, { "nn", - "nynorsk", + "Nynorsk", }, { "mt", - "maltese", + "Maltese", }, { "sa", - "sanskrit", + "Sanskrit", }, { "lb", - "luxembourgish", + "Luxembourgish", }, { "my", - "myanmar", + "Myanmar", }, { "bo", - "tibetan", + "Tibetan", }, { "tl", - "tagalog", + "Tagalog", }, { "mg", - "malagasy", + "Malagasy", }, { "as", - "assamese", + "Assamese", }, { "tt", - "tatar", + "Tatar", }, { "haw", - "hawaiian", + "Hawaiian", }, { "ln", - "lingala", + "Lingala", }, { "ha", - "hausa", + "Hausa", }, { "ba", - "bashkir", + "Bashkir", }, { "jw", - "javanese", + "Javanese", }, { "su", - "sundanese", + "Sundanese", }, }; From d3764f5bd7296ebfbbbe08ac06d3887623bfea95 Mon Sep 17 00:00:00 2001 From: Roy Shilkrot Date: Sun, 19 May 2024 23:07:55 -0400 Subject: [PATCH 3/3] Refactor whisper-processing.cpp to update duration_ms in run_whisper_inference --- src/whisper-utils/whisper-processing.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/whisper-utils/whisper-processing.cpp b/src/whisper-utils/whisper-processing.cpp index 50211ee..ecbab6e 100644 --- a/src/whisper-utils/whisper-processing.cpp +++ b/src/whisper-utils/whisper-processing.cpp @@ -139,6 +139,8 @@ int run_whisper_inference(struct cleanstream_data *gf, const float *pcm32f_data, // run the inference int whisper_full_result = -1; try { + gf->whisper_params.duration_ms = + (int)((float)pcm32f_size / WHISPER_SAMPLE_RATE * 1000.0f); whisper_full_result = whisper_full(gf->whisper_context, gf->whisper_params, pcm32f_data, (int)pcm32f_size); } catch (const std::exception &e) {