Skip to content

Commit

Permalink
Multilingual Query Expansion (#737)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Nov 19, 2023
1 parent b258ec1 commit 6fb07d2
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 11 deletions.
4 changes: 3 additions & 1 deletion backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.6)))

# A list of languages passed to the LLM to rephase the query
# For example "English,French,Spanish", be sure to use the "," separator
MULTILINGUAL_QUERY_EXPANSION = os.environ.get("MULTILINGUAL_QUERY_EXPANSION") or None

#####
# Model Server Configs
Expand Down
24 changes: 18 additions & 6 deletions backend/danswer/direct_qa/qa_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage

from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
Expand All @@ -22,6 +23,7 @@
from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.direct_qa_prompts import COT_PROMPT
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks
Expand Down Expand Up @@ -88,15 +90,20 @@ def is_json_output(self) -> bool:
return True

def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
self,
query: str,
context_chunks: list[InferenceChunk],
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
context_docs_str = "\n".join(
f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks
)

single_message = JSON_PROMPT.format(
context_docs_str=context_docs_str, user_query=query
)
context_docs_str=context_docs_str,
user_query=query,
language_hint_or_none=LANGUAGE_HINT if use_language_hint else "",
).strip()

prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
Expand All @@ -111,15 +118,20 @@ def is_json_output(self) -> bool:
return True

def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
self,
query: str,
context_chunks: list[InferenceChunk],
use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION),
) -> list[BaseMessage]:
context_docs_str = "\n".join(
f"\n{CODE_BLOCK_PAT.format(c.content)}\n" for c in context_chunks
)

single_message = COT_PROMPT.format(
context_docs_str=context_docs_str, user_query=query
)
context_docs_str=context_docs_str,
user_query=query,
language_hint_or_none=LANGUAGE_HINT if use_language_hint else "",
).strip()

prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
Expand Down
6 changes: 6 additions & 0 deletions backend/danswer/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import SECRET
Expand Down Expand Up @@ -175,6 +176,11 @@ def startup_event() -> None:
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")

if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)

if SKIP_RERANKING:
logger.info("Reranking step of search flow is disabled")

Expand Down
7 changes: 7 additions & 0 deletions backend/danswer/prompts/direct_qa_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
""".strip()


LANGUAGE_HINT = """
IMPORTANT: Respond in the same language as my query!
""".strip()


# This has to be doubly escaped due to json containing { } which are also used for format strings
EMPTY_SAMPLE_JSON = {
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
Expand Down Expand Up @@ -54,6 +59,7 @@
```
{QUESTION_PAT} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()


Expand All @@ -75,6 +81,7 @@
{QUESTION_PAT} {{user_query}}
{JSON_HELPFUL_HINT}
{{language_hint_or_none}}
""".strip()


