Skip to content

Commit

Permalink
Fix retriever (#105)
Browse files Browse the repository at this point in the history
* update retriever API so that thresholding is done outside of retriever call
  • Loading branch information
jerpint authored Jun 6, 2023
1 parent a9c1cb4 commit fc97df5
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 20 deletions.
43 changes: 30 additions & 13 deletions buster/retriever/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,18 @@ def __init__(self, top_k, thresh, max_tokens, embedding_model, *args, **kwargs):
self.max_tokens = max_tokens
self.embedding_model = embedding_model

# Add your access to documents in your own init

@abstractmethod
def get_documents(self, source: str = None) -> pd.DataFrame:
"""Get all current documents from a given source."""
...

@abstractmethod
def get_source_display_name(self, source: str) -> str:
"""Get the display name of a source."""
"""Get the display name of a source.
If source is None, returns all documents. If source does not exist, returns empty dataframe."""
...

@staticmethod
Expand All @@ -36,27 +40,40 @@ def get_embedding(query: str, engine: str):
logger.info("generating embedding")
return get_embedding(query, engine=engine)

def retrieve(self, query: str, source: str = None) -> pd.DataFrame:
top_k = self.top_k
thresh = self.thresh
query_embedding = self.get_embedding(query, engine=self.embedding_model)
@abstractmethod
def get_topk_documents(self, query: str, source: str = None, top_k: int = None) -> pd.DataFrame:
"""Get the topk documents matching a user's query.
documents = self.get_documents(source)
If no matches are found, returns an empty dataframe."""
...

documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))
def threshold_documents(self, matched_documents, thresh: float) -> pd.DataFrame:
# filter out matched_documents using a threshold
return matched_documents[matched_documents.similarity > thresh]

# sort the matched_documents by score
matched_documents = documents.sort_values("similarity", ascending=False)
def retrieve(self, query: str, source: str = None, top_k: int = None, thresh: float = None) -> pd.DataFrame:
if top_k is None:
top_k = self.top_k
if thresh is None:
thresh = self.thresh

# limit search to top_k matched_documents.
top_k = len(matched_documents) if top_k == -1 else top_k
matched_documents = matched_documents.head(top_k)
matched_documents = self.get_topk_documents(query=query, source=source, top_k=top_k)

# log matched_documents to the console
logger.info(f"matched documents before thresh: {matched_documents}")

# No matches were found, simply return at this point
if len(matched_documents) == 0:
return matched_documents

# otherwise, make sure we have the minimum required fields
assert "similarity" in matched_documents.columns
assert "embedding" in matched_documents.columns
assert "content" in matched_documents.columns

# filter out matched_documents using a threshold
matched_documents = matched_documents[matched_documents.similarity > thresh]
matched_documents = self.threshold_documents(matched_documents, thresh)

logger.info(f"matched documents after thresh: {matched_documents}")

return matched_documents
9 changes: 3 additions & 6 deletions buster/retriever/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,17 +57,14 @@ def get_documents(self, source: str = None) -> pd.DataFrame:

def get_source_display_name(self, source: str) -> str:
"""Get the display name of a source."""
if source == "":
if source is None:
return ALL_SOURCES
else:
display_name = self.db.sources.find_one({"name": source})["display_name"]
return display_name

def retrieve(self, query: str, top_k: int = None, source: str = None) -> pd.DataFrame:
if top_k is None:
# use default top_k value
top_k = self.top_k
if source == "" or source is None:
def get_topk_documents(self, query: str, source: str, top_k: int) -> pd.DataFrame:
if source is None:
filter = None
else:
filter = {"source": {"$eq": source}}
Expand Down
19 changes: 18 additions & 1 deletion buster/retriever/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path

import pandas as pd
from openai.embeddings_utils import cosine_similarity

import buster.documents.sqlite.schema as schema
from buster.retriever.base import ALL_SOURCES, Retriever
Expand Down Expand Up @@ -59,7 +60,7 @@ def get_documents(self, source: str = None) -> pd.DataFrame:

def get_source_display_name(self, source: str) -> str:
"""Get the display name of a source."""
if source == "":
if source is None:
return ALL_SOURCES
else:
cur = self.conn.execute("SELECT display_name FROM sources WHERE name = ?", (source,))
Expand All @@ -68,3 +69,19 @@ def get_source_display_name(self, source: str) -> str:
raise KeyError(f'"{source}" is not a known source')
(display_name,) = row
return display_name

def get_topk_documents(self, query: str, source: str = None, top_k: int = None) -> pd.DataFrame:
query_embedding = self.get_embedding(query, engine=self.embedding_model)

documents = self.get_documents(source)

documents["similarity"] = documents.embedding.apply(lambda x: cosine_similarity(x, query_embedding))

# sort the matched_documents by score
matched_documents = documents.sort_values("similarity", ascending=False)

# limit search to top_k matched_documents.
top_k = len(matched_documents) if top_k == -1 else top_k
matched_documents = matched_documents.head(top_k)

return matched_documents
6 changes: 6 additions & 0 deletions tests/test_chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def __init__(self, **kwargs):
def get_documents(self, source):
return self.documents

def get_topk_documents(self, query: str, source: str = None, top_k: int = None) -> pd.DataFrame:
documents = self.documents
documents["embedding"] = [get_fake_embedding() for _ in range(len(documents))]
documents["similarity"] = [np.random.random() for _ in range(len(documents))]
return documents

def get_embedding(self, query, engine):
return get_fake_embedding()

Expand Down

0 comments on commit fc97df5

Please sign in to comment.