Skip to content

Commit

Permalink
Add question pre-processing support (#108)
Browse files Browse the repository at this point in the history
* add support for verifying validity of questions before generating a response

* skip reranking when documents are empty

* pass message along with question when checking relevance

* add a default invalid question response
  • Loading branch information
jerpint authored Jun 27, 2023
1 parent 6c956ff commit 9b7de7c
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 11 deletions.
22 changes: 20 additions & 2 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from dataclasses import dataclass, field
from typing import Any

import pandas as pd

from buster.completers import Completer, Completion
from buster.retriever import Retriever
from buster.validators import Validator
Expand Down Expand Up @@ -81,13 +83,29 @@ def process_input(self, user_input: str, source: str = None) -> Completion:
if not user_input.endswith("\n"):
user_input += "\n"

matched_documents = self.retriever.retrieve(user_input, source=source)
# The returned message is either a generic invalid question message or an error handling message
question_relevant, irrelevant_question_message = self.validator.check_question_relevance(user_input)

if question_relevant:
# question is relevant, get completor to generate completion
matched_documents = self.retriever.retrieve(user_input, source=source)
completion = self.completer.get_completion(user_input=user_input, matched_documents=matched_documents)

completion = self.completer.get_completion(user_input=user_input, matched_documents=matched_documents)
else:
# question was determined irrelevant, so we instead return a generic response set by the user.
completion = Completion(
error=False,
user_input=user_input,
matched_documents=pd.DataFrame(),
completor=irrelevant_question_message,
answer_relevant=False,
question_relevant=False,
)

logger.info(f"Completion:\n{completion}")

return completion

def postprocess_completion(self, completion) -> Completion:
"""This will check if the answer is relevant, and rerank the sources by relevance too."""
return self.validator.validate(completion=completion)
31 changes: 23 additions & 8 deletions buster/completers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ class Completion:
matched_documents: pd.DataFrame
completor: Iterator | str
answer_relevant: bool = None
question_relevant: bool = None

# private property, should not be set at init
_completor: Iterator | str = field(init=False, repr=False) # e.g. a streamed response from openai.ChatCompletion
_text: str = None
_text: str = (
None # once the generator of the completor is exhausted, the text will be available in the self.text property
)

@property
def text(self):
Expand Down Expand Up @@ -101,6 +104,7 @@ def encode_df(df: pd.DataFrame) -> dict:
"text": self.text,
"matched_documents": self.matched_documents,
"answer_relevant": self.answer_relevant,
"question_relevant": self.question_relevant,
"error": self.error,
}
return jsonable_encoder(to_encode, custom_encoder=custom_encoder)
Expand Down Expand Up @@ -128,11 +132,13 @@ def __init__(
prompt_formatter: PromptFormatter,
completion_kwargs: dict,
no_documents_message: str = "No documents were found that match your question.",
completion_class: Completion = Completion,
):
self.completion_kwargs = completion_kwargs
self.documents_formatter = documents_formatter
self.prompt_formatter = prompt_formatter
self.no_documents_message = no_documents_message
self.completion_class = completion_class

@abstractmethod
def complete(self, prompt: str, user_input: str) -> Completion:
Expand All @@ -151,34 +157,43 @@ def prepare_prompt(self, matched_documents) -> str:
return prompt

def get_completion(self, user_input: str, matched_documents: pd.DataFrame) -> Completion:
# Call the API to generate a response
"""Generate a completion to a user's question based on matched documents."""

# The completor assumes a question was previously determined valid, otherwise it would not be called.
question_relevant = True

logger.info(f"{user_input=}")

if len(matched_documents) == 0:
logger.warning("no documents found...")
# no document was found, pass the appropriate message instead...
logger.warning("no documents found...")

# empty dataframe
matched_documents = pd.DataFrame(columns=matched_documents.columns)

completion = Completion(
# because we are proceeding with a completion, we assume the question is relevant.
completion = self.completion_class(
user_input=user_input,
completor=self.no_documents_message,
error=False,
matched_documents=matched_documents,
question_relevant=question_relevant,
)
return completion

# prepare the prompt
# prepare the prompt with matched documents
prompt = self.prepare_prompt(matched_documents)
logger.info(f"{prompt=}")

logger.info(f"querying model with parameters: {self.completion_kwargs}...")
completor = self.complete(prompt=prompt, user_input=user_input, **self.completion_kwargs)

completion = Completion(
completor=completor, error=self.error, matched_documents=matched_documents, user_input=user_input
completion = self.completion_class(
completor=completor,
error=self.error,
matched_documents=matched_documents,
user_input=user_input,
question_relevant=question_relevant,
)

return completion
Expand Down Expand Up @@ -224,13 +239,13 @@ def complete(self, prompt: str, user_input, **completion_kwargs) -> str | Iterat
{"role": "system", "content": prompt},
{"role": "user", "content": user_input},
]
self.error = False
try:
response = openai.ChatCompletion.create(
messages=messages,
**completion_kwargs,
)

self.error = False
if completion_kwargs.get("stream") is True:
# We are entering streaming mode, so here were just wrapping the streamed
# openai response to be easier to handle later
Expand Down
1 change: 1 addition & 0 deletions buster/examples/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"use_reranking": True,
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
},
retriever_cfg={
"db_path": "documents.db",
Expand Down
22 changes: 21 additions & 1 deletion buster/validators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,36 @@


class Validator:
def __init__(self, embedding_model: str, unknown_threshold: float, unknown_prompt: str, use_reranking: bool):
def __init__(
self,
embedding_model: str,
unknown_threshold: float,
unknown_prompt: str,
use_reranking: bool,
invalid_question_response: str = "This question is not relevant to my knowledge.",
):
self.embedding_model = embedding_model
self.unknown_threshold = unknown_threshold
self.unknown_prompt = unknown_prompt
self.use_reranking = use_reranking
self.invalid_question_response = invalid_question_response

@staticmethod
@lru_cache
def get_embedding(query: str, engine: str):
logger.info("generating embedding")
return get_embedding(query, engine=engine)

def check_question_relevance(self, question: str) -> tuple[bool, str]:
"""Determines wether a question is relevant or not for our given framework."""
# Override this method to suit your needs.
# By default, no checks happen.
# You could for example use a GPT call to check your question validity, at extra cost/latency.
# The message will be what's printed should question_relevant be False.
question_relevant = True
message: str = self.invalid_question_response
return question_relevant, message

def check_answer_relevance(self, answer: str, unknown_prompt: str = None) -> bool:
"""Check to see if a generated answer is relevant to the chatbot's knowledge or not.
Expand Down Expand Up @@ -60,6 +78,8 @@ def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFr
This score could be used to determine wether a document was actually relevant to generation.
An extra column is added in-place for the similarity score.
"""
if len(matched_documents) == 0:
return matched_documents
logger.info("Reranking documents based on answer similarity...")

answer_embedding = self.get_embedding(
Expand Down

0 comments on commit 9b7de7c

Please sign in to comment.