Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable load_from_bytes of model and ExtScorer #3331

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion native_client/Makefile
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ clean:
rm -f deepspeech

$(DEEPSPEECH_BIN): client.cc Makefile
$(CXX) $(CFLAGS) $(CFLAGS_DEEPSPEECH) $(SOX_CFLAGS) client.cc $(LDFLAGS) $(SOX_LDFLAGS)
$(CXX) $(CFLAGS) $(CFLAGS_DEEPSPEECH) $(SOX_CFLAGS) client.cc $(LDFLAGS) $(SOX_LDFLAGS) -llzma -lbz2
ifeq ($(OS),Darwin)
install_name_tool -change bazel-out/local-opt/bin/native_client/libdeepspeech.so @rpath/libdeepspeech.so deepspeech
endif
Expand Down
Empty file modified native_client/alphabet.cc
100644 → 100755
Empty file.
Empty file modified native_client/alphabet.h
100644 → 100755
Empty file.
8 changes: 8 additions & 0 deletions native_client/args.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ bool extended_metadata = false;

bool json_output = false;

bool init_from_array_of_bytes = false;

int json_candidate_transcripts = 3;

int stream_size = 0;
Expand Down Expand Up @@ -62,6 +64,7 @@ void PrintHelp(const char* bin)
"\t--stream size\t\t\tRun in stream mode, output intermediate results\n"
"\t--extended_stream size\t\t\tRun in stream mode using metadata output, output intermediate results\n"
"\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n"
"\t--init_from_bytes\t\tTest init model and scorer from array of bytes\n"
"\t--help\t\t\t\tShow help\n"
"\t--version\t\t\tPrint version and exits\n";
char* version = DS_Version();
Expand All @@ -83,6 +86,7 @@ bool ProcessArgs(int argc, char** argv)
{"t", no_argument, nullptr, 't'},
{"extended", no_argument, nullptr, 'e'},
{"json", no_argument, nullptr, 'j'},
{"init_from_bytes", no_argument, nullptr, 'B'},
{"candidate_transcripts", required_argument, nullptr, 150},
{"stream", required_argument, nullptr, 's'},
{"extended_stream", required_argument, nullptr, 'S'},
Expand Down Expand Up @@ -139,6 +143,10 @@ bool ProcessArgs(int argc, char** argv)
case 'j':
json_output = true;
break;

case 'B':
init_from_array_of_bytes = true;
break;

case 150:
json_candidate_transcripts = atoi(optarg);
Expand Down
30 changes: 28 additions & 2 deletions native_client/client.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
#include <unistd.h>
#endif // NO_DIR
#include <vector>
#include <iostream>
#include <fstream>

#include "deepspeech.h"
#include "args.h"
Expand Down Expand Up @@ -452,8 +454,21 @@ main(int argc, char **argv)

// Initialise DeepSpeech
ModelState* ctx;
std::string buffer_model_str;
// sphinx-doc: c_ref_model_start
int status = DS_CreateModel(model, &ctx);
int status;
if (init_from_array_of_bytes){
// Reading model file to a char * buffer
std::ifstream is_model( model, std::ios::binary );
std::stringstream buffer_model;
buffer_model << is_model.rdbuf();
buffer_model_str = buffer_model.str();
status = DS_CreateModelFromBuffer(buffer_model_str.c_str(), buffer_model_str.size(), &ctx);
}else {
// Keep old method due to backwards compatibility
status = DS_CreateModel(model, &ctx);
}

if (status != 0) {
char* error = DS_ErrorCodeToErrorMessage(status);
fprintf(stderr, "Could not create model: %s\n", error);
Expand All @@ -470,7 +485,18 @@ main(int argc, char **argv)
}

