Skip to content

Commit

Permalink
Merge pull request #3341 from coqui-ai/dev
Browse files Browse the repository at this point in the history
v0.21.2
  • Loading branch information
erogol authored Nov 30, 2023
2 parents 6189e2f + 6d1905c commit 2846640
Show file tree
Hide file tree
Showing 9 changed files with 183 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ Please use our dedicated channels for questions and discussion. Help is much mor
| Type | Links |
| ------------------------------- | --------------------------------------- |
| 💼 **Documentation** | [ReadTheDocs](https://tts.readthedocs.io/en/latest/)
| 💾 **Installation** | [TTS/README.md](https://github.com/coqui-ai/TTS/tree/dev#install-tts)|
| 💾 **Installation** | [TTS/README.md](https://github.com/coqui-ai/TTS/tree/dev#installation)|
| 👩‍💻 **Contributing** | [CONTRIBUTING.md](https://github.com/coqui-ai/TTS/blob/main/CONTRIBUTING.md)|
| 📌 **Road Map** | [Main Development Plans](https://github.com/coqui-ai/TTS/issues/378)
| 🚀 **Released Models** | [TTS Releases](https://github.com/coqui-ai/TTS/releases) and [Experimental Models](https://github.com/coqui-ai/TTS/wiki/Experimental-Released-Models)|
Expand Down
2 changes: 1 addition & 1 deletion TTS/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.21.1
0.21.2
70 changes: 63 additions & 7 deletions TTS/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from TTS.utils.synthesizer import Synthesizer
from TTS.config import load_config


class TTS(nn.Module):
"""TODO: Add voice conversion and Capacitron support."""

Expand Down Expand Up @@ -75,11 +76,13 @@ def __init__(
if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")

if model_name is not None:
if model_name is not None and len(model_name) > 0:
if "tts_models" in model_name or "coqui_studio" in model_name:
self.load_tts_model_by_name(model_name, gpu)
elif "voice_conversion_models" in model_name:
self.load_vc_model_by_name(model_name, gpu)
else:
self.load_model_by_name(model_name, gpu)

if model_path:
self.load_tts_model_by_path(
Expand All @@ -105,8 +108,12 @@ def is_coqui_studio(self):
@property
def is_multi_lingual(self):
# Not sure what sets this to None, but applied a fix to prevent crashing.
if (isinstance(self.model_name, str) and "xtts" in self.model_name or
self.config and ("xtts" in self.config.model or len(self.config.languages) > 1)):
if (
isinstance(self.model_name, str)
and "xtts" in self.model_name
or self.config
and ("xtts" in self.config.model or len(self.config.languages) > 1)
):
return True
if hasattr(self.synthesizer.tts_model, "language_manager") and self.synthesizer.tts_model.language_manager:
return self.synthesizer.tts_model.language_manager.num_languages > 1
Expand Down Expand Up @@ -149,6 +156,15 @@ def download_model_by_name(self, model_name: str):
vocoder_path, vocoder_config_path, _ = self.manager.download_model(model_item["default_vocoder"])
return model_path, config_path, vocoder_path, vocoder_config_path, None

def load_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of the 🐸TTS models by name.
Args:
model_name (str): Model name to load. You can list models by ```tts.models```.
gpu (bool, optional): Enable/disable GPU. Some models might be too slow on CPU. Defaults to False.
"""
self.load_tts_model_by_name(model_name, gpu)

def load_vc_model_by_name(self, model_name: str, gpu: bool = False):
"""Load one of the voice conversion models by name.
Expand Down Expand Up @@ -310,6 +326,7 @@ def tts(
speaker_wav: str = None,
emotion: str = None,
speed: float = None,
split_sentences: bool = True,
**kwargs,
):
"""Convert text to speech.
Expand All @@ -330,6 +347,12 @@ def tts(
speed (float, optional):
Speed factor to use for 🐸Coqui Studio models, between 0 and 2.0. If None, Studio models use 1.0.
Defaults to None.
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
kwargs (dict, optional):
Additional arguments for the model.
"""
self._check_arguments(
speaker=speaker, language=language, speaker_wav=speaker_wav, emotion=emotion, speed=speed, **kwargs
Expand All @@ -347,6 +370,7 @@ def tts(
style_wav=None,
style_text=None,
reference_speaker_name=None,
split_sentences=split_sentences,
**kwargs,
)
return wav
Expand All @@ -361,6 +385,7 @@ def tts_to_file(
speed: float = 1.0,
pipe_out=None,
file_path: str = "output.wav",
split_sentences: bool = True,
**kwargs,
):
"""Convert text to speech.
Expand All @@ -385,6 +410,10 @@ def tts_to_file(
Flag to stdout the generated TTS wav file for shell pipe.
file_path (str, optional):
Output file path. Defaults to "output.wav".
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
kwargs (dict, optional):
Additional arguments for the model.
"""
Expand All @@ -400,7 +429,14 @@ def tts_to_file(
file_path=file_path,
pipe_out=pipe_out,
)
wav = self.tts(text=text, speaker=speaker, language=language, speaker_wav=speaker_wav, **kwargs)
wav = self.tts(
text=text,
speaker=speaker,
language=language,
speaker_wav=speaker_wav,
split_sentences=split_sentences,
**kwargs,
)
self.synthesizer.save_wav(wav=wav, path=file_path, pipe_out=pipe_out)
return file_path

Expand Down Expand Up @@ -440,7 +476,14 @@ def voice_conversion_to_file(
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
return file_path

def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None, speaker: str = None):
def tts_with_vc(
self,
text: str,
language: str = None,
speaker_wav: str = None,
speaker: str = None,
split_sentences: bool = True,
):
"""Convert text to speech with voice conversion.
It combines tts with voice conversion to fake voice cloning.
Expand All @@ -460,10 +503,16 @@ def tts_with_vc(self, text: str, language: str = None, speaker_wav: str = None,
speaker (str, optional):
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
"""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
# Lazy code... save it to a temp file to resample it while reading it for VC
self.tts_to_file(text=text, speaker=speaker, language=language, file_path=fp.name)
self.tts_to_file(
text=text, speaker=speaker, language=language, file_path=fp.name, split_sentences=split_sentences
)
if self.voice_converter is None:
self.load_vc_model_by_name("voice_conversion_models/multilingual/vctk/freevc24")
wav = self.voice_converter.voice_conversion(source_wav=fp.name, target_wav=speaker_wav)
Expand All @@ -476,6 +525,7 @@ def tts_with_vc_to_file(
speaker_wav: str = None,
file_path: str = "output.wav",
speaker: str = None,
split_sentences: bool = True,
):
"""Convert text to speech with voice conversion and save to file.
Expand All @@ -495,6 +545,12 @@ def tts_with_vc_to_file(
speaker (str, optional):
Speaker name for multi-speaker. You can check whether loaded model is multi-speaker by
`tts.is_multi_speaker` and list speakers by `tts.speakers`. Defaults to None.
split_sentences (bool, optional):
Split text into sentences, synthesize them separately and concatenate the file audio.
Setting it False uses more VRAM and possibly hit model specific text length or VRAM limits. Only
applicable to the 🐸TTS models. Defaults to True.
"""
wav = self.tts_with_vc(text=text, language=language, speaker_wav=speaker_wav, speaker=speaker)
wav = self.tts_with_vc(
text=text, language=language, speaker_wav=speaker_wav, speaker=speaker, split_sentences=split_sentences
)
save_wav(wav=wav, path=file_path, sample_rate=self.voice_converter.vc_config.audio.output_sample_rate)
2 changes: 1 addition & 1 deletion TTS/tts/layers/xtts/trainer/gpt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def eval_step(self, batch, criterion):
return self.train_step(batch, criterion)

def on_train_epoch_start(self, trainer):
trainer.model.eval() # the whole model to eval
trainer.model.eval() # the whole model to eval
# put gpt model in training mode
trainer.model.xtts.gpt.train()

Expand Down
23 changes: 11 additions & 12 deletions TTS/tts/utils/text/punctuation.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class PuncPosition(Enum):
BEGIN = 0
END = 1
MIDDLE = 2
ALONE = 3


class Punctuation:
Expand Down Expand Up @@ -92,7 +91,7 @@ def _strip_to_restore(self, text):
return [text], []
# the text is only punctuations
if len(matches) == 1 and matches[0].group() == text:
return [], [_PUNC_IDX(text, PuncPosition.ALONE)]
return [], [_PUNC_IDX(text, PuncPosition.BEGIN)]
# build a punctuation map to be used later to restore punctuations
puncs = []
for match in matches:
Expand All @@ -107,11 +106,14 @@ def _strip_to_restore(self, text):
for idx, punc in enumerate(puncs):
split = text.split(punc.punc)
prefix, suffix = split[0], punc.punc.join(split[1:])
text = suffix
if prefix == "":
# We don't want to insert an empty string in case of initial punctuation
continue
splitted_text.append(prefix)
# if the text does not end with a punctuation, add it to the last item
if idx == len(puncs) - 1 and len(suffix) > 0:
splitted_text.append(suffix)
text = suffix
return splitted_text, puncs

@classmethod
Expand All @@ -127,10 +129,10 @@ def restore(cls, text, puncs):
['This is', 'example'], ['.', '!'] -> "This is. example!"
"""
return cls._restore(text, puncs, 0)
return cls._restore(text, puncs)

@classmethod
def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statements
def _restore(cls, text, puncs): # pylint: disable=too-many-return-statements
"""Auxiliary method for Punctuation.restore()"""
if not puncs:
return text
Expand All @@ -142,21 +144,18 @@ def _restore(cls, text, puncs, num): # pylint: disable=too-many-return-statemen
current = puncs[0]

if current.position == PuncPosition.BEGIN:
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:], num)
return cls._restore([current.punc + text[0]] + text[1:], puncs[1:])

if current.position == PuncPosition.END:
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:], num + 1)

if current.position == PuncPosition.ALONE:
return [current.mark] + cls._restore(text, puncs[1:], num + 1)
return [text[0] + current.punc] + cls._restore(text[1:], puncs[1:])

# POSITION == MIDDLE
if len(text) == 1: # pragma: nocover
# a corner case where the final part of an intermediate
# mark (I) has not been phonemized
return cls._restore([text[0] + current.punc], puncs[1:], num)
return cls._restore([text[0] + current.punc], puncs[1:])

return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:], num)
return cls._restore([text[0] + current.punc + text[1]] + text[2:], puncs[1:])


# if __name__ == "__main__":
Expand Down
35 changes: 32 additions & 3 deletions TTS/utils/manage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import os
import re
import tarfile
import zipfile
from pathlib import Path
Expand All @@ -26,7 +27,6 @@
}



class ModelManager(object):
tqdm_progress = None
"""Manage TTS models defined in .models.json.
Expand Down Expand Up @@ -276,13 +276,15 @@ def set_model_url(model_item: Dict):
model_item["model_url"] = model_item["hf_url"]
elif "fairseq" in model_item["model_name"]:
model_item["model_url"] = "https://coqui.gateway.scarf.sh/fairseq/"
elif "xtts" in model_item["model_name"]:
model_item["model_url"] = "https://coqui.gateway.scarf.sh/xtts/"
return model_item

def _set_model_item(self, model_name):
# fetch model info from the dict
model_type, lang, dataset, model = model_name.split("/")
model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
if "fairseq" in model_name:
model_type = "tts_models"
lang = model_name.split("/")[1]
model_item = {
"model_type": "tts_models",
"license": "CC BY-NC 4.0",
Expand All @@ -291,10 +293,37 @@ def _set_model_item(self, model_name):
"description": "this model is released by Meta under Fairseq repo. Visit https://github.com/facebookresearch/fairseq/tree/main/examples/mms for more info.",
}
model_item["model_name"] = model_name
elif "xtts" in model_name and len(model_name.split("/")) != 4:
# loading xtts models with only model name (e.g. xtts_v2.0.2)
# check model name has the version number with regex
version_regex = r"v\d+\.\d+\.\d+"
if re.search(version_regex, model_name):
model_version = model_name.split("_")[-1]
else:
model_version = "main"
model_type = "tts_models"
lang = "multilingual"
dataset = "multi-dataset"
model = model_name
model_item = {
"default_vocoder": None,
"license": "CPML",
"contact": "info@coqui.ai",
"tos_required": True,
"hf_url": [
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/model.pth",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/config.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/vocab.json",
f"https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/{model_version}/hash.md5",
],
}
else:
# get model from models.json
model_type, lang, dataset, model = model_name.split("/")
model_item = self.models_dict[model_type][lang][dataset][model]
model_item["model_type"] = model_type

model_full_name = f"{model_type}--{lang}--{dataset}--{model}"
md5hash = model_item["model_hash"] if "model_hash" in model_item else None
model_item = self.set_model_url(model_item)
return model_item, model_full_name, model, md5hash
Expand Down
9 changes: 7 additions & 2 deletions TTS/utils/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def tts(
style_text=None,
reference_wav=None,
reference_speaker_name=None,
split_sentences: bool = True,
**kwargs,
) -> List[int]:
"""🐸 TTS magic. Run all the models and generate speech.
Expand All @@ -277,6 +278,8 @@ def tts(
style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None.
reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None.
reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None.
split_sentences (bool, optional): split the input text into sentences. Defaults to True.
**kwargs: additional arguments to pass to the TTS model.
Returns:
List[int]: [description]
"""
Expand All @@ -289,8 +292,10 @@ def tts(
)

if text:
sens = self.split_into_sentences(text)
print(" > Text splitted to sentences.")
sens = [text]
if split_sentences:
print(" > Text splitted to sentences.")
sens = self.split_into_sentences(text)
print(sens)

# handle multi-speaker
Expand Down
Loading

0 comments on commit 2846640

Please sign in to comment.