Skip to content

Commit

Permalink
Hybrid Search Alpha Parameter (#714)
Browse files Browse the repository at this point in the history
  • Loading branch information
yuhongsun96 authored Nov 10, 2023
1 parent 5a4820c commit 69644b2
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 8 deletions.
2 changes: 2 additions & 0 deletions backend/danswer/configs/app_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,8 @@
EDIT_KEYWORD_QUERY = os.environ.get("EDIT_KEYWORD_QUERY", "").lower() == "true"
else:
EDIT_KEYWORD_QUERY = not os.environ.get("DOCUMENT_ENCODER_MODEL")
# Weighting factor between Vector and Keyword Search, 1 for completely vector search
HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.6)))


#####
Expand Down
1 change: 1 addition & 0 deletions backend/danswer/document_index/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def hybrid_retrieval(
filters: IndexFilters,
favor_recent: bool,
num_to_retrieve: int,
hybrid_alpha: float | None = None,
) -> list[InferenceChunk]:
raise NotImplementedError

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ schema danswer_chunk {
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
}

# Document score decays from 1 to 0.5 as age of last updated time increases
function inline recency_bias() {
# Cap the loss at 50% score reduction
expression: max(1 / (1 + query(decay_factor) * document_age), 0.5)
Expand Down Expand Up @@ -150,7 +151,7 @@ schema danswer_chunk {
}

global-phase {
expression: (normalize_linear(closeness(field, embeddings)) + normalize_linear(bm25(content))) / 2 * document_boost * recency_bias
expression: ((query(alpha) * normalize_linear(closeness(field, embeddings))) + ((1 - query(alpha)) * normalize_linear(bm25(content)))) * document_boost * recency_bias
rerank-count: 1000
}

Expand Down
9 changes: 7 additions & 2 deletions backend/danswer/document_index/vespa/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import EDIT_KEYWORD_QUERY
from danswer.configs.app_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.app_configs import HYBRID_ALPHA
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
from danswer.configs.app_configs import VESPA_HOST
Expand Down Expand Up @@ -432,7 +433,7 @@ def _vespa_hit_to_inference_chunk(hit: dict[str, Any]) -> InferenceChunk:
)


def _query_vespa(query_params: Mapping[str, str | int]) -> list[InferenceChunk]:
def _query_vespa(query_params: Mapping[str, str | int | float]) -> list[InferenceChunk]:
if "query" in query_params and not cast(str, query_params["query"]).strip():
raise ValueError("No/empty query received")
response = requests.get(SEARCH_ENDPOINT, params=query_params)
Expand Down Expand Up @@ -669,6 +670,7 @@ def hybrid_retrieval(
filters: IndexFilters,
favor_recent: bool,
num_to_retrieve: int,
hybrid_alpha: float | None = HYBRID_ALPHA,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
edit_keyword_query: bool = EDIT_KEYWORD_QUERY,
) -> list[InferenceChunk]:
Expand All @@ -690,11 +692,14 @@ def hybrid_retrieval(
" ".join(remove_stop_words(query)) if edit_keyword_query else query
)

params: dict[str, str | int] = {
params: dict[str, str | int | float] = {
"yql": yql,
"query": query_keywords,
"input.query(query_embedding)": str(query_embedding),
"input.query(decay_factor)": str(DOC_TIME_DECAY * decay_multiplier),
"input.query(alpha)": hybrid_alpha
if hybrid_alpha is not None
else HYBRID_ALPHA,
"hits": num_to_retrieve,
"offset": 0,
"ranking.profile": "hybrid_search",
Expand Down
26 changes: 21 additions & 5 deletions backend/danswer/search/search_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from sentence_transformers import SentenceTransformer # type: ignore
from sqlalchemy.orm import Session

from danswer.configs.app_configs import HYBRID_ALPHA
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
Expand Down Expand Up @@ -104,21 +105,33 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc

@log_function_time()
def doc_index_retrieval(
query: SearchQuery, document_index: DocumentIndex
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA,
) -> list[InferenceChunk]:
if query.search_type == SearchType.KEYWORD:
top_chunks = document_index.keyword_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
query=query.query,
filters=query.filters,
favor_recent=query.favor_recent,
num_to_retrieve=query.num_hits,
)

elif query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
query=query.query,
filters=query.filters,
favor_recent=query.favor_recent,
num_to_retrieve=query.num_hits,
)

elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
query=query.query,
filters=query.filters,
favor_recent=query.favor_recent,
num_to_retrieve=query.num_hits,
hybrid_alpha=hybrid_alpha,
)

else:
Expand Down Expand Up @@ -282,6 +295,7 @@ def apply_boost(
def search_chunks(
query: SearchQuery,
document_index: DocumentIndex,
hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
Expand All @@ -293,7 +307,9 @@ def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")

top_chunks = doc_index_retrieval(query=query, document_index=document_index)
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, hybrid_alpha=hybrid_alpha
)

if not top_chunks:
logger.info(
Expand Down
6 changes: 6 additions & 0 deletions deployment/docker_compose/docker-compose.dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ services:
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
# Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
# Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
# Don't change the NLP model configs unless you know what you're doing
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
Expand Down Expand Up @@ -86,6 +90,8 @@ services:
- NOTIFY_SLACKBOT_NO_ANSWER=${NOTIFY_SLACKBOT_NO_ANSWER:-}
# Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
# Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
- HYBRID_ALPHA=${HYBRID_ALPHA:-}
# Don't change the NLP model configs unless you know what you're doing
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
Expand Down

1 comment on commit 69644b2

@vercel
Copy link

@vercel vercel bot commented on 69644b2 Nov 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.