Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add docuemnts evaluator, refactor validators #151

Merged
merged 10 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions buster/busterbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,8 @@ class BusterConfig:

validator_cfg: dict = field(
default_factory=lambda: {
"unknown_prompts": [
"I Don't know how to answer your question.",
],
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"use_reranking": True,
"validate_documents": False,
}
)
tokenizer_cfg: dict = field(
Expand Down
5 changes: 5 additions & 0 deletions buster/completers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,11 @@ def postprocess(self):
answer=self.answer_text, matched_documents=self.matched_documents
)

if self.validator.validate_documents:
self.matched_documents = self.validator.check_documents_relevance(
answer=self.answer_text, matched_documents=self.matched_documents
)

# access the property so it gets set if not computed alerady
self.answer_relevant

Expand Down
40 changes: 25 additions & 15 deletions buster/examples/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,18 @@
from buster.formatters.prompts import PromptFormatter
from buster.retriever import DeepLakeRetriever, Retriever
from buster.tokenizers import GPTTokenizer
from buster.validators import QuestionAnswerValidator, Validator
from buster.validators import Validator

buster_cfg = BusterConfig(
validator_cfg={
"unknown_response_templates": [
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
],
"unknown_threshold": 0.85,
"embedding_model": "text-embedding-ada-002",
"use_reranking": True,
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
"check_question_prompt": """You are an chatbot answering questions on artificial intelligence.

"question_validator_cfg": {
"invalid_question_response": "This question does not seem relevant to my current knowledge.",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"check_question_prompt": """You are a chatbot answering questions on artificial intelligence.
Your job is to determine wether or not a question is valid, and should be answered.
More general questions are not considered valid, even if you might know the response.
A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.
Expand All @@ -30,11 +29,22 @@
false

A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.""",
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
"answer_validator_cfg": {
"unknown_response_templates": [
"I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?",
],
"unknown_threshold": 0.85,
},
"documents_validator_cfg": {
"completion_kwargs": {
"model": "gpt-3.5-turbo",
"stream": False,
"temperature": 0,
},
},
"use_reranking": True,
"validate_documents": True,
},
retriever_cfg={
"path": "deeplake_store",
Expand Down Expand Up @@ -98,6 +108,6 @@ def setup_buster(buster_cfg: BusterConfig):
prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg),
**buster_cfg.documents_answerer_cfg,
)
validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg)
validator: Validator = Validator(**buster_cfg.validator_cfg)
buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator)
return buster
4 changes: 3 additions & 1 deletion buster/llm_utils/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from functools import lru_cache

import numpy as np
import pandas as pd
Expand All @@ -11,7 +12,8 @@
client = OpenAI()


def get_openai_embedding(text: str, model: str = "text-embedding-ada-002"):
@lru_cache
hbertrand marked this conversation as resolved.
Show resolved Hide resolved
def get_openai_embedding(text: str, model: str = "text-embedding-ada-002") -> np.array:
try:
text = text.replace("\n", " ")
response = client.embeddings.create(
Expand Down
3 changes: 1 addition & 2 deletions buster/validators/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from .base import Validator
from .question_answer_validator import QuestionAnswerValidator

__all__ = [Validator, QuestionAnswerValidator]
__all__ = [Validator]
57 changes: 32 additions & 25 deletions buster/validators/base.py
Original file line number Diff line number Diff line change
@@ -1,44 +1,53 @@
import logging
from abc import ABC, abstractmethod
from functools import lru_cache

import pandas as pd

from buster.llm_utils import cosine_similarity, get_openai_embedding
from buster.validators.validators import (
AnswerValidator,
DocumentsValidator,
QuestionValidator,
)

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class Validator(ABC):
class Validator:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there really a reason for this class now?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i wonder the same thing, but it helps keep the code more organized

def __init__(
self,
embedding_model: str,
unknown_threshold: float,
use_reranking: bool,
invalid_question_response: str = "This question is not relevant to my internal knowledge base.",
validate_documents: bool,
question_validator_cfg=None,
answer_validator_cfg=None,
documents_validator_cfg=None,
):
self.embedding_model = embedding_model
self.unknown_threshold = unknown_threshold
self.question_validator = (
QuestionValidator(**question_validator_cfg) if question_validator_cfg is not None else QuestionValidator()
)
self.answer_validator = (
AnswerValidator(**answer_validator_cfg) if answer_validator_cfg is not None else AnswerValidator()
)
self.documents_validator = (
DocumentsValidator(**documents_validator_cfg)
if documents_validator_cfg is not None
else DocumentsValidator()
)
self.use_reranking = use_reranking
self.invalid_question_response = invalid_question_response

@staticmethod
@lru_cache
def get_embedding(text: str, model: str):
"""Currently supports OpenAI embeddings, override to add your own."""
logger.info("generating embedding")
return get_openai_embedding(text, model)
self.validate_documents = validate_documents

@abstractmethod
def check_question_relevance(self, question: str) -> tuple[bool, str]:
...
return self.question_validator.check_question_relevance(question)

@abstractmethod
def check_answer_relevance(self, answer: str) -> bool:
...
return self.answer_validator.check_answer_relevance(answer)

def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame:
def check_documents_relevance(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame:
return self.documents_validator.check_documents_relevance(answer, matched_documents)

def rerank_docs(
self, answer: str, matched_documents: pd.DataFrame, embedding_fn=get_openai_embedding
) -> pd.DataFrame:
"""Here we re-rank matched documents according to the answer provided by the llm.

This score could be used to determine wether a document was actually relevant to generation.
Expand All @@ -48,10 +57,8 @@ def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFr
return matched_documents
logger.info("Reranking documents based on answer similarity...")

answer_embedding = self.get_embedding(
answer,
model=self.embedding_model,
)
answer_embedding = embedding_fn(answer)

col = "similarity_to_answer"
matched_documents[col] = matched_documents.embedding.apply(lambda x: cosine_similarity(x, answer_embedding))

Expand Down
90 changes: 0 additions & 90 deletions buster/validators/question_answer_validator.py

This file was deleted.

Loading