Skip to content

Commit

Permalink
feat: support fish audio TTS (#7982)
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue authored Sep 5, 2024
1 parent 3e7597f commit bd09922
Show file tree
Hide file tree
Showing 12 changed files with 433 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
28 changes: 28 additions & 0 deletions api/core/model_runtime/model_providers/fishaudio/fishaudio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import logging

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.model_provider import ModelProvider

logger = logging.getLogger(__name__)


class FishAudioProvider(ModelProvider):
def validate_provider_credentials(self, credentials: dict) -> None:
"""
Validate provider credentials
For debugging purposes, this method now always passes validation.
:param credentials: provider credentials, credentials form defined in `provider_credential_schema`.
"""
try:
model_instance = self.get_model_instance(ModelType.TTS)
model_instance.validate_credentials(
credentials=credentials
)
except CredentialsValidateFailedError as ex:
raise ex
except Exception as ex:
logger.exception(f'{self.get_provider_schema().provider} credentials validate failed')
raise ex
76 changes: 76 additions & 0 deletions api/core/model_runtime/model_providers/fishaudio/fishaudio.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
provider: fishaudio
label:
en_US: Fish Audio
description:
en_US: Models provided by Fish Audio, currently only support TTS.
zh_Hans: Fish Audio 提供的模型,目前仅支持 TTS。
icon_small:
en_US: fishaudio_s_en.svg
icon_large:
en_US: fishaudio_l_en.svg
background: "#E5E7EB"
help:
title:
en_US: Get your API key from Fish Audio
zh_Hans: 从 Fish Audio 获取你的 API Key
url:
en_US: https://fish.audio/go-api/
supported_model_types:
- tts
configurate_methods:
- predefined-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: api_base
label:
en_US: API URL
type: text-input
required: false
default: https://api.fish.audio
placeholder:
en_US: Enter your API URL
zh_Hans: 在此输入您的 API URL
- variable: use_public_models
label:
en_US: Use Public Models
type: select
required: false
default: "false"
placeholder:
en_US: Toggle to use public models
zh_Hans: 切换以使用公共模型
options:
- value: "true"
label:
en_US: Allow Public Models
zh_Hans: 使用公共模型
- value: "false"
label:
en_US: Private Models Only
zh_Hans: 仅使用私有模型
- variable: latency
label:
en_US: Latency
type: select
required: false
default: "normal"
placeholder:
en_US: Toggle to choice latency
zh_Hans: 切换以调整延迟
options:
- value: "balanced"
label:
en_US: Low (may affect quality)
zh_Hans: 低延迟 (可能降低质量)
- value: "normal"
label:
en_US: Normal
zh_Hans: 标准
Empty file.
174 changes: 174 additions & 0 deletions api/core/model_runtime/model_providers/fishaudio/tts/tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
from typing import Optional

import httpx

from core.model_runtime.errors.invoke import InvokeBadRequestError, InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.tts_model import TTSModel


class FishAudioText2SpeechModel(TTSModel):
"""
Model class for Fish.audio Text to Speech model.
"""

def get_tts_model_voices(
self, model: str, credentials: dict, language: Optional[str] = None
) -> list:
api_base = credentials.get("api_base", "https://api.fish.audio")
api_key = credentials.get("api_key")
use_public_models = credentials.get("use_public_models", "false") == "true"

params = {
"self": str(not use_public_models).lower(),
"page_size": "100",
}

if language is not None:
if "-" in language:
language = language.split("-")[0]
params["language"] = language

results = httpx.get(
f"{api_base}/model",
headers={"Authorization": f"Bearer {api_key}"},
params=params,
)

results.raise_for_status()
data = results.json()

return [{"name": i["title"], "value": i["_id"]} for i in data["items"]]

def _invoke(
self,
model: str,
tenant_id: str,
credentials: dict,
content_text: str,
voice: str,
user: Optional[str] = None,
) -> any:
"""
Invoke text2speech model
:param model: model name
:param tenant_id: user tenant id
:param credentials: model credentials
:param voice: model timbre
:param content_text: text content to be translated
:param user: unique user id
:return: generator yielding audio chunks
"""

return self._tts_invoke_streaming(
model=model,
credentials=credentials,
content_text=content_text,
voice=voice,
)

def validate_credentials(
self, credentials: dict, user: Optional[str] = None
) -> None:
"""
Validate credentials for text2speech model
:param credentials: model credentials
:param user: unique user id
"""

try:
self.get_tts_model_voices(
None,
credentials={
"api_key": credentials["api_key"],
"api_base": credentials["api_base"],
# Disable public models will trigger a 403 error if user is not logged in
"use_public_models": "false",
},
)
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

def _tts_invoke_streaming(
self, model: str, credentials: dict, content_text: str, voice: str
) -> any:
"""
Invoke streaming text2speech model
:param model: model name
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: ID of the reference audio (if any)
:return: generator yielding audio chunks
"""

try:
word_limit = self._get_model_word_limit(model, credentials)
if len(content_text) > word_limit:
sentences = self._split_text_into_sentences(
content_text, max_length=word_limit
)
else:
sentences = [content_text.strip()]

for i in range(len(sentences)):
yield from self._tts_invoke_streaming_sentence(
credentials=credentials, content_text=sentences[i], voice=voice
)

except Exception as ex:
raise InvokeBadRequestError(str(ex))

def _tts_invoke_streaming_sentence(
self, credentials: dict, content_text: str, voice: Optional[str] = None
) -> any:
"""
Invoke streaming text2speech model
:param credentials: model credentials
:param content_text: text content to be translated
:param voice: ID of the reference audio (if any)
:return: generator yielding audio chunks
"""
api_key = credentials.get("api_key")
api_url = credentials.get("api_base", "https://api.fish.audio")
latency = credentials.get("latency")

if not api_key:
raise InvokeBadRequestError("API key is required")

with httpx.stream(
"POST",
api_url + "/v1/tts",
json={
"text": content_text,
"reference_id": voice,
"latency": latency
},
headers={
"Authorization": f"Bearer {api_key}",
},
timeout=None,
) as response:
if response.status_code != 200:
raise InvokeBadRequestError(
f"Error: {response.status_code} - {response.text}"
)
yield from response.iter_bytes()

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeBadRequestError: [
httpx.HTTPStatusError,
],
}
5 changes: 5 additions & 0 deletions api/core/model_runtime/model_providers/fishaudio/tts/tts.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
model: tts-default
model_type: tts
model_properties:
word_limit: 1000
audio_type: 'mp3'
82 changes: 82 additions & 0 deletions api/tests/integration_tests/model_runtime/__mock/fishaudio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import os
from collections.abc import Callable
from typing import Literal

