diff --git a/native_client/client.cc b/native_client/client.cc index 847777bb5e..b88963f6f8 100755 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -417,6 +417,7 @@ main(int argc, char **argv) // Initialise DeepSpeech ModelState* ctx; + std::string buffer_model_str; // sphinx-doc: c_ref_model_start int status; if (init_from_array_of_bytes){ @@ -424,7 +425,8 @@ main(int argc, char **argv) std::ifstream is_model( model, std::ios::binary ); std::stringstream buffer_model; buffer_model << is_model.rdbuf(); - status = DS_CreateModelFromBuffer(buffer_model.str(), &ctx); + buffer_model_str = buffer_model.str(); + status = DS_CreateModelFromBuffer(buffer_model_str.c_str(), buffer_model_str.size(), &ctx); }else { // Keep old method due to backwards compatibility status = DS_CreateModel(model, &ctx); @@ -451,7 +453,8 @@ main(int argc, char **argv) std::ifstream is_scorer(scorer, std::ios::binary ); std::stringstream buffer_scorer; buffer_scorer << is_scorer.rdbuf(); - status = DS_EnableExternalScorerFromBuffer(ctx, buffer_scorer.str()); + std::string tmp_str_scorer = buffer_scorer.str(); + status = DS_EnableExternalScorerFromBuffer(ctx, tmp_str_scorer.c_str(), tmp_str_scorer.size()); } else { // Keep old method due to backwards compatibility status = DS_EnableExternalScorer(ctx, scorer); diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index d7f26ea0c2..32468f3a70 100755 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -263,9 +263,10 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) } int -DS_CreateModel_(const std::string &aModelString, +DS_CreateModel_(const char* aModelString, bool init_from_bytes, - ModelState** retval) + ModelState** retval, + size_t bufferSize=0) { *retval = nullptr; @@ -278,7 +279,7 @@ DS_CreateModel_(const std::string &aModelString, LOGD("DeepSpeech: %s", ds_git_version()); #endif - if (aModelString.length() < 1) { + if ( (!init_from_bytes && (strlen(aModelString) < 1)) || (init_from_bytes && (bufferSize<1))) { std::cerr << "No model specified, cannot continue." << std::endl; return DS_ERR_NO_MODEL; } @@ -296,7 +297,7 @@ DS_CreateModel_(const std::string &aModelString, return DS_ERR_FAIL_CREATE_MODEL; } - int err = model->init(aModelString, init_from_bytes); + int err = model->init(aModelString, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } @@ -313,10 +314,11 @@ DS_CreateModel(const char* aModelPath, } int -DS_CreateModelFromBuffer(const std::string &aModelBuffer, +DS_CreateModelFromBuffer(const char* aModelBuffer, + size_t bufferSize, ModelState** retval) { - return DS_CreateModel_(aModelBuffer, true, retval); + return DS_CreateModel_(aModelBuffer, true, retval, bufferSize); } @@ -347,12 +349,18 @@ DS_FreeModel(ModelState* ctx) int DS_EnableExternalScorer_(ModelState* aCtx, - const std::string &aScorerString, - bool init_from_bytes) + const char* aScorerString, + bool init_from_bytes, + size_t bufferSize=0) { std::unique_ptr scorer(new Scorer()); - int err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_); + int err; + if (init_from_bytes) + err = scorer->init(std::string(aScorerString, bufferSize), init_from_bytes, aCtx->alphabet_); + else + err = scorer->init(aScorerString, init_from_bytes, aCtx->alphabet_); + if (err != 0) { return DS_ERR_INVALID_SCORER; @@ -370,9 +378,10 @@ DS_EnableExternalScorer(ModelState* aCtx, int DS_EnableExternalScorerFromBuffer(ModelState* aCtx, - const std::string &aScorerBuffer) + const char* aScorerBuffer, + size_t bufferSize) { - return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true); + return DS_EnableExternalScorer_(aCtx, aScorerBuffer, true, bufferSize); } int diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 673e518e23..13170dcc6f 100755 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -76,6 +76,7 @@ typedef struct Metadata { APPLY(DS_ERR_SCORER_NO_TRIE, 0x2007, "Reached end of scorer file before loading vocabulary trie.") \ APPLY(DS_ERR_SCORER_INVALID_TRIE, 0x2008, "Invalid magic in trie header.") \ APPLY(DS_ERR_SCORER_VERSION_MISMATCH, 0x2009, "Scorer file version does not match expected version.") \ + APPLY(DS_ERR_MODEL_NOT_SUP_BUFFER, 0x2010, "Load from buffer does not support memorymaped models.") \ APPLY(DS_ERR_FAIL_INIT_MMAP, 0x3000, "Failed to initialize memory mapped model.") \ APPLY(DS_ERR_FAIL_INIT_SESS, 0x3001, "Failed to initialize the session.") \ APPLY(DS_ERR_FAIL_INTERPRETER, 0x3002, "Interpreter failed.") \ @@ -118,7 +119,8 @@ int DS_CreateModel(const char* aModelPath, * @return Zero on success, non-zero on failure. */ DEEPSPEECH_EXPORT -int DS_CreateModelFromBuffer(const std::string &aModelBuffer, +int DS_CreateModelFromBuffer(const char* aModelBuffer, + size_t bufferSize, ModelState** retval); @@ -185,7 +187,8 @@ int DS_EnableExternalScorer(ModelState* aCtx, */ DEEPSPEECH_EXPORT int DS_EnableExternalScorerFromBuffer(ModelState* aCtx, - const std::string &aScorerBuffer); + const char* aScorerBuffer, + size_t bufferSize); /** diff --git a/native_client/modelstate.cc b/native_client/modelstate.cc index 5f4ce2e274..dcbc904d8d 100755 --- a/native_client/modelstate.cc +++ b/native_client/modelstate.cc @@ -24,7 +24,7 @@ ModelState::~ModelState() } int -ModelState::init(const std::string &model_string, bool init_from_bytes) +ModelState::init(const char* model_string, bool init_from_bytes, size_t bufferSize) { return DS_ERR_OK; } diff --git a/native_client/modelstate.h b/native_client/modelstate.h index 166edff00e..41fe817e71 100755 --- a/native_client/modelstate.h +++ b/native_client/modelstate.h @@ -31,7 +31,7 @@ struct ModelState { ModelState(); virtual ~ModelState(); - virtual int init(const std::string &model_string, bool init_from_bytes); + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize); virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) = 0; diff --git a/native_client/tflitemodelstate.cc b/native_client/tflitemodelstate.cc index b72b1fa1ae..5823b3268e 100755 --- a/native_client/tflitemodelstate.cc +++ b/native_client/tflitemodelstate.cc @@ -156,24 +156,21 @@ getTfliteDelegates() } int -TFLiteModelState::init(const std::string &model_string, bool init_from_bytes) +TFLiteModelState::init(const char *model_string, bool init_from_bytes, size_t bufferSize) { - int err = ModelState::init(model_string, init_from_bytes); + int err = ModelState::init(model_string, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } - + if (init_from_bytes){ - char *tmp_buffer = new char[model_string.size()]; - std::copy(model_string.begin(), model_string.end(), tmp_buffer); - // Using c_str does not work - fbmodel_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(tmp_buffer,model_string.size()); + fbmodel_ = tflite::FlatBufferModel::VerifyAndBuildFromBuffer(model_string, bufferSize); if (!fbmodel_) { std::cerr << "Error at reading model buffer " << std::endl; return DS_ERR_FAIL_INIT_MMAP; } } else { - fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_string.c_str()); + fbmodel_ = tflite::FlatBufferModel::BuildFromFile(model_string); if (!fbmodel_) { std::cerr << "Error at reading model file " << model_string << std::endl; return DS_ERR_FAIL_INIT_MMAP; @@ -334,7 +331,6 @@ TFLiteModelState::init(const std::string &model_string, bool init_from_bytes) assert(dims_c->data[1] == dims_h->data[1]); assert(state_size_ > 0); state_size_ = dims_c->data[1]; - return DS_ERR_OK; } diff --git a/native_client/tflitemodelstate.h b/native_client/tflitemodelstate.h index 38de5b9d7d..d2e95966e4 100755 --- a/native_client/tflitemodelstate.h +++ b/native_client/tflitemodelstate.h @@ -31,7 +31,7 @@ struct TFLiteModelState : public ModelState TFLiteModelState(); virtual ~TFLiteModelState(); - virtual int init(const std::string &model_string, bool init_from_bytes) override; + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize) override; virtual void compute_mfcc(const std::vector& audio_buffer, std::vector& mfcc_output) override; diff --git a/native_client/tfmodelstate.cc b/native_client/tfmodelstate.cc index 7ec8a76777..d2871e21d8 100755 --- a/native_client/tfmodelstate.cc +++ b/native_client/tfmodelstate.cc @@ -22,10 +22,11 @@ TFModelState::~TFModelState() } } -int loadGraphFromBinaryData(Env* env, const std::string& data, +int loadGraphFromBinaryData(Env* env, const char* data, size_t bufferSize, ::tensorflow::protobuf::MessageLite* proto) { - if (!proto->ParseFromString(data)) { + std::string dataString(data, bufferSize); + if (!proto->ParseFromString(dataString)) { std::cerr << "Can't parse data as binary proto" << std::endl; return -1; } @@ -33,9 +34,9 @@ int loadGraphFromBinaryData(Env* env, const std::string& data, } int -TFModelState::init(const std::string &model_string, bool init_from_bytes) +TFModelState::init(const char* model_string, bool init_from_bytes, size_t bufferSize) { - int err = ModelState::init(model_string, init_from_bytes); + int err = ModelState::init(model_string, init_from_bytes, bufferSize); if (err != DS_ERR_OK) { return err; } @@ -46,16 +47,16 @@ TFModelState::init(const std::string &model_string, bool init_from_bytes) mmap_env_.reset(new MemmappedEnv(Env::Default())); bool is_mmap = false; if (init_from_bytes) { - int loadGraphStatus = loadGraphFromBinaryData(Env::Default(), model_string, &graph_def_); + int loadGraphStatus = loadGraphFromBinaryData(mmap_env_.get(), model_string, bufferSize, &graph_def_); if (loadGraphStatus != 0) { return DS_ERR_FAIL_CREATE_SESS; } } else { - is_mmap = model_string.find(".pbmm") != std::string::npos; + is_mmap = std::string(model_string).find(".pbmm") != std::string::npos; if (!is_mmap) { std::cerr << "Warning: reading entire model file into memory. Transform model file into an mmapped graph to reduce heap usage." << std::endl; } else { - status = mmap_env_->InitializeFromFile(model_string.c_str()); + status = mmap_env_->InitializeFromFile(model_string); if (!status.ok()) { std::cerr << status << std::endl; return DS_ERR_FAIL_INIT_MMAP; @@ -77,14 +78,17 @@ TFModelState::init(const std::string &model_string, bool init_from_bytes) session_.reset(session); if (init_from_bytes){ - // Need some help + if( is_mmap) { + std::cerr << "Load from buffer does not support .pbmm models." << std::endl; + return DS_ERR_MODEL_NOT_SUP_BUFFER; + } } else { if (is_mmap) { status = ReadBinaryProto(mmap_env_.get(), MemmappedFileSystem::kMemmappedPackageDefaultGraphDef, &graph_def_); } else { - status = ReadBinaryProto(Env::Default(), model_string.c_str(), &graph_def_); + status = ReadBinaryProto(Env::Default(), model_string, &graph_def_); } } diff --git a/native_client/tfmodelstate.h b/native_client/tfmodelstate.h index 85628acf3c..0308b94338 100755 --- a/native_client/tfmodelstate.h +++ b/native_client/tfmodelstate.h @@ -18,7 +18,7 @@ struct TFModelState : public ModelState TFModelState(); virtual ~TFModelState(); - virtual int init(const std::string &model_string, bool init_from_bytes) override; + virtual int init(const char* model_string, bool init_from_bytes, size_t bufferSize) override; virtual void infer(const std::vector& mfcc, unsigned int n_frames,