From feb33f8c0e3cb43736a39ce04a84a9a241f79cc5 Mon Sep 17 00:00:00 2001 From: daniele Date: Wed, 9 Sep 2020 13:45:49 -0300 Subject: [PATCH] Loading model from both path or array of bytes --- native_client/Makefile | 2 +- native_client/alphabet.cc | 0 native_client/alphabet.h | 0 native_client/args.h | 8 ++ native_client/client.cc | 30 ++++++- native_client/ctcdecode/scorer.cpp | 85 +++++++++++-------- native_client/ctcdecode/scorer.h | 6 +- native_client/deepspeech.cc | 56 +++++++++++-- native_client/deepspeech.h | 35 +++++++- native_client/kenlm/lm/bhiksha.cc | 8 +- native_client/kenlm/lm/bhiksha.hh | 3 +- native_client/kenlm/lm/binary_format.cc | 104 +++++++++++++++++++++++- native_client/kenlm/lm/binary_format.hh | 12 ++- native_client/kenlm/lm/model.cc | 64 ++++++++++++++- native_client/kenlm/lm/model.hh | 14 +++- native_client/kenlm/lm/quantize.cc | 10 ++- native_client/kenlm/lm/quantize.hh | 4 +- native_client/kenlm/lm/search_hashed.hh | 2 +- native_client/kenlm/lm/search_trie.hh | 7 +- native_client/kenlm/lm/vocab.cc | 52 ++++++++++++ native_client/kenlm/lm/vocab.hh | 3 + native_client/kenlm/util/file.cc | 37 +++++++++ native_client/kenlm/util/file.hh | 1 + native_client/kenlm/util/mmap.cc | 82 ++++++++++++++++++- native_client/kenlm/util/mmap.hh | 22 +++-- native_client/modelstate.cc | 2 +- native_client/modelstate.h | 2 +- native_client/tflitemodelstate.cc | 30 +++++-- native_client/tflitemodelstate.h | 2 +- native_client/tfmodelstate.cc | 66 ++++++++++----- native_client/tfmodelstate.h | 2 +- 31 files changed, 647 insertions(+), 104 deletions(-) mode change 100644 => 100755 native_client/Makefile mode change 100644 => 100755 native_client/alphabet.cc mode change 100644 => 100755 native_client/alphabet.h mode change 100644 => 100755 native_client/args.h mode change 100644 => 100755 native_client/client.cc mode change 100644 => 100755 native_client/deepspeech.cc mode change 100644 => 100755 native_client/deepspeech.h mode change 100644 => 100755 native_client/modelstate.cc mode change 100644 => 100755 native_client/modelstate.h mode change 100644 => 100755 native_client/tflitemodelstate.cc mode change 100644 => 100755 native_client/tflitemodelstate.h mode change 100644 => 100755 native_client/tfmodelstate.cc mode change 100644 => 100755 native_client/tfmodelstate.h diff --git a/native_client/Makefile b/native_client/Makefile old mode 100644 new mode 100755 index b645499c28..7115f8ee04 --- a/native_client/Makefile +++ b/native_client/Makefile @@ -19,7 +19,7 @@ clean: rm -f deepspeech $(DEEPSPEECH_BIN): client.cc Makefile - $(CXX) $(CFLAGS) $(CFLAGS_DEEPSPEECH) $(SOX_CFLAGS) client.cc $(LDFLAGS) $(SOX_LDFLAGS) + $(CXX) $(CFLAGS) $(CFLAGS_DEEPSPEECH) $(SOX_CFLAGS) client.cc $(LDFLAGS) $(SOX_LDFLAGS) -llzma -lbz2 ifeq ($(OS),Darwin) install_name_tool -change bazel-out/local-opt/bin/native_client/libdeepspeech.so @rpath/libdeepspeech.so deepspeech endif diff --git a/native_client/alphabet.cc b/native_client/alphabet.cc old mode 100644 new mode 100755 diff --git a/native_client/alphabet.h b/native_client/alphabet.h old mode 100644 new mode 100755 diff --git a/native_client/args.h b/native_client/args.h old mode 100644 new mode 100755 index 856988dd4e..e4cafc6d4b --- a/native_client/args.h +++ b/native_client/args.h @@ -34,6 +34,8 @@ bool extended_metadata = false; bool json_output = false; +bool init_from_array_of_bytes = false; + int json_candidate_transcripts = 3; int stream_size = 0; @@ -59,6 +61,7 @@ void PrintHelp(const char* bin) "\t--candidate_transcripts NUMBER\tNumber of candidate transcripts to include in JSON output\n" "\t--stream size\t\t\tRun in stream mode, output intermediate results\n" "\t--hot_words\t\t\tHot-words and their boosts. Word:Boost pairs are comma-separated\n" + "\t--init_from_bytes\t\tTest init model and scorer from array of bytes\n" "\t--help\t\t\t\tShow help\n" "\t--version\t\t\tPrint version and exits\n"; char* version = DS_Version(); @@ -80,6 +83,7 @@ bool ProcessArgs(int argc, char** argv) {"t", no_argument, nullptr, 't'}, {"extended", no_argument, nullptr, 'e'}, {"json", no_argument, nullptr, 'j'}, + {"init_from_bytes", no_argument, nullptr, 'B'}, {"candidate_transcripts", required_argument, nullptr, 150}, {"stream", required_argument, nullptr, 's'}, {"hot_words", required_argument, nullptr, 'w'}, @@ -135,6 +139,10 @@ bool ProcessArgs(int argc, char** argv) case 'j': json_output = true; break; + + case 'B': + init_from_array_of_bytes = true; + break; case 150: json_candidate_transcripts = atoi(optarg); diff --git a/native_client/client.cc b/native_client/client.cc old mode 100644 new mode 100755 index 96e1ff3999..b88963f6f8 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -33,6 +33,8 @@ #include #endif // NO_DIR #include +#include +#include #include "deepspeech.h" #include "args.h" @@ -415,8 +417,21 @@ main(int argc, char **argv) // Initialise DeepSpeech ModelState* ctx; + std::string buffer_model_str; // sphinx-doc: c_ref_model_start - int status = DS_CreateModel(model, &ctx); + int status; + if (init_from_array_of_bytes){ + // Reading model file to a char * buffer + std::ifstream is_model( model, std::ios::binary ); + std::stringstream buffer_model; + buffer_model << is_model.rdbuf(); + buffer_model_str = buffer_model.str(); + status = DS_CreateModelFromBuffer(buffer_model_str.c_str(), buffer_model_str.size(), &ctx); + }else { + // Keep old method due to backwards compatibility + status = DS_CreateModel(model, &ctx); + } + if (status != 0) { char* error = DS_ErrorCodeToErrorMessage(status); fprintf(stderr, "Could not create model: %s\n", error); @@ -433,7 +448,18 @@ main(int argc, char **argv) } if (scorer) { - status = DS_EnableExternalScorer(ctx, scorer); + if (init_from_array_of_bytes){ + // Reading scorer file to a string buffer + std::ifstream is_scorer(scorer, std::ios::binary ); + std::stringstream buffer_scorer; + buffer_scorer << is_scorer.rdbuf(); + std::string tmp_str_scorer = buffer_scorer.str(); + status = DS_EnableExternalScorerFromBuffer(ctx, tmp_str_scorer.c_str(), tmp_str_scorer.size()); + } else { + // Keep old method due to backwards compatibility + status = DS_EnableExternalScorer(ctx, scorer); + } + if (status != 0) { fprintf(stderr, "Could not enable external scorer.\n"); return 1; diff --git a/native_client/ctcdecode/scorer.cpp b/native_client/ctcdecode/scorer.cpp index 5f25a33596..7fc044b527 100644 --- a/native_client/ctcdecode/scorer.cpp +++ b/native_client/ctcdecode/scorer.cpp @@ -29,22 +29,24 @@ static const int32_t FILE_VERSION = 6; int Scorer::init(const std::string& lm_path, + bool load_from_bytes, const Alphabet& alphabet) { set_alphabet(alphabet); - return load_lm(lm_path); + return load_lm(lm_path, load_from_bytes); } int Scorer::init(const std::string& lm_path, + bool load_from_bytes, const std::string& alphabet_config_path) { - int err = alphabet_.init(alphabet_config_path.c_str()); + int err = alphabet_.init(alphabet_config_path.c_str()); // Do we need to make this initiable from bytes? if (err != 0) { return err; } setup_char_map(); - return load_lm(lm_path); + return load_lm(lm_path, load_from_bytes); } void @@ -69,45 +71,60 @@ void Scorer::setup_char_map() } } -int Scorer::load_lm(const std::string& lm_path) +int Scorer::load_lm(const std::string& lm_string, bool load_from_bytes) { - // Check if file is readable to avoid KenLM throwing an exception - const char* filename = lm_path.c_str(); - if (access(filename, R_OK) != 0) { - return DS_ERR_SCORER_UNREADABLE; - } - - // Check if the file format is valid to avoid KenLM throwing an exception - lm::ngram::ModelType model_type; - if (!lm::ngram::RecognizeBinary(filename, model_type)) { - return DS_ERR_SCORER_INVALID_LM; + if (!load_from_bytes){ + // Check if file is readable to avoid KenLM throwing an exception + const char* filename = lm_string.c_str(); + if (access(filename, R_OK) != 0) { + return DS_ERR_SCORER_UNREADABLE; + } + + // Check if the file format is valid to avoid KenLM throwing an exception + lm::ngram::ModelType model_type; + if (!lm::ngram::RecognizeBinary(filename, model_type)) { + return DS_ERR_SCORER_INVALID_LM; + } } // Load the LM lm::ngram::Config config; config.load_method = util::LoadMethod::LAZY; - language_model_.reset(lm::ngram::LoadVirtual(filename, config)); - max_order_ = language_model_->Order(); - - uint64_t package_size; - { - util::scoped_fd fd(util::OpenReadOrThrow(filename)); - package_size = util::SizeFile(fd.get()); + if (load_from_bytes){ + language_model_.reset(lm::ngram::LoadVirtual(lm_string.c_str(), lm_string.size(), config)); + } else { + language_model_.reset(lm::ngram::LoadVirtual(lm_string.c_str(), config)); } + + max_order_ = language_model_->Order(); + std::stringstream stst; uint64_t trie_offset = language_model_->GetEndOfSearchOffset(); - if (package_size <= trie_offset) { - // File ends without a trie structure - return DS_ERR_SCORER_NO_TRIE; + + if (!load_from_bytes){ + uint64_t package_size; + { + util::scoped_fd fd(util::OpenReadOrThrow(lm_string.c_str())); + package_size = util::SizeFile(fd.get()); + } + + if (package_size <= trie_offset) { + // File ends without a trie structure + return DS_ERR_SCORER_NO_TRIE; + } + // Read metadata and trie from file + std::ifstream fin(lm_string.c_str(), std::ios::binary); + stst<(&magic), sizeof(magic)); if (magic != MAGIC) { @@ -140,9 +157,13 @@ int Scorer::load_trie(std::ifstream& fin, const std::string& file_path) reset_params(alpha, beta); fst::FstReadOptions opt; - opt.mode = fst::FstReadOptions::MAP; - opt.source = file_path; - dictionary.reset(FstType::Read(fin, opt)); + if (load_from_bytes) { + dictionary.reset(fst::ConstFst::Read(fin, opt)); + } else { + opt.mode = fst::FstReadOptions::MAP; + opt.source = file_path; + dictionary.reset(FstType::Read(fin, opt)); + } return DS_ERR_OK; } diff --git a/native_client/ctcdecode/scorer.h b/native_client/ctcdecode/scorer.h index 5aee1046ff..08713e4e83 100644 --- a/native_client/ctcdecode/scorer.h +++ b/native_client/ctcdecode/scorer.h @@ -39,9 +39,11 @@ class Scorer { Scorer& operator=(const Scorer&) = delete; int init(const std::string &lm_path, + bool load_from_bytes, const Alphabet &alphabet); int init(const std::string &lm_path, + bool load_from_bytes, const std::string &alphabet_config_path); double get_log_cond_prob(const std::vector &words, @@ -84,7 +86,7 @@ class Scorer { void fill_dictionary(const std::unordered_set &vocabulary); // load language model from given path - int load_lm(const std::string &lm_path); + int load_lm(const std::string &lm_path, bool load_from_bytes=false); // language model weight double alpha = 0.; @@ -98,7 +100,7 @@ class Scorer { // necessary setup after setting alphabet void setup_char_map(); - int load_trie(std::ifstream& fin, const std::string& file_path); + int load_trie(std::stringstream& fin, const std::string& file_path, bool load_from_bytes=false); private: std::unique_ptr language_model_; diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc old mode 100644 new mode 100755 index 57f77ba119..72a8d1def1 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -263,8 +263,9 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) } int -DS_CreateModel(const char* aModelPath, - ModelState** retval) +DS_CreateModel_(const std::string &aModelString, + bool init_from_bytes, + ModelState** retval) { *retval = nullptr; @@ -277,7 +278,7 @@ DS_CreateModel(const char* aModelPath, LOGD("DeepSpeech: %s", ds_git_version()); #endif - if (!aModelPath || strlen(aModelPath) < 1) { + if ( (!init_from_bytes && (strlen(aModelString) < 1)) || (init_from_bytes && (bufferSize<1))) { std::cerr << "No model specified, cannot continue." << std::endl; return DS_ERR_NO_MODEL; } @@ -294,8 +295,8 @@ DS_CreateModel(const char* aModelPath, std::cerr << "Could not allocate model state." << std::endl; return DS_ERR_FAIL_CREATE_MODEL; } - - int err = model->init(aModelPath); + + int err = model->init(aModelString, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } @@ -304,6 +305,22 @@ DS_CreateModel(const char* aModelPath, return DS_ERR_OK; } +int +DS_CreateModel(const char* aModelPath, + ModelState** retval) +{ + return DS_CreateModel_(aModelPath, false, retval); +} + +int +DS_CreateModelFromBuffer(const char* aModelBuffer, + size_t bufferSize, + ModelState** retval) +{ + return DS_CreateModel_(aModelBuffer, true, retval, bufferSize); +} + + unsigned int DS_GetModelBeamWidth(const ModelState* aCtx) { @@ -330,11 +347,19 @@ DS_FreeModel(ModelState* ctx) } int -DS_EnableExternalScorer(ModelState* aCtx, - const char* aScorerPath) +DS_EnableExternalScorer_(ModelState* aCtx, + const std::string &aScorerString, + bool init_from_bytes) { std::unique_ptr scorer(new Scorer()); - int err = scorer->init(aScorerPath, aCtx->alphabet_); + + int err; + if (init_from_bytes) + err = scorer->init(std::string(aScorerString, bufferSize), init_from_bytes, aCtx->alphabet_); + else + err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_); + + if (err != 0) { return DS_ERR_INVALID_SCORER; } @@ -342,6 +367,21 @@ DS_EnableExternalScorer(ModelState* aCtx, return DS_ERR_OK; } +int +DS_EnableExternalScorer(ModelState* aCtx, + const char* aScorerPath) +{ + return DS_EnableExternalScorer_(aCtx, aScorerPath, false); +} + +int +DS_EnableExternalScorerFromBuffer(ModelState* aCtx, + const char* aScorerBuffer, + size_t bufferSize) +{ + return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true, bufferSize); +} + int DS_AddHotWord(ModelState* aCtx, const char* word, diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h old mode 100644 new mode 100755 index 35e9289a2e..13170dcc6f --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -1,6 +1,8 @@ #ifndef DEEPSPEECH_H #define DEEPSPEECH_H +#include + #ifdef __cplusplus extern "C" { #endif @@ -74,6 +76,7 @@ typedef struct Metadata { APPLY(DS_ERR_SCORER_NO_TRIE, 0x2007, "Reached end of scorer file before loading vocabulary trie.") \ APPLY(DS_ERR_SCORER_INVALID_TRIE, 0x2008, "Invalid magic in trie header.") \ APPLY(DS_ERR_SCORER_VERSION_MISMATCH, 0x2009, "Scorer file version does not match expected version.") \ + APPLY(DS_ERR_MODEL_NOT_SUP_BUFFER, 0x2010, "Load from buffer does not support memorymaped models.") \ APPLY(DS_ERR_FAIL_INIT_MMAP, 0x3000, "Failed to initialize memory mapped model.") \ APPLY(DS_ERR_FAIL_INIT_SESS, 0x3001, "Failed to initialize the session.") \ APPLY(DS_ERR_FAIL_INTERPRETER, 0x3002, "Interpreter failed.") \ @@ -96,7 +99,7 @@ DS_FOR_EACH_ERROR(DEFINE) }; /** - * @brief An object providing an interface to a trained DeepSpeech model. + * @brief An object providing an interface to a trained DeepSpeech model. Maintained for not breaking backwards compatibility * * @param aModelPath The path to the frozen model graph. * @param[out] retval a ModelState pointer @@ -107,6 +110,20 @@ DEEPSPEECH_EXPORT int DS_CreateModel(const char* aModelPath, ModelState** retval); +/** + * @brief An object providing an interface to a trained DeepSpeech model, loaded from buffer. + * + * @param aModelBuffer The buffer containing the content of frozen model graph. + * @param[out] retval a ModelState pointer + * + * @return Zero on success, non-zero on failure. + */ +DEEPSPEECH_EXPORT +int DS_CreateModelFromBuffer(const char* aModelBuffer, + size_t bufferSize, + ModelState** retval); + + /** * @brief Get beam width value used by the model. If {@link DS_SetModelBeamWidth} * was not called before, will return the default value loaded from the @@ -160,6 +177,20 @@ DEEPSPEECH_EXPORT int DS_EnableExternalScorer(ModelState* aCtx, const char* aScorerPath); +/** + * @brief Enable decoding using an external scorer loaded from buffer. + * + * @param aCtx The ModelState pointer for the model being changed. + * @param aScorerBuffer The buffer containing the content of an external-scorer file. + * + * @return Zero on success, non-zero on failure (invalid arguments). + */ +DEEPSPEECH_EXPORT +int DS_EnableExternalScorerFromBuffer(ModelState* aCtx, + const char* aScorerBuffer, + size_t bufferSize); + + /** * @brief Add a hot-word and its boost. * @@ -196,6 +227,8 @@ int DS_EraseHotWord(ModelState* aCtx, DEEPSPEECH_EXPORT int DS_ClearHotWords(ModelState* aCtx); + + /** * @brief Disable decoding using an external scorer. * diff --git a/native_client/kenlm/lm/bhiksha.cc b/native_client/kenlm/lm/bhiksha.cc index 4262b615e5..e64c0a9276 100644 --- a/native_client/kenlm/lm/bhiksha.cc +++ b/native_client/kenlm/lm/bhiksha.cc @@ -17,9 +17,13 @@ DontBhiksha::DontBhiksha(const void * /*base*/, uint64_t /*max_offset*/, uint64_ const uint8_t kArrayBhikshaVersion = 0; // TODO: put this in binary file header instead when I change the binary file format again. -void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { +void ArrayBhiksha::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config, bool load_from_bytes) { uint8_t buffer[2]; - file.ReadForConfig(buffer, 2, offset); + if(load_from_bytes){ + file.ReadForConfig(buffer, 2, offset, load_from_bytes); + } else { + file.ReadForConfig(buffer, 2, offset); + } uint8_t version = buffer[0]; uint8_t configured_bits = buffer[1]; if (version != kArrayBhikshaVersion) UTIL_THROW(FormatLoadException, "This file has sorted array compression version " << (unsigned) version << " but the code expects version " << (unsigned)kArrayBhikshaVersion); diff --git a/native_client/kenlm/lm/bhiksha.hh b/native_client/kenlm/lm/bhiksha.hh index 36438f1d29..bc6eb48bff 100644 --- a/native_client/kenlm/lm/bhiksha.hh +++ b/native_client/kenlm/lm/bhiksha.hh @@ -34,6 +34,7 @@ class DontBhiksha { static const ModelType kModelTypeAdd = static_cast(0); static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &/*config*/) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &, bool) {} static uint64_t Size(uint64_t /*max_offset*/, uint64_t /*max_next*/, const Config &/*config*/) { return 0; } @@ -65,7 +66,7 @@ class ArrayBhiksha { public: static const ModelType kModelTypeAdd = kArrayAdd; - static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config, bool load_from_bytes); static uint64_t Size(uint64_t max_offset, uint64_t max_next, const Config &config); diff --git a/native_client/kenlm/lm/binary_format.cc b/native_client/kenlm/lm/binary_format.cc index 802943f572..ea64a8f97b 100644 --- a/native_client/kenlm/lm/binary_format.cc +++ b/native_client/kenlm/lm/binary_format.cc @@ -11,6 +11,7 @@ #include #include +#include namespace lm { namespace ngram { @@ -114,6 +115,48 @@ bool IsBinaryFormat(int fd) { return false; } +bool IsBinaryFormat(char *file_data, uint64_t size) { + char *file_data_temp = new char[size]; + memcpy(file_data_temp,file_data, size); + + if (size == util::kBadSize || (size <= static_cast(sizeof(Sanity)))) { + delete[] file_data_temp; + return false; + } + + // Try reading the header. + util::scoped_memory memory; + try { + util::MapRead(util::LAZY, file_data_temp, 0, sizeof(Sanity), memory); + } catch (const util::Exception &e) { + return false; + } + Sanity reference_header = Sanity(); + reference_header.SetToReference(); + + if (!std::memcmp(memory.get(), &reference_header, sizeof(Sanity))) { + return true; + } + if (!std::memcmp(memory.get(), kMagicIncomplete, strlen(kMagicIncomplete))) { + UTIL_THROW(FormatLoadException, "This binary file did not finish building"); + } + if (!std::memcmp(memory.get(), kMagicBeforeVersion, strlen(kMagicBeforeVersion))) { + char *end_ptr; + const char *begin_version = static_cast(memory.get()) + strlen(kMagicBeforeVersion); + long int version = std::strtol(begin_version, &end_ptr, 10); + if ((end_ptr != begin_version) && version != kMagicVersion) { + UTIL_THROW(FormatLoadException, "Binary file has version " << version << " but this implementation expects version " << kMagicVersion << " so you'll have to use the ARPA to rebuild your binary"); + } + OldSanity old_sanity = OldSanity(); + old_sanity.SetToReference(); + UTIL_THROW_IF(!std::memcmp(memory.get(), &old_sanity, sizeof(OldSanity)), FormatLoadException, "Looks like this is an old 32-bit format. The old 32-bit format has been removed so that 64-bit and 32-bit files are exchangeable."); + UTIL_THROW(FormatLoadException, "File looks like it should be loaded with mmap, but the test values don't match. Try rebuilding the binary format LM using the same code revision, compiler, and architecture"); + } + return false; +} + + + void ReadHeader(int fd, Parameters &out) { util::SeekOrThrow(fd, sizeof(Sanity)); util::ReadOrThrow(fd, &out.fixed, sizeof(out.fixed)); @@ -124,6 +167,23 @@ void ReadHeader(int fd, Parameters &out) { if (out.fixed.order) util::ReadOrThrow(fd, &*out.counts.begin(), sizeof(uint64_t) * out.fixed.order); } +void ReadHeader(char *file_data, Parameters &out) { + const char *file_data_tmp = file_data; + file_data_tmp += sizeof(Sanity); + std::memcpy(&out.fixed, file_data_tmp, sizeof(out.fixed)); + file_data_tmp += sizeof(out.fixed); + + if (out.fixed.probing_multiplier < 1.0) + UTIL_THROW(FormatLoadException, "Binary format claims to have a probing multiplier of " << out.fixed.probing_multiplier << " which is < 1.0."); + + out.counts.resize(static_cast(out.fixed.order)); + + if (out.fixed.order) { + std::memcpy(&*out.counts.begin(), file_data_tmp, sizeof(uint64_t) * out.fixed.order); + } +} + + void MatchCheck(ModelType model_type, unsigned int search_version, const Parameters ¶ms) { if (params.fixed.model_type != model_type) { if (static_cast(params.fixed.model_type) >= (sizeof(kModelNames) / sizeof(const char *))) @@ -147,11 +207,26 @@ void BinaryFormat::InitializeBinary(int fd, ModelType model_type, unsigned int s header_size_ = TotalHeaderSize(params.counts.size()); } +void BinaryFormat::InitializeBinary(char *file_data, ModelType model_type, unsigned int search_version, Parameters ¶ms) { + file_data_ = file_data; + write_mmap_ = NULL; // Ignore write requests; this is already in binary format. + ReadHeader(file_data, params); + MatchCheck(model_type, search_version, params); + header_size_ = TotalHeaderSize(params.counts.size()); +} + + void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const { assert(header_size_ != kInvalidSize); util::ErsatzPRead(file_.get(), to, amount, offset_excluding_header + header_size_); } +void BinaryFormat::ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header, bool load_from_memory) const { + assert(header_size_ != kInvalidSize); + util::ErsatzPRead(file_data_, to, amount, offset_excluding_header + header_size_); +} + + void *BinaryFormat::LoadBinary(std::size_t size) { assert(header_size_ != kInvalidSize); const uint64_t file_size = util::SizeFile(file_.get()); @@ -165,6 +240,17 @@ void *BinaryFormat::LoadBinary(std::size_t size) { return reinterpret_cast(mapping_.get()) + header_size_; } +void *BinaryFormat::LoadBinary(std::size_t size, const uint64_t file_size) { /* Loading the binary from memory */ + assert(header_size_ != kInvalidSize); + // The header is smaller than a page, so we have to map the whole header as well. + uint64_t total_map = static_cast(header_size_) + static_cast(size); + UTIL_THROW_IF(file_size != util::kBadSize && file_size < total_map, FormatLoadException, "Binary file has size " << file_size << " but the headers say it should be at least " << total_map); + util::MapRead(load_method_, file_data_, 0, util::CheckOverflow(total_map), mapping_); + vocab_string_offset_ = total_map; + return reinterpret_cast(mapping_.get()) + header_size_; +} + + void *BinaryFormat::SetupJustVocab(std::size_t memory_size, uint8_t order) { vocab_size_ = memory_size; if (!write_mmap_) { @@ -282,7 +368,7 @@ void BinaryFormat::FinishFile(const Config &config, ModelType model_type, unsign } void BinaryFormat::MapFile(void *&vocab_base, void *&search_base) { - mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED); + mapping_.reset(util::MapOrThrow(vocab_string_offset_, true, util::kFileFlags, false, (int) file_.get()), vocab_string_offset_, util::scoped_memory::MMAP_ALLOCATED); vocab_base = reinterpret_cast(mapping_.get()) + header_size_; search_base = reinterpret_cast(mapping_.get()) + header_size_ + vocab_size_ + vocab_pad_; } @@ -298,5 +384,21 @@ bool RecognizeBinary(const char *file, ModelType &recognized) { return true; } +bool RecognizeBinary(const char *file_data, const uint64_t file_data_size, ModelType &recognized) { + + char *file_data_temp = new char[file_data_size]; + memcpy(file_data_temp, file_data, file_data_size); + + if (!IsBinaryFormat(file_data_temp, file_data_size)){ + return false; + } + + Parameters params; + ReadHeader(file_data_temp, params); + recognized = params.fixed.model_type; + return true; +} + + } // namespace ngram } // namespace lm diff --git a/native_client/kenlm/lm/binary_format.hh b/native_client/kenlm/lm/binary_format.hh index ff99b95741..46fa8d0e0c 100644 --- a/native_client/kenlm/lm/binary_format.hh +++ b/native_client/kenlm/lm/binary_format.hh @@ -24,6 +24,7 @@ extern const char *kModelNames[6]; * this header designed for use by decoder authors. */ bool RecognizeBinary(const char *file, ModelType &recognized); +bool RecognizeBinary(const char *file_data, const uint64_t file_data_size, ModelType &recognized); struct FixedWidthParameters { unsigned char order; @@ -48,13 +49,20 @@ class BinaryFormat { public: explicit BinaryFormat(const Config &config); + ~BinaryFormat(){ + file_data_ = NULL; + } + // Reading a binary file: // Takes ownership of fd void InitializeBinary(int fd, ModelType model_type, unsigned int search_version, Parameters ¶ms); + void InitializeBinary(char *file_data, ModelType model_type, unsigned int search_version, Parameters ¶ms); // Used to read parts of the file to update the config object before figuring out full size. void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header) const; + void ReadForConfig(void *to, std::size_t amount, uint64_t offset_excluding_header, bool useMemory) const; // Actually load the binary file and return a pointer to the beginning of the search area. void *LoadBinary(std::size_t size); + void *LoadBinary(std::size_t size, const uint64_t file_size); uint64_t VocabStringReadingOffset() const { assert(vocab_string_offset_ != kInvalidOffset); @@ -81,9 +89,10 @@ class BinaryFormat { // File behind memory, if any. util::scoped_fd file_; + char *file_data_; // If there is a file involved, a single mapping. - util::scoped_memory mapping_; + util::scoped_memory mapping_= new util::scoped_memory(true); // If the data is only in memory, separately allocate each because the trie // knows vocab's size before it knows search's size (because SRILM might @@ -100,6 +109,7 @@ class BinaryFormat { }; bool IsBinaryFormat(int fd); +bool IsBinaryFormat(char *file_data, uint64_t size); } // namespace ngram } // namespace lm diff --git a/native_client/kenlm/lm/model.cc b/native_client/kenlm/lm/model.cc index fc4e374c88..758cc05198 100644 --- a/native_client/kenlm/lm/model.cc +++ b/native_client/kenlm/lm/model.cc @@ -66,7 +66,7 @@ template GenericModel::Ge Config new_config(init_config); new_config.probing_multiplier = parameters.fixed.probing_multiplier; - Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config); + Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config, false); UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config)), parameters.counts, new_config); @@ -89,6 +89,48 @@ template GenericModel::Ge P::Init(begin_sentence, null_context, vocab_, search_.Order()); } +template GenericModel::GenericModel(const char *file_data, const uint64_t file_data_size, const Config &init_config) : backing_(init_config) { + char *file_data_temp = new char[file_data_size]; + memcpy(file_data_temp, file_data, file_data_size); + + if (IsBinaryFormat(file_data_temp, file_data_size)) { + Parameters parameters; + backing_.InitializeBinary(file_data_temp, kModelType, kVersion, parameters); + CheckCounts(parameters.counts); + + Config new_config(init_config); + new_config.probing_multiplier = parameters.fixed.probing_multiplier; + Search::UpdateConfigFromBinary(backing_, parameters.counts, VocabularyT::Size(parameters.counts[0], new_config), new_config, true); + + UTIL_THROW_IF(new_config.enumerate_vocab && !parameters.fixed.has_vocabulary, FormatLoadException, "The decoder requested all the vocabulary strings, but this binary file does not have them. You may need to rebuild the binary file with an updated version of build_binary."); + + SetupMemory(backing_.LoadBinary(Size(parameters.counts, new_config), file_data_size), parameters.counts, new_config); + + vocab_.LoadedBinary(parameters.fixed.has_vocabulary, file_data_temp, new_config.enumerate_vocab, backing_.VocabStringReadingOffset(), true); + + delete[] file_data_temp; + } else { + std::cerr << "Fatal error: Not binary!" << std::endl; + delete[] file_data_temp; + return; + } + // g++ prints warnings unless these are fully initialized. + State begin_sentence = State(); + + begin_sentence.length = 1; + begin_sentence.words[0] = vocab_.BeginSentence(); + typename Search::Node ignored_node; + bool ignored_independent_left; + uint64_t ignored_extend_left; + + begin_sentence.backoff[0] = search_.LookupUnigram(begin_sentence.words[0], ignored_node, ignored_independent_left, ignored_extend_left).Backoff(); + + State null_context = State(); + null_context.length = 0; + P::Init(begin_sentence, null_context, vocab_, search_.Order()); +} + + template void GenericModel::InitializeFromARPA(int fd, const char *file, const Config &config) { // Backing file is the ARPA. util::FilePiece f(fd, file, config.ProgressMessages()); @@ -348,6 +390,26 @@ base::Model *LoadVirtual(const char *file_name, const Config &config, ModelType UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); } } +base::Model *LoadVirtual(const char *file_data, const uint64_t file_data_size, const Config &config, ModelType model_type) { + RecognizeBinary(file_data, file_data_size, model_type); + switch (model_type) { + case PROBING: + UTIL_THROW(FormatLoadException, "Probing without memory option " << model_type); + case REST_PROBING: + UTIL_THROW(FormatLoadException, "Rest Probing without memory option " << model_type); + case TRIE: + UTIL_THROW(FormatLoadException, "Trie without memory option " << model_type); + case QUANT_TRIE: + UTIL_THROW(FormatLoadException, "Quant Trie without memory option " << model_type); + case ARRAY_TRIE: + UTIL_THROW(FormatLoadException, "Array Trie without memory option " << model_type); + case QUANT_ARRAY_TRIE: + return new QuantArrayTrieModelMemory(file_data, file_data_size, config); + default: + UTIL_THROW(FormatLoadException, "Confused by model type " << model_type); + } +} + } // namespace ngram } // namespace lm diff --git a/native_client/kenlm/lm/model.hh b/native_client/kenlm/lm/model.hh index 9b7206e8d7..8c39b27a2f 100644 --- a/native_client/kenlm/lm/model.hh +++ b/native_client/kenlm/lm/model.hh @@ -49,7 +49,7 @@ template class GenericModel : public base::Mod * lm/binary_format.hh. */ explicit GenericModel(const char *file, const Config &config = Config()); - + explicit GenericModel(const char *file_data, const uint64_t file_data_size, const Config &config = Config()); /* Score p(new_word | in_state) and incorporate new_word into out_state. * Note that in_state and out_state must be different references: * &in_state != &out_state. @@ -133,23 +133,33 @@ template class GenericModel : public base::Mod class name : public from {\ public:\ name(const char *file, const Config &config = Config()) : from(file, config) {}\ + }; + +#define LM_NAME_MODEL_FROM_MEMORY(name, from)\ +class name : public from {\ + public:\ + name(const char *file, const Config &config = Config()) : from(file, config) {}\ + name(const char *file_data, size_t file_data_size, const Config &config = Config()) : from(file_data, file_data_size, config) {}\ }; + LM_NAME_MODEL(ProbingModel, detail::GenericModel LM_COMMA() ProbingVocabulary>); LM_NAME_MODEL(RestProbingModel, detail::GenericModel LM_COMMA() ProbingVocabulary>); LM_NAME_MODEL(TrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); LM_NAME_MODEL(QuantTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel LM_COMMA() SortedVocabulary>); +LM_NAME_MODEL_FROM_MEMORY(QuantArrayTrieModelMemory, detail::GenericModel LM_COMMA() SortedVocabulary>); // Default implementation. No real reason for it to be the default. typedef ::lm::ngram::ProbingVocabulary Vocabulary; typedef ProbingModel Model; - +typedef QuantArrayTrieModelMemory ModelMemory; /* Autorecognize the file type, load, and return the virtual base class. Don't * use the virtual base class if you can avoid it. Instead, use the above * classes as template arguments to your own virtual feature function.*/ base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING); +base::Model *LoadVirtual(const char *file_data, const uint64_t file_data_size, const Config &config = Config(), ModelType if_arpa = PROBING); } // namespace ngram } // namespace lm diff --git a/native_client/kenlm/lm/quantize.cc b/native_client/kenlm/lm/quantize.cc index 02b5dbc0e0..99b00b6043 100644 --- a/native_client/kenlm/lm/quantize.cc +++ b/native_client/kenlm/lm/quantize.cc @@ -38,9 +38,15 @@ const char kSeparatelyQuantizeVersion = 2; } // namespace -void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config) { +void SeparatelyQuantize::UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config, bool load_from_memory) { unsigned char buffer[3]; - file.ReadForConfig(buffer, 3, offset); + if(load_from_memory){ + file.ReadForConfig(buffer, 3, offset, load_from_memory); + }else{ + file.ReadForConfig(buffer, 3, offset); + } + std::string strBuffer((char*)buffer,3); + char version = buffer[0]; config.prob_bits = buffer[1]; config.backoff_bits = buffer[2]; diff --git a/native_client/kenlm/lm/quantize.hh b/native_client/kenlm/lm/quantize.hh index 8500aceec0..457736eeaa 100644 --- a/native_client/kenlm/lm/quantize.hh +++ b/native_client/kenlm/lm/quantize.hh @@ -24,7 +24,7 @@ class BinaryFormat; class DontQuantize { public: static const ModelType kModelTypeAdd = static_cast(0); - static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &, bool) {} static uint64_t Size(uint8_t /*order*/, const Config &/*config*/) { return 0; } static uint8_t MiddleBits(const Config &/*config*/) { return 63; } static uint8_t LongestBits(const Config &/*config*/) { return 31; } @@ -137,7 +137,7 @@ class SeparatelyQuantize { public: static const ModelType kModelTypeAdd = kQuantAdd; - static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); + static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config, bool load_from_memory); static uint64_t Size(uint8_t order, const Config &config) { uint64_t longest_table = (static_cast(1) << static_cast(config.prob_bits)) * sizeof(float); diff --git a/native_client/kenlm/lm/search_hashed.hh b/native_client/kenlm/lm/search_hashed.hh index 9dc84454c9..42e051ffcd 100644 --- a/native_client/kenlm/lm/search_hashed.hh +++ b/native_client/kenlm/lm/search_hashed.hh @@ -72,7 +72,7 @@ template class HashedSearch { static const unsigned int kVersion = 0; // TODO: move probing_multiplier here with next binary file format update. - static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector &, uint64_t, Config &) {} + static void UpdateConfigFromBinary(const BinaryFormat &, const std::vector &, uint64_t, Config &, bool) {} static uint64_t Size(const std::vector &counts, const Config &config) { uint64_t ret = Unigram::Size(counts[0]); diff --git a/native_client/kenlm/lm/search_trie.hh b/native_client/kenlm/lm/search_trie.hh index 1adba6e5ea..95e8dafaba 100644 --- a/native_client/kenlm/lm/search_trie.hh +++ b/native_client/kenlm/lm/search_trie.hh @@ -38,11 +38,12 @@ template class TrieSearch { static const unsigned int kVersion = 1; - static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector &counts, uint64_t offset, Config &config) { - Quant::UpdateConfigFromBinary(file, offset, config); + static void UpdateConfigFromBinary(const BinaryFormat &file, const std::vector &counts, uint64_t offset, Config &config, bool load_from_memory) { + Quant::UpdateConfigFromBinary(file, offset, config, load_from_memory); + // Currently the unigram pointers are not compresssed, so there will only be a header for order > 2. if (counts.size() > 2) - Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config); + Bhiksha::UpdateConfigFromBinary(file, offset + Quant::Size(counts.size(), config) + Unigram::Size(counts[0]), config, load_from_memory); } static uint64_t Size(const std::vector &counts, const Config &config) { diff --git a/native_client/kenlm/lm/vocab.cc b/native_client/kenlm/lm/vocab.cc index 7996ec7e2b..ef2e78e3ee 100644 --- a/native_client/kenlm/lm/vocab.cc +++ b/native_client/kenlm/lm/vocab.cc @@ -14,6 +14,7 @@ #include #include +#include namespace lm { namespace ngram { @@ -51,6 +52,36 @@ void ReadWords(int fd, EnumerateVocab *enumerate, WordIndex expected_count, uint } UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file."); } +void ReadWords(char* file_data, EnumerateVocab *enumerate, WordIndex expected_count, uint64_t offset) { + const char *file_data_tmp = file_data; + file_data_tmp += offset; + + // Check that we're at the right place by reading which is always first. + char check_unk[6]; + std::memcpy(check_unk, file_data_tmp, 6); + file_data_tmp += 6; + + UTIL_THROW_IF( + memcmp(check_unk, "", 6), + FormatLoadException, + "Vocabulary words are in the wrong place. This could be because the binary file was built with stale gcc and old kenlm. Stale gcc, including the gcc distributed with RedHat and OS X, has a bug that ignores pragma pack for template-dependent types. New kenlm works around this, so you'll save memory but have to rebuild any binary files using the probing data structure."); + if (!enumerate) { + return; + } + enumerate->Add(0, ""); + + WordIndex index = 1; // Read already. + std::istringstream in(file_data_tmp); + + for (std::string line; std::getline(in, line); ) + { + // std::cerr << "LINHA -> " << line << std::endl; + enumerate->Add(index, line); + } + + UTIL_THROW_IF(expected_count != index, FormatLoadException, "The binary file has the wrong number of words at the end. This could be caused by a truncated binary file."); +} + // Constructor ordering madness. int SeekAndReturn(int fd, uint64_t start) { @@ -192,6 +223,16 @@ void SortedVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to, if (have_words) ReadWords(fd, to, bound_, offset); } +void SortedVocabulary::LoadedBinary(bool have_words, char* file_data, EnumerateVocab *to, uint64_t offset, bool load_from_memory) { + end_ = begin_ + *(reinterpret_cast(begin_) - 1); + SetSpecial(Index(""), Index(""), 0); + bound_ = end_ - begin_ + 1; + if (have_words) { + ReadWords(file_data, to, bound_, offset); + } +} + + template void SortedVocabulary::GenericFinished(T *reorder) { if (enumerate_) { if (!strings_to_enumerate_.empty()) { @@ -282,6 +323,17 @@ void ProbingVocabulary::LoadedBinary(bool have_words, int fd, EnumerateVocab *to if (have_words) ReadWords(fd, to, bound_, offset); } +void ProbingVocabulary::LoadedBinary(bool have_words, char* file_data, EnumerateVocab *to, uint64_t offset, bool load_from_memory) { + UTIL_THROW_IF(header_->version != kProbingVocabularyVersion, FormatLoadException, "The binary file has probing version " << header_->version << " but the code expects version " << kProbingVocabularyVersion << ". Please rerun build_binary using the same version of the code."); + bound_ = header_->bound; + + SetSpecial(Index(""), Index(""), 0); + if (have_words) { + ReadWords(file_data, to, bound_, offset); + } +} + + void MissingUnknown(const Config &config) { switch(config.unknown_missing) { case SILENT: diff --git a/native_client/kenlm/lm/vocab.hh b/native_client/kenlm/lm/vocab.hh index f36e62ca21..a3456103fe 100644 --- a/native_client/kenlm/lm/vocab.hh +++ b/native_client/kenlm/lm/vocab.hh @@ -14,6 +14,7 @@ #include #include #include +#include namespace lm { struct ProbBackoff; @@ -111,6 +112,7 @@ class SortedVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); + void LoadedBinary(bool have_words, char* file_data, EnumerateVocab *to, uint64_t offset, bool load_from_memory); uint64_t *&EndHack() { return end_; } @@ -190,6 +192,7 @@ class ProbingVocabulary : public base::Vocabulary { bool SawUnk() const { return saw_unk_; } void LoadedBinary(bool have_words, int fd, EnumerateVocab *to, uint64_t offset); + void LoadedBinary(bool have_words, char* file_data, EnumerateVocab *to, uint64_t offset, bool load_from_memory); private: void InternalFinishedLoading(); diff --git a/native_client/kenlm/util/file.cc b/native_client/kenlm/util/file.cc index 1a70387e48..aa76ae6929 100644 --- a/native_client/kenlm/util/file.cc +++ b/native_client/kenlm/util/file.cc @@ -236,6 +236,43 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size) { UTIL_THROW_IF(1 != std::fwrite(data, size, 1, to), ErrnoException, "Short write; requested size " << size); } +void ErsatzPRead(char *file_data, void *to_void, std::size_t size, uint64_t off) { + uint8_t *to = static_cast(to_void); + while (size) { +#if defined(_WIN32) || defined(_WIN64) + /* BROKEN: changes file pointer. Even if you save it and change it back, it won't be safe to use concurrently with write() or read() which lmplz does. */ + // size_t might be 64-bit. DWORD is always 32. + DWORD reading = static_cast(std::min(kMaxDWORD, size)); + DWORD ret; + OVERLAPPED overlapped; + memset(&overlapped, 0, sizeof(OVERLAPPED)); + overlapped.Offset = static_cast(off); + overlapped.OffsetHigh = static_cast(off >> 32); + UTIL_THROW_IF(!ReadFile((HANDLE)_get_osfhandle(fd), to, reading, &ret, &overlapped), WindowsException, "ReadFile failed for offset " << off); +#else + ssize_t ret; + errno = 0; + ret = GuardLarge(size); + + file_data += off; + #if defined(_WIN32) || defined(_WIN64) + CopyMemory(out.get(), file_data, size); + #else + std::memcpy(to, file_data, GuardLarge(size)); + #endif + + if (ret <= 0) { + if (ret == -1 && errno == EINTR) continue; + UTIL_THROW_IF(ret == 0, EndOfFileException, " for reading " << size << " bytes at " << off << " from buffer"); + } +#endif + size -= ret; + off += ret; + to += ret; + } +} + + void ErsatzPRead(int fd, void *to_void, std::size_t size, uint64_t off) { uint8_t *to = static_cast(to_void); while (size) { diff --git a/native_client/kenlm/util/file.hh b/native_client/kenlm/util/file.hh index 4a50e73060..f72eab54cc 100644 --- a/native_client/kenlm/util/file.hh +++ b/native_client/kenlm/util/file.hh @@ -134,6 +134,7 @@ void WriteOrThrow(FILE *to, const void *data, std::size_t size); * above. */ void ErsatzPRead(int fd, void *to, std::size_t size, uint64_t off); +void ErsatzPRead(char *file_data, void *to_void, std::size_t size, uint64_t off); void ErsatzPWrite(int fd, const void *data_void, std::size_t size, uint64_t off); void FSyncOrThrow(int fd); diff --git a/native_client/kenlm/util/mmap.cc b/native_client/kenlm/util/mmap.cc index 39b9cd598d..35de58c157 100644 --- a/native_client/kenlm/util/mmap.cc +++ b/native_client/kenlm/util/mmap.cc @@ -20,6 +20,10 @@ #if defined(_WIN32) || defined(_WIN64) #include #include +#include +#include +#include + #else #include #include @@ -56,17 +60,17 @@ template T RoundUpPow2(T value, T mult) { } } // namespace -scoped_memory::scoped_memory(std::size_t size, bool zeroed) : data_(NULL), size_(0), source_(NONE_ALLOCATED) { +scoped_memory::scoped_memory(std::size_t size, bool zeroed, bool load_from_memory) : data_(NULL), size_(0), source_(NONE_ALLOCATED), load_from_memory_(load_from_memory) { HugeMalloc(size, zeroed, *this); } void scoped_memory::reset(void *data, std::size_t size, Alloc source) { switch(source_) { case MMAP_ROUND_UP_ALLOCATED: - scoped_mmap(data_, RoundUpPow2(size_, (std::size_t)SizePage())); + scoped_mmap(data_, RoundUpPow2(size_, (std::size_t)SizePage()), load_from_memory_); break; case MMAP_ALLOCATED: - scoped_mmap(data_, size_); + scoped_mmap(data_, size_, load_from_memory_); break; case MALLOC_ALLOCATED: free(data_); @@ -130,6 +134,46 @@ void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int return ret; } +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, char *file_data, uint64_t offset) { +#ifdef MAP_POPULATE // Linux specific + if (prefault) { + flags |= MAP_POPULATE; + } +#endif +#if defined(_WIN32) || defined(_WIN64) + + TCHAR szName[]=TEXT("Global\\LanguageModelMapping"); + + int protectC = for_write ? PAGE_READWRITE : PAGE_READONLY; + int protectM = for_write ? FILE_MAP_WRITE : FILE_MAP_READ; + uint64_t total_size = size + offset; + // HANDLE hMapping = CreateFileMapping((HANDLE)_get_osfhandle(fd), NULL, protectC, total_size >> 32, static_cast(total_size), NULL); + + HANDLE hMapping = CreateFileMapping(INVALID_HANDLE_VALUE, NULL, PAGE_READWRITE, total_size >> 32, static_cast(total_size), szName); + UTIL_THROW_IF(!hMapping, ErrnoException, "CreateFileMapping failed"); + // LPVOID ret = MapViewOfFile(hMapping, protectM, offset >> 32, offset, size); + + LPVOID ret = MapViewOfFile(hMapping, FILE_MAP_ALL_ACCESS, offset >> 32, offset, size); + + if (ret == NULL){ + std::cerr<< "ret NULL" << GetLastError() << std::endl; + } + + CloseHandle(hMapping); + UTIL_THROW_IF(!ret, ErrnoException, "MapViewOfFile failed"); + CopyMemory((PVOID)ret, file_data, total_size); +#else + // int protect = for_write ? (PROT_READ | PROT_WRITE) : PROT_READ; + int protect = PROT_READ | PROT_WRITE; + void *ret; + flags = MAP_ANONYMOUS|MAP_SHARED; + UTIL_THROW_IF((ret = mmap(NULL, size, protect, flags, -1, offset)) == MAP_FAILED, ErrnoException, "mmap failed for size " << size << " at offset " << offset); + memcpy(ret, file_data, size); +#endif + return ret; +} + + void SyncOrThrow(void *start, size_t length) { #if defined(_WIN32) || defined(_WIN64) UTIL_THROW_IF(!::FlushViewOfFile(start, length), ErrnoException, "Failed to sync mmap"); @@ -323,6 +367,38 @@ void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scope } } +void MapRead(LoadMethod method, char *file_data, uint64_t offset, std::size_t size, scoped_memory &out) { + switch (method) { + case LAZY: + out.reset(MapOrThrow(size, false, kFileFlags, false, file_data, offset), size, scoped_memory::MMAP_ALLOCATED); + break; + case POPULATE_OR_LAZY: +#ifdef MAP_POPULATE + case POPULATE_OR_READ: +#endif + out.reset(MapOrThrow(size, false, kFileFlags, true, file_data, offset), size, scoped_memory::MMAP_ALLOCATED); + break; +#ifndef MAP_POPULATE + case POPULATE_OR_READ: +#endif + case READ: + HugeMalloc(size, false, out); + file_data += offset; + #if defined(_WIN32) || defined(_WIN64) + CopyMemory(out.get(), file_data, size); + #else + std::memcpy(out.get(), file_data, size); + #endif + break; + case PARALLEL_READ: + HugeMalloc(size, false, out); + file_data += offset; + std::memcpy(out.get(), file_data, size); + break; + } +} + + void *MapZeroedWrite(int fd, std::size_t size) { ResizeOrThrow(fd, 0); ResizeOrThrow(fd, size); diff --git a/native_client/kenlm/util/mmap.hh b/native_client/kenlm/util/mmap.hh index 21bca1ddc6..cc51c1b549 100644 --- a/native_client/kenlm/util/mmap.hh +++ b/native_client/kenlm/util/mmap.hh @@ -17,8 +17,8 @@ std::size_t SizePage(); // (void*)-1 is MAP_FAILED; this is done to avoid including the mmap header here. class scoped_mmap { public: - scoped_mmap() : data_((void*)-1), size_(0) {} - scoped_mmap(void *data, std::size_t size) : data_(data), size_(size) {} + scoped_mmap(bool load_from_memory=false) : data_((void*)-1), size_(0), load_from_memory_(load_from_memory) {} + scoped_mmap(void *data, std::size_t size, bool load_from_memory=false) : data_(data), size_(size), load_from_memory_(load_from_memory) {} ~scoped_mmap(); void *get() const { return data_; } @@ -29,6 +29,8 @@ class scoped_mmap { char *end() { return reinterpret_cast(data_) + size_; } std::size_t size() const { return size_; } + bool load_from_memory_; + void reset(void *data, std::size_t size) { scoped_mmap other(data_, size_); data_ = data; @@ -66,13 +68,16 @@ class scoped_memory { NONE_ALLOCATED // nothing to free (though there can be something here if it's owned by somebody else). } Alloc; - scoped_memory(void *data, std::size_t size, Alloc source) - : data_(data), size_(size), source_(source) {} + scoped_memory(void *data, std::size_t size, Alloc source, bool load_from_memory=false) + : data_(data), size_(size), source_(source), load_from_memory_(load_from_memory) {} + + + scoped_memory(bool load_from_memory_=false) : data_(NULL), size_(0), + source_(NONE_ALLOCATED), load_from_memory_(load_from_memory_) {} - scoped_memory() : data_(NULL), size_(0), source_(NONE_ALLOCATED) {} // Calls HugeMalloc - scoped_memory(std::size_t to, bool zero_new); + scoped_memory(std::size_t to, bool zero_new, bool load_from_memory_=false); #if __cplusplus >= 201103L scoped_memory(scoped_memory &&from) noexcept @@ -91,6 +96,9 @@ class scoped_memory { char *end() { return reinterpret_cast(data_) + size_; } std::size_t size() const { return size_; } + bool load_from_memory_; + + Alloc source() const { return source_; } void reset() { reset(NULL, 0, NONE_ALLOCATED); } @@ -119,6 +127,7 @@ extern const int kFileFlags; // Cross-platform, error-checking wrapper for mmap(). void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, int fd, uint64_t offset = 0); +void *MapOrThrow(std::size_t size, bool for_write, int flags, bool prefault, char *file_data, uint64_t offset = 0); // msync wrapper void SyncOrThrow(void *start, size_t length); @@ -153,6 +162,7 @@ enum LoadMethod { }; void MapRead(LoadMethod method, int fd, uint64_t offset, std::size_t size, scoped_memory &out); +void MapRead(LoadMethod method, char *file_data, uint64_t offset, std::size_t size, scoped_memory &out); // Open file name with mmap of size bytes, all of which are initially zero. void *MapZeroedWrite(int fd, std::size_t size); diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc old mode 100644 new mode 100755 index d8637c3656..dcbc904d8d --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -24,7 +24,7 @@ ModelState::~ModelState() } int -ModelState::init(const char* model_path) +ModelState::init(const char* model_string, bool init_from_bytes, size_t bufferSize) { return DS_ERR_OK; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h old mode 100644 new mode 100755 index 4beb78b472..41fe817e71 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -31,7 +31,7 @@ struct ModelState { ModelState(); virtual ~ModelState(); - virtual int init(const char* model_path); + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize); virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) = 0; diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc old mode 100644 new mode 100755 index 50a68a4b94..5823b3268e --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -156,23 +156,36 @@ getTfliteDelegates() } int -TFLiteModelState::init(const char* model_path) +TFLiteModelState::init(const char *model_string, bool init_from_bytes, size_t bufferSize) { - int err = ModelState::init(model_path); + int err = ModelState::init(model_string, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } - - fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_path); - if (!fbmodel_) { - std::cerr << "Error at reading model file " << model_path << std::endl; - return DS_ERR_FAIL_INIT_MMAP; + + if (init_from_bytes){ + fbmodel_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(model_string, bufferSize); + if (!fbmodel_) { + std::cerr << "Error at reading model buffer " << std::endl; + return DS_ERR_FAIL_INIT_MMAP; + } + } else { + fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_string); + if (!fbmodel_) { + std::cerr << "Error at reading model file " << model_string << std::endl; + return DS_ERR_FAIL_INIT_MMAP; + } } tflite::ops::builtin::BuiltinOpResolver resolver; tflite::InterpreterBuilder(*fbmodel_, resolver)(&interpreter_); if (!interpreter_) { - std::cerr << "Error at InterpreterBuilder for model file " << model_path << std::endl; + if (init_from_bytes) { + std::cerr << "Error at InterpreterBuilder for model buffer " << std::endl; + } else { + std::cerr << "Error at InterpreterBuilder for model file " << model_string << std::endl; + } + return DS_ERR_FAIL_INTERPRETER; } @@ -318,7 +331,6 @@ TFLiteModelState::init(const char* model_path) assert(dims_c->data[1] == dims_h->data[1]); assert(state_size_ > 0); state_size_ = dims_c->data[1]; - return DS_ERR_OK; } diff --git a/native_client/tflitemodelstate.h b/native_client/tflitemodelstate.h old mode 100644 new mode 100755 index ace62ecf89..d2e95966e4 --- a/native_client/tflitemodelstate.h +++ b/native_client/tflitemodelstate.h @@ -31,7 +31,7 @@ struct TFLiteModelState : public ModelState TFLiteModelState(); virtual ~TFLiteModelState(); - virtual int init(const char* model_path) override; + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize) override; virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) override; diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc old mode 100644 new mode 100755 index 65328e308a..d2871e21d8 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -22,10 +22,21 @@ TFModelState::~TFModelState() } } +int loadGraphFromBinaryData(Env* env, const char* data, size_t bufferSize, + ::tensorflow::protobuf::MessageLite* proto) { + + std::string dataString(data, bufferSize); + if (!proto->ParseFromString(dataString)) { + std::cerr << "Can't parse data as binary proto" << std::endl; + return -1; + } + return 0; +} + int -TFModelState::init(const char* model_path) +TFModelState::init(const char* model_string, bool init_from_bytes, size_t bufferSize) { - int err = ModelState::init(model_path); + int err = ModelState::init(model_string, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } @@ -34,21 +45,28 @@ TFModelState::init(const char* model_path) SessionOptions options; mmap_env_.reset(new MemmappedEnv(Env::Default())); - - bool is_mmap = std::string(model_path).find(".pbmm") != std::string::npos; - if (!is_mmap) { - std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl; - } else { - status = mmap_env_->InitializeFromFile(model_path); - if (!status.ok()) { - std::cerr << status << std::endl; - return DS_ERR_FAIL_INIT_MMAP; + bool is_mmap = false; + if (init_from_bytes) { + int loadGraphStatus = loadGraphFromBinaryData(mmap_env_.get(), model_string, bufferSize, &graph_def_); + if (loadGraphStatus != 0) { + return DS_ERR_FAIL_CREATE_SESS; } + } else { + is_mmap = std::string(model_string).find(".pbmm") != std::string::npos; + if (!is_mmap) { + std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl; + } else { + status = mmap_env_->InitializeFromFile(model_string); + if (!status.ok()) { + std::cerr << status << std::endl; + return DS_ERR_FAIL_INIT_MMAP; + } - options.config.mutable_graph_options() - ->mutable_optimizer_options() - ->set_opt_level(::OptimizerOptions::L0); - options.env = mmap_env_.get(); + options.config.mutable_graph_options() + ->mutable_optimizer_options() + ->set_opt_level(::OptimizerOptions::L0); + options.env = mmap_env_.get(); + } } Session* session; @@ -59,13 +77,21 @@ TFModelState::init(const char* model_path) } session_.reset(session); - if (is_mmap) { - status = ReadBinaryProto(mmap_env_.get(), - MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, - &graph_def_); + if (init_from_bytes){ + if( is_mmap) { + std::cerr << "Load from buffer does not support .pbmm models." << std::endl; + return DS_ERR_MODEL_NOT_SUP_BUFFER; + } } else { - status = ReadBinaryProto(Env::Default(), model_path, &graph_def_); + if (is_mmap) { + status = ReadBinaryProto(mmap_env_.get(), + MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, + &graph_def_); + } else { + status = ReadBinaryProto(Env::Default(), model_string, &graph_def_); + } } + if (!status.ok()) { std::cerr << status << std::endl; return DS_ERR_FAIL_READ_PROTOBUF; diff --git a/native_client/tfmodelstate.h b/native_client/tfmodelstate.h old mode 100644 new mode 100755 index 2a8db699df..0308b94338 --- a/native_client/tfmodelstate.h +++ b/native_client/tfmodelstate.h @@ -18,7 +18,7 @@ struct TFModelState : public ModelState TFModelState(); virtual ~TFModelState(); - virtual int init(const char* model_path) override; + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize) override; virtual void infer(const std::vector& mfcc, unsigned int n_frames,