Skip to content

Commit

Permalink
Add cumstomized score for hotwords & add Finalize to stream (#281)
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool authored Mar 7, 2024
1 parent 884ce6d commit 3c7724c
Show file tree
Hide file tree
Showing 12 changed files with 316 additions and 109 deletions.
17 changes: 16 additions & 1 deletion .github/scripts/run-test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,21 @@ for wave in ${waves[@]}; do
done
done

log "Start testing ${repo_url} with hotwords"

time $EXE \
$repo/tokens.txt \
$repo/encoder_jit_trace-pnnx.ncnn.param \
$repo/encoder_jit_trace-pnnx.ncnn.bin \
$repo/decoder_jit_trace-pnnx.ncnn.param \
$repo/decoder_jit_trace-pnnx.ncnn.bin \
$repo/joiner_jit_trace-pnnx.ncnn.param \
$repo/joiner_jit_trace-pnnx.ncnn.bin \
$repo/test_wavs/1.wav \
2 \
modified_beam_search \
$repo/test_wavs/hotwords.txt

rm -rf $repo

log "------------------------------------------------------------"
Expand Down Expand Up @@ -588,4 +603,4 @@ time $EXE \
modified_beam_search \
$repo/hotwords.txt 1.6

rm -rf $repo
rm -rf $repo
2 changes: 2 additions & 0 deletions sherpa-ncnn/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -77,4 +77,6 @@ endif()
if(SHERPA_NCNN_ENABLE_TEST)
add_executable(test-resample test-resample.cc)
target_link_libraries(test-resample sherpa-ncnn-core)
add_executable(test-context-graph test-context-graph.cc)
target_link_libraries(test-context-graph sherpa-ncnn-core)
endif()
81 changes: 74 additions & 7 deletions sherpa-ncnn/csrc/context-graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,67 @@

#include "sherpa-ncnn/csrc/context-graph.h"

#include <algorithm>
#include <cassert>
#include <queue>
#include <string>
#include <tuple>
#include <utility>

namespace sherpa_ncnn {
void ContextGraph::Build(
const std::vector<std::vector<int32_t>> &token_ids) const {
void ContextGraph::Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const {
if (!scores.empty()) {
assert(token_ids.size() == scores.size());
}
if (!phrases.empty()) {
assert(token_ids.size() == phrases.size());
}
if (!ac_thresholds.empty()) {
assert(token_ids.size() == ac_thresholds.size());
}
for (int32_t i = 0; i < token_ids.size(); ++i) {
auto node = root_.get();
float score = scores.empty() ? 0.0f : scores[i];
score = score == 0.0f ? context_score_ : score;
float ac_threshold = ac_thresholds.empty() ? 0.0f : ac_thresholds[i];
ac_threshold = ac_threshold == 0.0f ? ac_threshold_ : ac_threshold;
std::string phrase = phrases.empty() ? std::string() : phrases[i];

for (int32_t j = 0; j < token_ids[i].size(); ++j) {
int32_t token = token_ids[i][j];
if (0 == node->next.count(token)) {
bool is_end = j == token_ids[i].size() - 1;
node->next[token] = std::make_unique<ContextState>(
token, context_score_, node->node_score + context_score_,
is_end ? node->node_score + context_score_ : 0, is_end);
token, score, node->node_score + score,
is_end ? node->node_score + score : 0, j + 1,
is_end ? ac_threshold : 0.0f, is_end,
is_end ? phrase : std::string());
} else {
float token_score = std::max(score, node->next[token]->token_score);
node->next[token]->token_score = token_score;
float node_score = node->node_score + token_score;
node->next[token]->node_score = node_score;
bool is_end =
(j == token_ids[i].size() - 1) || node->next[token]->is_end;
node->next[token]->output_score = is_end ? node_score : 0.0f;
node->next[token]->is_end = is_end;
if (j == token_ids[i].size() - 1) {
node->next[token]->phrase = phrase;
node->next[token]->ac_threshold = ac_threshold;
}
}
node = node->next[token].get();
}
}
FillFailOutput();
}

std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
const ContextState *state, int32_t token) const {
std::tuple<float, const ContextState *, const ContextState *>
ContextGraph::ForwardOneStep(const ContextState *state, int32_t token,
bool strict_mode /*= true*/) const {
const ContextState *node;
float score;
if (1 == state->next.count(token)) {
Expand All @@ -45,7 +81,22 @@ std::pair<float, const ContextState *> ContextGraph::ForwardOneStep(
}
score = node->node_score - state->node_score;
}
return std::make_pair(score + node->output_score, node);

assert(nullptr != node);

const ContextState *matched_node =
node->is_end ? node : (node->output != nullptr ? node->output : nullptr);

if (!strict_mode && node->output_score != 0) {
assert(nullptr != matched_node);
float output_score =
node->is_end ? node->node_score
: (node->output != nullptr ? node->output->node_score
: node->node_score);
return std::make_tuple(score + output_score - node->node_score, root_.get(),
matched_node);
}
return std::make_tuple(score + node->output_score, node, matched_node);
}

std::pair<float, const ContextState *> ContextGraph::Finalize(
Expand All @@ -54,6 +105,22 @@ std::pair<float, const ContextState *> ContextGraph::Finalize(
return std::make_pair(score, root_.get());
}

std::pair<bool, const ContextState *> ContextGraph::IsMatched(
const ContextState *state) const {
bool status = false;
const ContextState *node = nullptr;
if (state->is_end) {
status = true;
node = state;
} else {
if (state->output != nullptr) {
status = true;
node = state->output;
}
}
return std::make_pair(status, node);
}

void ContextGraph::FillFailOutput() const {
std::queue<const ContextState *> node_queue;
for (auto &kv : root_->next) {
Expand Down
46 changes: 36 additions & 10 deletions sherpa-ncnn/csrc/context-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
#define SHERPA_NCNN_CSRC_CONTEXT_GRAPH_H_

#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>
#include <vector>


namespace sherpa_ncnn {

class ContextGraph;
Expand All @@ -21,43 +22,68 @@ struct ContextState {
float token_score;
float node_score;
float output_score;
int32_t level;
float ac_threshold;
bool is_end;
std::string phrase;
std::unordered_map<int32_t, std::unique_ptr<ContextState>> next;
const ContextState *fail = nullptr;
const ContextState *output = nullptr;

ContextState() = default;
ContextState(int32_t token, float token_score, float node_score,
float output_score, bool is_end)
float output_score, int32_t level = 0, float ac_threshold = 0.0f,
bool is_end = false, const std::string &phrase = {})
: token(token),
token_score(token_score),
node_score(node_score),
output_score(output_score),
is_end(is_end) {}
level(level),
ac_threshold(ac_threshold),
is_end(is_end),
phrase(phrase) {}
};

class ContextGraph {
public:
ContextGraph() = default;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float hotwords_score)
: context_score_(hotwords_score) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0, false);
float context_score, float ac_threshold,
const std::vector<float> &scores = {},
const std::vector<std::string> &phrases = {},
const std::vector<float> &ac_thresholds = {})
: context_score_(context_score), ac_threshold_(ac_threshold) {
root_ = std::make_unique<ContextState>(-1, 0, 0, 0);
root_->fail = root_.get();
Build(token_ids);
Build(token_ids, scores, phrases, ac_thresholds);
}

std::pair<float, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id) const;
ContextGraph(const std::vector<std::vector<int32_t>> &token_ids,
float context_score, const std::vector<float> &scores = {},
const std::vector<std::string> &phrases = {})
: ContextGraph(token_ids, context_score, 0.0f, scores, phrases,
std::vector<float>()) {}

std::tuple<float, const ContextState *, const ContextState *> ForwardOneStep(
const ContextState *state, int32_t token_id,
bool strict_mode = true) const;

std::pair<bool, const ContextState *> IsMatched(
const ContextState *state) const;

std::pair<float, const ContextState *> Finalize(
const ContextState *state) const;

const ContextState *Root() const { return root_.get(); }

private:
float context_score_;
float ac_threshold_;
std::unique_ptr<ContextState> root_;
void Build(const std::vector<std::vector<int32_t>> &token_ids) const;
void Build(const std::vector<std::vector<int32_t>> &token_ids,
const std::vector<float> &scores,
const std::vector<std::string> &phrases,
const std::vector<float> &ac_thresholds) const;
void FillFailOutput() const;
};

Expand Down
85 changes: 5 additions & 80 deletions sherpa-ncnn/csrc/modified-beam-search-decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,82 +117,7 @@ ncnn::Mat ModifiedBeamSearchDecoder::BuildDecoderInput(

void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out,
DecoderResult *result) {
int32_t context_size = model_->ContextSize();
Hypotheses cur = std::move(result->hyps);
/* encoder_out.w == encoder_out_dim, encoder_out.h == num_frames. */
for (int32_t t = 0; t != encoder_out.h; ++t) {
std::vector<Hypothesis> prev = cur.GetTopK(num_active_paths_, true);
cur.Clear();

ncnn::Mat decoder_input = BuildDecoderInput(prev);
ncnn::Mat decoder_out;
if (t == 0 && prev.size() == 1 && prev[0].ys.size() == context_size &&
!result->decoder_out.empty()) {
// When an endpoint is detected, we keep the decoder_out
decoder_out = result->decoder_out;
} else {
decoder_out = RunDecoder2D(model_, decoder_input);
}

// decoder_out.w == decoder_dim
// decoder_out.h == num_active_paths
ncnn::Mat encoder_out_t(encoder_out.w, 1, encoder_out.row(t));
// Note: encoder_out_t.h == 1, we rely on the binary op broadcasting
// in ncnn
// See https://github.com/Tencent/ncnn/wiki/binaryop-broadcasting
// broadcast B for outer axis, type 14
ncnn::Mat joiner_out = model_->RunJoiner(encoder_out_t, decoder_out);

// joiner_out.w == vocab_size
// joiner_out.h == num_active_paths
LogSoftmax(&joiner_out);

float *p_joiner_out = joiner_out;

for (int32_t i = 0; i != joiner_out.h; ++i) {
float prev_log_prob = prev[i].log_prob;
for (int32_t k = 0; k != joiner_out.w; ++k, ++p_joiner_out) {
*p_joiner_out += prev_log_prob;
}
}

auto topk = TopkIndex(static_cast<float *>(joiner_out),
joiner_out.w * joiner_out.h, num_active_paths_);

int32_t frame_offset = result->frame_offset;
for (auto i : topk) {
int32_t hyp_index = i / joiner_out.w;
int32_t new_token = i % joiner_out.w;

const float *p = joiner_out.row(hyp_index);

Hypothesis new_hyp = prev[hyp_index];

// blank id is fixed to 0
if (new_token != 0 && new_token != 2) {
new_hyp.ys.push_back(new_token);
new_hyp.num_trailing_blanks = 0;
new_hyp.timestamps.push_back(t + frame_offset);
} else {
++new_hyp.num_trailing_blanks;
}
// We have already added prev[hyp_index].log_prob to p[new_token]
new_hyp.log_prob = p[new_token];

cur.Add(std::move(new_hyp));
}
}

result->hyps = std::move(cur);
result->frame_offset += encoder_out.h;
auto hyp = result->hyps.GetMostProbable(true);

// set decoder_out in case of endpointing
ncnn::Mat decoder_input = BuildDecoderInput({hyp});
result->decoder_out = model_->RunDecoder(decoder_input);

result->tokens = std::move(hyp.ys);
result->num_trailing_blanks = hyp.num_trailing_blanks;
Decode(encoder_out, nullptr, result);
}

void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
Expand Down Expand Up @@ -252,10 +177,10 @@ void ModifiedBeamSearchDecoder::Decode(ncnn::Mat encoder_out, Stream *s,
new_hyp.num_trailing_blanks = 0;
new_hyp.timestamps.push_back(t + frame_offset);
if (s && s->GetContextGraph()) {
auto context_res =
s->GetContextGraph()->ForwardOneStep(context_state, new_token);
context_score = context_res.first;
new_hyp.context_state = context_res.second;
auto context_res = s->GetContextGraph()->ForwardOneStep(
context_state, new_token, false /*strict_mode*/);
context_score = std::get<0>(context_res);
new_hyp.context_state = std::get<1>(context_res);
}
} else {
++new_hyp.num_trailing_blanks;
Expand Down
Loading

0 comments on commit 3c7724c

Please sign in to comment.