Skip to content

Commit

Permalink
chore: Add fix_utf8 and remove_leading_trailing_nonalpha functions
Browse files Browse the repository at this point in the history
  • Loading branch information
royshil committed May 19, 2024
1 parent 544d7ab commit 1481b5f
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 117 deletions.
4 changes: 2 additions & 2 deletions src/cleanstream-filter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ void cleanstream_defaults(obs_data_t *s)
obs_data_set_default_string(s, "initial_prompt", "");
obs_data_set_default_int(s, "n_threads", 4);
obs_data_set_default_int(s, "n_max_text_ctx", 16384);
obs_data_set_default_bool(s, "no_context", false);
obs_data_set_default_bool(s, "no_context", true);
obs_data_set_default_bool(s, "single_segment", true);
obs_data_set_default_bool(s, "print_special", false);
obs_data_set_default_bool(s, "print_progress", false);
Expand All @@ -409,7 +409,7 @@ void cleanstream_defaults(obs_data_t *s)
obs_data_set_default_bool(s, "speed_up", false);
obs_data_set_default_bool(s, "suppress_blank", true);
obs_data_set_default_bool(s, "suppress_non_speech_tokens", true);
obs_data_set_default_double(s, "temperature", 0.5);
obs_data_set_default_double(s, "temperature", 0.1);
obs_data_set_default_double(s, "max_initial_ts", 1.0);
obs_data_set_default_double(s, "length_penalty", -1.0);
}
Expand Down
138 changes: 23 additions & 115 deletions src/whisper-utils/whisper-processing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,110 +18,7 @@
#include <Windows.h>
#endif
#include "model-utils/model-downloader.h"

#define VAD_THOLD 0.0001f
#define FREQ_THOLD 100.0f

void high_pass_filter(float *pcmf32, size_t pcm32f_size, float cutoff, uint32_t sample_rate)
{
const float rc = 1.0f / (2.0f * (float)M_PI * cutoff);
const float dt = 1.0f / (float)sample_rate;
const float alpha = dt / (rc + dt);

float y = pcmf32[0];

for (size_t i = 1; i < pcm32f_size; i++) {
y = alpha * (y + pcmf32[i] - pcmf32[i - 1]);
pcmf32[i] = y;
}
}

// VAD (voice activity detection), return true if speech detected
bool vad_simple(float *pcmf32, size_t pcm32f_size, uint32_t sample_rate, float vad_thold,
float freq_thold, bool verbose)
{
const uint64_t n_samples = pcm32f_size;

if (freq_thold > 0.0f) {
high_pass_filter(pcmf32, pcm32f_size, freq_thold, sample_rate);
}

float energy_all = 0.0f;

for (uint64_t i = 0; i < n_samples; i++) {
energy_all += fabsf(pcmf32[i]);
}

energy_all /= (float)n_samples;

if (verbose) {
obs_log(LOG_INFO, "%s: energy_all: %f, vad_thold: %f, freq_thold: %f", __func__,
energy_all, vad_thold, freq_thold);
}

if (energy_all < vad_thold) {
return false;
}

return true;
}

float avg_energy_in_window(const float *pcmf32, size_t window_i, uint64_t n_samples_window)
{
float energy_in_window = 0.0f;
for (uint64_t j = 0; j < n_samples_window; j++) {
energy_in_window += fabsf(pcmf32[window_i + j]);
}
energy_in_window /= (float)n_samples_window;

return energy_in_window;
}

float max_energy_in_window(const float *pcmf32, size_t window_i, uint64_t n_samples_window)
{
float energy_in_window = 0.0f;
for (uint64_t j = 0; j < n_samples_window; j++) {
energy_in_window = std::max(energy_in_window, fabsf(pcmf32[window_i + j]));
}

return energy_in_window;
}

