diff --git a/README.md b/README.md index 342c422..43c4d96 100644 --- a/README.md +++ b/README.md @@ -40,12 +40,10 @@ We plan to support a variaty of open source TTS datasets, include but not limite ## Pretrained Models -| Dataset | Language | Checkpoint Model | Runtime Model | -| ------- | -------- | ---------------- | ------------- | -| Baker | CN | [BERT](https://wenet.org.cn/downloads?models=wetts&version=baker_bert_exp.tar.gz) | [BERT](https://wenet.org.cn/downloads?models=wetts&version=baker_bert_onnx.tar.gz) | -| Baker | CN | [VITS](https://wenet.org.cn/downloads?models=wetts&version=baker_vits_v1_exp.tar.gz) | [VITS](https://wenet.org.cn/downloads?models=wetts&version=baker_vits_v1_onnx.tar.gz) | - -English G2P model: [english_us_arpa v2.0.0a](https://wenet.org.cn/downloads?models=wetts&version=g2p_en.tar.gz), powered by [MFA](https://github.com/MontrealCorpusTools/mfa-models/releases/tag/g2p-english_us_arpa-v2.0.0a). +| Dataset | Language | Checkpoint Model | Runtime Model | +| -------------- | -------- | ---------------- | ------------- | +| Baker | CN | [BERT](https://wenet.org.cn/downloads?models=wetts&version=baker_bert_exp.tar.gz) | [BERT](https://wenet.org.cn/downloads?models=wetts&version=baker_bert_onnx.tar.gz) | +| Multilingual | CN | [VITS](https://wenet.org.cn/downloads?models=wetts&version=multilingual_vits_v3_exp.tar.gz) | [VITS](https://wenet.org.cn/downloads?models=wetts&version=multilingual_vits_v3_onnx.tar.gz) | ## Runtime @@ -64,21 +62,10 @@ cd runtime/onnxruntime cmake -B build -DCMAKE_BUILD_TYPE=Release cmake --build build ./build/bin/tts_main \ - --tagger baker_bert_onnx/zh_tn_tagger.fst \ - --verbalizer baker_bert_onnx/zh_tn_verbalizer.fst \ - --vocab baker_bert_onnx/vocab.txt \ - --char2pinyin baker_bert_onnx/pinyin_dict.txt \ - --pinyin2id baker_bert_onnx/polyphone_phone.txt \ - --pinyin2phones baker_bert_onnx/lexicon.txt \ - --g2p_prosody_model baker_bert_onnx/19.onnx \ - --speaker2id baker_vits_v1_onnx/speaker.txt \ + --frontend_flags baker_bert_onnx/frontend.flags \ + --vits_flags multilingual_vits_v3_onnx/vits.flags \ --sname baker \ - --phone2id baker_vits_v1_onnx/phones.txt \ - --vits_model baker_vits_v1_onnx/G_250000.onnx \ - --text "你好,我是小明。" \ - --cmudict g2p_en/cmudict.dict \ # optional - --g2p_en_model g2p_en/model.fst \ # optional - --g2p_en_sym g2p_en/phones.sym \ # optional + --text "hello我是小明。" \ --wav_path audio.wav ``` diff --git a/runtime/core/bin/tts_main.cc b/runtime/core/bin/tts_main.cc index 01495b9..7a5f1fe 100644 --- a/runtime/core/bin/tts_main.cc +++ b/runtime/core/bin/tts_main.cc @@ -22,6 +22,10 @@ #include "model/tts_model.h" #include "utils/string.h" +// Flags +DEFINE_string(frontend_flags, "", "frontend flags file"); +DEFINE_string(vits_flags, "", "vits flags file"); + // Text Normalization DEFINE_string(tagger, "", "tagger fst file"); DEFINE_string(verbalizer, "", "verbalizer fst file"); @@ -43,16 +47,18 @@ DEFINE_string(g2p_prosody_model, "", "g2p prosody model file"); // VITS DEFINE_string(speaker2id, "", "speaker to id"); DEFINE_string(phone2id, "", "phone to id"); -DEFINE_string(sname, "", "speaker name"); DEFINE_string(vits_model, "", "e2e tts model file"); - DEFINE_int32(sampling_rate, 22050, "sampling rate of pcm"); + +DEFINE_string(sname, "", "speaker name"); DEFINE_string(text, "", "input text"); DEFINE_string(wav_path, "", "output wave path"); int main(int argc, char* argv[]) { gflags::ParseCommandLineFlags(&argc, &argv, false); google::InitGoogleLogging(argv[0]); + gflags::ReadFromFlagsFile(FLAGS_frontend_flags, "", false); + gflags::ReadFromFlagsFile(FLAGS_vits_flags, "", false); auto tn = std::make_shared(FLAGS_tagger, FLAGS_verbalizer); diff --git a/wetts/vits/inference_onnx.py b/wetts/vits/inference_onnx.py index 1913dff..9adeea6 100644 --- a/wetts/vits/inference_onnx.py +++ b/wetts/vits/inference_onnx.py @@ -62,7 +62,8 @@ def main(): scales = scales.unsqueeze(0) for line in open(args.test_file): - audio_path, sid, text = line.strip().split("|") + audio_path, speaker, text = line.strip().split("|") + sid = speaker_dict[speaker] seq = [phone_dict[symbol] for symbol in text.split()] x = torch.LongTensor([seq]) diff --git a/wetts/vits/utils.py b/wetts/vits/utils.py index d433eef..774008f 100644 --- a/wetts/vits/utils.py +++ b/wetts/vits/utils.py @@ -4,8 +4,6 @@ import logging import os -import numpy as np -from scipy.io.wavfile import read import torch MATPLOTLIB_FLAG = False @@ -152,11 +150,6 @@ def plot_alignment_to_numpy(alignment, info=None): return data -def load_wav_to_torch(full_path): - sampling_rate, data = read(full_path) - return torch.FloatTensor(data.astype(np.float32)), sampling_rate - - def load_filepaths_and_text(filename, split="|"): with open(filename, encoding="utf-8") as f: filepaths_and_text = [line.strip().split(split) for line in f] @@ -217,8 +210,7 @@ def get_hparams(init=True): def get_hparams_from_dir(model_dir): config_save_path = os.path.join(model_dir, "config.json") with open(config_save_path, "r") as f: - data = f.read() - config = json.loads(data) + config = json.load(f) hparams = HParams(**config) hparams.model_dir = model_dir @@ -227,8 +219,7 @@ def get_hparams_from_dir(model_dir): def get_hparams_from_file(config_path): with open(config_path, "r") as f: - data = f.read() - config = json.loads(data) + config = json.load(f) hparams = HParams(**config) return hparams