import httpx
import pytest
from _pytest.monkeypatch import MonkeyPatch


def mock_get(*args, **kwargs):
if kwargs.get("headers", {}).get("Authorization") != "Bearer test":
raise httpx.HTTPStatusError(
"Invalid API key",
request=httpx.Request("GET", ""),
response=httpx.Response(401),
)

return httpx.Response(
200,
json={
"items": [
{"title": "Model 1", "_id": "model1"},
{"title": "Model 2", "_id": "model2"},
]
},
request=httpx.Request("GET", ""),
)


def mock_stream(*args, **kwargs):
class MockStreamResponse:
def __init__(self):
self.status_code = 200

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
pass

def iter_bytes(self):
yield b"Mocked audio data"

return MockStreamResponse()


def mock_fishaudio(
monkeypatch: MonkeyPatch,
methods: list[Literal["list-models", "tts"]],
) -> Callable[[], None]:
"""
mock fishaudio module
:param monkeypatch: pytest monkeypatch fixture
:return: unpatch function
"""

def unpatch() -> None:
monkeypatch.undo()

if "list-models" in methods:
monkeypatch.setattr(httpx, "get", mock_get)

if "tts" in methods:
monkeypatch.setattr(httpx, "stream", mock_stream)

return unpatch


MOCK = os.getenv("MOCK_SWITCH", "false").lower() == "true"


@pytest.fixture
def setup_fishaudio_mock(request, monkeypatch):
methods = request.param if hasattr(request, "param") else []
if MOCK:
unpatch = mock_fishaudio(monkeypatch, methods=methods)

yield

if MOCK:
unpatch()
Empty file.
Loading

0 comments on commit bd09922

Please sign in to comment.