Skip to content

Commit

Permalink
feat: openai-like embedding model support (#375)
Browse files Browse the repository at this point in the history
Support OpenAI-Like embedding models, one of them is ZhipuAI.
#373
  • Loading branch information
Icemap authored Nov 13, 2024
1 parent 19a5c7c commit c9176b0
Show file tree
Hide file tree
Showing 5 changed files with 112 additions and 0 deletions.
9 changes: 9 additions & 0 deletions backend/app/rag/chat_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from google.oauth2 import service_account
from google.auth.transport.requests import Request

from app.rag.embeddings.openai_like_embedding import OpenAILikeEmbedding
from app.rag.node_postprocessor import MetadataPostFilter
from app.rag.node_postprocessor.metadata_post_filter import MetadataFilters
from app.rag.node_postprocessor.baisheng_reranker import BaishengRerank
Expand Down Expand Up @@ -290,6 +291,14 @@ def get_embedding_model(
model=model,
**config,
)
case EmbeddingProvider.OPENAI_LIKE:
api_base = config.pop("api_base", "https://open.bigmodel.cn/api/paas/v4")
return OpenAILikeEmbedding(
model=model,
api_base=api_base,
api_key=credentials,
**config,
)
case _:
raise ValueError(f"Got unknown embedding provider: {provider}")

Expand Down
12 changes: 12 additions & 0 deletions backend/app/rag/embed_model_option.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,4 +88,16 @@ class EmbeddingModelOption(BaseModel):
credentials_type="str",
default_credentials="dummy",
),
EmbeddingModelOption(
provider=EmbeddingProvider.OPENAI_LIKE,
provider_display_name="OpenAI Like",
provider_description="OpenAI-Like is a set of platforms that provide text embeddings similar to OpenAI. Such as ZhiPuAI.",
provider_url="https://open.bigmodel.cn/dev/api/vector/embedding-3",
default_embedding_model="embedding-3",
embedding_model_description=f"Please select a text embedding model with {settings.EMBEDDING_DIMS} dimensions.",
credentials_display_name="OpenAI Like API Key",
credentials_description="The API key of OpenAI Like. For ZhipuAI, you can find it in https://open.bigmodel.cn/usercenter/apikeys",
credentials_type="str",
default_credentials="dummy",
),
]
85 changes: 85 additions & 0 deletions backend/app/rag/embeddings/openai_like_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import Any, List, Optional

from llama_index.core.base.embeddings.base import DEFAULT_EMBED_BATCH_SIZE
from llama_index.core.bridge.pydantic import Field, PrivateAttr
from llama_index.core.callbacks.base import CallbackManager
from llama_index.core.embeddings import BaseEmbedding
from openai import OpenAI, AsyncOpenAI


class OpenAILikeEmbedding(BaseEmbedding):
# We cannot directly call the llama-index's API because it limited the model name
# And the 'embedding-2' or 'embedding-3' is not one of the OpenAI's model name

model: str = Field(
default="embedding-3",
description="The model to use when calling Zhipu AI API",
)
_client: OpenAI = PrivateAttr()
_aclient: AsyncOpenAI = PrivateAttr()

def __init__(
self,
api_key: str,
model: str = "embedding-3",
api_base: str = "https://open.bigmodel.cn/api/paas/v4/",
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
callback_manager: Optional[CallbackManager] = None,
**kwargs: Any,
) -> None:
super().__init__(
embed_batch_size=embed_batch_size,
callback_manager=callback_manager,
model=model,
**kwargs,
)

self.model = model
self._client = OpenAI(api_key=api_key, base_url=api_base)
self._aclient = AsyncOpenAI(api_key=api_key, base_url=api_base)

def get_embeddings(self, sentences: list[str]) -> List[List[float]]:
"""Get embeddings."""
# Call Zhipu AI Embedding API via OpenAI client
embedding_objs = self._client.embeddings.create(input=sentences, model=self.model).data
embeddings = [obj.embedding for obj in embedding_objs]

return embeddings

async def aget_embeddings(self, sentences: list[str]) -> List[List[float]]:
"""Asynchronously get text embeddings."""
result = await self._aclient.embeddings.create(input=sentences, model=self.model)
embeddings = [obj.embedding for obj in result.data]

return embeddings

@classmethod
def class_name(cls) -> str:
return "OpenAILikeEmbedding"

def _get_query_embedding(self, query: str) -> List[float]:
"""Get query embedding."""
return self.get_embeddings([query])[0]

async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
result = await self.aget_embeddings([query])
return result[0]

def _get_text_embedding(self, text: str) -> List[float]:
"""Get text embedding."""
return self._get_text_embeddings([text])[0]

async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
result = await self._aget_text_embeddings([text])
return result[0]

def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
return self.get_embeddings(texts)

async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
return await self.aget_embeddings(texts)
1 change: 1 addition & 0 deletions backend/app/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class EmbeddingProvider(str, enum.Enum):
COHERE = "cohere"
OLLAMA = "ollama"
LOCAL = "local"
OPENAI_LIKE = "openai_like"


class RerankerProvider(str, enum.Enum):
Expand Down
5 changes: 5 additions & 0 deletions frontend/app/src/pages/docs/deploy-with-docker.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ This document provides instructions for deploying the entire RAG using Docker Co
- EMBEDDING_DIMS: 768
- EMBEDDING_MAX_TOKENS: 8192
- find more models in https://jina.ai/embeddings/
- ZhipuAI
- embedding-3
- EMBEDDING_DIMS: 2048
- EMBEDDING_MAX_TOKENS: 8192
- Find more details in https://open.bigmodel.cn/dev/api/vector/embedding-3
- Local Embedding Server
- BAAI/bge-m3
- EMBEDDING_DIMS: 1024
Expand Down

0 comments on commit c9176b0

Please sign in to comment.