Skip to content

Commit

Permalink
[http] Update http server (#189)
Browse files Browse the repository at this point in the history
Co-authored-by: 李盛强 <lichengqiang@lichengqiangdeMacBook-Pro.local>
  • Loading branch information
Shengqiang-Li and 李盛强 authored Jan 13, 2024
1 parent e1ebf15 commit 7687924
Show file tree
Hide file tree
Showing 7 changed files with 125 additions and 31 deletions.
43 changes: 38 additions & 5 deletions runtime/core/bin/http_server_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,36 +16,69 @@
#include "glog/logging.h"

#include "http/http_server.h"
#include "utils/log.h"
#include "processor/wetext_processor.h"

#include "frontend/g2p_en.h"
#include "frontend/g2p_prosody.h"
#include "frontend/wav.h"
#include "model/tts_model.h"
#include "utils/string.h"

// Flags
DEFINE_string(frontend_flags, "", "frontend flags file");
DEFINE_string(vits_flags, "", "vits flags file");

// Text Normalization
DEFINE_string(tagger, "", "tagger fst file");
DEFINE_string(verbalizer, "", "verbalizer fst file");

// Tokenizer
DEFINE_string(vocab, "", "tokenizer vocab file");

// G2P for English
DEFINE_string(cmudict, "", "cmudict for english words");
DEFINE_string(g2p_en_model, "", "english g2p fst model for oov");
DEFINE_string(g2p_en_sym, "", "english g2p symbol table for oov");

// G2P for Chinese
DEFINE_string(char2pinyin, "", "chinese character to pinyin");
DEFINE_string(pinyin2id, "", "pinyin to id");
DEFINE_string(pinyin2phones, "", "pinyin to phones");
DEFINE_string(g2p_prosody_model, "", "g2p prosody model file");

// VITS
DEFINE_string(speaker2id, "", "speaker to id");
DEFINE_string(phone2id, "", "phone to id");
DEFINE_string(vits_model, "", "e2e tts model file");
DEFINE_int32(sampling_rate, 22050, "sampling rate of pcm");


// port
DEFINE_int32(port, 10086, "http listening port");

int main(int argc, char* argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, false);
google::InitGoogleLogging(argv[0]);
gflags::ReadFromFlagsFile(FLAGS_frontend_flags, "", false);
gflags::ReadFromFlagsFile(FLAGS_vits_flags, "", false);

auto tn = std::make_shared<wetext::Processor>(FLAGS_tagger, FLAGS_verbalizer);

bool has_en = !FLAGS_g2p_en_model.empty() && !FLAGS_g2p_en_sym.empty() &&
!FLAGS_g2p_en_sym.empty();
std::shared_ptr<wetts::G2pEn> g2p_en =
has_en ? std::make_shared<wetts::G2pEn>(FLAGS_cmudict, FLAGS_g2p_en_model,
FLAGS_g2p_en_sym)
: nullptr;

auto g2p_prosody = std::make_shared<wetts::G2pProsody>(
FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id,
FLAGS_pinyin2phones);
auto tts_model = std::make_shared<wetts::TtsModel>(
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, tn, g2p_prosody);
FLAGS_pinyin2phones, g2p_en);
auto model = std::make_shared<wetts::TtsModel>(
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate,
tn, g2p_prosody);

wetts::HttpServer server(FLAGS_port, tts_model);
wetts::HttpServer server(FLAGS_port, model);
LOG(INFO) << "Listening at port " << FLAGS_port;
server.Start();
return 0;
Expand Down
3 changes: 2 additions & 1 deletion runtime/core/bin/tts_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ int main(int argc, char* argv[]) {
FLAGS_g2p_prosody_model, FLAGS_vocab, FLAGS_char2pinyin, FLAGS_pinyin2id,
FLAGS_pinyin2phones, g2p_en);
auto model = std::make_shared<wetts::TtsModel>(
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, tn, g2p_prosody);
FLAGS_vits_model, FLAGS_speaker2id, FLAGS_phone2id, FLAGS_sampling_rate,
tn, g2p_prosody);

std::vector<float> audio;
int sid = model->GetSid(FLAGS_sname);
Expand Down
61 changes: 39 additions & 22 deletions runtime/core/http/http_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,25 @@

#include "frontend/wav.h"
#include "utils/string.h"
#include "utils/timer.h"