Expand Down
12 changes: 12 additions & 0 deletions backend/danswer/prompts/secondary_llm_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,18 @@
""".strip()


LANGUAGE_REPHRASE_PROMPT = """
Rephrase the query in {target_language}.
If the query is already in the correct language, \
simply repeat the ORIGINAL query back to me, EXACTLY as is with no rephrasing.
NEVER change proper nouns, technical terms, acronyms, or terms you are not familiar with.
IMPORTANT, if the query is already in the target language, DO NOT REPHRASE OR EDIT the query!
Query:
{query}
""".strip()


# User the following for easy viewing of prompts
if __name__ == "__main__":
print(ANSWERABLE_PROMPT)
56 changes: 52 additions & 4 deletions backend/danswer/search/search_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from collections.abc import Callable
from copy import deepcopy
from typing import Any
from typing import cast

import numpy
Expand All @@ -10,6 +12,7 @@

from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import HYBRID_ALPHA
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
Expand All @@ -36,11 +39,13 @@
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
from danswer.secondary_llm_flows.query_expansion import rephrase_query
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from danswer.utils.timing import log_function_time


Expand Down Expand Up @@ -108,6 +113,30 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
return search_docs


def combine_retrieval_results(
chunk_sets: list[list[InferenceChunk]],
) -> list[InferenceChunk]:
all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set]

unique_chunks: dict[tuple[str, int], InferenceChunk] = {}
for chunk in all_chunks:
key = (chunk.document_id, chunk.chunk_id)
if key not in unique_chunks:
unique_chunks[key] = chunk
continue

stored_chunk_score = unique_chunks[key].score or 0
this_chunk_score = chunk.score or 0
if stored_chunk_score < this_chunk_score:
unique_chunks[key] = chunk

sorted_chunks = sorted(
unique_chunks.values(), key=lambda x: x.score or 0, reverse=True
)

return sorted_chunks


@log_function_time()
def doc_index_retrieval(
query: SearchQuery,
Expand Down Expand Up @@ -313,6 +342,7 @@ def search_chunks(
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
multilingual_query_expansion: str | None = MULTILINGUAL_QUERY_EXPANSION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
Expand All @@ -331,9 +361,25 @@ def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")

top_chunks = doc_index_retrieval(
query=query, document_index=document_index, hybrid_alpha=hybrid_alpha
)
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_query_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, hybrid_alpha=hybrid_alpha
)
else:
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = rephrase_query(query.query, multilingual_query_expansion)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)
for rephrase in set(query_rephrases):
q_copy = deepcopy(query)
q_copy.query = rephrase
run_queries.append(
(doc_index_retrieval, (q_copy, document_index, hybrid_alpha))
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)

if not top_chunks:
logger.info(
Expand Down Expand Up @@ -384,7 +430,9 @@ def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None
functions_to_run.append(run_llm_filter)
run_llm_filter_id = run_llm_filter.result_id

parallel_results = run_functions_in_parallel(functions_to_run)
parallel_results: dict[str, Any] = {}
if functions_to_run:
parallel_results = run_functions_in_parallel(functions_to_run)

ranked_results = parallel_results.get(str(run_rerank_id))
if ranked_results is None:
Expand Down
48 changes: 48 additions & 0 deletions backend/danswer/secondary_llm_flows/query_expansion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from collections.abc import Callable

from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.secondary_llm_flows import LANGUAGE_REPHRASE_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

logger = setup_logger()


def llm_rephrase_query(query: str, language: str) -> str:
def _get_rephrase_messages() -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": LANGUAGE_REPHRASE_PROMPT.format(
query=query, target_language=language
),
},
]

return messages

messages = _get_rephrase_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
logger.debug(model_output)

return model_output


def rephrase_query(
query: str,
multilingual_query_expansion: str,
use_threads: bool = True,
) -> list[str]:
languages = multilingual_query_expansion.split(",")
languages = [language.strip() for language in languages]
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_rephrase_query, (query, language)) for language in languages
]

return run_functions_tuples_in_parallel(functions_with_args)

else:
return [llm_rephrase_query(query, language) for language in languages]
9 changes: 9 additions & 0 deletions deployment/docker_compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@ services:
- SKIP_RERANKING=${SKIP_RERANKING:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Leave this on pretty please? Nothing sensitive is collected!
# https://docs.danswer.dev/more/telemetry
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
# Log all of the prompts to the LLM
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-info}
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
Expand Down Expand Up @@ -106,11 +109,17 @@ services:
- SKIP_RERANKING=${SKIP_RERANKING:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MULTILINGUAL_QUERY_EXPANSION=${MULTILINGUAL_QUERY_EXPANSION:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Leave this on pretty please? Nothing sensitive is collected!
# https://docs.danswer.dev/more/telemetry
- DISABLE_TELEMETRY=${DISABLE_TELEMETRY:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
# Log all of the prompts to the LLM
- LOG_ALL_MODEL_INTERACTIONS=${LOG_ALL_MODEL_INTERACTIONS:-info}
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
Expand Down
41 changes: 41 additions & 0 deletions deployment/docker_compose/env.multilingual.template
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# This env template shows how to configure Danswer for multilingual use
# In this case, it is configured for French and English
# To use it, copy it to .env in the docker_compose directory.
# Feel free to combine it with the other templates to suit your needs


# A recent MIT license multilingual model: https://huggingface.co/intfloat/multilingual-e5-small
DOCUMENT_ENCODER_MODEL="intfloat/multilingual-e5-small"

# The model above is trained with the following prefix for queries and passages to improve retrieval
# by letting the model know which of the two type is currently being embedded
ASYM_QUERY_PREFIX="query: "
ASYM_PASSAGE_PREFIX="passage: "

# Depends model by model, this one is tuned with this as True
NORMALIZE_EMBEDDINGS="True"

# Due to the loss function used in training, this model outputs similarity scores from range ~0.6 to 1
SIM_SCORE_RANGE_LOW="0.6"
SIM_SCORE_RANGE_LOW="0.8"

# No recent multilingual reranking models small enough to run on CPU, so turning it off
SKIP_RERANKING="True"

# Use LLM to determine if chunks are relevant to the query
# may not work well for languages that do not have much training data in the LLM training set
DISABLE_LLM_CHUNK_FILTER="True"

# Rephrase the user query in specified languages using LLM, use comma separated values
MULTILINGUAL_QUERY_EXPANSION="English, French"

# Enables fine-grained embeddings for better retrieval
# At the cost of indexing speed (~5x slower), query time is same speed
ENABLE_MINI_CHUNK="True"

# Stronger model will help with multilingual tasks
GEN_AI_MODEL_VERSION="gpt-4"
GEN_AI_API_KEY=<provide your api key>

# More verbose logging if desired
LOG_LEVEL="debug"

1 comment on commit 6fb07d2

@vercel
Copy link

@vercel vercel bot commented on 6fb07d2 Nov 19, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.