Skip to content

Commit

Permalink
Refactor - code changes with best practices, add docstrings and impro…
Browse files Browse the repository at this point in the history
…ved error handling
  • Loading branch information
hillaryke committed Aug 2, 2024
1 parent d06e544 commit 0db4ef5
Show file tree
Hide file tree
Showing 13 changed files with 467 additions and 213 deletions.
71 changes: 65 additions & 6 deletions backend/app/main.py
Original file line number Diff line number Diff line change
@@ -1,54 +1,100 @@
from typing import Optional

from fastapi import FastAPI, HTTPException, Depends
from pydantic import ValidationError
from pydantic import BaseModel
from pydantic import ValidationError, BaseModel
from config.utils import load_config
from src.rag_pipeline.rag_system import RAGSystem
from src.env_loader import load_api_keys

# Load API keys from environment
load_api_keys()

app = FastAPI()


class InitializeRequest(BaseModel):
"""
Request model for initializing the RAG system.
Attributes:
strategy_name (str): The name of the strategy to use for initialization.
split_docs (Optional[int]): The number of split documents to use. Defaults to None.
"""

strategy_name: str
split_docs: Optional[int] = None


class QueryRequest(BaseModel):
"""
Request model for querying the RAG system.
Attributes:
question (str): The question to query the RAG system with.
"""

question: str


# Global variable to store the initialized RAGSystem
rag_system_instance: Optional[RAGSystem] = None


# Dependency to get the initialized RAGSystem
def get_rag_system():
def get_rag_system() -> RAGSystem:
"""
Dependency to get the initialized RAGSystem.
Raises:
HTTPException: If the RAG system is not initialized.
Returns:
RAGSystem: The initialized RAG system instance.
"""
if rag_system_instance is None:
raise HTTPException(status_code=500, detail="RAG system is not initialized")
return rag_system_instance


@app.get("/")
def read_root():
"""
Root endpoint to check if the API is running.
Returns:
dict: A welcome message.
"""
return {"message": "Hello from FastAPI backend!"}


@app.get("/health")
def health_check():
"""
Health check endpoint to verify if the RAG system is initialized and running.
Returns:
dict: The status of the RAG system.
"""
try:
rag_system = get_rag_system()
print(rag_system)
return {"status": "RAG system is initialized and running"}
except HTTPException as e:
return {"status": "RAG system is not initialized", "detail": str(e)}


@app.post("/initialize") # New endpoint for initialization
@app.post("/initialize")
def initialize_rag_system(init_request: InitializeRequest):
"""
Endpoint to initialize the RAG system.
Args:
init_request (InitializeRequest): The initialization request containing strategy name and split_docs.
Returns:
dict: A message indicating the result of the initialization.
Raises:
HTTPException: If there is a configuration error or initialization fails.
"""
global rag_system_instance
try:
config = load_config(init_request.strategy_name)
Expand All @@ -69,6 +115,19 @@ def initialize_rag_system(init_request: InitializeRequest):
def query_rag_system(
query_request: QueryRequest, rag_system: RAGSystem = Depends(get_rag_system)
):
"""
Endpoint to query the RAG system.
Args:
query_request (QueryRequest): The query request containing the question.
rag_system (RAGSystem): The initialized RAG system instance.
Returns:
dict: The answer from the RAG system.
Raises:
HTTPException: If the query fails.
"""
try:
answer = rag_system.query(query_request.question)
return {"answer": answer}
Expand Down
48 changes: 48 additions & 0 deletions config/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,48 @@


class VectorStoreConfig(BaseModel):
"""
Configuration for the vector store.
Attributes:
collection_name (str): The name of the collection in the vector store.
clear_store (bool): Whether to clear the store before use.
use_existing_vectorstore (bool): Whether to use an existing vector store.
"""

collection_name: str = "cnn_dailymail"
clear_store: bool = True
use_existing_vectorstore: bool = False


class ChunkingConfig(BaseModel):
"""
Configuration for document chunking.
Attributes:
chunk_type (str): The type of chunking strategy.
chunk_size (int): The size of each chunk.
chunk_overlap (int): The overlap between chunks.
"""

chunk_type: str = "recursive"
chunk_size: int = 1000
chunk_overlap: int = 200


class RetrievalConfig(BaseModel):
"""
Configuration for document retrieval.
Attributes:
k_documents (int): The number of documents to retrieve.
use_ensemble (bool): Whether to use an ensemble retriever.
use_multiquery (bool): Whether to use a multi-query retriever.
use_reranker (bool): Whether to use a reranker.
use_cohere_reranker (bool): Whether to use the Cohere reranker.
top_n_ranked (int): The number of top-ranked documents to return.
"""

