Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docuemnts evaluator, refactor validators #151

Merged
merged 10 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions buster/completers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def postprocess(self):
answer=self.answer_text, matched_documents=self.matched_documents
)

if self.validator.validate_documents:
self.matched_documents = self.validator.check_documents_relevance(
answer=self.answer_text, matched_documents=self.matched_documents
)

# access the property so it gets set if not computed alerady
self.answer_relevant

Expand Down
40 changes: 25 additions & 15 deletions buster/examples/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
from buster.formatters.prompts import PromptFormatter
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers import GPTTokenizer
from buster.validators import QuestionAnswerValidator, Validator
from buster.validators import Validator

buster_cfg = BusterConfig(
validator_cfg={
"unknown_response_templates": [
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
],
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"use_reranking": True,
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
"check_question_prompt": """You are an chatbot answering questions on artificial intelligence.

"question_validator_cfg": {
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"check_question_prompt": """You are an chatbot answering questions on artificial intelligence.
jerpint marked this conversation as resolved.
Show resolved Hide resolved
Your job is to determine wether or not a question is valid, and should be answered.
More general questions are not considered valid, even if you might know the response.
A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.
Expand All @@ -30,11 +29,22 @@
false

A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.""",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"answer_validator_cfg": {
"unknown_response_templates": [
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
],
"unknown_threshold": 0.85,
},
"documents_validator_cfg": {
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
},
"use_reranking": True,
"validate_documents": True,
},
retriever_cfg={
"path": "deeplake_store",
Expand Down Expand Up @@ -98,6 +108,6 @@ def setup_buster(buster_cfg: BusterConfig):
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
**buster_cfg.documents_answerer_cfg,
)
validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg)
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator)
return buster
2 changes: 2 additions & 0 deletions buster/llm_utils/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import lru_cache

import numpy as np
import pandas as pd
Expand All @@ -11,6 +12,7 @@
client = OpenAI()


@lru_cache
hbertrand marked this conversation as resolved.
Show resolved Hide resolved
def get_openai_embedding(text: str, model: str = "text-embedding-ada-002"):
try:
text = text.replace("\n", " ")
Expand Down
3 changes: 1 addition & 2 deletions buster/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .base import Validator
from .question_answer_validator import QuestionAnswerValidator

__all__ = [Validator, QuestionAnswerValidator]
__all__ = [Validator]
49 changes: 24 additions & 25 deletions buster/validators/base.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,45 @@
import logging
from abc import ABC, abstractmethod
from functools import lru_cache

import pandas as pd

from buster.llm_utils import cosine_similarity, get_openai_embedding
from buster.validators.question_answer_validator import (
AnswerValidator,
DocumentsValidator,
QuestionValidator,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class Validator(ABC):
class Validator:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there really a reason for this class now?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder the same thing, but it helps keep the code more organized

def __init__(
self,
embedding_model: str,
unknown_threshold: float,
use_reranking: bool,
invalid_question_response: str = "This question is not relevant to my internal knowledge base.",
validate_documents: bool,
question_validator_cfg=None,
answer_validator_cfg=None,
documents_validator_cfg=None,
):
self.embedding_model = embedding_model
self.unknown_threshold = unknown_threshold
self.question_validator = QuestionValidator(**question_validator_cfg)
self.answer_validator = AnswerValidator(**answer_validator_cfg)
self.documents_validator = DocumentsValidator(**documents_validator_cfg)
self.use_reranking = use_reranking
self.invalid_question_response = invalid_question_response
self.validate_documents = validate_documents

@staticmethod
@lru_cache
def get_embedding(text: str, model: str):
"""Currently supports OpenAI embeddings, override to add your own."""
logger.info("generating embedding")
return get_openai_embedding(text, model)

@abstractmethod
def check_question_relevance(self, question: str) -> tuple[bool, str]:
...
return self.question_validator.check_question_relevance(question)

@abstractmethod
def check_answer_relevance(self, answer: str) -> bool:
...
return self.answer_validator.check_answer_relevance(answer)

def check_documents_relevance(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame:
return self.documents_validator.check_documents_relevance(answer, matched_documents)

def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame:
def rerank_docs(
self, answer: str, matched_documents: pd.DataFrame, embedding_fn=get_openai_embedding
) -> pd.DataFrame:
"""Here we re-rank matched documents according to the answer provided by the llm.

This score could be used to determine wether a document was actually relevant to generation.
Expand All @@ -48,10 +49,8 @@ def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFr
return matched_documents
logger.info("Reranking documents based on answer similarity...")

answer_embedding = self.get_embedding(
answer,
model=self.embedding_model,
)
answer_embedding = embedding_fn(answer)

col = "similarity_to_answer"
matched_documents[col] = matched_documents.embedding.apply(lambda x: cosine_similarity(x, answer_embedding))

Expand Down
155 changes: 99 additions & 56 deletions buster/validators/question_answer_validator.py
Original file line number Diff line number Diff line change
@@ -1,90 +1,133 @@
import concurrent.futures
import logging

import pandas as pd

from buster.completers import ChatGPTCompleter
from buster.llm_utils import cosine_similarity
from buster.validators import Validator
from buster.llm_utils.embeddings import get_openai_embedding

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class QuestionAnswerValidator(Validator):
def __init__(
self, completion_kwargs: dict, check_question_prompt: str, unknown_response_templates: list[str], **kwargs
):
super().__init__(**kwargs)

class QuestionValidator:
def __init__(self, completion_kwargs: dict, check_question_prompt: str, invalid_question_response: str):
self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs)
self.check_question_prompt = check_question_prompt
self.unknown_response_templates = unknown_response_templates
self.invalid_question_response = invalid_question_response

def check_question_relevance(self, question: str) -> tuple[bool, str]:
"""Determines wether a question is relevant or not for our given framework."""

def get_relevance(outputs: str) -> bool:
# remove trailing periods, happens sometimes...
outputs = outputs.strip(".").lower()

if outputs == "true":
relevance = True
elif outputs == "false":
relevance = False
else:
# Default assume it's no longer relevant if the detector didn't give one of [true, false]
logger.warning(f"the question validation returned an unexpeced value: {outputs}. Assuming Invalid...")
relevance = False
return relevance

response = self.invalid_question_response
"""Determines whether a question is relevant for our given framework."""
try:
outputs, error = self.completer.complete(self.check_question_prompt, user_input=question)
relevance = get_relevance(outputs)
outputs, _ = self.completer.complete(self.check_question_prompt, user_input=question)
outputs = outputs.strip(".").lower()
if outputs not in ["true", "false"]:
logger.warning(f"the question validation returned an unexpeced value: {outputs=}. Assuming Invalid...")
jerpint marked this conversation as resolved.
Show resolved Hide resolved
relevance = outputs.strip(".").lower() == "true"
response = self.invalid_question_response

except Exception as e:
# Something went wrong, assume immediately not relevant.
logger.exception("Something went wrong during question relevance detection. See traceback:")
logger.exception("Error during question relevance detection.")
relevance = False
response = "Unable to process your question at the moment, try again soon"

logger.info(f"Question {relevance=}")

return relevance, response

def check_answer_relevance(self, answer: str) -> bool:
"""Check to see if a generated answer is relevant to the chatbot's knowledge or not.

We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant.
Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding.
class AnswerValidator:
def __init__(self, unknown_response_templates: list[str], unknown_threshold: float, embedding_fn: callable = None):
jerpint marked this conversation as resolved.
Show resolved Hide resolved
self.unknown_response_templates = unknown_response_templates
self.unknown_threshold = unknown_threshold

unk_threshold can be a value between [-1,1]. Usually, 0.85 is a good value.
"""
logger.info("Checking for answer relevance...")
if embedding_fn is None:
self.embedding_fn = get_openai_embedding

def check_answer_relevance(self, answer: str) -> bool:
"""Check if a generated answer is relevant to the chatbot's knowledge."""
if answer == "":
raise ValueError("Cannot compute embedding of an empty string.")

# if unknown_prompt is None:
unknown_responses = self.unknown_response_templates

unknown_embeddings = [
self.get_embedding(
unknown_response,
model=self.embedding_model,
)
for unknown_response in unknown_responses
self.embedding_fn(unknown_response) for unknown_response in self.unknown_response_templates
]

answer_embedding = self.get_embedding(
answer,
model=self.embedding_model,
)
answer_embedding = self.embedding_fn(answer)
unknown_similarity_scores = [
cosine_similarity(answer_embedding, unknown_embedding) for unknown_embedding in unknown_embeddings
]
logger.info(f"{unknown_similarity_scores=}")

# Likely that the answer is meaningful, add the top sources
answer_relevant: bool = (
False if any(score > self.unknown_threshold for score in unknown_similarity_scores) else True
)
return answer_relevant
# If any score is above the threshold, the answer is considered not relevant
return not any(score > self.unknown_threshold for score in unknown_similarity_scores)


class DocumentsValidator:
def __init__(
self,
completion_kwargs: dict = None,
system_prompt: str = None,
user_input_formatter: str = None,
max_calls: int = 30,
):
if system_prompt is None:
system_prompt = """
Your goal is to determine if the contents of a document can be attributed to a provided answer.
jerpint marked this conversation as resolved.
Show resolved Hide resolved
This means that if information in the document is found in the answer, it is relevant. Otherwise it is not.
Your goal is to determine if the information contained in a document was used to generate an answer.
You will be comparing a document to an answer. If the answer can be inferred from the document, return 'true'. Otherwise return 'false'.
Only respond with 'true' or 'false'."""
self.system_prompt = system_prompt

if user_input_formatter is None:
user_input_formatter = """
answer: {answer}
document: {document}
"""
self.user_input_formatter = user_input_formatter

if completion_kwargs is None:
completion_kwargs = {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
}

self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs)