if (scorer) {
status = DS_EnableExternalScorer(ctx, scorer);
if (init_from_array_of_bytes){
// Reading scorer file to a string buffer
std::ifstream is_scorer(scorer, std::ios::binary );
std::stringstream buffer_scorer;
buffer_scorer << is_scorer.rdbuf();
std::string tmp_str_scorer = buffer_scorer.str();
status = DS_EnableExternalScorerFromBuffer(ctx, tmp_str_scorer.c_str(), tmp_str_scorer.size());
} else {
// Keep old method due to backwards compatibility
status = DS_EnableExternalScorer(ctx, scorer);
}

if (status != 0) {
fprintf(stderr, "Could not enable external scorer.\n");
return 1;
Expand Down
85 changes: 53 additions & 32 deletions native_client/ctcdecode/scorer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,22 +29,24 @@ static const int32_t FILE_VERSION = 6;

int
Scorer::init(const std::string& lm_path,
bool load_from_bytes,
const Alphabet& alphabet)
{
set_alphabet(alphabet);
return load_lm(lm_path);
return load_lm(lm_path, load_from_bytes);
}

int
Scorer::init(const std::string& lm_path,
bool load_from_bytes,
const std::string& alphabet_config_path)
{
int err = alphabet_.init(alphabet_config_path.c_str());
int err = alphabet_.init(alphabet_config_path.c_str()); // Do we need to make this initiable from bytes?
if (err != 0) {
return err;
}
setup_char_map();
return load_lm(lm_path);
return load_lm(lm_path, load_from_bytes);
}

void
Expand All @@ -69,45 +71,60 @@ void Scorer::setup_char_map()
}
}