namespace wetts {

namespace urls = boost::urls;
namespace uuids = boost::uuids;

http::message_generator ConnectionHandler::handle_request(
const std::string& wav_path) {
// Attempt to open the file
http::message_generator ConnectionHandler::HandleRequest(
char* wav_data, int data_size) {
beast::error_code ec;
http::file_body::value_type body;
body.open(wav_path.c_str(), beast::file_mode::scan, ec);

// Cache the size since we need it after the move
auto const size = body.size();
// Respond to GET request
http::response<http::file_body> res{
std::piecewise_construct, std::make_tuple(std::move(body)),
std::make_tuple(http::status::ok, request_.version())};
http::response<http::buffer_body> res;
res.result(http::status::ok);
res.set(http::field::server, BOOST_BEAST_VERSION_STRING);
res.set(http::field::content_type, "audio/wav");
res.content_length(size);
res.keep_alive(request_.keep_alive());
res.body().data = wav_data;
res.body().size = data_size;
res.body().more = false;
res.prepare_payload();
return res;
}

Expand Down Expand Up @@ -80,16 +76,37 @@ void ConnectionHandler::operator()() {
std::string name = (*params.find("name")).value;
int sid = tts_model_->GetSid(name);
// 2. Synthesis audio from text
std::vector<float> audio;
tts_model_->Synthesis(text, sid, &audio);
wetts::WavWriter wav_writer(audio.data(), audio.size(), 1, 22050, 16);
// 3. Write samples to file named uuid.wav
std::string wav_path =
uuids::to_string(uuids::random_generator()()) + ".wav";
wav_writer.Write(wav_path);

int sample_rate = tts_model_->sampling_rate();
int num_channels = 1;
int bits_per_sample = 16;
LOG(INFO) << "Sample rate: " << sample_rate;
LOG(INFO) << "Num of channels: " << num_channels;
LOG(INFO) << "Bit per sample: " << bits_per_sample;
int extract_time = 0;
wetts::Timer timer;
std::vector<float> pcm;
tts_model_->Synthesis(text, sid, &pcm);
int pcm_size = pcm.size();
extract_time = timer.Elapsed();
LOG(INFO) << "TTS pcm duration: "
<< pcm_size * 1000 / num_channels / sample_rate << "ms";
LOG(INFO) << "Cost time: " << static_cast<float>(extract_time) << "ms";
// 3. Convert pcm to wav
std::vector<int16_t> audio(pcm_size);
for (int i = 0; i < pcm_size; ++i) {
audio[i] = static_cast<int16_t>(pcm[i]);
}
int audio_size = pcm_size * sizeof(int16_t);
int data_size = audio_size + 44;
WavHeader header(pcm_size, num_channels, sample_rate, bits_per_sample);
std::vector<char> wav_data;
wav_data.insert(wav_data.end(), reinterpret_cast<char*>(&header),
reinterpret_cast<char*>(&header) + 44);
wav_data.insert(wav_data.end(), reinterpret_cast<char*>(audio.data()),
reinterpret_cast<char*>(audio.data()) + audio_size);
// Handle request
http::message_generator msg = handle_request(wav_path);
http::message_generator msg =
HandleRequest(wav_data.data(), data_size);
// Determine if we should close the connection
bool keep_alive = msg.keep_alive();
// Send the response
Expand Down
2 changes: 1 addition & 1 deletion runtime/core/http/http_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ConnectionHandler {
ConnectionHandler(tcp::socket&& socket, std::shared_ptr<TtsModel> tts_model)
: socket_(std::move(socket)), tts_model_(std::move(tts_model)) {}
void operator()();
http::message_generator handle_request(const std::string& wav_path);
http::message_generator HandleRequest(char* wav_data, int data_size);

private:
tcp::socket socket_;
Expand Down
3 changes: 2 additions & 1 deletion runtime/core/model/tts_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@
namespace wetts {

TtsModel::TtsModel(const std::string& model_path, const std::string& speaker2id,
const std::string& phone2id,
const std::string& phone2id, const int sampling_rate,
std::shared_ptr<wetext::Processor> tn,
std::shared_ptr<G2pProsody> g2p_prosody)
: OnnxModel(model_path),
tn_(std::move(tn)),
g2p_prosody_(std::move(g2p_prosody)) {
sampling_rate_ = sampling_rate;
ReadTableFile(phone2id, &phone2id_);
ReadTableFile(speaker2id, &speaker2id_);
}
Expand Down
5 changes: 4 additions & 1 deletion runtime/core/model/tts_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,17 @@ namespace wetts {
class TtsModel : public OnnxModel {
public:
explicit TtsModel(const std::string& model_path,
const std::string& speaker2id, const std::string& phone2id,
const std::string& speaker2id,
const std::string& phone2id,
const int sampling_rate,
std::shared_ptr<wetext::Processor> processor,
std::shared_ptr<G2pProsody> g2p_prosody);
void Forward(const std::vector<int64_t>& phonemes, const int sid,
std::vector<float>* audio);
void Synthesis(const std::string& text, const int sid,
std::vector<float>* audio);
int GetSid(const std::string& name);
int sampling_rate() const { return sampling_rate_; }

private:
int sampling_rate_;
Expand Down
39 changes: 39 additions & 0 deletions runtime/core/utils/timer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright (c) 2021 Mobvoi Inc (Binbin Zhang)
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef UTILS_TIMER_H_
#define UTILS_TIMER_H_

#include <chrono>

namespace wetts {

class Timer {
public:
Timer() : time_start_(std::chrono::steady_clock::now()) {}
void Reset() { time_start_ = std::chrono::steady_clock::now(); }
// return int in milliseconds
int Elapsed() const {
auto time_now = std::chrono::steady_clock::now();
return std::chrono::duration_cast<std::chrono::milliseconds>(time_now -
time_start_)
.count();
}

private:
std::chrono::time_point<std::chrono::steady_clock> time_start_;
};
} // namespace wetts

#endif // UTILS_TIMER_H_

0 comments on commit 7687924

Please sign in to comment.