k_documents: int = 5
use_ensemble: bool = False
use_multiquery: bool = False
Expand All @@ -23,11 +53,29 @@ class RetrievalConfig(BaseModel):


class ModelsConfig(BaseModel):
"""
Configuration for models.
Attributes:
generator_model (str): The name of the generator model.
queries_generator_model (str): The name of the queries generator model.
"""

generator_model: str = "gpt-4o-mini"
queries_generator_model: str = "gpt-4o-mini"


class Config(BaseModel):
"""
Main configuration class that aggregates all other configurations.
Attributes:
vectorstore (VectorStoreConfig): Configuration for the vector store.
chunking (ChunkingConfig): Configuration for document chunking.
retrieval (RetrievalConfig): Configuration for document retrieval.
models (ModelsConfig): Configuration for models.
"""

vectorstore: VectorStoreConfig
chunking: ChunkingConfig
retrieval: RetrievalConfig
Expand Down
32 changes: 28 additions & 4 deletions config/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,39 @@
import yaml

from typing import List, Dict, Any
from loguru import logger
from config.settings import Config
from src.env_loader import load_api_keys

# Load API keys from environment
load_api_keys()


def load_config(strategy_name: str) -> Config:
"""
Load configuration for a given strategy from a YAML file.
Args:
strategy_name (str): The name of the strategy to load the configuration for.
Returns:
Config: The configuration object.
Raises:
FileNotFoundError: If the configuration file does not exist.
yaml.YAMLError: If there is an error parsing the YAML file.
TypeError: If the configuration data cannot be converted to a Config object.
"""
config_file = f"config/config_{strategy_name}.yaml"
with open(config_file, "r") as file:
config_data = yaml.safe_load(file)
return Config(**config_data)
try:
with open(config_file, "r") as file:
config_data = yaml.safe_load(file)
return Config(**config_data)
except FileNotFoundError as e:
logger.error(f"Configuration file not found: {config_file}")
raise
except yaml.YAMLError as e:
logger.error(f"Error parsing YAML file: {config_file}")
raise
except TypeError as e:
logger.error(f"Error converting configuration data to Config object: {e}")
raise
40 changes: 19 additions & 21 deletions src/benchmark_analysis/benchmarks_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,37 @@
import altair as alt

class BenchmarkAnalysis:
def __init__(self, baseline_df, prompt_eng_df):
def __init__(self, baseline_df, optimized_df):
self.baseline_df = baseline_df.copy()
self.prompt_eng_df = prompt_eng_df.copy()
self.optimized_df = optimized_df.copy()
self._clean_data()
self.numeric_columns = ['answer_correctness', 'faithfulness', 'answer_relevancy', 'context_precision']

def _clean_data(self):
"""Drops unnamed index columns if they exist."""
self.baseline_df.drop(columns=['Unnamed: 0'], errors='ignore', inplace=True)
self.prompt_eng_df.drop(columns=['Unnamed: 0'], errors='ignore', inplace=True)
self.optimized_df.drop(columns=['Unnamed: 0'], errors='ignore', inplace=True)

def calculate_summary_statistics(self):
"""Calculates summary statistics for the specified numeric columns."""
numeric_columns = ['answer_correctness', 'faithfulness', 'answer_relevancy', 'context_precision']
summary_stats = {
'Metric': [],
'Baseline_Average': [],
'Prompt_eng_opt_Average': [],
'Optimized_Average': [],
'Baseline_Highest': [],
'Prompt_eng_opt_Highest': [],
'Optimized_Highest': [],
'Baseline_Lowest': [],
'Prompt_eng_opt_Lowest': []
'Optimized_Lowest': []
}

for column in numeric_columns:
for column in self.numeric_columns:
summary_stats['Metric'].append(column)
summary_stats['Baseline_Average'].append(self.baseline_df[column].mean())
summary_stats['Prompt_eng_opt_Average'].append(self.prompt_eng_df[column].mean())
summary_stats['Optimized_Average'].append(self.optimized_df[column].mean())
summary_stats['Baseline_Highest'].append(self.baseline_df[column].max())
summary_stats['Prompt_eng_opt_Highest'].append(self.prompt_eng_df[column].max())
summary_stats['Optimized_Highest'].append(self.optimized_df[column].max())
summary_stats['Baseline_Lowest'].append(self.baseline_df[column].min())
summary_stats['Prompt_eng_opt_Lowest'].append(self.prompt_eng_df[column].min())
summary_stats['Optimized_Lowest'].append(self.optimized_df[column].min())