// Find a word boundary
size_t word_boundary_simple(const float *pcmf32, size_t pcm32f_size, uint32_t sample_rate,
float thold, bool verbose)
{
// scan the buffer with a window of 50ms
const uint64_t n_samples_window = (sample_rate * 50) / 1000;

float first_window_energy = avg_energy_in_window(pcmf32, 0, n_samples_window);
float last_window_energy =
avg_energy_in_window(pcmf32, pcm32f_size - n_samples_window, n_samples_window);
float max_energy_in_middle =
max_energy_in_window(pcmf32, n_samples_window, pcm32f_size - n_samples_window);

if (verbose) {
obs_log(LOG_INFO,
"%s: first_window_energy: %f, last_window_energy: %f, max_energy_in_middle: %f",
__func__, first_window_energy, last_window_energy, max_energy_in_middle);
// print avg energy in all windows in sample
for (uint64_t i = 0; i < pcm32f_size - n_samples_window; i += n_samples_window) {
obs_log(LOG_INFO, "%s: avg energy_in_window %llu: %f", __func__, i,
avg_energy_in_window(pcmf32, i, n_samples_window));
}
}

const float max_energy_thold = max_energy_in_middle * thold;
if (first_window_energy < max_energy_thold && last_window_energy < max_energy_thold) {
if (verbose) {
obs_log(LOG_INFO, "%s: word boundary found between %llu and %llu", __func__,
n_samples_window, pcm32f_size - n_samples_window);
}
return n_samples_window;
}

return 0;
}
#include "whisper-utils.h"

struct whisper_context *init_whisper_context(const std::string &model_path_in,
struct cleanstream_data *gf)
Expand Down Expand Up @@ -267,29 +164,40 @@ int run_whisper_inference(struct cleanstream_data *gf, const float *pcm32f_data,
}
sentence_p /= (float)n_tokens;

// convert text to lowercase
std::string text_lower(text);
std::transform(text_lower.begin(), text_lower.end(), text_lower.begin(), ::tolower);
// trim whitespace (use lambda)
text_lower.erase(std::find_if(text_lower.rbegin(), text_lower.rend(),
[](unsigned char ch) { return !std::isspace(ch); })
.base(),
text_lower.end());
std::string text_preproc = text;

if (text_preproc.empty()) {
return DETECTION_RESULT_SILENCE;
}

// if language is en convert text to lowercase
if (strcmp(gf->whisper_params.language, "en") == 0) {
std::string text_lower;
std::transform(text_preproc.begin(), text_preproc.end(), text_lower.begin(),
::tolower);
text_preproc = text_lower;
// remove leading and trailing non-alphanumeric characters
text_preproc = remove_leading_trailing_nonalpha(text_preproc);
} else {
// fix UTF8 encoding
std::string text_fixed = fix_utf8(text);
text_preproc = text_fixed;
}

if (gf->log_words) {
obs_log(LOG_INFO, "[%s --> %s] (%.3f) %s", to_timestamp(t0).c_str(),
to_timestamp(t1).c_str(), sentence_p, text_lower.c_str());
to_timestamp(t1).c_str(), sentence_p, text_preproc.c_str());
}

