diff --git a/dbgpt/_private/config.py b/dbgpt/_private/config.py index d430a870d..5c5d13f6a 100644 --- a/dbgpt/_private/config.py +++ b/dbgpt/_private/config.py @@ -151,6 +151,15 @@ def __init__(self) -> None: os.environ["siliconflow_proxyllm_api_base"] = os.getenv( "SILICONFLOW_API_BASE", "https://api.siliconflow.cn/v1" ) + self.gitee_proxy_api_key = os.getenv("GITEE_API_KEY") + if self.gitee_proxy_api_key: + os.environ["gitee_proxyllm_proxy_api_key"] = self.gitee_proxy_api_key + os.environ["gitee_proxyllm_proxyllm_backend"] = os.getenv( + "GITEE_MODEL_VERSION", "Qwen2.5-72B-Instruct" + ) + os.environ["gitee_proxyllm_api_base"] = os.getenv( + "GITEE_API_BASE", "https://ai.gitee.com/v1" + ) self.proxy_server_url = os.getenv("PROXY_SERVER_URL") diff --git a/dbgpt/configs/model_config.py b/dbgpt/configs/model_config.py index 062e3b5b2..88e72430c 100644 --- a/dbgpt/configs/model_config.py +++ b/dbgpt/configs/model_config.py @@ -81,6 +81,7 @@ def get_device() -> str: "deepseek_proxyllm": "deepseek_proxyllm", # https://docs.siliconflow.cn/quickstart "siliconflow_proxyllm": "siliconflow_proxyllm", + "gitee_proxyllm": "gitee_proxyllm", "llama-2-7b": os.path.join(MODEL_PATH, "Llama-2-7b-chat-hf"), "llama-2-13b": os.path.join(MODEL_PATH, "Llama-2-13b-chat-hf"), "llama-2-70b": os.path.join(MODEL_PATH, "Llama-2-70b-chat-hf"), @@ -307,6 +308,7 @@ def get_device() -> str: "bge-base-zh": os.path.join(MODEL_PATH, "bge-base-zh"), # https://huggingface.co/BAAI/bge-m3, beg need normalize_embeddings=True "bge-m3": os.path.join(MODEL_PATH, "bge-m3"), + "bge-large-zh-v1.5": os.path.join(MODEL_PATH, "bge-large-zh-v1.5"), "gte-large-zh": os.path.join(MODEL_PATH, "gte-large-zh"), "gte-base-zh": os.path.join(MODEL_PATH, "gte-base-zh"), "sentence-transforms": os.path.join(MODEL_PATH, "all-MiniLM-L6-v2"), diff --git a/dbgpt/model/adapter/proxy_adapter.py b/dbgpt/model/adapter/proxy_adapter.py index 30a4a0494..5b3328862 100644 --- a/dbgpt/model/adapter/proxy_adapter.py +++ b/dbgpt/model/adapter/proxy_adapter.py @@ -364,6 +364,31 @@ def get_async_generate_stream_function(self, model, model_path: str): return siliconflow_generate_stream +class GiteeProxyLLMModelAdapter(ProxyLLMModelAdapter): + """Gitee proxy LLM model adapter. + + See Also: `Gitee Documentation `_ + """ + + def support_async(self) -> bool: + return True + + def do_match(self, lower_model_name_or_path: Optional[str] = None): + return lower_model_name_or_path == "gitee_proxyllm" + + def get_llm_client_class( + self, params: ProxyModelParameters + ) -> Type[ProxyLLMClient]: + from dbgpt.model.proxy.llms.gitee import GiteeLLMClient + + return GiteeLLMClient + + def get_async_generate_stream_function(self, model, model_path: str): + from dbgpt.model.proxy.llms.gitee import gitee_generate_stream + + return gitee_generate_stream + + register_model_adapter(OpenAIProxyLLMModelAdapter) register_model_adapter(ClaudeProxyLLMModelAdapter) register_model_adapter(TongyiProxyLLMModelAdapter) @@ -378,3 +403,4 @@ def get_async_generate_stream_function(self, model, model_path: str): register_model_adapter(MoonshotProxyLLMModelAdapter) register_model_adapter(DeepseekProxyLLMModelAdapter) register_model_adapter(SiliconFlowProxyLLMModelAdapter) +register_model_adapter(GiteeProxyLLMModelAdapter) diff --git a/dbgpt/model/proxy/__init__.py b/dbgpt/model/proxy/__init__.py index fa0044722..88973a727 100644 --- a/dbgpt/model/proxy/__init__.py +++ b/dbgpt/model/proxy/__init__.py @@ -7,6 +7,7 @@ from dbgpt.model.proxy.llms.claude import ClaudeLLMClient from dbgpt.model.proxy.llms.deepseek import DeepseekLLMClient from dbgpt.model.proxy.llms.gemini import GeminiLLMClient + from dbgpt.model.proxy.llms.gitee import GiteeLLMClient from dbgpt.model.proxy.llms.moonshot import MoonshotLLMClient from dbgpt.model.proxy.llms.ollama import OllamaLLMClient from dbgpt.model.proxy.llms.siliconflow import SiliconFlowLLMClient @@ -31,6 +32,7 @@ def __lazy_import(name): "MoonshotLLMClient": "dbgpt.model.proxy.llms.moonshot", "OllamaLLMClient": "dbgpt.model.proxy.llms.ollama", "DeepseekLLMClient": "dbgpt.model.proxy.llms.deepseek", + "GiteeLLMClient": "dbgpt.model.proxy.llms.gitee", } if name in module_path: @@ -57,4 +59,5 @@ def __getattr__(name): "MoonshotLLMClient", "OllamaLLMClient", "DeepseekLLMClient", + "GiteeLLMClient", ] diff --git a/dbgpt/model/proxy/llms/gitee.py b/dbgpt/model/proxy/llms/gitee.py new file mode 100644 index 000000000..b39d094c9 --- /dev/null +++ b/dbgpt/model/proxy/llms/gitee.py @@ -0,0 +1,83 @@ +import os +from typing import TYPE_CHECKING, Any, Dict, Optional, Union + +from dbgpt.model.proxy.llms.proxy_model import ProxyModel, parse_model_request + +from .chatgpt import OpenAILLMClient + +if TYPE_CHECKING: + from httpx._types import ProxiesTypes + from openai import AsyncAzureOpenAI, AsyncOpenAI + + ClientType = Union[AsyncAzureOpenAI, AsyncOpenAI] + + +_GITEE_DEFAULT_MODEL = "Qwen2.5-72B-Instruct" + + +async def gitee_generate_stream( + model: ProxyModel, tokenizer, params, device, context_len=2048 +): + client: GiteeLLMClient = model.proxy_llm_client + request = parse_model_request(params, client.default_model, stream=True) + async for r in client.generate_stream(request): + yield r + + +class GiteeLLMClient(OpenAILLMClient): + """Gitee LLM Client. + + Gitee's API is compatible with OpenAI's API, so we inherit from OpenAILLMClient. + """ + + def __init__( + self, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_type: Optional[str] = None, + api_version: Optional[str] = None, + model: Optional[str] = None, + proxies: Optional["ProxiesTypes"] = None, + timeout: Optional[int] = 240, + model_alias: Optional[str] = "gitee_proxyllm", + context_length: Optional[int] = None, + openai_client: Optional["ClientType"] = None, + openai_kwargs: Optional[Dict[str, Any]] = None, + **kwargs + ): + api_base = api_base or os.getenv("GITEE_API_BASE") or "https://ai.gitee.com/v1" + api_key = api_key or os.getenv("GITEE_API_KEY") + model = model or _GITEE_DEFAULT_MODEL + if not context_length: + if "200k" in model: + context_length = 200 * 1024 + else: + context_length = 4096 + + if not api_key: + raise ValueError( + "Gitee API key is required, please set 'GITEE_API_KEY' in environment " + "or pass it as an argument." + ) + + super().__init__( + api_key=api_key, + api_base=api_base, + api_type=api_type, + api_version=api_version, + model=model, + proxies=proxies, + timeout=timeout, + model_alias=model_alias, + context_length=context_length, + openai_client=openai_client, + openai_kwargs=openai_kwargs, + **kwargs + ) + + @property + def default_model(self) -> str: + model = self._model + if not model: + model = _GITEE_DEFAULT_MODEL + return model