forked from coqui-ai/TTS
-
Notifications
You must be signed in to change notification settings - Fork 1
/
bark_patch.patch
104 lines (99 loc) · 4.12 KB
/
bark_patch.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
diff --git a/TTS/tts/layers/bark/inference_funcs.py b/TTS/tts/layers/bark/inference_funcs.py
index 616d20a4..da61eb1d 100644
--- a/TTS/tts/layers/bark/inference_funcs.py
+++ b/TTS/tts/layers/bark/inference_funcs.py
@@ -106,6 +106,7 @@ def generate_voice(
audio,
model,
output_path,
+ return_voice=False
):
"""Generate a new voice from a given audio and text prompt.
@@ -146,8 +147,42 @@ def generate_voice(
semantic_tokens = tokenizer.get_token(semantic_vectors)
semantic_tokens = semantic_tokens.cpu().numpy()
+ if return_voice:
+ fine_prompt = codes
+ coarse_prompt = codes[:2,:]
+ semantic_prompt = semantic_tokens
+ history_prompt = (semantic_prompt, coarse_prompt, fine_prompt)
+ return history_prompt
+
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
+def semantic_tokens_from_audio(audio, model):
+ """
+ Generate sloppy semantic tokens from a waveform using a model that some random guy on the internet has distilled.
+ Half of this is copy-pasted from generate_voice.
+ I hate my life.
+ """
+ if isinstance(audio, str):
+ audio, sr = torchaudio.load(audio)
+ audio = convert_audio(audio, sr, model.config.sample_rate, model.encodec.channels)
+ audio = audio.unsqueeze(0).to(model.device)
+
+ # generate semantic tokens
+ # Load the HuBERT model
+ hubert_manager = HubertManager()
+ # hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
+ hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
+
+ hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
+
+ # Load the CustomTokenizer model
+ tokenizer = HubertTokenizer.load_from_checkpoint(
+ model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device
+ )
+
+ semantic_vectors = hubert_model.forward(audio[0], input_sample_hz=model.config.sample_rate)
+ semantic_tokens = tokenizer.get_token(semantic_vectors)
+ return semantic_tokens.cpu().numpy()
def generate_text_semantic(
text,
@@ -184,7 +219,7 @@ def generate_text_semantic(
Returns:
np.ndarray: The generated semantic tokens.
"""
- print(f"history_prompt in gen: {history_prompt}")
+ # print(f"history_prompt in gen: {history_prompt}")
assert isinstance(text, str)
text = _normalize_whitespace(text)
assert len(text.strip()) > 0
diff --git a/TTS/tts/models/bark.py b/TTS/tts/models/bark.py
index f198c3d5..e2cb132a 100644
--- a/TTS/tts/models/bark.py
+++ b/TTS/tts/models/bark.py
@@ -118,12 +118,13 @@ class Bark(BaseTTS):
history_prompt=history_prompt,
temp=temp,
base=base,
+ silent=True # panariel: i changed this to True
)
x_fine_gen = generate_fine(
x_coarse_gen,
self,
history_prompt=history_prompt,
- temp=0.5,
+ temp=0.5, # used to be 0.5
base=base,
)
audio_arr = codec_decode(x_fine_gen, self)
@@ -174,7 +175,10 @@ class Bark(BaseTTS):
if voice_dir is not None:
voice_dirs = [voice_dir]
try:
- _ = load_voice(speaker_id, voice_dirs)
+ # _ = load_voice(speaker_id, voice_dirs) # ma questi l'hanno mai provato il loro stesso modello? che cazzo
+ # import pdb
+ # pdb.set_trace()
+ _ = load_voice(self, speaker_id, voice_dirs)
except (KeyError, FileNotFoundError):
output_path = os.path.join(voice_dir, speaker_id + ".npz")
os.makedirs(voice_dir, exist_ok=True)
@@ -215,6 +219,8 @@ class Bark(BaseTTS):
"""
speaker_id = "random" if speaker_id is None else speaker_id
+ # import pdb
+ # pdb.set_trace()
voice_dirs = self._set_voice_dirs(voice_dirs)
history_prompt = load_voice(self, speaker_id, voice_dirs)
outputs = self.generate_audio(text, history_prompt=history_prompt, **kwargs)