Skip to content

Commit

Permalink
TTS-FIX
Browse files Browse the repository at this point in the history
  • Loading branch information
Dartvauder committed Sep 28, 2024
1 parent 98ebe84 commit 86e3c1a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 6 deletions.
16 changes: 10 additions & 6 deletions LaunchFile/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,7 @@ def get_languages():

def generate_text_and_speech(input_text, system_prompt, input_audio, llm_model_type, llm_model_name, llm_lora_model_name, enable_web_search, enable_libretranslate, target_lang, enable_openparse, pdf_file, enable_multimodal, input_image, enable_tts,
llm_settings_html, max_new_tokens, max_length, min_length, n_ctx, n_batch, temperature, top_p, min_p, typical_p, top_k,
do_sample, early_stopping, stopping, repetition_penalty, frequency_penalty, presence_penalty, length_penalty, no_repeat_ngram_size, num_beams, num_return_sequences, chat_history_format, tts_settings_html, speaker_wav, language, tts_temperature, tts_top_p, tts_top_k, tts_speed, output_format):
do_sample, early_stopping, stopping, repetition_penalty, frequency_penalty, presence_penalty, length_penalty, no_repeat_ngram_size, num_beams, num_return_sequences, chat_history_format, tts_settings_html, speaker_wav, language, tts_temperature, tts_top_p, tts_top_k, tts_speed, tts_repetition_penalty, tts_length_penalty, output_format):
global chat_history, chat_dir, tts_model, whisper_model

if 'chat_history' not in globals() or chat_history is None:
Expand Down Expand Up @@ -1134,11 +1134,9 @@ def image_to_base64_data_uri(image_path):
with open(chat_history_path, "w", encoding="utf-8") as f:
json.dump(chat_history_json, f, ensure_ascii=False, indent=4)
if enable_tts and text:
repetition_penalty = 2.0
length_penalty = 1.0
wav = tts_model.tts(text=text, speaker_wav=f"inputs/audio/voices/{speaker_wav}", language=language,
temperature=tts_temperature, top_p=tts_top_p, top_k=tts_top_k, speed=tts_speed,
repetition_penalty=repetition_penalty, length_penalty=length_penalty)
repetition_penalty=tts_repetition_penalty, length_penalty=tts_length_penalty)
now = datetime.now()
audio_filename = f"TTS_{now.strftime('%Y%m%d_%H%M%S')}.{output_format}"
audio_path = os.path.join(chat_dir, 'audio', audio_filename)
Expand Down Expand Up @@ -1167,7 +1165,7 @@ def image_to_base64_data_uri(image_path):
yield chat_history, audio_path, chat_dir, None


def generate_tts_stt(text, audio, tts_settings_html, speaker_wav, language, tts_temperature, tts_top_p, tts_top_k, tts_speed, tts_output_format, stt_output_format):
def generate_tts_stt(text, audio, tts_settings_html, speaker_wav, language, tts_temperature, tts_top_p, tts_top_k, tts_speed, tts_repetition_penalty, tts_length_penalty, tts_output_format, stt_output_format):
global tts_model, whisper_model

tts_output = None
Expand All @@ -1186,7 +1184,8 @@ def generate_tts_stt(text, audio, tts_settings_html, speaker_wav, language, tts_

try:
wav = tts_model.tts(text=text, speaker_wav=f"inputs/audio/voices/{speaker_wav}", language=language,
temperature=tts_temperature, top_p=tts_top_p, top_k=tts_top_k, speed=tts_speed)
temperature=tts_temperature, top_p=tts_top_p, top_k=tts_top_k, speed=tts_speed,
repetition_penalty=tts_repetition_penalty, length_penalty=tts_length_penalty )
except Exception as e:
return None, str(e)

Expand Down Expand Up @@ -8628,6 +8627,8 @@ def reload_interface():
gr.Slider(minimum=0.01, maximum=1.0, value=0.9, step=0.01, label=_("TTS Top P", lang), interactive=True),
gr.Slider(minimum=1, maximum=100, value=20, step=1, label=_("TTS Top K", lang), interactive=True),
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label=_("TTS Speed", lang), interactive=True),
gr.Slider(minimum=0.1, maximum=2.0, value=2.0, step=0.1, label=_("TTS Repetition penalty", lang), interactive=True),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label=_("TTS Length penalty", lang), interactive=True),
gr.Radio(choices=["wav", "mp3", "ogg"], label=_("Select output format", lang), value="wav", interactive=True)
],
additional_inputs_accordion=gr.Accordion(label=_("LLM and TTS Settings", lang), open=False),
Expand Down Expand Up @@ -8657,6 +8658,9 @@ def reload_interface():
gr.Slider(minimum=0.01, maximum=1.0, value=0.9, step=0.01, label=_("TTS Top P", lang), interactive=True),
gr.Slider(minimum=1, maximum=100, value=20, step=1, label=_("TTS Top K", lang), interactive=True),
gr.Slider(minimum=0.5, maximum=2.0, value=1.0, step=0.1, label=_("TTS Speed", lang), interactive=True),
gr.Slider(minimum=0.1, maximum=2.0, value=2.0, step=0.1, label=_("TTS Repetition penalty", lang),
interactive=True),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label=_("TTS Length penalty", lang), interactive=True),
gr.Radio(choices=["wav", "mp3", "ogg"], label=_("Select TTS output format", lang), value="wav", interactive=True),
gr.Dropdown(choices=["txt", "json"], label=_("Select STT output format", lang), value="txt", interactive=True)
],
Expand Down
2 changes: 2 additions & 0 deletions translations/ru.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"Min P": "Минимальное P",
"Typical P": "Типичное P",
"Stop sequences (optional)": "Последовательности остановки (необязательно)",
"TTS Repetition penalty": "TTS Штраф за повторение",
"TTS Length penalty": "TTS Штраф за длину",
"Enable Do Sample": "Включить выборку",
"Enable Early Stopping": "Включить раннюю остановку",
"Repetition penalty": "Штраф за повторение",
Expand Down
2 changes: 2 additions & 0 deletions translations/zh.json
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"Min P": "最小P值",
"Typical P": "典型P值",
"Stop sequences (optional)": "停止序列(可选)",
"TTS Repetition penalty": "TTS重复惩罚",
"TTS Length penalty": "TTS长度惩罚",
"Enable Do Sample": "启用采样",
"Enable Early Stopping": "启用提前停止",
"Repetition penalty": "重复惩罚",
Expand Down

0 comments on commit 86e3c1a

Please sign in to comment.