summary_df = pd.DataFrame(summary_stats)
return summary_df
Expand All @@ -45,19 +45,19 @@ def visualize_summary_statistics(self, summary_df):

# Average comparison
plt.subplot(3, 1, 1)
sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Average', 'Prompt_eng_opt_Average']))
sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Average', 'Optimized_Average']))
plt.title('Average Comparison')
plt.xticks(rotation=45)

# Highest value comparison
plt.subplot(3, 1, 2)
sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Highest', 'Prompt_eng_opt_Highest']))
sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Highest', 'Optimized_Highest']))
plt.title('Highest Value Comparison')
plt.xticks(rotation=45)

# Lowest value comparison
plt.subplot(3, 1, 3)
sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Lowest', 'Prompt_eng_opt_Lowest']))
sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Lowest', 'Optimized_Lowest']))
plt.title('Lowest Value Comparison')
plt.xticks(rotation=45)

Expand All @@ -66,22 +66,20 @@ def visualize_summary_statistics(self, summary_df):

def calculate_deviations(self):
"""Calculates deviations between baseline and prompt engineering optimized DataFrames."""
numeric_columns = ['answer_correctness', 'faithfulness', 'answer_relevancy', 'context_precision']
deviations = {
'question': self.baseline_df['question'],
'answer': self.baseline_df['answer']
}

for column in numeric_columns:
deviations[column + '_deviation'] = self.baseline_df[column] - self.prompt_eng_df[column]
for column in self.numeric_columns:
deviations[column + '_deviation'] = self.baseline_df[column] - self.optimized_df[column]

deviations_df = pd.DataFrame(deviations)
return deviations_df

def visualize_deviations(self, deviations_df):
"""Visualizes the deviations using Altair."""
numeric_columns = ['answer_correctness', 'faithfulness', 'answer_relevancy', 'context_precision']
deviation_melted = deviations_df.melt(id_vars=['question', 'answer'], value_vars=[col + '_deviation' for col in numeric_columns], var_name='Metric', value_name='Deviation')
deviation_melted = deviations_df.melt(id_vars=['question', 'answer'], value_vars=[col + '_deviation' for col in self.numeric_columns], var_name='Metric', value_name='Deviation')

# Create the Altair plot
chart = alt.Chart(deviation_melted).mark_bar().encode(
Expand All @@ -99,8 +97,8 @@ def visualize_deviations(self, deviations_df):

# Example usage:
# baseline_df = pd.read_csv('path_to_baseline.csv')
# prompt_eng_df = pd.read_csv('path_to_prompt_eng.csv')
# benchmark = BenchmarkAnalysis(baseline_df, prompt_eng_df)
# optimized_df = pd.read_csv('path_to_prompt_eng.csv')
# benchmark = BenchmarkAnalysis(baseline_df, optimized_df)
# summary_df = benchmark.calculate_summary_statistics()
# print(summary_df)
# benchmark.visualize_summary_statistics(summary_df)
Expand Down
9 changes: 6 additions & 3 deletions src/env_loader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from dotenv import load_dotenv


def load_api_keys(key_name: str = None) -> str:
"""
Load an API keys from environment variables. If key_name is provided, load the API key with that name.
Expand All @@ -20,6 +21,8 @@ def load_api_keys(key_name: str = None) -> str:
api_key = os.getenv(key_name)

if api_key is None:
raise ValueError(f"API key '{key_name}' not found in environment variables.")

return api_key
raise ValueError(
f"API key '{key_name}' not found in environment variables."
)

return api_key
8 changes: 5 additions & 3 deletions src/rag_pipeline/chunking_strategies.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain_openai.embeddings import OpenAIEmbeddings

def chunk_by_recursive_split(documents: list[Document], chunk_size: int = 1000, chunk_overlap: int = 200) -> list[Document]:

def chunk_by_recursive_split(
documents: list[Document], chunk_size: int = 1000, chunk_overlap: int = 200
) -> list[Document]:
"""
Splits a list documents into chunks of a specified size using a recursive character-based approach. Splits are based purely on
character count.
Expand All @@ -26,4 +28,4 @@ def chunk_by_recursive_split(documents: list[Document], chunk_size: int = 1000,
except Exception as e:
print(f"Error during recursive split: {e}")
chunks = [] # Ensure chunks is defined even in case of error
return chunks
return chunks
Loading

0 comments on commit 0db4ef5

Please sign in to comment.