Skip to content
This repository has been archived by the owner on Nov 13, 2024. It is now read-only.

Commit

Permalink
add defualt LLM to chat engine
Browse files Browse the repository at this point in the history
  • Loading branch information
acatav committed Sep 14, 2023
1 parent 026793e commit 690f813
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions context_engine/chat_engine/chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
FunctionCallingQueryGenerator, )
from context_engine.context_engine import ContextEngine
from context_engine.tokenizer import Tokenizer
from context_engine.llm import BaseLLM
from context_engine.llm import BaseLLM, OpenAILLM
from context_engine.llm.models import ModelParams, SystemMessage
from context_engine.models.api_models import (StreamingChatChunk, ChatResponse,
StreamingChatResponse, )
Expand Down Expand Up @@ -54,11 +54,12 @@ async def aget_context(self, messages: Messages) -> Context:
class ChatEngine(BaseChatEngine):

DEFAULT_QUERY_GENERATOR = FunctionCallingQueryGenerator
DEFAULT_LLM = OpenAILLM

def __init__(self,
*,
llm: BaseLLM,
context_engine: ContextEngine,
llm: Optional[BaseLLM] = None,
max_prompt_tokens: int = 4096,
max_generated_tokens: Optional[int] = None,
max_context_tokens: Optional[int] = None,
Expand All @@ -67,7 +68,7 @@ def __init__(self,
history_pruning: str = "recent",
min_history_messages: int = 1
):
self.llm = llm
self.llm = llm if llm is not None else self.DEFAULT_LLM()
self.context_engine = context_engine
self.max_prompt_tokens = max_prompt_tokens
self.max_generated_tokens = max_generated_tokens
Expand Down
2 changes: 1 addition & 1 deletion service/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _init_engines():
context_engine = ContextEngine(knowledge_base=kb)
llm = OpenAILLM()

chat_engine = ChatEngine(llm=llm, context_engine=context_engine)
chat_engine = ChatEngine(context_engine=context_engine, llm=llm)


def start():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/chat_engine/test_chat_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def _init_chat_engine(self,
max_generated_tokens: int = 200,
**kwargs):
return ChatEngine(
llm=self.mock_llm,
context_engine=self.mock_context_engine,
llm=self.mock_llm,
query_builder=self.mock_query_builder,
system_prompt=system_prompt,
max_prompt_tokens=max_prompt_tokens,
Expand Down

0 comments on commit 690f813

Please sign in to comment.