Skip to content

Commit

Permalink
Merge branch 'refs/heads/main' into feature/update_feishu_base_tools
Browse files Browse the repository at this point in the history
  • Loading branch information
黎斌 committed Oct 11, 2024
2 parents 4124e9e + 42b02b3 commit 32d6199
Show file tree
Hide file tree
Showing 40 changed files with 744 additions and 206 deletions.
17 changes: 17 additions & 0 deletions api/configs/middleware/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,22 @@ def BROKER_USE_SSL(self) -> bool:
return self.CELERY_BROKER_URL.startswith("rediss://") if self.CELERY_BROKER_URL else False


class InternalTestConfig(BaseSettings):
"""
Configuration settings for Internal Test
"""

AWS_SECRET_ACCESS_KEY: Optional[str] = Field(
description="Internal test AWS secret access key",
default=None,
)

AWS_ACCESS_KEY_ID: Optional[str] = Field(
description="Internal test AWS access key ID",
default=None,
)


class MiddlewareConfig(
# place the configs in alphabet order
CeleryConfig,
Expand Down Expand Up @@ -224,5 +240,6 @@ class MiddlewareConfig(
TiDBVectorConfig,
WeaviateConfig,
ElasticsearchConfig,
InternalTestConfig,
):
pass
24 changes: 24 additions & 0 deletions api/controllers/console/datasets/external.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from services.dataset_service import DatasetService
from services.external_knowledge_service import ExternalDatasetService
from services.hit_testing_service import HitTestingService
from services.knowledge_service import ExternalDatasetTestService


def _validate_name(name):
Expand Down Expand Up @@ -232,8 +233,31 @@ def post(self, dataset_id):
raise InternalServerError(str(e))


class BedrockRetrievalApi(Resource):
# this api is only for internal testing
def post(self):
parser = reqparse.RequestParser()
parser.add_argument("retrieval_setting", nullable=False, required=True, type=dict, location="json")
parser.add_argument(
"query",
nullable=False,
required=True,
type=str,
)
parser.add_argument("knowledge_id", nullable=False, required=True, type=str)
args = parser.parse_args()

# Call the knowledge retrieval service
result = ExternalDatasetTestService.knowledge_retrieval(
args["retrieval_setting"], args["query"], args["knowledge_id"]
)
return result, 200


api.add_resource(ExternalKnowledgeHitTestingApi, "/datasets/<uuid:dataset_id>/external-hit-testing")
api.add_resource(ExternalDatasetCreateApi, "/datasets/external")
api.add_resource(ExternalApiTemplateListApi, "/datasets/external-knowledge-api")
api.add_resource(ExternalApiTemplateApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>")
api.add_resource(ExternalApiUseCheckApi, "/datasets/external-knowledge-api/<uuid:external_knowledge_api_id>/use-check")
# this api is only for internal test
api.add_resource(BedrockRetrievalApi, "/test/retrieval")
7 changes: 3 additions & 4 deletions api/controllers/console/workspace/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,12 @@ class ModelProviderIconApi(Resource):
Get model provider icon
"""

@setup_required
@login_required
@account_initialization_required
def get(self, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
provider=provider, icon_type=icon_type, lang=lang
provider=provider,
icon_type=icon_type,
lang=lang,
)

return send_file(io.BytesIO(icon), mimetype=mimetype)
Expand Down
62 changes: 61 additions & 1 deletion api/core/model_runtime/model_providers/siliconflow/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from collections.abc import Generator
from typing import Optional, Union

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel


Expand All @@ -29,3 +39,53 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
def _add_custom_parameters(cls, credentials: dict) -> None:
credentials["mode"] = "chat"
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get("function_calling_type") == "tool_call"
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name="temperature",
use_template="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="max_tokens",
use_template="max_tokens",
default=512,
min=1,
max=int(credentials.get("max_tokens", 1024)),
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
type=ParameterType.INT,
),
ParameterRule(
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="top_k",
use_template="top_k",
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="frequency_penalty",
use_template="frequency_penalty",
label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"),
type=ParameterType.FLOAT,
),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ supported_model_types:
- speech2text
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
Expand All @@ -30,3 +31,57 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '4096'
type: text-input
show_on:
- variable: __model_type
value: llm
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- value: function_call
label:
en_US: Support
zh_Hans: 支持
show_on:
- variable: __model_type
value: llm
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- gte-rerank
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model: gte-rerank
model_type: rerank
model_properties:
context_size: 4000
136 changes: 136 additions & 0 deletions api/core/model_runtime/model_providers/tongyi/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Optional

import dashscope
from dashscope.common.error import (
AuthenticationError,
InvalidParameter,
RequestFailure,
ServiceUnavailableError,
UnsupportedHTTPMethod,
UnsupportedModel,
)

from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel


class GTERerankModel(RerankModel):
"""
Model class for GTE rerank model.
"""

def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=docs)

# initialize client
dashscope.api_key = credentials["dashscope_api_key"]

response = dashscope.TextReRank.call(
query=query,
documents=docs,
model=model,
top_n=top_n,
return_documents=True,
)

rerank_documents = []
for _, result in enumerate(response.output.results):
# format document
rerank_document = RerankDocument(
index=result.index,
score=result.relevance_score,
text=result["document"]["text"],
)

# score threshold check
if score_threshold is not None:
if result.relevance_score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)

return RerankResult(model=model, docs=rerank_documents)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self.invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
print(ex)
raise CredentialsValidateFailedError(str(ex))

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
RequestFailure,
],
InvokeServerUnavailableError: [
ServiceUnavailableError,
],
InvokeRateLimitError: [],
InvokeAuthorizationError: [
AuthenticationError,
],
InvokeBadRequestError: [
InvalidParameter,
UnsupportedModel,
UnsupportedHTTPMethod,
],
}
1 change: 1 addition & 0 deletions api/core/model_runtime/model_providers/tongyi/tongyi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ supported_model_types:
- llm
- tts
- text-embedding
- rerank
configurate_methods:
- predefined-model
- customizable-model
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import math
from typing import Any, Optional
from urllib.parse import urlparse

Expand Down Expand Up @@ -112,7 +113,8 @@ def delete(self) -> None:

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 10)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}

results = self._client.search(index=self._collection_name, knn=knn, size=top_k)

Expand Down
Loading

0 comments on commit 32d6199

Please sign in to comment.