if (text_lower.empty()) {
if (text_preproc.empty()) {
return DETECTION_RESULT_SILENCE;
}

// use a regular expression to detect filler words with a word boundary
try {
if (gf->detect_regex != nullptr && strlen(gf->detect_regex) > 0) {
std::regex filler_regex(gf->detect_regex);
if (std::regex_search(text_lower, filler_regex,
if (std::regex_search(text_preproc, filler_regex,
std::regex_constants::match_any)) {
return DETECTION_RESULT_BEEP;
}
Expand Down
110 changes: 110 additions & 0 deletions src/whisper-utils/whisper-utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
#include "whisper-processing.h"

#include <obs-module.h>

#include <filesystem>
#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <thread>

void update_whisper_model(struct cleanstream_data *gf, obs_data_t *s)
{
Expand Down Expand Up @@ -140,3 +146,107 @@ void start_whisper_thread_with_path(struct cleanstream_data *gf, const std::stri
std::thread new_whisper_thread(whisper_loop, gf);
gf->whisper_thread.swap(new_whisper_thread);
}

#define is_lead_byte(c) (((c)&0xe0) == 0xc0 || ((c)&0xf0) == 0xe0 || ((c)&0xf8) == 0xf0)
#define is_trail_byte(c) (((c)&0xc0) == 0x80)

inline int lead_byte_length(const uint8_t c)
{
if ((c & 0xe0) == 0xc0) {
return 2;
} else if ((c & 0xf0) == 0xe0) {
return 3;
} else if ((c & 0xf8) == 0xf0) {
return 4;
} else {
return 1;
}
}

inline bool is_valid_lead_byte(const uint8_t *c)
{
const int length = lead_byte_length(c[0]);
if (length == 1) {
return true;
}
if (length == 2 && is_trail_byte(c[1])) {
return true;
}
if (length == 3 && is_trail_byte(c[1]) && is_trail_byte(c[2])) {
return true;
}
if (length == 4 && is_trail_byte(c[1]) && is_trail_byte(c[2]) && is_trail_byte(c[3])) {
return true;
}
return false;
}

/*
* Fix UTF8 encoding issues on Windows.
*/
std::string fix_utf8(const std::string &str)
{
#ifdef _WIN32
// Some UTF8 charsets on Windows output have a bug, instead of 0xd? it outputs
// 0xf?, and 0xc? becomes 0xe?, so we need to fix it.
std::stringstream ss;
uint8_t *c_str = (uint8_t *)str.c_str();
for (size_t i = 0; i < str.size(); ++i) {
if (is_lead_byte(c_str[i])) {
// this is a unicode leading byte
// if the next char is 0xff - it's a bug char, replace it with 0x9f
if (c_str[i + 1] == 0xff) {
c_str[i + 1] = 0x9f;
}
if (!is_valid_lead_byte(c_str + i)) {
// This is a bug lead byte, because it's length 3 and the i+2 byte is also
// a lead byte
c_str[i] = c_str[i] - 0x20;
}
} else {
if (c_str[i] >= 0xf8) {
// this may be a malformed lead byte.
// lets see if it becomes a valid lead byte if we "fix" it
uint8_t buf_[4];
buf_[0] = c_str[i] - 0x20;
buf_[1] = c_str[i + 1];
buf_[2] = c_str[i + 2];
buf_[3] = c_str[i + 3];
if (is_valid_lead_byte(buf_)) {
// this is a malformed lead byte, fix it
c_str[i] = c_str[i] - 0x20;
}
}
}
}

return std::string((char *)c_str);
#else
return str;
#endif
}

/*
* Remove leading and trailing non-alphabetic characters from a string.
* This function is used to remove leading and trailing spaces, newlines, tabs or punctuation.
* @param str: the string to remove leading and trailing non-alphabetic characters from.
* @return: the string with leading and trailing non-alphabetic characters removed.
*/
std::string remove_leading_trailing_nonalpha(const std::string &str)
{
std::string str_copy = str;
// remove trailing spaces, newlines, tabs or punctuation
auto last_non_space =
std::find_if(str_copy.rbegin(), str_copy.rend(), [](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
}).base();
str_copy.erase(last_non_space, str_copy.end());
// remove leading spaces, newlines, tabs or punctuation
auto first_non_space = std::find_if(str_copy.begin(), str_copy.end(),
[](unsigned char ch) {
return !std::isspace(ch) || !std::ispunct(ch);
}) +
1;
str_copy.erase(str_copy.begin(), first_non_space);
return str_copy;
}
2 changes: 2 additions & 0 deletions src/whisper-utils/whisper-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,7 @@
void update_whisper_model(struct cleanstream_data *gf, obs_data_t *s);
void shutdown_whisper_thread(struct cleanstream_data *gf);
void start_whisper_thread_with_path(struct cleanstream_data *gf, const std::string &path);
std::string fix_utf8(const std::string &str);
std::string remove_leading_trailing_nonalpha(const std::string &str);

#endif /* WHISPER_UTILS_H */

0 comments on commit 1481b5f

Please sign in to comment.