self.max_calls = max_calls

def check_document_relevance(self, answer: str, document: str) -> bool:
user_input = self.user_input_formatter.format(answer=answer, document=document)
output, _ = self.completer.complete(prompt=self.system_prompt, user_input=user_input)

# remove trailing periods, happens sometimes...
output = output.strip(".").lower()

if output not in ["true", "false"]:
# Default assume it's relevant if the detector didn't give one of [true, false]
logger.warning(f"the validation returned an unexpeced value: {output}. Assuming valid...")
return True
return output == "true"

def check_documents_relevance(self, answer: str, matched_documents: pd.DataFrame) -> list[bool]:
"""Determines wether a question is relevant or not for our given framework."""

logger.info(f"Checking document relevance of {len(matched_documents)} documents")

if len(matched_documents) > self.max_calls:
raise ValueError("Max calls exceeded, increase max_calls to allow this.")

# Here we parallelize the calls. We introduce a wrapper as a workaround.
def _check_documents(args):
"Thin wrapper so we can pass args as a Tuple and use ThreadPoolExecutor."
answer, document = args
return self.check_document_relevance(answer=answer, document=document)

args_list = [(answer, doc) for doc in matched_documents.content.to_list()]
with concurrent.futures.ThreadPoolExecutor() as executor:
relevance = list(executor.map(_check_documents, args_list))

logger.info(f"{relevance=}")
# add it back to the dataframe
matched_documents["relevance"] = relevance
return matched_documents
Loading