Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cumstomized score for hotwords & add Finalize to stream #281

Merged
merged 10 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion .github/scripts/run-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ for wave in ${waves[@]}; do
done
done

log "Start testing ${repo_url} with hotwords"

time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.param \
$repo/encoder_jit_trace-pnnx.ncnn.bin \
$repo/decoder_jit_trace-pnnx.ncnn.param \
$repo/decoder_jit_trace-pnnx.ncnn.bin \
$repo/joiner_jit_trace-pnnx.ncnn.param \
$repo/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1.wav \
2 \
modified_beam_search \
$repo/test_wavs/hotwords.txt

rm -rf $repo

log "------------------------------------------------------------"
Expand Down Expand Up @@ -588,4 +603,4 @@ time $EXE \
modified_beam_search \
$repo/hotwords.txt 1.6

rm -rf $repo
rm -rf $repo
2 changes: 2 additions & 0 deletions sherpa-ncnn/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,6 @@ endif()
if(SHERPA_NCNN_ENABLE_TEST)
add_executable(test-resample test-resample.cc)
target_link_libraries(test-resample sherpa-ncnn-core)
add_executable(test-context-graph test-context-graph.cc)
target_link_libraries(test-context-graph sherpa-ncnn-core)
endif()
81 changes: 74 additions & 7 deletions sherpa-ncnn/csrc/context-graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,67 @@

#include "sherpa-ncnn/csrc/context-graph.h"

#include <algorithm>
#include <cassert>
#include <queue>
#include <string>
#include <tuple>
#include <utility>

namespace sherpa_ncnn {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const {
if (!scores.empty()) {
assert(token_ids.size() == scores.size());
}
if (!phrases.empty()) {
assert(token_ids.size() == phrases.size());
}
if (!ac_thresholds.empty()) {
assert(token_ids.size() == ac_thresholds.size());
}
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
float score = scores.empty() ? 0.0f : scores[i];
score = score == 0.0f ? context_score_ : score;
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
std::string phrase = phrases.empty() ? std::string() : phrases[i];

for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? node->node_score + context_score_ : 0, is_end);
token, score, node->node_score + score,
is_end ? node->node_score + score : 0, j + 1,
is_end ? ac_threshold : 0.0f, is_end,
is_end ? phrase : std::string());
} else {
float token_score = std::max(score, node->next[token]->token_score);
node->next[token]->token_score = token_score;
float node_score = node->node_score + token_score;
node->next[token]->node_score = node_score;
bool is_end =
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
node->next[token]->output_score = is_end ? node_score : 0.0f;
node->next[token]->is_end = is_end;
if (j == token_ids[i].size() - 1) {
node->next[token]->phrase = phrase;
node->next[token]->ac_threshold = ac_threshold;
}
}
node = node->next[token].get();
}
}
FillFailOutput();
}

std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
std::tuple<float, const ContextState *, const ContextState *>
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
bool strict_mode /*= true*/) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
Expand All @@ -45,7 +81,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
}
score = node->node_score - state->node_score;
}
return std::make_pair(score + node->output_score, node);

assert(nullptr != node);

const ContextState *matched_node =
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);

if (!strict_mode && node->output_score != 0) {
assert(nullptr != matched_node);
float output_score =
node->is_end ? node->node_score
: (node->output != nullptr ? node->output->node_score
: node->node_score);
return std::make_tuple(score + output_score - node->node_score, root_.get(),
matched_node);
}
return std::make_tuple(score + node->output_score, node, matched_node);
}

std::pair<float, const ContextState *> ContextGraph::Finalize(
Expand All @@ -54,6 +105,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
return std::make_pair(score, root_.get());
}

std::pair<bool, const ContextState *> ContextGraph::IsMatched(
const ContextState *state) const {
bool status = false;
const ContextState *node = nullptr;
if (state->is_end) {
status = true;
node = state;
} else {
if (state->output != nullptr) {
status = true;
node = state->output;
}
}
return std::make_pair(status, node);
}

void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {
Expand Down
46 changes: 36 additions & 10 deletions sherpa-ncnn/csrc/context-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
#define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_

#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>


namespace sherpa_ncnn {

class ContextGraph;
Expand All @@ -21,43 +22,68 @@ struct ContextState {
float token_score;
float node_score;
float output_score;
int32_t level;
float ac_threshold;
bool is_end;
std::string phrase;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;

ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float output_score, bool is_end)
float output_score, int32_t level = 0, float ac_threshold = 0.0f,
bool is_end = false, const std::string &phrase = {})
: token(token),
token_score(token_score),
node_score(node_score),
output_score(output_score),
is_end(is_end) {}
level(level),
ac_threshold(ac_threshold),
is_end(is_end),
phrase(phrase) {}
};

