Skip to content

Commit

Permalink
Merge branch 'feat/new-account-page' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
douxc committed Oct 11, 2024
2 parents b4bd0d0 + 4e36833 commit 4256149
Show file tree
Hide file tree
Showing 38 changed files with 1,261 additions and 546 deletions.
4 changes: 4 additions & 0 deletions api/controllers/console/workspace/model_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@ class ModelProviderIconApi(Resource):
"""
Get model provider icon
"""

@setup_required
@login_required
@account_initialization_required
def get(self, provider: str, icon_type: str, lang: str):
model_provider_service = ModelProviderService()
icon, mimetype = model_provider_service.get_model_provider_icon(
Expand Down
62 changes: 61 additions & 1 deletion api/core/model_runtime/model_providers/siliconflow/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from collections.abc import Generator
from typing import Optional, Union

from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool
from core.model_runtime.entities.model_entities import (
AIModelEntity,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
ParameterType,
)
from core.model_runtime.model_providers.openai_api_compatible.llm.llm import OAIAPICompatLargeLanguageModel


Expand All @@ -29,3 +39,53 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
def _add_custom_parameters(cls, credentials: dict) -> None:
credentials["mode"] = "chat"
credentials["endpoint_url"] = "https://api.siliconflow.cn/v1"

def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
return AIModelEntity(
model=model,
label=I18nObject(en_US=model, zh_Hans=model),
model_type=ModelType.LLM,
features=[ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL]
if credentials.get("function_calling_type") == "tool_call"
else [],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get("context_size", 8000)),
ModelPropertyKey.MODE: LLMMode.CHAT.value,
},
parameter_rules=[
ParameterRule(
name="temperature",
use_template="temperature",
label=I18nObject(en_US="Temperature", zh_Hans="温度"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="max_tokens",
use_template="max_tokens",
default=512,
min=1,
max=int(credentials.get("max_tokens", 1024)),
label=I18nObject(en_US="Max Tokens", zh_Hans="最大标记"),
type=ParameterType.INT,
),
ParameterRule(
name="top_p",
use_template="top_p",
label=I18nObject(en_US="Top P", zh_Hans="Top P"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="top_k",
use_template="top_k",
label=I18nObject(en_US="Top K", zh_Hans="Top K"),
type=ParameterType.FLOAT,
),
ParameterRule(
name="frequency_penalty",
use_template="frequency_penalty",
label=I18nObject(en_US="Frequency Penalty", zh_Hans="重复惩罚"),
type=ParameterType.FLOAT,
),
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ supported_model_types:
- speech2text
configurate_methods:
- predefined-model
- customizable-model
provider_credential_schema:
credential_form_schemas:
- variable: api_key
Expand All @@ -30,3 +31,57 @@ provider_credential_schema:
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
model_credential_schema:
model:
label:
en_US: Model Name
zh_Hans: 模型名称
placeholder:
en_US: Enter your model name
zh_Hans: 输入模型名称
credential_form_schemas:
- variable: api_key
label:
en_US: API Key
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 API Key
en_US: Enter your API Key
- variable: context_size
label:
zh_Hans: 模型上下文长度
en_US: Model context size
required: true
type: text-input
default: '4096'
placeholder:
zh_Hans: 在此输入您的模型上下文长度
en_US: Enter your Model context size
- variable: max_tokens
label:
zh_Hans: 最大 token 上限
en_US: Upper bound for max tokens
default: '4096'
type: text-input
show_on:
- variable: __model_type
value: llm
- variable: function_calling_type
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- value: function_call
label:
en_US: Support
zh_Hans: 支持
show_on:
- variable: __model_type
value: llm
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
- gte-rerank
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
model: gte-rerank
model_type: rerank
model_properties:
context_size: 4000
136 changes: 136 additions & 0 deletions api/core/model_runtime/model_providers/tongyi/rerank/rerank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from typing import Optional

import dashscope
from dashscope.common.error import (
AuthenticationError,
InvalidParameter,
RequestFailure,
ServiceUnavailableError,
UnsupportedHTTPMethod,
UnsupportedModel,
)

from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
InvokeBadRequestError,
InvokeConnectionError,
InvokeError,
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel


class GTERerankModel(RerankModel):
"""
Model class for GTE rerank model.
"""

def _invoke(
self,
model: str,
credentials: dict,
query: str,
docs: list[str],
score_threshold: Optional[float] = None,
top_n: Optional[int] = None,
user: Optional[str] = None,
) -> RerankResult:
"""
Invoke rerank model
:param model: model name
:param credentials: model credentials
:param query: search query
:param docs: docs for reranking
:param score_threshold: score threshold
:param top_n: top n
:param user: unique user id
:return: rerank result
"""
if len(docs) == 0:
return RerankResult(model=model, docs=docs)

# initialize client
dashscope.api_key = credentials["dashscope_api_key"]

response = dashscope.TextReRank.call(
query=query,
documents=docs,
model=model,
top_n=top_n,
return_documents=True,
)

rerank_documents = []
for _, result in enumerate(response.output.results):
# format document
rerank_document = RerankDocument(
index=result.index,
score=result.relevance_score,
text=result["document"]["text"],
)

# score threshold check
if score_threshold is not None:
if result.relevance_score >= score_threshold:
rerank_documents.append(rerank_document)
else:
rerank_documents.append(rerank_document)

return RerankResult(model=model, docs=rerank_documents)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials
:param model: model name
:param credentials: model credentials
:return:
"""
try:
self.invoke(
model=model,
credentials=credentials,
query="What is the capital of the United States?",
docs=[
"Carson City is the capital city of the American state of Nevada. At the 2010 United States "
"Census, Carson City had a population of 55,274.",
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that "
"are a political division controlled by the United States. Its capital is Saipan.",
],
score_threshold=0.8,
)
except Exception as ex:
print(ex)
raise CredentialsValidateFailedError(str(ex))

@property
def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]:
"""
Map model invoke error to unified error
The key is the error type thrown to the caller
The value is the error type thrown by the model,
which needs to be converted into a unified error type for the caller.
:return: Invoke error mapping
"""
return {
InvokeConnectionError: [
RequestFailure,
],
InvokeServerUnavailableError: [
ServiceUnavailableError,
],
InvokeRateLimitError: [],
InvokeAuthorizationError: [
AuthenticationError,
],
InvokeBadRequestError: [
InvalidParameter,
UnsupportedModel,
UnsupportedHTTPMethod,
],
}
1 change: 1 addition & 0 deletions api/core/model_runtime/model_providers/tongyi/tongyi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ supported_model_types:
- llm
- tts
- text-embedding
- rerank
configurate_methods:
- predefined-model
- customizable-model
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
import math
from typing import Any, Optional
from urllib.parse import urlparse

Expand Down Expand Up @@ -112,7 +113,8 @@ def delete(self) -> None:

def search_by_vector(self, query_vector: list[float], **kwargs: Any) -> list[Document]:
top_k = kwargs.get("top_k", 10)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k}
num_candidates = math.ceil(top_k * 1.5)
knn = {"field": Field.VECTOR.value, "query_vector": query_vector, "k": top_k, "num_candidates": num_candidates}

results = self._client.search(index=self._collection_name, knn=knn, size=top_k)

Expand Down
4 changes: 3 additions & 1 deletion api/core/rag/entities/context_entities.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

from pydantic import BaseModel


Expand All @@ -7,4 +9,4 @@ class DocumentContext(BaseModel):
"""

content: str
score: float
score: Optional[float] = None
3 changes: 3 additions & 0 deletions api/core/rag/retrieval/dataset_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def retrieve(
source["content"] = segment.content
retrieval_resource_list.append(source)
if hit_callback and retrieval_resource_list:
retrieval_resource_list = sorted(retrieval_resource_list, key=lambda x: x.get("score"), reverse=True)
for position, item in enumerate(retrieval_resource_list, start=1):
item["position"] = position
hit_callback.return_retriever_resource_info(retrieval_resource_list)
if document_context_list:
document_context_list = sorted(document_context_list, key=lambda x: x.score, reverse=True)
Expand Down
6 changes: 3 additions & 3 deletions api/core/tools/provider/builtin/jina/jina.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ identity:
zh_Hans: Jina AI
pt_BR: Jina AI
description:
en_US: Convert any URL to an LLM-friendly input or perform searches on the web for grounding information. Experience improved output for your agent and RAG systems at no cost.
zh_Hans: 将任何URL转换为LLM易读的输入或在网页上搜索引擎上搜索引擎。
pt_BR: Converte qualquer URL em uma entrada LLm-fácil de ler ou realize pesquisas na web para obter informação de grounding. Tenha uma experiência melhor para seu agente e sistemas RAG sem custo.
en_US: Your Search Foundation, Supercharged!
zh_Hans: 您的搜索底座,从此不同!
pt_BR: Your Search Foundation, Supercharged!
icon: icon.svg
tags:
- search
Expand Down
7 changes: 4 additions & 3 deletions api/core/tools/provider/builtin/vanna/tools/vanna.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ def _invoke(
# with "visualize" set to True (default behavior) leads to remote code execution.
# Affected versions: <= 0.5.5
#########################################################################################
generate_chart = False
# generate_chart = tool_parameters.get("generate_chart", True)
res = vn.ask(prompt, False, True, generate_chart)
allow_llm_to_see_data = tool_parameters.get("allow_llm_to_see_data", False)
res = vn.ask(
prompt, print_results=False, auto_train=True, visualize=False, allow_llm_to_see_data=allow_llm_to_see_data
)

result = []

Expand Down
12 changes: 6 additions & 6 deletions api/core/tools/provider/builtin/vanna/tools/vanna.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ parameters:
en_US: If enabled, it will attempt to train on the metadata of that database
zh_Hans: 是否自动从数据库获取元数据来训练
form: form
- name: generate_chart
- name: allow_llm_to_see_data
type: boolean
required: false
default: True
default: false
label:
en_US: Generate Charts
zh_Hans: 生成图表
en_US: Whether to allow the LLM to see the data
zh_Hans: 是否允许LLM查看数据
human_description:
en_US: Generate Charts
zh_Hans: 是否生成图表
en_US: Whether to allow the LLM to see the data
zh_Hans: 是否允许LLM查看数据
form: form
Loading

0 comments on commit 4256149

Please sign in to comment.