Skip to content

Commit

Permalink
Improve model token limit detection (#3292)
Browse files Browse the repository at this point in the history
* Properly find context window for ollama llama

* Better ollama support + upgrade litellm

* Ugprade OpenAI as well

* Fix mypy
  • Loading branch information
Weves authored Nov 30, 2024
1 parent 63d1eef commit 16863de
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 23 deletions.
4 changes: 3 additions & 1 deletion backend/danswer/configs/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@
)

# Typically, GenAI models nowadays are at least 4K tokens
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
)

# Number of tokens from chat history to include at maximum
# 3000 should be enough context regardless of use, no need to include as much as possible
Expand Down
11 changes: 8 additions & 3 deletions backend/danswer/llm/chat_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
from langchain_core.prompt_values import PromptValue

from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
from danswer.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
from danswer.llm.interfaces import LLM
Expand Down Expand Up @@ -161,7 +163,9 @@ def _convert_delta_to_message_chunk(

if role == "user":
return HumanMessageChunk(content=content)
elif role == "assistant":
# NOTE: if tool calls are present, then it's an assistant.
# In Ollama, the role will be None for tool-calls
elif role == "assistant" or tool_calls:
if tool_calls:
tool_call = tool_calls[0]
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
Expand Down Expand Up @@ -236,6 +240,7 @@ def __init__(
custom_config: dict[str, str] | None = None,
extra_headers: dict[str, str] | None = None,
extra_body: dict | None = LITELLM_EXTRA_BODY,
model_kwargs: dict[str, Any] | None = None,
long_term_logger: LongTermLogger | None = None,
):
self._timeout = timeout
Expand Down Expand Up @@ -268,7 +273,7 @@ def __init__(
for k, v in custom_config.items():
os.environ[k] = v

model_kwargs: dict[str, Any] = {}
model_kwargs = model_kwargs or {}
if extra_headers:
model_kwargs.update({"extra_headers": extra_headers})
if extra_body:
Expand Down
13 changes: 13 additions & 0 deletions backend/danswer/llm/factory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Any

from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
Expand All @@ -13,6 +16,15 @@
from danswer.utils.long_term_log import LongTermLogger


def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
"""Ollama requires us to specify the max context window.
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
TODO: allow model-specific values to be configured via the UI.
"""
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}


def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
Expand Down Expand Up @@ -132,5 +144,6 @@ def get_llm(
temperature=temperature,
custom_config=custom_config,
extra_headers=build_llm_extra_headers(additional_headers),
model_kwargs=_build_extra_model_kwargs(provider),
long_term_logger=long_term_logger,
)
91 changes: 74 additions & 17 deletions backend/danswer/llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import io
import json
from collections.abc import Callable
Expand Down Expand Up @@ -385,6 +386,62 @@ def test_llm(llm: LLM) -> str | None:
return error_msg


def get_model_map() -> dict:
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))

# NOTE: we could add additional models here in the future,
# but for now there is no point. Ollama allows the user to
# to specify their desired max context window, and it's
# unlikely to be standard across users even for the same model
# (it heavily depends on their hardware). For now, we'll just
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
# for model_name in [
# "llama3.2",
# "llama3.2:1b",
# "llama3.2:3b",
# "llama3.2:11b",
# "llama3.2:90b",
# ]:
# starting_map[f"ollama/{model_name}"] = {
# "max_tokens": 128000,
# "max_input_tokens": 128000,
# "max_output_tokens": 128000,
# }

return starting_map


def _strip_extra_provider_from_model_name(model_name: str) -> str:
return model_name.split("/")[1] if "/" in model_name else model_name


def _strip_colon_from_model_name(model_name: str) -> str:
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name


def _find_model_obj(
model_map: dict, provider: str, model_names: list[str | None]
) -> dict | None:
# Filter out None values and deduplicate model names
filtered_model_names = [name for name in model_names if name]

# First try all model names with provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(f"{provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {provider}/{model_name}")
return model_obj

# Then try all model names without provider prefix
for model_name in filtered_model_names:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")
return model_obj

return None


def get_llm_max_tokens(
model_map: dict,
model_name: str,
Expand All @@ -397,22 +454,22 @@ def get_llm_max_tokens(
return GEN_AI_MAX_TOKENS

try:
model_obj = model_map.get(f"{model_provider}/{model_name}")
if model_obj:
logger.debug(f"Using model object for {model_provider}/{model_name}")

if not model_obj:
model_obj = model_map.get(model_name)
if model_obj:
logger.debug(f"Using model object for {model_name}")

if not model_obj:
model_name_split = model_name.split("/")
if len(model_name_split) > 1:
model_obj = model_map.get(model_name_split[1])
if model_obj:
logger.debug(f"Using model object for {model_name_split[1]}")

extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
model_name
)
model_obj = _find_model_obj(
model_map,
model_provider,
[
model_name,
# Remove leading extra provider. Usually for cases where user has a
# customer model proxy which appends another prefix
extra_provider_stripped_model_name,
# remove :XXXX from the end, if present. Needed for ollama.
_strip_colon_from_model_name(model_name),
_strip_colon_from_model_name(extra_provider_stripped_model_name),
],
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
Expand Down Expand Up @@ -488,7 +545,7 @@ def get_max_input_tokens(
# `model_cost` dict is a named public interface:
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
# model_map is litellm.model_cost
litellm_model_map = litellm.model_cost
litellm_model_map = get_model_map()

input_toks = (
get_llm_max_tokens(
Expand Down
4 changes: 2 additions & 2 deletions backend/requirements/default.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
litellm==1.50.2
litellm==1.53.1
lxml==5.3.0
lxml_html_clean==0.2.2
llama-index==0.9.45
Expand All @@ -38,7 +38,7 @@ msal==1.28.0
nltk==3.8.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.52.2
openai==1.55.3
openpyxl==3.1.2
playwright==1.41.2
psutil==5.9.5
Expand Down

0 comments on commit 16863de

Please sign in to comment.