class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float hotwords_score)
: context_score_(hotwords_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
float context_score, float ac_threshold,
const std::vector<float> &scores = {},
const std::vector<std::string> &phrases = {},
const std::vector<float> &ac_thresholds = {})
: context_score_(context_score), ac_threshold_(ac_threshold) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
root_->fail = root_.get();
Build(token_ids);
Build(token_ids, scores, phrases, ac_thresholds);
}

std::pair<float, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id) const;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score, const std::vector<float> &scores = {},
const std::vector<std::string> &phrases = {})
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
std::vector<float>()) {}

std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id,
bool strict_mode = true) const;

std::pair<bool, const ContextState *> IsMatched(
const ContextState *state) const;

std::pair<float, const ContextState *> Finalize(
const ContextState *state) const;

const ContextState *Root() const { return root_.get(); }

private:
float context_score_;
float ac_threshold_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const;
void FillFailOutput() const;
};

Expand Down
85 changes: 5 additions & 80 deletions sherpa-ncnn/csrc/modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,82 +117,7 @@ ncnn::Mat ModifiedBeamSearchDecoder::BuildDecoderInput(

void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
DecoderResult *result) {
int32_t context_size = model_->ContextSize();
Hypotheses cur = std::move(result->hyps);
/* encoder_out.w == encoder_out_dim, encoder_out.h == num_frames. */
for (int32_t t = 0; t != encoder_out.h; ++t) {
std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
cur.Clear();

ncnn::Mat decoder_input = BuildDecoderInput(prev);
ncnn::Mat decoder_out;
if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
!result->decoder_out.empty()) {
// When an endpoint is detected, we keep the decoder_out
decoder_out = result->decoder_out;
} else {
decoder_out = RunDecoder2D(model_, decoder_input);
}

// decoder_out.w == decoder_dim
// decoder_out.h == num_active_paths
ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
// Note: encoder_out_t.h == 1, we rely on the binary op broadcasting
// in ncnn
// See https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
// broadcast B for outer axis, type 14
ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);

// joiner_out.w == vocab_size
// joiner_out.h == num_active_paths
LogSoftmax(&joiner_out);

float *p_joiner_out = joiner_out;

for (int32_t i = 0; i != joiner_out.h; ++i) {
float prev_log_prob = prev[i].log_prob;
for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) {
*p_joiner_out += prev_log_prob;
}
}

auto topk = TopkIndex(static_cast<float *>(joiner_out),
joiner_out.w * joiner_out.h, num_active_paths_);

int32_t frame_offset = result->frame_offset;
for (auto i : topk) {
int32_t hyp_index = i / joiner_out.w;
int32_t new_token = i % joiner_out.w;

const float *p = joiner_out.row(hyp_index);

Hypothesis new_hyp = prev[hyp_index];

// blank id is fixed to 0
if (new_token != 0 && new_token != 2) {
new_hyp.ys.push_back(new_token);
new_hyp.num_trailing_blanks = 0;
new_hyp.timestamps.push_back(t + frame_offset);
} else {
++new_hyp.num_trailing_blanks;
}
// We have already added prev[hyp_index].log_prob to p[new_token]
new_hyp.log_prob = p[new_token];

cur.Add(std::move(new_hyp));
}
}

result->hyps = std::move(cur);
result->frame_offset += encoder_out.h;
auto hyp = result->hyps.GetMostProbable(true);

// set decoder_out in case of endpointing
ncnn::Mat decoder_input = BuildDecoderInput({hyp});
result->decoder_out = model_->RunDecoder(decoder_input);

result->tokens = std::move(hyp.ys);
result->num_trailing_blanks = hyp.num_trailing_blanks;
Decode(encoder_out, nullptr, result);
}

void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
Expand Down Expand Up @@ -252,10 +177,10 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
new_hyp.num_trailing_blanks = 0;
new_hyp.timestamps.push_back(t + frame_offset);
if (s && s->GetContextGraph()) {
auto context_res =
s->GetContextGraph()->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
auto context_res = s->GetContextGraph()->ForwardOneStep(
context_state, new_token, false /*strict_mode*/);
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
} else {
++new_hyp.num_trailing_blanks;
Expand Down
Loading
Loading