int Scorer::load_lm(const std::string& lm_path)
int Scorer::load_lm(const std::string& lm_string, bool load_from_bytes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either provide a default value, split the implementation in two load_lm or update all call sites: this PR does not build because of load_lm calls in generate_scorer_package.cpp

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've provided a default value (false) for the boolean

{
// Check if file is readable to avoid KenLM throwing an exception
const char* filename = lm_path.c_str();
if (access(filename, R_OK) != 0) {
return DS_ERR_SCORER_UNREADABLE;
}

// Check if the file format is valid to avoid KenLM throwing an exception
lm::ngram::ModelType model_type;
if (!lm::ngram::RecognizeBinary(filename, model_type)) {
return DS_ERR_SCORER_INVALID_LM;
if (!load_from_bytes){
// Check if file is readable to avoid KenLM throwing an exception
const char* filename = lm_string.c_str();
if (access(filename, R_OK) != 0) {
return DS_ERR_SCORER_UNREADABLE;
}

// Check if the file format is valid to avoid KenLM throwing an exception
lm::ngram::ModelType model_type;
if (!lm::ngram::RecognizeBinary(filename, model_type)) {
return DS_ERR_SCORER_INVALID_LM;
}
}

// Load the LM
lm::ngram::Config config;
config.load_method = util::LoadMethod::LAZY;
language_model_.reset(lm::ngram::LoadVirtual(filename, config));
max_order_ = language_model_->Order();

uint64_t package_size;
{
util::scoped_fd fd(util::OpenReadOrThrow(filename));
package_size = util::SizeFile(fd.get());
if (load_from_bytes){
language_model_.reset(lm::ngram::LoadVirtual(lm_string.c_str(), lm_string.size(), config));
} else {
language_model_.reset(lm::ngram::LoadVirtual(lm_string.c_str(), config));
}

max_order_ = language_model_->Order();
std::stringstream stst;
uint64_t trie_offset = language_model_->GetEndOfSearchOffset();
if (package_size <= trie_offset) {
// File ends without a trie structure
return DS_ERR_SCORER_NO_TRIE;

if (!load_from_bytes){
uint64_t package_size;
{
util::scoped_fd fd(util::OpenReadOrThrow(lm_string.c_str()));
package_size = util::SizeFile(fd.get());
}

if (package_size <= trie_offset) {
// File ends without a trie structure
return DS_ERR_SCORER_NO_TRIE;
}
// Read metadata and trie from file
std::ifstream fin(lm_string.c_str(), std::ios::binary);
stst<<fin.rdbuf();
} else {
stst = std::stringstream(lm_string);
}

// Read metadata and trie from file
std::ifstream fin(lm_path, std::ios::binary);
fin.seekg(trie_offset);
return load_trie(fin, lm_path);
stst.seekg(trie_offset);
return load_trie(stst, lm_string, load_from_bytes);
}

int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
int Scorer::load_trie(std::stringstream& fin, const std::string& file_path, bool load_from_bytes)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stringstream is compatible with ifstream ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I tested yes

{

int magic;
fin.read(reinterpret_cast<char*>(&magic), sizeof(magic));
if (magic != MAGIC) {
Expand Down Expand Up @@ -140,9 +157,13 @@ int Scorer::load_trie(std::ifstream& fin, const std::string& file_path)
reset_params(alpha, beta);

fst::FstReadOptions opt;
opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt));
if (load_from_bytes) {
dictionary.reset(fst::ConstFst<fst::StdArc>::Read(fin, opt));
} else {
opt.mode = fst::FstReadOptions::MAP;
opt.source = file_path;
dictionary.reset(FstType::Read(fin, opt));
}
return DS_ERR_OK;
}

Expand Down
6 changes: 4 additions & 2 deletions native_client/ctcdecode/scorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,11 @@ class Scorer {
Scorer& operator=(const Scorer&) = delete;

int init(const std::string &lm_path,
bool load_from_bytes,
const Alphabet &alphabet);

int init(const std::string &lm_path,
bool load_from_bytes,
const std::string &alphabet_config_path);

double get_log_cond_prob(const std::vector<std::string> &words,
Expand Down Expand Up @@ -84,7 +86,7 @@ class Scorer {
void fill_dictionary(const std::unordered_set<std::string> &vocabulary);

// load language model from given path
int load_lm(const std::string &lm_path);
int load_lm(const std::string &lm_path, bool load_from_bytes=false);

// language model weight
double alpha = 0.;
Expand All @@ -98,7 +100,7 @@ class Scorer {
// necessary setup after setting alphabet
void setup_char_map();

int load_trie(std::ifstream& fin, const std::string& file_path);
int load_trie(std::stringstream& fin, const std::string& file_path, bool load_from_bytes=false);

private:
std::unique_ptr<lm::base::Model> language_model_;
Expand Down
58 changes: 50 additions & 8 deletions native_client/deepspeech.cc
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,10 @@ StreamingState::processBatch(const vector<float>& buf, unsigned int n_steps)
}

int
DS_CreateModel(const char* aModelPath,
ModelState** retval)
DS_CreateModel_(const char* aModelString,
bool init_from_bytes,
ModelState** retval,
size_t bufferSize=0)
{
*retval = nullptr;

Expand All @@ -277,7 +279,7 @@ DS_CreateModel(const char* aModelPath,
LOGD("DeepSpeech: %s", ds_git_version());
#endif

if (!aModelPath || strlen(aModelPath) < 1) {
if ( (!init_from_bytes && (strlen(aModelString) < 1)) || (init_from_bytes && (bufferSize<1))) {
std::cerr << "No model specified, cannot continue." << std::endl;
return DS_ERR_NO_MODEL;
}
Expand All @@ -294,8 +296,8 @@ DS_CreateModel(const char* aModelPath,
std::cerr << "Could not allocate model state." << std::endl;
return DS_ERR_FAIL_CREATE_MODEL;
}

int err = model->init(aModelPath);
int err = model->init(aModelString, init_from_bytes, bufferSize);
if (err != DS_ERR_OK) {
return err;
}
Expand All @@ -304,6 +306,22 @@ DS_CreateModel(const char* aModelPath,
return DS_ERR_OK;
}

int
DS_CreateModel(const char* aModelPath,
ModelState** retval)
{
return DS_CreateModel_(aModelPath, false, retval);
}

int
DS_CreateModelFromBuffer(const char* aModelBuffer,
size_t bufferSize,
ModelState** retval)
{
return DS_CreateModel_(aModelBuffer, true, retval, bufferSize);
}


unsigned int
DS_GetModelBeamWidth(const ModelState* aCtx)
{
Expand All @@ -330,18 +348,42 @@ DS_FreeModel(ModelState* ctx)
}

int
DS_EnableExternalScorer(ModelState* aCtx,
const char* aScorerPath)
DS_EnableExternalScorer_(ModelState* aCtx,
const char* aScorerString,
bool init_from_bytes,
size_t bufferSize=0)
{
std::unique_ptr<Scorer> scorer(new Scorer());
int err = scorer->init(aScorerPath, aCtx->alphabet_);

int err;
if (init_from_bytes)
err = scorer->init(std::string(aScorerString, bufferSize), init_from_bytes, aCtx->alphabet_);
else
err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_);


if (err != 0) {
return DS_ERR_INVALID_SCORER;
}
aCtx->scorer_ = std::move(scorer);
return DS_ERR_OK;
}

int
DS_EnableExternalScorer(ModelState* aCtx,
const char* aScorerPath)
{
return DS_EnableExternalScorer_(aCtx, aScorerPath, false);
}

int
DS_EnableExternalScorerFromBuffer(ModelState* aCtx,
const char* aScorerBuffer,
size_t bufferSize)
{
return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true, bufferSize);
}

int
DS_AddHotWord(ModelState* aCtx,
const char* word,
Expand Down
Loading