From 06c23fe6d5dfad85f5c20449515bf273241c0d43 Mon Sep 17 00:00:00 2001 From: daniele Date: Wed, 9 Sep 2020 13:45:49 -0300 Subject: [PATCH 1/9] Loading model from both path or array of bytes --- native_client/Makefile | 2 +- native_client/alphabet.cc | 0 native_client/alphabet.h | 0 native_client/args.h | 8 ++++ native_client/client.cc | 16 ++++++- native_client/deepspeech.cc | 14 ++++-- native_client/deepspeech.h | 18 +++++++- native_client/modelstate.cc | 2 +- native_client/modelstate.h | 2 +- native_client/tflitemodelstate.cc | 31 ++++++++++--- native_client/tflitemodelstate.h | 2 +- native_client/tfmodelstate.cc | 75 ++++++++++++++++++++++--------- native_client/tfmodelstate.h | 2 +- 13 files changed, 135 insertions(+), 37 deletions(-) mode change 100644 => 100755 native_client/Makefile mode change 100644 => 100755 native_client/alphabet.cc mode change 100644 => 100755 native_client/alphabet.h mode change 100644 => 100755 native_client/args.h mode change 100644 => 100755 native_client/client.cc mode change 100644 => 100755 native_client/deepspeech.cc mode change 100644 => 100755 native_client/deepspeech.h mode change 100644 => 100755 native_client/modelstate.cc mode change 100644 => 100755 native_client/modelstate.h mode change 100644 => 100755 native_client/tflitemodelstate.cc mode change 100644 => 100755 native_client/tflitemodelstate.h mode change 100644 => 100755 native_client/tfmodelstate.cc mode change 100644 => 100755 native_client/tfmodelstate.h diff --git a/native_client/Makefile b/native_client/Makefile old mode 100644 new mode 100755 index b645499c28..7115f8ee04 --- a/native_client/Makefile +++ b/native_client/Makefile @@ -19,7 +19,7 @@ clean: rm -f deepspeech $(DEEPSPEECH_BIN): client.cc Makefile - $(CXX) $(CFLAGS) $(CFLAGS_DEEPSPEECH) $(SOX_CFLAGS) client.cc $(LDFLAGS) $(SOX_LDFLAGS) + $(CXX) $(CFLAGS) $(CFLAGS_DEEPSPEECH) $(SOX_CFLAGS) client.cc $(LDFLAGS) $(SOX_LDFLAGS) -llzma -lbz2 ifeq ($(OS),Darwin) install_name_tool -change bazel-out/local-opt/bin/native_client/libdeepspeech.so @rpath/libdeepspeech.so deepspeech endif diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc old mode 100644 new mode 100755 diff --git a/native_client/alphabet.h b/native_client/alphabet.h old mode 100644 new mode 100755 diff --git a/native_client/args.h b/native_client/args.h old mode 100644 new mode 100755 index 069347e09f..15d7b94ccc --- a/native_client/args.h +++ b/native_client/args.h @@ -34,6 +34,8 @@ bool extended_metadata = false; bool json_output = false; +bool init_from_array_of_bytes = false; + int json_candidate_transcripts = 3; int stream_size = 0; @@ -62,6 +64,7 @@ void PrintHelp(const char* bin) "\t--stream size\t\t\tRun in stream mode, output intermediate results\n" "\t--extended_stream size\t\t\tRun in stream mode using metadata output, output intermediate results\n" "\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n" + "\t--init_from_bytes\t\tTest init model and scorer from array of bytes\n" "\t--help\t\t\t\tShow help\n" "\t--version\t\t\tPrint version and exits\n"; char* version = DS_Version(); @@ -83,6 +86,7 @@ bool ProcessArgs(int argc, char** argv) {"t", no_argument, nullptr, 't'}, {"extended", no_argument, nullptr, 'e'}, {"json", no_argument, nullptr, 'j'}, + {"init_from_bytes", no_argument, nullptr, 'B'}, {"candidate_transcripts", required_argument, nullptr, 150}, {"stream", required_argument, nullptr, 's'}, {"extended_stream", required_argument, nullptr, 'S'}, @@ -139,6 +143,10 @@ bool ProcessArgs(int argc, char** argv) case 'j': json_output = true; break; + + case 'B': + init_from_array_of_bytes = true; + break; case 150: json_candidate_transcripts = atoi(optarg); diff --git a/native_client/client.cc b/native_client/client.cc old mode 100644 new mode 100755 index 7d88b4d6c8..ae413d18a5 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -33,6 +33,8 @@ #include #endif // NO_DIR #include +#include +#include #include "deepspeech.h" #include "args.h" @@ -453,7 +455,19 @@ main(int argc, char **argv) // Initialise DeepSpeech ModelState* ctx; // sphinx-doc: c_ref_model_start - int status = DS_CreateModel(model, &ctx); + int status; + if (init_from_array_of_bytes){ + // Reading model file to a char * buffer + std::ifstream is( model, std::ios::binary ); + std::vector buffer(std::istreambuf_iterator(is), {}); + std::string bufferS(buffer.begin(), buffer.end()); + std::cout<<"Loading from buffer"<& buf, unsigned int n_steps) int DS_CreateModel(const char* aModelPath, ModelState** retval) +{ + return DS_CreateModel_(aModelPath, false, retval); +} + +int +DS_CreateModel_(const std::string &aModelString, + bool init_from_bytes, + ModelState** retval) { *retval = nullptr; @@ -277,7 +285,7 @@ DS_CreateModel(const char* aModelPath, LOGD("DeepSpeech: %s", ds_git_version()); #endif - if (!aModelPath || strlen(aModelPath) < 1) { + if (aModelString.length() < 1) { std::cerr << "No model specified, cannot continue." << std::endl; return DS_ERR_NO_MODEL; } @@ -294,8 +302,8 @@ DS_CreateModel(const char* aModelPath, std::cerr << "Could not allocate model state." << std::endl; return DS_ERR_FAIL_CREATE_MODEL; } - - int err = model->init(aModelPath); + + int err = model->init(aModelString, init_from_bytes); if (err != DS_ERR_OK) { return err; } diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h old mode 100644 new mode 100755 index 35e9289a2e..cc9425133d --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -1,6 +1,8 @@ #ifndef DEEPSPEECH_H #define DEEPSPEECH_H +#include + #ifdef __cplusplus extern "C" { #endif @@ -96,7 +98,7 @@ DS_FOR_EACH_ERROR(DEFINE) }; /** - * @brief An object providing an interface to a trained DeepSpeech model. + * @brief An object providing an interface to a trained DeepSpeech model. Maintained for not breaking backwards compatibility * * @param aModelPath The path to the frozen model graph. * @param[out] retval a ModelState pointer @@ -107,6 +109,20 @@ DEEPSPEECH_EXPORT int DS_CreateModel(const char* aModelPath, ModelState** retval); +/** + * @brief An object providing an interface to a trained DeepSpeech model. + * + * @param aModelString The path/string for initializing the model graph. + * @param init_from_bytes Wheter the model will be initialized using path or array of bytes. + * @param[out] retval a ModelState pointer + * + * @return Zero on success, non-zero on failure. + */ +DEEPSPEECH_EXPORT +int DS_CreateModel_(const std::string &aModelString, + bool init_from_bytes, + ModelState** retval); + /** * @brief Get beam width value used by the model. If {@link DS_SetModelBeamWidth} * was not called before, will return the default value loaded from the diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc old mode 100644 new mode 100755 index d8637c3656..5f4ce2e274 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -24,7 +24,7 @@ ModelState::~ModelState() } int -ModelState::init(const char* model_path) +ModelState::init(const std::string &model_string, bool init_from_bytes) { return DS_ERR_OK; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h old mode 100644 new mode 100755 index 4beb78b472..166edff00e --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -31,7 +31,7 @@ struct ModelState { ModelState(); virtual ~ModelState(); - virtual int init(const char* model_path); + virtual int init(const std::string &model_string, bool init_from_bytes); virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) = 0; diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc old mode 100644 new mode 100755 index 50a68a4b94..065067c1ac --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -156,23 +156,40 @@ getTfliteDelegates() } int -TFLiteModelState::init(const char* model_path) +TFLiteModelState::init(const std::string &model_string, bool init_from_bytes) { - int err = ModelState::init(model_path); + int err = ModelState::init(model_string, init_from_bytes); if (err != DS_ERR_OK) { return err; } - fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_path); - if (!fbmodel_) { - std::cerr << "Error at reading model file " << model_path << std::endl; - return DS_ERR_FAIL_INIT_MMAP; + if (init_from_bytes){ + char *tmp_buffer = new char[model_string.size()]; + std::copy(model_string.begin(), model_string.end(), tmp_buffer); + // Using c_str does not work + fbmodel_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(tmp_buffer,model_string.size()); + if (!fbmodel_) { + std::cerr << "Error at reading model buffer " << std::endl; + return DS_ERR_FAIL_INIT_MMAP; + } + } else { + fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_string.c_str()); + if (!fbmodel_) { + std::cerr << "Error at reading model file " << model_string << std::endl; + return DS_ERR_FAIL_INIT_MMAP; + } } + tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*fbmodel_, resolver)(&interpreter_); if (!interpreter_) { - std::cerr << "Error at InterpreterBuilder for model file " << model_path << std::endl; + if (init_from_bytes) { + std::cerr << "Error at InterpreterBuilder for model buffer " << std::endl; + } else { + std::cerr << "Error at InterpreterBuilder for model file " << model_string << std::endl; + } + return DS_ERR_FAIL_INTERPRETER; } diff --git a/native_client/tflitemodelstate.h b/native_client/tflitemodelstate.h old mode 100644 new mode 100755 index ace62ecf89..38de5b9d7d --- a/native_client/tflitemodelstate.h +++ b/native_client/tflitemodelstate.h @@ -31,7 +31,7 @@ struct TFLiteModelState : public ModelState TFLiteModelState(); virtual ~TFLiteModelState(); - virtual int init(const char* model_path) override; + virtual int init(const std::string &model_string, bool init_from_bytes) override; virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) override; diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc old mode 100644 new mode 100755 index 65328e308a..5e5bd45cea --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -22,10 +22,33 @@ TFModelState::~TFModelState() } } +int loadGraphFromBinaryData(Env* env, const std::string& data, + ::tensorflow::protobuf::MessageLite* proto) { + std::string model_buffer; + + std::ifstream graph_input_stream; + + graph_input_stream.open("../exported_model/output_graph.pb", std::ios::binary); + model_buffer = std::string((std::istreambuf_iterator(graph_input_stream)), + (std::istreambuf_iterator())); + graph_input_stream.close(); + + if (!proto->ParseFromString(model_buffer)) { + std::cerr << "Can't parse data as binary proto" << std::endl; + return -1; + } + return 0; +} + int -TFModelState::init(const char* model_path) +TFModelState::init(const std::string &model_string, bool init_from_bytes) { - int err = ModelState::init(model_path); + if (init_from_bytes){ + std::cerr << "=============== Init model from bytes"<InitializeFromFile(model_path); - if (!status.ok()) { - std::cerr << status << std::endl; - return DS_ERR_FAIL_INIT_MMAP; + bool is_mmap = false; + if (init_from_bytes) { + int loadGraphStatus = loadGraphFromBinaryData(Env::Default(), model_string, &graph_def_); + if (loadGraphStatus != 0) { + return DS_ERR_FAIL_CREATE_SESS; } + } else { + is_mmap = model_string.find(".pbmm") != std::string::npos; + if (!is_mmap) { + std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl; + } else { + status = mmap_env_->InitializeFromFile(model_string.c_str()); + if (!status.ok()) { + std::cerr << status << std::endl; + return DS_ERR_FAIL_INIT_MMAP; + } - options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_opt_level(::OptimizerOptions::L0); - options.env = mmap_env_.get(); + options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_opt_level(::OptimizerOptions::L0); + options.env = mmap_env_.get(); + } } Session* session; @@ -59,13 +89,18 @@ TFModelState::init(const char* model_path) } session_.reset(session); - if (is_mmap) { - status = ReadBinaryProto(mmap_env_.get(), - MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, - &graph_def_); + if (init_from_bytes){ + // Need some help } else { - status = ReadBinaryProto(Env::Default(), model_path, &graph_def_); + if (is_mmap) { + status = ReadBinaryProto(mmap_env_.get(), + MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, + &graph_def_); + } else { + status = ReadBinaryProto(Env::Default(), model_string.c_str(), &graph_def_); + } } + if (!status.ok()) { std::cerr << status << std::endl; return DS_ERR_FAIL_READ_PROTOBUF; diff --git a/native_client/tfmodelstate.h b/native_client/tfmodelstate.h old mode 100644 new mode 100755 index 2a8db699df..85628acf3c --- a/native_client/tfmodelstate.h +++ b/native_client/tfmodelstate.h @@ -18,7 +18,7 @@ struct TFModelState : public ModelState TFModelState(); virtual ~TFModelState(); - virtual int init(const char* model_path) override; + virtual int init(const std::string &model_string, bool init_from_bytes) override; virtual void infer(const std::vector& mfcc, unsigned int n_frames, From dc8553b7085d82667beb48c3a11062d9af58f46f Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Tue, 22 Sep 2020 08:24:51 -0300 Subject: [PATCH 2/9] Loading ExtScorer from array of bytes --- doc/BUILDING.rst | 4 +- native_client/client.cc | 21 +++-- native_client/ctcdecode/scorer.cpp | 85 +++++++++++-------- native_client/ctcdecode/scorer.h | 6 +- native_client/deepspeech.cc | 12 ++- native_client/deepspeech.h | 16 ++++ native_client/kenlm/lm/bhiksha.cc | 8 +- native_client/kenlm/lm/bhiksha.hh | 3 +- native_client/kenlm/lm/binary_format.cc | 104 +++++++++++++++++++++++- native_client/kenlm/lm/binary_format.hh | 12 ++- native_client/kenlm/lm/model.cc | 64 ++++++++++++++- native_client/kenlm/lm/model.hh | 14 +++- native_client/kenlm/lm/quantize.cc | 10 ++- native_client/kenlm/lm/quantize.hh | 4 +- native_client/kenlm/lm/search_hashed.hh | 2 +- native_client/kenlm/lm/search_trie.hh | 7 +- native_client/kenlm/lm/vocab.cc | 52 ++++++++++++ native_client/kenlm/lm/vocab.hh | 3 + native_client/kenlm/util/file.cc | 37 +++++++++ native_client/kenlm/util/file.hh | 1 + native_client/kenlm/util/mmap.cc | 82 ++++++++++++++++++- native_client/kenlm/util/mmap.hh | 22 +++-- 22 files changed, 501 insertions(+), 68 deletions(-) diff --git a/doc/BUILDING.rst b/doc/BUILDING.rst index 59f1a3b953..44adb4fbad 100644 --- a/doc/BUILDING.rst +++ b/doc/BUILDING.rst @@ -73,7 +73,7 @@ You can now use Bazel to build the main DeepSpeech library, ``libdeepspeech.so`` .. code-block:: - bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=0" --copt=-fvisibility=hidden //native_client:libdeepspeech.so + bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=1" --copt=-fvisibility=hidden //native_client:libdeepspeech.so The generated binaries will be saved to ``bazel-bin/native_client/``. @@ -87,7 +87,7 @@ Using the example from above you can build the library and that binary at the sa .. code-block:: - bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=0" --copt=-fvisibility=hidden //native_client:libdeepspeech.so //native_client:generate_scorer_package + bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=1" --copt=-fvisibility=hidden //native_client:libdeepspeech.so //native_client:generate_scorer_package The generated binaries will be saved to ``bazel-bin/native_client/``. diff --git a/native_client/client.cc b/native_client/client.cc index ae413d18a5..d60370bd19 100755 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -458,11 +458,10 @@ main(int argc, char **argv) int status; if (init_from_array_of_bytes){ // Reading model file to a char * buffer - std::ifstream is( model, std::ios::binary ); - std::vector buffer(std::istreambuf_iterator(is), {}); - std::string bufferS(buffer.begin(), buffer.end()); - std::cout<<"Loading from buffer"<Order(); - - uint64_t package_size; - { - util::scoped_fd fd(util::OpenReadOrThrow(filename)); - package_size = util::SizeFile(fd.get()); + if (load_from_bytes){ + language_model_.reset(lm::ngram::LoadVirtual(lm_string.c_str(), lm_string.size(), config)); + } else { + language_model_.reset(lm::ngram::LoadVirtual(lm_string.c_str(), config)); } + + max_order_ = language_model_->Order(); + std::stringstream stst; uint64_t trie_offset = language_model_->GetEndOfSearchOffset(); - if (package_size <= trie_offset) { - // File ends without a trie structure - return DS_ERR_SCORER_NO_TRIE; + + if (!load_from_bytes){ + uint64_t package_size; + { + util::scoped_fd fd(util::OpenReadOrThrow(lm_string.c_str())); + package_size = util::SizeFile(fd.get()); + } + + if (package_size <= trie_offset) { + // File ends without a trie structure + return DS_ERR_SCORER_NO_TRIE; + } + // Read metadata and trie from file + std::ifstream fin(lm_string.c_str(), std::ios::binary); + stst<(&magic), sizeof(magic)); if (magic != MAGIC) { @@ -140,9 +157,13 @@ int Scorer::load_trie(std::ifstream& fin, const std::string& file_path) reset_params(alpha, beta); fst::FstReadOptions opt; - opt.mode = fst::FstReadOptions::MAP; - opt.source = file_path; - dictionary.reset(FstType::Read(fin, opt)); + if (load_from_bytes) { + dictionary.reset(fst::ConstFst::Read(fin, opt)); + } else { + opt.mode = fst::FstReadOptions::MAP; + opt.source = file_path; + dictionary.reset(FstType::Read(fin, opt)); + } return DS_ERR_OK; } diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 5aee1046ff..500dd0b28a 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -39,9 +39,11 @@ class Scorer { Scorer& operator=(const Scorer&) = delete; int init(const std::string &lm_path, + bool load_from_bytes, const Alphabet &alphabet); int init(const std::string &lm_path, + bool load_from_bytes, const std::string &alphabet_config_path); double get_log_cond_prob(const std::vector &words, @@ -84,7 +86,7 @@ class Scorer { void fill_dictionary(const std::unordered_set &vocabulary); // load language model from given path - int load_lm(const std::string &lm_path); + int load_lm(const std::string &lm_path, bool load_from_bytes); // language model weight double alpha = 0.; @@ -98,7 +100,7 @@ class Scorer { // necessary setup after setting alphabet void setup_char_map(); - int load_trie(std::ifstream& fin, const std::string& file_path); + int load_trie(std::stringstream& fin, const std::string& file_path, bool load_from_bytes); private: std::unique_ptr language_model_; diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index e3d787b5f4..65d33f23c9 100755 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -340,9 +340,19 @@ DS_FreeModel(ModelState* ctx) int DS_EnableExternalScorer(ModelState* aCtx, const char* aScorerPath) +{ + return DS_EnableExternalScorer_(aCtx, aScorerPath, false); +} + +int +DS_EnableExternalScorer_(ModelState* aCtx, + const std::string &aScorerString, + bool init_from_bytes) { std::unique_ptr scorer(new Scorer()); - int err = scorer->init(aScorerPath, aCtx->alphabet_); + + int err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_); + if (err != 0) { return DS_ERR_INVALID_SCORER; } diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index cc9425133d..7e8fc62e4e 100755 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -176,6 +176,20 @@ DEEPSPEECH_EXPORT int DS_EnableExternalScorer(ModelState* aCtx, const char* aScorerPath); +/** + * @brief Enable decoding using an external scorer. + * + * @param aCtx The ModelState pointer for the model being changed. + * @param aScorerString The path/array_of_bytes to initialize the external scorer. + * @param init_from_bytes Wheter the scorer will be initialized by file or array of bytes. + * + * @return Zero on success, non-zero on failure (invalid arguments). + */ +DEEPSPEECH_EXPORT +int DS_EnableExternalScorer_(ModelState* aCtx, + const std::string &aScorerString, + bool init_from_bytes); + /** * @brief Add a hot-word and its boost. * @@ -212,6 +226,8 @@ int DS_EraseHotWord(ModelState* aCtx, DEEPSPEECH_EXPORT int DS_ClearHotWords(ModelState* aCtx); + + /** * @brief Disable decoding using an external scorer. * diff --git a/native_client/kenlm/lm/bhiksha.cc b/native_client/kenlm/lm/bhiksha.cc index 4262b615e5..e64c0a9276 100644 --- a/native_client/kenlm/lm/bhiksha.cc +++ b/native_client/kenlm/lm/bhiksha.cc @@ -17,9 +17,13 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_ const uint8_t kArrayBhikshaVersion = 0; // TODO: put this in binary file header instead when I change the binary file format again. -void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { +void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config, bool load_from_bytes) { uint8_t buffer[2]; - file.ReadForConfig(buffer, 2, offset); + if(load_from_bytes){ + file.ReadForConfig(buffer, 2, offset, load_from_bytes); + } else { + file.ReadForConfig(buffer, 2, offset); + } uint8_t version = buffer[0]; uint8_t configured_bits = buffer[1]; if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); diff --git a/native_client/kenlm/lm/bhiksha.hh b/native_client/kenlm/lm/bhiksha.hh index 36438f1d29..bc6eb48bff 100644 --- a/native_client/kenlm/lm/bhiksha.hh +++ b/native_client/kenlm/lm/bhiksha.hh @@ -34,6 +34,7 @@ class DontBhiksha { static const ModelType kModelTypeAdd = static_cast(0); static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &, bool) {} static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } @@ -65,7 +66,7 @@ class ArrayBhiksha { public: static const ModelType kModelTypeAdd = kArrayAdd; - static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config, bool load_from_bytes); static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); diff --git a/native_client/kenlm/lm/binary_format.cc b/native_client/kenlm/lm/binary_format.cc index 802943f572..ea64a8f97b 100644 --- a/native_client/kenlm/lm/binary_format.cc +++ b/native_client/kenlm/lm/binary_format.cc @@ -11,6 +11,7 @@ #include #include +#include namespace lm { namespace ngram { @@ -114,6 +115,48 @@ bool IsBinaryFormat(int fd) { return false; } +bool IsBinaryFormat(char *file_data, uint64_t size) { + char *file_data_temp = new char[size]; + memcpy(file_data_temp,file_data, size); + + if (size == util::kBadSize || (size <= static_cast(sizeof(Sanity)))) { + delete[] file_data_temp; + return false; + } + + // Try reading the header. + util::scoped_memory memory; + try { + util::MapRead(util::LAZY, file_data_temp, 0, sizeof(Sanity), memory); + } catch (const util::Exception &e) { + return false; + } + Sanity reference_header = Sanity(); + reference_header.SetToReference(); + + if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) { + return true; + } + if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) { + UTIL_THROW(FormatLoadException, "This binary file did not finish building"); + } + if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { + char *end_ptr; + const char *begin_version = static_cast(memory.get()) + strlen(kMagicBeforeVersion); + long int version = std::strtol(begin_version, &end_ptr, 10); + if ((end_ptr != begin_version) && version != kMagicVersion) { + UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary"); + } + OldSanity old_sanity = OldSanity(); + old_sanity.SetToReference(); + UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable."); + UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture"); + } + return false; +} + + + void ReadHeader(int fd, Parameters &out) { util::SeekOrThrow(fd, sizeof(Sanity)); util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed)); @@ -124,6 +167,23 @@ void ReadHeader(int fd, Parameters &out) { if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); } +void ReadHeader(char *file_data, Parameters &out) { + const char *file_data_tmp = file_data; + file_data_tmp += sizeof(Sanity); + std::memcpy(&out.fixed, file_data_tmp, sizeof(out.fixed)); + file_data_tmp += sizeof(out.fixed); + + if (out.fixed.probing_multiplier < 1.0) + UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); + + out.counts.resize(static_cast(out.fixed.order)); + + if (out.fixed.order) { + std::memcpy(&*out.counts.begin(), file_data_tmp, sizeof(uint64_t) * out.fixed.order); + } +} + + void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) { if (params.fixed.model_type != model_type) { if (static_cast(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) @@ -147,11 +207,26 @@ void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int s header_size_ = TotalHeaderSize(params.counts.size()); } +void BinaryFormat::InitializeBinary(char *file_data, ModelType model_type, unsigned int search_version, Parameters ¶ms) { + file_data_ = file_data; + write_mmap_ = NULL; // Ignore write requests; this is already in binary format. + ReadHeader(file_data, params); + MatchCheck(model_type, search_version, params); + header_size_ = TotalHeaderSize(params.counts.size()); +} + + void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { assert(header_size_ != kInvalidSize); util::ErsatzPRead(file_.get(), to, amount, offset_excluding_header + header_size_); } +void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header, bool load_from_memory) const { + assert(header_size_ != kInvalidSize); + util::ErsatzPRead(file_data_, to, amount, offset_excluding_header + header_size_); +} + + void *BinaryFormat::LoadBinary(std::size_t size) { assert(header_size_ != kInvalidSize); const uint64_t file_size = util::SizeFile(file_.get()); @@ -165,6 +240,17 @@ void *BinaryFormat::LoadBinary(std::size_t size) { return reinterpret_cast(mapping_.get()) + header_size_; } +void *BinaryFormat::LoadBinary(std::size_t size, const uint64_t file_size) { /* Loading the binary from memory */ + assert(header_size_ != kInvalidSize); + // The header is smaller than a page, so we have to map the whole header as well. + uint64_t total_map = static_cast(header_size_) + static_cast(size); + UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); + util::MapRead(load_method_, file_data_, 0, util::CheckOverflow(total_map), mapping_); + vocab_string_offset_ = total_map; + return reinterpret_cast(mapping_.get()) + header_size_; +} + + void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) { vocab_size_ = memory_size; if (!write_mmap_) { @@ -282,7 +368,7 @@ void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsign } void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) { - mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED); + mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, (int) file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED); vocab_base = reinterpret_cast(mapping_.get()) + header_size_; search_base = reinterpret_cast(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_; } @@ -298,5 +384,21 @@ bool RecognizeBinary(const char *file, ModelType &recognized) { return true; } +bool RecognizeBinary(const char *file_data, const uint64_t file_data_size, ModelType &recognized) { + + char *file_data_temp = new char[file_data_size]; + memcpy(file_data_temp, file_data, file_data_size); + + if (!IsBinaryFormat(file_data_temp, file_data_size)){ + return false; + } + + Parameters params; + ReadHeader(file_data_temp, params); + recognized = params.fixed.model_type; + return true; +} + + } // namespace ngram } // namespace lm diff --git a/native_client/kenlm/lm/binary_format.hh b/native_client/kenlm/lm/binary_format.hh index ff99b95741..46fa8d0e0c 100644 --- a/native_client/kenlm/lm/binary_format.hh +++ b/native_client/kenlm/lm/binary_format.hh @@ -24,6 +24,7 @@ extern const char *kModelNames[6]; * this header designed for use by decoder authors. */ bool RecognizeBinary(const char *file, ModelType &recognized); +bool RecognizeBinary(const char *file_data, const uint64_t file_data_size, ModelType &recognized); struct FixedWidthParameters { unsigned char order; @@ -48,13 +49,20 @@ class BinaryFormat { public: explicit BinaryFormat(const Config &config); + ~BinaryFormat(){ + file_data_ = NULL; + } + // Reading a binary file: // Takes ownership of fd void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms); + void InitializeBinary(char *file_data, ModelType model_type, unsigned int search_version, Parameters ¶ms); // Used to read parts of the file to update the config object before figuring out full size. void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const; + void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header, bool useMemory) const; // Actually load the binary file and return a pointer to the beginning of the search area. void *LoadBinary(std::size_t size); + void *LoadBinary(std::size_t size, const uint64_t file_size); uint64_t VocabStringReadingOffset() const { assert(vocab_string_offset_ != kInvalidOffset); @@ -81,9 +89,10 @@ class BinaryFormat { // File behind memory, if any. util::scoped_fd file_; + char *file_data_; // If there is a file involved, a single mapping. - util::scoped_memory mapping_; + util::scoped_memory mapping_= new util::scoped_memory(true); // If the data is only in memory, separately allocate each because the trie // knows vocab's size before it knows search's size (because SRILM might @@ -100,6 +109,7 @@ class BinaryFormat { }; bool IsBinaryFormat(int fd); +bool IsBinaryFormat(char *file_data, uint64_t size); } // namespace ngram } // namespace lm diff --git a/native_client/kenlm/lm/model.cc b/native_client/kenlm/lm/model.cc index fc4e374c88..758cc05198 100644 --- a/native_client/kenlm/lm/model.cc +++ b/native_client/kenlm/lm/model.cc @@ -66,7 +66,7 @@ template GenericModel::Ge Config new_config(init_config); new_config.probing_multiplier = parameters.fixed.probing_multiplier; - Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config); + Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config, false); UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config); @@ -89,6 +89,48 @@ template GenericModel::Ge P::Init(begin_sentence, null_context, vocab_, search_.Order()); } +template GenericModel::GenericModel(const char *file_data, const uint64_t file_data_size, const Config &init_config) : backing_(init_config) { + char *file_data_temp = new char[file_data_size]; + memcpy(file_data_temp, file_data, file_data_size); + + if (IsBinaryFormat(file_data_temp, file_data_size)) { + Parameters parameters; + backing_.InitializeBinary(file_data_temp, kModelType, kVersion, parameters); + CheckCounts(parameters.counts); + + Config new_config(init_config); + new_config.probing_multiplier = parameters.fixed.probing_multiplier; + Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config, true); + + UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + + SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config), file_data_size), parameters.counts, new_config); + + vocab_.LoadedBinary(parameters.fixed.has_vocabulary, file_data_temp, new_config.enumerate_vocab, backing_.VocabStringReadingOffset(), true); + + delete[] file_data_temp; + } else { + std::cerr << "Fatal error: Not binary!" << std::endl; + delete[] file_data_temp; + return; + } + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + typename Search::Node ignored_node; + bool ignored_independent_left; + uint64_t ignored_extend_left; + + begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); + + State null_context = State(); + null_context.length = 0; + P::Init(begin_sentence, null_context, vocab_, search_.Order()); +} + + template void GenericModel::InitializeFromARPA(int fd, const char *file, const Config &config) { // Backing file is the ARPA. util::FilePiece f(fd, file, config.ProgressMessages()); @@ -348,6 +390,26 @@ base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); } } +base::Model *LoadVirtual(const char *file_data, const uint64_t file_data_size, const Config &config, ModelType model_type) { + RecognizeBinary(file_data, file_data_size, model_type); + switch (model_type) { + case PROBING: + UTIL_THROW(FormatLoadException, "Probing without memory option " << model_type); + case REST_PROBING: + UTIL_THROW(FormatLoadException, "Rest Probing without memory option " << model_type); + case TRIE: + UTIL_THROW(FormatLoadException, "Trie without memory option " << model_type); + case QUANT_TRIE: + UTIL_THROW(FormatLoadException, "Quant Trie without memory option " << model_type); + case ARRAY_TRIE: + UTIL_THROW(FormatLoadException, "Array Trie without memory option " << model_type); + case QUANT_ARRAY_TRIE: + return new QuantArrayTrieModelMemory(file_data, file_data_size, config); + default: + UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); + } +} + } // namespace ngram } // namespace lm diff --git a/native_client/kenlm/lm/model.hh b/native_client/kenlm/lm/model.hh index 9b7206e8d7..8c39b27a2f 100644 --- a/native_client/kenlm/lm/model.hh +++ b/native_client/kenlm/lm/model.hh @@ -49,7 +49,7 @@ template class GenericModel : public base::Mod * lm/binary_format.hh. */ explicit GenericModel(const char *file, const Config &config = Config()); - + explicit GenericModel(const char *file_data, const uint64_t file_data_size, const Config &config = Config()); /* Score p(new_word | in_state) and incorporate new_word into out_state. * Note that in_state and out_state must be different references: * &in_state != &out_state. @@ -133,23 +133,33 @@ template class GenericModel : public base::Mod class name : public from {\ public:\ name(const char *file, const Config &config = Config()) : from(file, config) {}\ + }; + +#define LM_NAME_MODEL_FROM_MEMORY(name, from)\ +class name : public from {\ + public:\ + name(const char *file, const Config &config = Config()) : from(file, config) {}\ + name(const char *file_data, size_t file_data_size, const Config &config = Config()) : from(file_data, file_data_size, config) {}\ }; + LM_NAME_MODEL(ProbingModel, detail::GenericModel LM_COMMA() ProbingVocabulary>); LM_NAME_MODEL(RestProbingModel, detail::GenericModel LM_COMMA() ProbingVocabulary>); LM_NAME_MODEL(TrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); LM_NAME_MODEL(QuantTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); +LM_NAME_MODEL_FROM_MEMORY(QuantArrayTrieModelMemory, detail::GenericModel LM_COMMA() SortedVocabulary>); // Default implementation. No real reason for it to be the default. typedef ::lm::ngram::ProbingVocabulary Vocabulary; typedef ProbingModel Model; - +typedef QuantArrayTrieModelMemory ModelMemory; /* Autorecognize the file type, load, and return the virtual base class. Don't * use the virtual base class if you can avoid it. Instead, use the above * classes as template arguments to your own virtual feature function.*/ base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING); +base::Model *LoadVirtual(const char *file_data, const uint64_t file_data_size, const Config &config = Config(), ModelType if_arpa = PROBING); } // namespace ngram } // namespace lm diff --git a/native_client/kenlm/lm/quantize.cc b/native_client/kenlm/lm/quantize.cc index 02b5dbc0e0..99b00b6043 100644 --- a/native_client/kenlm/lm/quantize.cc +++ b/native_client/kenlm/lm/quantize.cc @@ -38,9 +38,15 @@ const char kSeparatelyQuantizeVersion = 2; } // namespace -void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { +void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config, bool load_from_memory) { unsigned char buffer[3]; - file.ReadForConfig(buffer, 3, offset); + if(load_from_memory){ + file.ReadForConfig(buffer, 3, offset, load_from_memory); + }else{ + file.ReadForConfig(buffer, 3, offset); + } + std::string strBuffer((char*)buffer,3); + char version = buffer[0]; config.prob_bits = buffer[1]; config.backoff_bits = buffer[2]; diff --git a/native_client/kenlm/lm/quantize.hh b/native_client/kenlm/lm/quantize.hh index 8500aceec0..457736eeaa 100644 --- a/native_client/kenlm/lm/quantize.hh +++ b/native_client/kenlm/lm/quantize.hh @@ -24,7 +24,7 @@ class BinaryFormat; class DontQuantize { public: static const ModelType kModelTypeAdd = static_cast(0); - static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &, bool) {} static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -137,7 +137,7 @@ class SeparatelyQuantize { public: static const ModelType kModelTypeAdd = kQuantAdd; - static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config, bool load_from_memory); static uint64_t Size(uint8_t order, const Config &config) { uint64_t longest_table = (static_cast(1) << static_cast(config.prob_bits)) * sizeof(float); diff --git a/native_client/kenlm/lm/search_hashed.hh b/native_client/kenlm/lm/search_hashed.hh index 9dc84454c9..42e051ffcd 100644 --- a/native_client/kenlm/lm/search_hashed.hh +++ b/native_client/kenlm/lm/search_hashed.hh @@ -72,7 +72,7 @@ template class HashedSearch { static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. - static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector &, uint64_t, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector &, uint64_t, Config &, bool) {} static uint64_t Size(const std::vector &counts, const Config &config) { uint64_t ret = Unigram::Size(counts[0]); diff --git a/native_client/kenlm/lm/search_trie.hh b/native_client/kenlm/lm/search_trie.hh index 1adba6e5ea..95e8dafaba 100644 --- a/native_client/kenlm/lm/search_trie.hh +++ b/native_client/kenlm/lm/search_trie.hh @@ -38,11 +38,12 @@ template class TrieSearch { static const unsigned int kVersion = 1; - static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector &counts, uint64_t offset, Config &config) { - Quant::UpdateConfigFromBinary(file, offset, config); + static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector &counts, uint64_t offset, Config &config, bool load_from_memory) { + Quant::UpdateConfigFromBinary(file, offset, config, load_from_memory); + // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. if (counts.size() > 2) - Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config); + Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config, load_from_memory); } static uint64_t Size(const std::vector &counts, const Config &config) { diff --git a/native_client/kenlm/lm/vocab.cc b/native_client/kenlm/lm/vocab.cc index 7996ec7e2b..ef2e78e3ee 100644 --- a/native_client/kenlm/lm/vocab.cc +++ b/native_client/kenlm/lm/vocab.cc @@ -14,6 +14,7 @@ #include #include +#include namespace lm { namespace ngram { @@ -51,6 +52,36 @@ void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint } UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file."); } +void ReadWords(char* file_data, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) { + const char *file_data_tmp = file_data; + file_data_tmp += offset; + + // Check that we're at the right place by reading which is always first. + char check_unk[6]; + std::memcpy(check_unk, file_data_tmp, 6); + file_data_tmp += 6; + + UTIL_THROW_IF( + memcmp(check_unk, "", 6), + FormatLoadException, + "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure."); + if (!enumerate) { + return; + } + enumerate->Add(0, ""); + + WordIndex index = 1; // Read already. + std::istringstream in(file_data_tmp); + + for (std::string line; std::getline(in, line); ) + { + // std::cerr << "LINHA -> " << line << std::endl; + enumerate->Add(index, line); + } + + UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file."); +} + // Constructor ordering madness. int SeekAndReturn(int fd, uint64_t start) { @@ -192,6 +223,16 @@ void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, if (have_words) ReadWords(fd, to, bound_, offset); } +void SortedVocabulary::LoadedBinary(bool have_words, char* file_data, EnumerateVocab *to, uint64_t offset, bool load_from_memory) { + end_ = begin_ + *(reinterpret_cast(begin_) - 1); + SetSpecial(Index(""), Index(""), 0); + bound_ = end_ - begin_ + 1; + if (have_words) { + ReadWords(file_data, to, bound_, offset); + } +} + + template void SortedVocabulary::GenericFinished(T *reorder) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { @@ -282,6 +323,17 @@ void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to if (have_words) ReadWords(fd, to, bound_, offset); } +void ProbingVocabulary::LoadedBinary(bool have_words, char* file_data, EnumerateVocab *to, uint64_t offset, bool load_from_memory) { + UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); + bound_ = header_->bound; + + SetSpecial(Index(""), Index(""), 0); + if (have_words) { + ReadWords(file_data, to, bound_, offset); + } +} + + void MissingUnknown(const Config &config) { switch(config.unknown_missing) { case SILENT: diff --git a/native_client/kenlm/lm/vocab.hh b/native_client/kenlm/lm/vocab.hh index f36e62ca21..a3456103fe 100644 --- a/native_client/kenlm/lm/vocab.hh +++ b/native_client/kenlm/lm/vocab.hh @@ -14,6 +14,7 @@ #include #include #include +#include namespace lm { struct ProbBackoff; @@ -111,6 +112,7 @@ class SortedVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); + void LoadedBinary(bool have_words, char* file_data, EnumerateVocab *to, uint64_t offset, bool load_from_memory); uint64_t *&EndHack() { return end_; } @@ -190,6 +192,7 @@ class ProbingVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); + void LoadedBinary(bool have_words, char* file_data, EnumerateVocab *to, uint64_t offset, bool load_from_memory); private: void InternalFinishedLoading(); diff --git a/native_client/kenlm/util/file.cc b/native_client/kenlm/util/file.cc index 1a70387e48..aa76ae6929 100644 --- a/native_client/kenlm/util/file.cc +++ b/native_client/kenlm/util/file.cc @@ -236,6 +236,43 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size) { UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), ErrnoException, "Short write; requested size " << size); } +void ErsatzPRead(char *file_data, void *to_void, std::size_t size, uint64_t off) { + uint8_t *to = static_cast(to_void); + while (size) { +#if defined(_WIN32) || defined(_WIN64) + /* BROKEN: changes file pointer. Even if you save it and change it back, it won't be safe to use concurrently with write() or read() which lmplz does. */ + // size_t might be 64-bit. DWORD is always 32. + DWORD reading = static_cast(std::min(kMaxDWORD, size)); + DWORD ret; + OVERLAPPED overlapped; + memset(&overlapped, 0, sizeof(OVERLAPPED)); + overlapped.Offset = static_cast(off); + overlapped.OffsetHigh = static_cast(off >> 32); + UTIL_THROW_IF(!ReadFile((HANDLE)_get_osfhandle(fd), to, reading, &ret, &overlapped), WindowsException, "ReadFile failed for offset " << off); +#else + ssize_t ret; + errno = 0; + ret = GuardLarge(size); + + file_data += off; + #if defined(_WIN32) || defined(_WIN64) + CopyMemory(out.get(), file_data, size); + #else + std::memcpy(to, file_data, GuardLarge(size)); + #endif + + if (ret <= 0) { + if (ret == -1 && errno == EINTR) continue; + UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from buffer"); + } +#endif + size -= ret; + off += ret; + to += ret; + } +} + + void ErsatzPRead(int fd, void *to_void, std::size_t size, uint64_t off) { uint8_t *to = static_cast(to_void); while (size) { diff --git a/native_client/kenlm/util/file.hh b/native_client/kenlm/util/file.hh index 4a50e73060..f72eab54cc 100644 --- a/native_client/kenlm/util/file.hh +++ b/native_client/kenlm/util/file.hh @@ -134,6 +134,7 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size); * above. */ void ErsatzPRead(int fd, void *to, std::size_t size, uint64_t off); +void ErsatzPRead(char *file_data, void *to_void, std::size_t size, uint64_t off); void ErsatzPWrite(int fd, const void *data_void, std::size_t size, uint64_t off); void FSyncOrThrow(int fd); diff --git a/native_client/kenlm/util/mmap.cc b/native_client/kenlm/util/mmap.cc index 39b9cd598d..35de58c157 100644 --- a/native_client/kenlm/util/mmap.cc +++ b/native_client/kenlm/util/mmap.cc @@ -20,6 +20,10 @@ #if defined(_WIN32) || defined(_WIN64) #include #include +#include +#include +#include + #else #include #include @@ -56,17 +60,17 @@ template T RoundUpPow2(T value, T mult) { } } // namespace -scoped_memory::scoped_memory(std::size_t size, bool zeroed) : data_(NULL), size_(0), source_(NONE_ALLOCATED) { +scoped_memory::scoped_memory(std::size_t size, bool zeroed, bool load_from_memory) : data_(NULL), size_(0), source_(NONE_ALLOCATED), load_from_memory_(load_from_memory) { HugeMalloc(size, zeroed, *this); } void scoped_memory::reset(void *data, std::size_t size, Alloc source) { switch(source_) { case MMAP_ROUND_UP_ALLOCATED: - scoped_mmap(data_, RoundUpPow2(size_, (std::size_t)SizePage())); + scoped_mmap(data_, RoundUpPow2(size_, (std::size_t)SizePage()), load_from_memory_); break; case MMAP_ALLOCATED: - scoped_mmap(data_, size_); + scoped_mmap(data_, size_, load_from_memory_); break; case MALLOC_ALLOCATED: free(data_); @@ -130,6 +134,46 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int return ret; } +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, char *file_data, uint64_t offset) { +#ifdef MAP_POPULATE // Linux specific + if (prefault) { + flags |= MAP_POPULATE; + } +#endif +#if defined(_WIN32) || defined(_WIN64) + + TCHAR szName[]=TEXT("Global\\LanguageModelMapping"); + + int protectC = for_write ? PAGE_READWRITE : PAGE_READONLY; + int protectM = for_write ? FILE_MAP_WRITE : FILE_MAP_READ; + uint64_t total_size = size + offset; + // HANDLE hMapping = CreateFileMapping((HANDLE)_get_osfhandle(fd), NULL, protectC, total_size >> 32, static_cast(total_size), NULL); + + HANDLE hMapping = CreateFileMapping(INVALID_HANDLE_VALUE, NULL, PAGE_READWRITE, total_size >> 32, static_cast(total_size), szName); + UTIL_THROW_IF(!hMapping, ErrnoException, "CreateFileMapping failed"); + // LPVOID ret = MapViewOfFile(hMapping, protectM, offset >> 32, offset, size); + + LPVOID ret = MapViewOfFile(hMapping, FILE_MAP_ALL_ACCESS, offset >> 32, offset, size); + + if (ret == NULL){ + std::cerr<< "ret NULL" << GetLastError() << std::endl; + } + + CloseHandle(hMapping); + UTIL_THROW_IF(!ret, ErrnoException, "MapViewOfFile failed"); + CopyMemory((PVOID)ret, file_data, total_size); +#else + // int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; + int protect = PROT_READ | PROT_WRITE; + void *ret; + flags = MAP_ANONYMOUS|MAP_SHARED; + UTIL_THROW_IF((ret = mmap(NULL, size, protect, flags, -1, offset)) == MAP_FAILED, ErrnoException, "mmap failed for size " << size << " at offset " << offset); + memcpy(ret, file_data, size); +#endif + return ret; +} + + void SyncOrThrow(void *start, size_t length) { #if defined(_WIN32) || defined(_WIN64) UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap"); @@ -323,6 +367,38 @@ void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scope } } +void MapRead(LoadMethod method, char *file_data, uint64_t offset, std::size_t size, scoped_memory &out) { + switch (method) { + case LAZY: + out.reset(MapOrThrow(size, false, kFileFlags, false, file_data, offset), size, scoped_memory::MMAP_ALLOCATED); + break; + case POPULATE_OR_LAZY: +#ifdef MAP_POPULATE + case POPULATE_OR_READ: +#endif + out.reset(MapOrThrow(size, false, kFileFlags, true, file_data, offset), size, scoped_memory::MMAP_ALLOCATED); + break; +#ifndef MAP_POPULATE + case POPULATE_OR_READ: +#endif + case READ: + HugeMalloc(size, false, out); + file_data += offset; + #if defined(_WIN32) || defined(_WIN64) + CopyMemory(out.get(), file_data, size); + #else + std::memcpy(out.get(), file_data, size); + #endif + break; + case PARALLEL_READ: + HugeMalloc(size, false, out); + file_data += offset; + std::memcpy(out.get(), file_data, size); + break; + } +} + + void *MapZeroedWrite(int fd, std::size_t size) { ResizeOrThrow(fd, 0); ResizeOrThrow(fd, size); diff --git a/native_client/kenlm/util/mmap.hh b/native_client/kenlm/util/mmap.hh index 21bca1ddc6..cc51c1b549 100644 --- a/native_client/kenlm/util/mmap.hh +++ b/native_client/kenlm/util/mmap.hh @@ -17,8 +17,8 @@ std::size_t SizePage(); // (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here. class scoped_mmap { public: - scoped_mmap() : data_((void*)-1), size_(0) {} - scoped_mmap(void *data, std::size_t size) : data_(data), size_(size) {} + scoped_mmap(bool load_from_memory=false) : data_((void*)-1), size_(0), load_from_memory_(load_from_memory) {} + scoped_mmap(void *data, std::size_t size, bool load_from_memory=false) : data_(data), size_(size), load_from_memory_(load_from_memory) {} ~scoped_mmap(); void *get() const { return data_; } @@ -29,6 +29,8 @@ class scoped_mmap { char *end() { return reinterpret_cast(data_) + size_; } std::size_t size() const { return size_; } + bool load_from_memory_; + void reset(void *data, std::size_t size) { scoped_mmap other(data_, size_); data_ = data; @@ -66,13 +68,16 @@ class scoped_memory { NONE_ALLOCATED // nothing to free (though there can be something here if it's owned by somebody else). } Alloc; - scoped_memory(void *data, std::size_t size, Alloc source) - : data_(data), size_(size), source_(source) {} + scoped_memory(void *data, std::size_t size, Alloc source, bool load_from_memory=false) + : data_(data), size_(size), source_(source), load_from_memory_(load_from_memory) {} + + + scoped_memory(bool load_from_memory_=false) : data_(NULL), size_(0), + source_(NONE_ALLOCATED), load_from_memory_(load_from_memory_) {} - scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {} // Calls HugeMalloc - scoped_memory(std::size_t to, bool zero_new); + scoped_memory(std::size_t to, bool zero_new, bool load_from_memory_=false); #if __cplusplus >= 201103L scoped_memory(scoped_memory &&from) noexcept @@ -91,6 +96,9 @@ class scoped_memory { char *end() { return reinterpret_cast(data_) + size_; } std::size_t size() const { return size_; } + bool load_from_memory_; + + Alloc source() const { return source_; } void reset() { reset(NULL, 0, NONE_ALLOCATED); } @@ -119,6 +127,7 @@ extern const int kFileFlags; // Cross-platform, error-checking wrapper for mmap(). void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, uint64_t offset = 0); +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, char *file_data, uint64_t offset = 0); // msync wrapper void SyncOrThrow(void *start, size_t length); @@ -153,6 +162,7 @@ enum LoadMethod { }; void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scoped_memory &out); +void MapRead(LoadMethod method, char *file_data, uint64_t offset, std::size_t size, scoped_memory &out); // Open file name with mmap of size bytes, all of which are initially zero. void *MapZeroedWrite(int fd, std::size_t size); From d6a4e374814f4675e1a441d9997083ceda208b14 Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Tue, 22 Sep 2020 09:21:32 -0300 Subject: [PATCH 3/9] removing debug info --- native_client/tfmodelstate.cc | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc index 5e5bd45cea..7ec8a76777 100755 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -24,16 +24,8 @@ TFModelState::~TFModelState() int loadGraphFromBinaryData(Env* env, const std::string& data, ::tensorflow::protobuf::MessageLite* proto) { - std::string model_buffer; - - std::ifstream graph_input_stream; - - graph_input_stream.open("../exported_model/output_graph.pb", std::ios::binary); - model_buffer = std::string((std::istreambuf_iterator(graph_input_stream)), - (std::istreambuf_iterator())); - graph_input_stream.close(); - - if (!proto->ParseFromString(model_buffer)) { + + if (!proto->ParseFromString(data)) { std::cerr << "Can't parse data as binary proto" << std::endl; return -1; } @@ -43,11 +35,6 @@ int loadGraphFromBinaryData(Env* env, const std::string& data, int TFModelState::init(const std::string &model_string, bool init_from_bytes) { - if (init_from_bytes){ - std::cerr << "=============== Init model from bytes"< Date: Fri, 25 Sep 2020 12:09:37 -0300 Subject: [PATCH 4/9] Exposing methods DS_CreateModelFromBuffer and DS_EnableExternalScorerFromBuffer --- doc/BUILDING.rst | 4 ++-- native_client/client.cc | 4 ++-- native_client/deepspeech.cc | 14 ++++++++++++++ native_client/deepspeech.h | 25 +++++++++++++++++++++++++ 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/doc/BUILDING.rst b/doc/BUILDING.rst index 44adb4fbad..59f1a3b953 100644 --- a/doc/BUILDING.rst +++ b/doc/BUILDING.rst @@ -73,7 +73,7 @@ You can now use Bazel to build the main DeepSpeech library, ``libdeepspeech.so`` .. code-block:: - bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=1" --copt=-fvisibility=hidden //native_client:libdeepspeech.so + bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=0" --copt=-fvisibility=hidden //native_client:libdeepspeech.so The generated binaries will be saved to ``bazel-bin/native_client/``. @@ -87,7 +87,7 @@ Using the example from above you can build the library and that binary at the sa .. code-block:: - bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=1" --copt=-fvisibility=hidden //native_client:libdeepspeech.so //native_client:generate_scorer_package + bazel build --workspace_status_command="bash native_client/bazel_workspace_status_cmd.sh" --config=monolithic -c opt --copt=-O3 --copt="-D_GLIBCXX_USE_CXX11_ABI=0" --copt=-fvisibility=hidden //native_client:libdeepspeech.so //native_client:generate_scorer_package The generated binaries will be saved to ``bazel-bin/native_client/``. diff --git a/native_client/client.cc b/native_client/client.cc index d60370bd19..b5ef039ae3 100755 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -461,7 +461,7 @@ main(int argc, char **argv) std::ifstream is_model( model, std::ios::binary ); std::stringstream buffer_model; buffer_model << is_model.rdbuf(); - status = DS_CreateModel_(buffer_model.str(), true, &ctx); + status = DS_CreateModelFromBuffer(buffer_model.str(), &ctx); }else { // Keep old method due to backwards compatibility status = DS_CreateModel(model, &ctx); @@ -488,7 +488,7 @@ main(int argc, char **argv) std::ifstream is_scorer(scorer, std::ios::binary ); std::stringstream buffer_scorer; buffer_scorer << is_scorer.rdbuf(); - status = DS_EnableExternalScorer_(ctx, buffer_scorer.str(), true); + status = DS_EnableExternalScorerFromBuffer(ctx, buffer_scorer.str()); } else { // Keep old method due to backwards compatibility status = DS_EnableExternalScorer(ctx, scorer); diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 65d33f23c9..8b0605d220 100755 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -269,6 +269,13 @@ DS_CreateModel(const char* aModelPath, return DS_CreateModel_(aModelPath, false, retval); } +int +DS_CreateModelFromBuffer(const std::string &aModelBuffer, + ModelState** retval) +{ + return DS_CreateModel_(aModelBuffer, true, retval); +} + int DS_CreateModel_(const std::string &aModelString, bool init_from_bytes, @@ -344,6 +351,13 @@ DS_EnableExternalScorer(ModelState* aCtx, return DS_EnableExternalScorer_(aCtx, aScorerPath, false); } +int +DS_EnableExternalScorerFromBuffer(ModelState* aCtx, + const std::string &aScorerBuffer) +{ + return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true); +} + int DS_EnableExternalScorer_(ModelState* aCtx, const std::string &aScorerString, diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 7e8fc62e4e..10f53775b6 100755 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -109,6 +109,19 @@ DEEPSPEECH_EXPORT int DS_CreateModel(const char* aModelPath, ModelState** retval); +/** + * @brief An object providing an interface to a trained DeepSpeech model, loaded from buffer. + * + * @param aModelBuffer The buffer containing the content of frozen model graph. + * @param[out] retval a ModelState pointer + * + * @return Zero on success, non-zero on failure. + */ +DEEPSPEECH_EXPORT +int DS_CreateModelFromBuffer(const std::string &aModelBuffer, + ModelState** retval); + + /** * @brief An object providing an interface to a trained DeepSpeech model. * @@ -176,6 +189,18 @@ DEEPSPEECH_EXPORT int DS_EnableExternalScorer(ModelState* aCtx, const char* aScorerPath); +/** + * @brief Enable decoding using an external scorer loaded from buffer. + * + * @param aCtx The ModelState pointer for the model being changed. + * @param aScorerBuffer The buffer containing the content of an external-scorer file. + * + * @return Zero on success, non-zero on failure (invalid arguments). + */ +DEEPSPEECH_EXPORT +int DS_EnableExternalScorerFromBuffer(ModelState* aCtx, + const std::string &aScorerBuffer); + /** * @brief Enable decoding using an external scorer. * From dd73ec8711c38089b10614e54dd517815b1f5630 Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Mon, 28 Sep 2020 11:17:30 -0300 Subject: [PATCH 5/9] Loading tflite buffer from buffer working --- native_client/tflitemodelstate.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc index 065067c1ac..b72b1fa1ae 100755 --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -180,7 +180,6 @@ TFLiteModelState::init(const std::string &model_string, bool init_from_bytes) } } - tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*fbmodel_, resolver)(&interpreter_); if (!interpreter_) { From 9faeb51f23bd9187316c09c429d08aaf0ec33dba Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Mon, 28 Sep 2020 11:37:49 -0300 Subject: [PATCH 6/9] Fixing API functions exposition --- native_client/deepspeech.cc | 57 +++++++++++++++++++------------------ native_client/deepspeech.h | 27 ------------------ 2 files changed, 29 insertions(+), 55 deletions(-) diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index 8b0605d220..d7f26ea0c2 100755 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -262,20 +262,6 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) num_classes); } -int -DS_CreateModel(const char* aModelPath, - ModelState** retval) -{ - return DS_CreateModel_(aModelPath, false, retval); -} - -int -DS_CreateModelFromBuffer(const std::string &aModelBuffer, - ModelState** retval) -{ - return DS_CreateModel_(aModelBuffer, true, retval); -} - int DS_CreateModel_(const std::string &aModelString, bool init_from_bytes, @@ -319,6 +305,21 @@ DS_CreateModel_(const std::string &aModelString, return DS_ERR_OK; } +int +DS_CreateModel(const char* aModelPath, + ModelState** retval) +{ + return DS_CreateModel_(aModelPath, false, retval); +} + +int +DS_CreateModelFromBuffer(const std::string &aModelBuffer, + ModelState** retval) +{ + return DS_CreateModel_(aModelBuffer, true, retval); +} + + unsigned int DS_GetModelBeamWidth(const ModelState* aCtx) { @@ -344,20 +345,6 @@ DS_FreeModel(ModelState* ctx) delete ctx; } -int -DS_EnableExternalScorer(ModelState* aCtx, - const char* aScorerPath) -{ - return DS_EnableExternalScorer_(aCtx, aScorerPath, false); -} - -int -DS_EnableExternalScorerFromBuffer(ModelState* aCtx, - const std::string &aScorerBuffer) -{ - return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true); -} - int DS_EnableExternalScorer_(ModelState* aCtx, const std::string &aScorerString, @@ -374,6 +361,20 @@ DS_EnableExternalScorer_(ModelState* aCtx, return DS_ERR_OK; } +int +DS_EnableExternalScorer(ModelState* aCtx, + const char* aScorerPath) +{ + return DS_EnableExternalScorer_(aCtx, aScorerPath, false); +} + +int +DS_EnableExternalScorerFromBuffer(ModelState* aCtx, + const std::string &aScorerBuffer) +{ + return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true); +} + int DS_AddHotWord(ModelState* aCtx, const char* word, diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 10f53775b6..673e518e23 100755 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -122,20 +122,6 @@ int DS_CreateModelFromBuffer(const std::string &aModelBuffer, ModelState** retval); -/** - * @brief An object providing an interface to a trained DeepSpeech model. - * - * @param aModelString The path/string for initializing the model graph. - * @param init_from_bytes Wheter the model will be initialized using path or array of bytes. - * @param[out] retval a ModelState pointer - * - * @return Zero on success, non-zero on failure. - */ -DEEPSPEECH_EXPORT -int DS_CreateModel_(const std::string &aModelString, - bool init_from_bytes, - ModelState** retval); - /** * @brief Get beam width value used by the model. If {@link DS_SetModelBeamWidth} * was not called before, will return the default value loaded from the @@ -201,19 +187,6 @@ DEEPSPEECH_EXPORT int DS_EnableExternalScorerFromBuffer(ModelState* aCtx, const std::string &aScorerBuffer); -/** - * @brief Enable decoding using an external scorer. - * - * @param aCtx The ModelState pointer for the model being changed. - * @param aScorerString The path/array_of_bytes to initialize the external scorer. - * @param init_from_bytes Wheter the scorer will be initialized by file or array of bytes. - * - * @return Zero on success, non-zero on failure (invalid arguments). - */ -DEEPSPEECH_EXPORT -int DS_EnableExternalScorer_(ModelState* aCtx, - const std::string &aScorerString, - bool init_from_bytes); /** * @brief Add a hot-word and its boost. From 71e19bc5ba2d8fb5887a075dd4e2abf60e0a450a Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Mon, 28 Sep 2020 13:05:29 -0300 Subject: [PATCH 7/9] load_trie with default value for load_from_bytes --- native_client/ctcdecode/scorer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 500dd0b28a..7be312bb9d 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -100,7 +100,7 @@ class Scorer { // necessary setup after setting alphabet void setup_char_map(); - int load_trie(std::stringstream& fin, const std::string& file_path, bool load_from_bytes); + int load_trie(std::stringstream& fin, const std::string& file_path, bool load_from_bytes=false); private: std::unique_ptr language_model_; From 3b9d2b00bb953b5f44b81143e41b425f4e86682c Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Tue, 29 Sep 2020 09:17:32 -0300 Subject: [PATCH 8/9] Seting a default value for load_from_bytes in load_lm --- native_client/ctcdecode/scorer.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 7be312bb9d..08713e4e83 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -86,7 +86,7 @@ class Scorer { void fill_dictionary(const std::unordered_set &vocabulary); // load language model from given path - int load_lm(const std::string &lm_path, bool load_from_bytes); + int load_lm(const std::string &lm_path, bool load_from_bytes=false); // language model weight double alpha = 0.; From df40e22ee3474096a08acab85fe3d84997720185 Mon Sep 17 00:00:00 2001 From: Bernardo Henz Date: Wed, 30 Sep 2020 12:18:25 -0300 Subject: [PATCH 9/9] -Change string to char* in the API --- native_client/client.cc | 7 +++++-- native_client/deepspeech.cc | 31 ++++++++++++++++++++----------- native_client/deepspeech.h | 7 +++++-- native_client/modelstate.cc | 2 +- native_client/modelstate.h | 2 +- native_client/tflitemodelstate.cc | 14 +++++--------- native_client/tflitemodelstate.h | 2 +- native_client/tfmodelstate.cc | 22 +++++++++++++--------- native_client/tfmodelstate.h | 2 +- 9 files changed, 52 insertions(+), 37 deletions(-) diff --git a/native_client/client.cc b/native_client/client.cc index b5ef039ae3..5b2240f868 100755 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -454,6 +454,7 @@ main(int argc, char **argv) // Initialise DeepSpeech ModelState* ctx; + std::string buffer_model_str; // sphinx-doc: c_ref_model_start int status; if (init_from_array_of_bytes){ @@ -461,7 +462,8 @@ main(int argc, char **argv) std::ifstream is_model( model, std::ios::binary ); std::stringstream buffer_model; buffer_model << is_model.rdbuf(); - status = DS_CreateModelFromBuffer(buffer_model.str(), &ctx); + buffer_model_str = buffer_model.str(); + status = DS_CreateModelFromBuffer(buffer_model_str.c_str(), buffer_model_str.size(), &ctx); }else { // Keep old method due to backwards compatibility status = DS_CreateModel(model, &ctx); @@ -488,7 +490,8 @@ main(int argc, char **argv) std::ifstream is_scorer(scorer, std::ios::binary ); std::stringstream buffer_scorer; buffer_scorer << is_scorer.rdbuf(); - status = DS_EnableExternalScorerFromBuffer(ctx, buffer_scorer.str()); + std::string tmp_str_scorer = buffer_scorer.str(); + status = DS_EnableExternalScorerFromBuffer(ctx, tmp_str_scorer.c_str(), tmp_str_scorer.size()); } else { // Keep old method due to backwards compatibility status = DS_EnableExternalScorer(ctx, scorer); diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index d7f26ea0c2..32468f3a70 100755 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -263,9 +263,10 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) } int -DS_CreateModel_(const std::string &aModelString, +DS_CreateModel_(const char* aModelString, bool init_from_bytes, - ModelState** retval) + ModelState** retval, + size_t bufferSize=0) { *retval = nullptr; @@ -278,7 +279,7 @@ DS_CreateModel_(const std::string &aModelString, LOGD("DeepSpeech: %s", ds_git_version()); #endif - if (aModelString.length() < 1) { + if ( (!init_from_bytes && (strlen(aModelString) < 1)) || (init_from_bytes && (bufferSize<1))) { std::cerr << "No model specified, cannot continue." << std::endl; return DS_ERR_NO_MODEL; } @@ -296,7 +297,7 @@ DS_CreateModel_(const std::string &aModelString, return DS_ERR_FAIL_CREATE_MODEL; } - int err = model->init(aModelString, init_from_bytes); + int err = model->init(aModelString, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } @@ -313,10 +314,11 @@ DS_CreateModel(const char* aModelPath, } int -DS_CreateModelFromBuffer(const std::string &aModelBuffer, +DS_CreateModelFromBuffer(const char* aModelBuffer, + size_t bufferSize, ModelState** retval) { - return DS_CreateModel_(aModelBuffer, true, retval); + return DS_CreateModel_(aModelBuffer, true, retval, bufferSize); } @@ -347,12 +349,18 @@ DS_FreeModel(ModelState* ctx) int DS_EnableExternalScorer_(ModelState* aCtx, - const std::string &aScorerString, - bool init_from_bytes) + const char* aScorerString, + bool init_from_bytes, + size_t bufferSize=0) { std::unique_ptr scorer(new Scorer()); - int err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_); + int err; + if (init_from_bytes) + err = scorer->init(std::string(aScorerString, bufferSize), init_from_bytes, aCtx->alphabet_); + else + err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_); + if (err != 0) { return DS_ERR_INVALID_SCORER; @@ -370,9 +378,10 @@ DS_EnableExternalScorer(ModelState* aCtx, int DS_EnableExternalScorerFromBuffer(ModelState* aCtx, - const std::string &aScorerBuffer) + const char* aScorerBuffer, + size_t bufferSize) { - return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true); + return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true, bufferSize); } int diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 673e518e23..13170dcc6f 100755 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -76,6 +76,7 @@ typedef struct Metadata { APPLY(DS_ERR_SCORER_NO_TRIE, 0x2007, "Reached end of scorer file before loading vocabulary trie.") \ APPLY(DS_ERR_SCORER_INVALID_TRIE, 0x2008, "Invalid magic in trie header.") \ APPLY(DS_ERR_SCORER_VERSION_MISMATCH, 0x2009, "Scorer file version does not match expected version.") \ + APPLY(DS_ERR_MODEL_NOT_SUP_BUFFER, 0x2010, "Load from buffer does not support memorymaped models.") \ APPLY(DS_ERR_FAIL_INIT_MMAP, 0x3000, "Failed to initialize memory mapped model.") \ APPLY(DS_ERR_FAIL_INIT_SESS, 0x3001, "Failed to initialize the session.") \ APPLY(DS_ERR_FAIL_INTERPRETER, 0x3002, "Interpreter failed.") \ @@ -118,7 +119,8 @@ int DS_CreateModel(const char* aModelPath, * @return Zero on success, non-zero on failure. */ DEEPSPEECH_EXPORT -int DS_CreateModelFromBuffer(const std::string &aModelBuffer, +int DS_CreateModelFromBuffer(const char* aModelBuffer, + size_t bufferSize, ModelState** retval); @@ -185,7 +187,8 @@ int DS_EnableExternalScorer(ModelState* aCtx, */ DEEPSPEECH_EXPORT int DS_EnableExternalScorerFromBuffer(ModelState* aCtx, - const std::string &aScorerBuffer); + const char* aScorerBuffer, + size_t bufferSize); /** diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index 5f4ce2e274..dcbc904d8d 100755 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -24,7 +24,7 @@ ModelState::~ModelState() } int -ModelState::init(const std::string &model_string, bool init_from_bytes) +ModelState::init(const char* model_string, bool init_from_bytes, size_t bufferSize) { return DS_ERR_OK; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 166edff00e..41fe817e71 100755 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -31,7 +31,7 @@ struct ModelState { ModelState(); virtual ~ModelState(); - virtual int init(const std::string &model_string, bool init_from_bytes); + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize); virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) = 0; diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc index b72b1fa1ae..5823b3268e 100755 --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -156,24 +156,21 @@ getTfliteDelegates() } int -TFLiteModelState::init(const std::string &model_string, bool init_from_bytes) +TFLiteModelState::init(const char *model_string, bool init_from_bytes, size_t bufferSize) { - int err = ModelState::init(model_string, init_from_bytes); + int err = ModelState::init(model_string, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } - + if (init_from_bytes){ - char *tmp_buffer = new char[model_string.size()]; - std::copy(model_string.begin(), model_string.end(), tmp_buffer); - // Using c_str does not work - fbmodel_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(tmp_buffer,model_string.size()); + fbmodel_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(model_string, bufferSize); if (!fbmodel_) { std::cerr << "Error at reading model buffer " << std::endl; return DS_ERR_FAIL_INIT_MMAP; } } else { - fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_string.c_str()); + fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_string); if (!fbmodel_) { std::cerr << "Error at reading model file " << model_string << std::endl; return DS_ERR_FAIL_INIT_MMAP; @@ -334,7 +331,6 @@ TFLiteModelState::init(const std::string &model_string, bool init_from_bytes) assert(dims_c->data[1] == dims_h->data[1]); assert(state_size_ > 0); state_size_ = dims_c->data[1]; - return DS_ERR_OK; } diff --git a/native_client/tflitemodelstate.h b/native_client/tflitemodelstate.h index 38de5b9d7d..d2e95966e4 100755 --- a/native_client/tflitemodelstate.h +++ b/native_client/tflitemodelstate.h @@ -31,7 +31,7 @@ struct TFLiteModelState : public ModelState TFLiteModelState(); virtual ~TFLiteModelState(); - virtual int init(const std::string &model_string, bool init_from_bytes) override; + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize) override; virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) override; diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc index 7ec8a76777..d2871e21d8 100755 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -22,10 +22,11 @@ TFModelState::~TFModelState() } } -int loadGraphFromBinaryData(Env* env, const std::string& data, +int loadGraphFromBinaryData(Env* env, const char* data, size_t bufferSize, ::tensorflow::protobuf::MessageLite* proto) { - if (!proto->ParseFromString(data)) { + std::string dataString(data, bufferSize); + if (!proto->ParseFromString(dataString)) { std::cerr << "Can't parse data as binary proto" << std::endl; return -1; } @@ -33,9 +34,9 @@ int loadGraphFromBinaryData(Env* env, const std::string& data, } int -TFModelState::init(const std::string &model_string, bool init_from_bytes) +TFModelState::init(const char* model_string, bool init_from_bytes, size_t bufferSize) { - int err = ModelState::init(model_string, init_from_bytes); + int err = ModelState::init(model_string, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } @@ -46,16 +47,16 @@ TFModelState::init(const std::string &model_string, bool init_from_bytes) mmap_env_.reset(new MemmappedEnv(Env::Default())); bool is_mmap = false; if (init_from_bytes) { - int loadGraphStatus = loadGraphFromBinaryData(Env::Default(), model_string, &graph_def_); + int loadGraphStatus = loadGraphFromBinaryData(mmap_env_.get(), model_string, bufferSize, &graph_def_); if (loadGraphStatus != 0) { return DS_ERR_FAIL_CREATE_SESS; } } else { - is_mmap = model_string.find(".pbmm") != std::string::npos; + is_mmap = std::string(model_string).find(".pbmm") != std::string::npos; if (!is_mmap) { std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl; } else { - status = mmap_env_->InitializeFromFile(model_string.c_str()); + status = mmap_env_->InitializeFromFile(model_string); if (!status.ok()) { std::cerr << status << std::endl; return DS_ERR_FAIL_INIT_MMAP; @@ -77,14 +78,17 @@ TFModelState::init(const std::string &model_string, bool init_from_bytes) session_.reset(session); if (init_from_bytes){ - // Need some help + if( is_mmap) { + std::cerr << "Load from buffer does not support .pbmm models." << std::endl; + return DS_ERR_MODEL_NOT_SUP_BUFFER; + } } else { if (is_mmap) { status = ReadBinaryProto(mmap_env_.get(), MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, &graph_def_); } else { - status = ReadBinaryProto(Env::Default(), model_string.c_str(), &graph_def_); + status = ReadBinaryProto(Env::Default(), model_string, &graph_def_); } } diff --git a/native_client/tfmodelstate.h b/native_client/tfmodelstate.h index 85628acf3c..0308b94338 100755 --- a/native_client/tfmodelstate.h +++ b/native_client/tfmodelstate.h @@ -18,7 +18,7 @@ struct TFModelState : public ModelState TFModelState(); virtual ~TFModelState(); - virtual int init(const std::string &model_string, bool init_from_bytes) override; + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize) override; virtual void infer(const std::vector& mfcc, unsigned int n_frames,