diff --git a/backend/app/main.py b/backend/app/main.py index 16aaeaf..6d3ea79 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1,23 +1,38 @@ from typing import Optional from fastapi import FastAPI, HTTPException, Depends -from pydantic import ValidationError -from pydantic import BaseModel +from pydantic import ValidationError, BaseModel from config.utils import load_config from src.rag_pipeline.rag_system import RAGSystem from src.env_loader import load_api_keys +# Load API keys from environment load_api_keys() app = FastAPI() class InitializeRequest(BaseModel): + """ + Request model for initializing the RAG system. + + Attributes: + strategy_name (str): The name of the strategy to use for initialization. + split_docs (Optional[int]): The number of split documents to use. Defaults to None. + """ + strategy_name: str split_docs: Optional[int] = None class QueryRequest(BaseModel): + """ + Request model for querying the RAG system. + + Attributes: + question (str): The question to query the RAG system with. + """ + question: str @@ -25,8 +40,16 @@ class QueryRequest(BaseModel): rag_system_instance: Optional[RAGSystem] = None -# Dependency to get the initialized RAGSystem -def get_rag_system(): +def get_rag_system() -> RAGSystem: + """ + Dependency to get the initialized RAGSystem. + + Raises: + HTTPException: If the RAG system is not initialized. + + Returns: + RAGSystem: The initialized RAG system instance. + """ if rag_system_instance is None: raise HTTPException(status_code=500, detail="RAG system is not initialized") return rag_system_instance @@ -34,21 +57,44 @@ def get_rag_system(): @app.get("/") def read_root(): + """ + Root endpoint to check if the API is running. + + Returns: + dict: A welcome message. + """ return {"message": "Hello from FastAPI backend!"} @app.get("/health") def health_check(): + """ + Health check endpoint to verify if the RAG system is initialized and running. + + Returns: + dict: The status of the RAG system. + """ try: rag_system = get_rag_system() - print(rag_system) return {"status": "RAG system is initialized and running"} except HTTPException as e: return {"status": "RAG system is not initialized", "detail": str(e)} -@app.post("/initialize") # New endpoint for initialization +@app.post("/initialize") def initialize_rag_system(init_request: InitializeRequest): + """ + Endpoint to initialize the RAG system. + + Args: + init_request (InitializeRequest): The initialization request containing strategy name and split_docs. + + Returns: + dict: A message indicating the result of the initialization. + + Raises: + HTTPException: If there is a configuration error or initialization fails. + """ global rag_system_instance try: config = load_config(init_request.strategy_name) @@ -69,6 +115,19 @@ def initialize_rag_system(init_request: InitializeRequest): def query_rag_system( query_request: QueryRequest, rag_system: RAGSystem = Depends(get_rag_system) ): + """ + Endpoint to query the RAG system. + + Args: + query_request (QueryRequest): The query request containing the question. + rag_system (RAGSystem): The initialized RAG system instance. + + Returns: + dict: The answer from the RAG system. + + Raises: + HTTPException: If the query fails. + """ try: answer = rag_system.query(query_request.question) return {"answer": answer} diff --git a/config/settings.py b/config/settings.py index 185cc3c..d8360b8 100644 --- a/config/settings.py +++ b/config/settings.py @@ -2,18 +2,48 @@ class VectorStoreConfig(BaseModel): + """ + Configuration for the vector store. + + Attributes: + collection_name (str): The name of the collection in the vector store. + clear_store (bool): Whether to clear the store before use. + use_existing_vectorstore (bool): Whether to use an existing vector store. + """ + collection_name: str = "cnn_dailymail" clear_store: bool = True use_existing_vectorstore: bool = False class ChunkingConfig(BaseModel): + """ + Configuration for document chunking. + + Attributes: + chunk_type (str): The type of chunking strategy. + chunk_size (int): The size of each chunk. + chunk_overlap (int): The overlap between chunks. + """ + chunk_type: str = "recursive" chunk_size: int = 1000 chunk_overlap: int = 200 class RetrievalConfig(BaseModel): + """ + Configuration for document retrieval. + + Attributes: + k_documents (int): The number of documents to retrieve. + use_ensemble (bool): Whether to use an ensemble retriever. + use_multiquery (bool): Whether to use a multi-query retriever. + use_reranker (bool): Whether to use a reranker. + use_cohere_reranker (bool): Whether to use the Cohere reranker. + top_n_ranked (int): The number of top-ranked documents to return. + """ + k_documents: int = 5 use_ensemble: bool = False use_multiquery: bool = False @@ -23,11 +53,29 @@ class RetrievalConfig(BaseModel): class ModelsConfig(BaseModel): + """ + Configuration for models. + + Attributes: + generator_model (str): The name of the generator model. + queries_generator_model (str): The name of the queries generator model. + """ + generator_model: str = "gpt-4o-mini" queries_generator_model: str = "gpt-4o-mini" class Config(BaseModel): + """ + Main configuration class that aggregates all other configurations. + + Attributes: + vectorstore (VectorStoreConfig): Configuration for the vector store. + chunking (ChunkingConfig): Configuration for document chunking. + retrieval (RetrievalConfig): Configuration for document retrieval. + models (ModelsConfig): Configuration for models. + """ + vectorstore: VectorStoreConfig chunking: ChunkingConfig retrieval: RetrievalConfig diff --git a/config/utils.py b/config/utils.py index 87af2cd..d9b844d 100644 --- a/config/utils.py +++ b/config/utils.py @@ -1,15 +1,39 @@ import yaml - from typing import List, Dict, Any from loguru import logger from config.settings import Config from src.env_loader import load_api_keys +# Load API keys from environment load_api_keys() def load_config(strategy_name: str) -> Config: + """ + Load configuration for a given strategy from a YAML file. + + Args: + strategy_name (str): The name of the strategy to load the configuration for. + + Returns: + Config: The configuration object. + + Raises: + FileNotFoundError: If the configuration file does not exist. + yaml.YAMLError: If there is an error parsing the YAML file. + TypeError: If the configuration data cannot be converted to a Config object. + """ config_file = f"config/config_{strategy_name}.yaml" - with open(config_file, "r") as file: - config_data = yaml.safe_load(file) - return Config(**config_data) + try: + with open(config_file, "r") as file: + config_data = yaml.safe_load(file) + return Config(**config_data) + except FileNotFoundError as e: + logger.error(f"Configuration file not found: {config_file}") + raise + except yaml.YAMLError as e: + logger.error(f"Error parsing YAML file: {config_file}") + raise + except TypeError as e: + logger.error(f"Error converting configuration data to Config object: {e}") + raise diff --git a/src/benchmark_analysis/benchmarks_analysis.py b/src/benchmark_analysis/benchmarks_analysis.py index b9147a0..357e3a3 100644 --- a/src/benchmark_analysis/benchmarks_analysis.py +++ b/src/benchmark_analysis/benchmarks_analysis.py @@ -4,37 +4,37 @@ import altair as alt class BenchmarkAnalysis: - def __init__(self, baseline_df, prompt_eng_df): + def __init__(self, baseline_df, optimized_df): self.baseline_df = baseline_df.copy() - self.prompt_eng_df = prompt_eng_df.copy() + self.optimized_df = optimized_df.copy() self._clean_data() + self.numeric_columns = ['answer_correctness', 'faithfulness', 'answer_relevancy', 'context_precision'] def _clean_data(self): """Drops unnamed index columns if they exist.""" self.baseline_df.drop(columns=['Unnamed: 0'], errors='ignore', inplace=True) - self.prompt_eng_df.drop(columns=['Unnamed: 0'], errors='ignore', inplace=True) + self.optimized_df.drop(columns=['Unnamed: 0'], errors='ignore', inplace=True) def calculate_summary_statistics(self): """Calculates summary statistics for the specified numeric columns.""" - numeric_columns = ['answer_correctness', 'faithfulness', 'answer_relevancy', 'context_precision'] summary_stats = { 'Metric': [], 'Baseline_Average': [], - 'Prompt_eng_opt_Average': [], + 'Optimized_Average': [], 'Baseline_Highest': [], - 'Prompt_eng_opt_Highest': [], + 'Optimized_Highest': [], 'Baseline_Lowest': [], - 'Prompt_eng_opt_Lowest': [] + 'Optimized_Lowest': [] } - for column in numeric_columns: + for column in self.numeric_columns: summary_stats['Metric'].append(column) summary_stats['Baseline_Average'].append(self.baseline_df[column].mean()) - summary_stats['Prompt_eng_opt_Average'].append(self.prompt_eng_df[column].mean()) + summary_stats['Optimized_Average'].append(self.optimized_df[column].mean()) summary_stats['Baseline_Highest'].append(self.baseline_df[column].max()) - summary_stats['Prompt_eng_opt_Highest'].append(self.prompt_eng_df[column].max()) + summary_stats['Optimized_Highest'].append(self.optimized_df[column].max()) summary_stats['Baseline_Lowest'].append(self.baseline_df[column].min()) - summary_stats['Prompt_eng_opt_Lowest'].append(self.prompt_eng_df[column].min()) + summary_stats['Optimized_Lowest'].append(self.optimized_df[column].min()) summary_df = pd.DataFrame(summary_stats) return summary_df @@ -45,19 +45,19 @@ def visualize_summary_statistics(self, summary_df): # Average comparison plt.subplot(3, 1, 1) - sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Average', 'Prompt_eng_opt_Average'])) + sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Average', 'Optimized_Average'])) plt.title('Average Comparison') plt.xticks(rotation=45) # Highest value comparison plt.subplot(3, 1, 2) - sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Highest', 'Prompt_eng_opt_Highest'])) + sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Highest', 'Optimized_Highest'])) plt.title('Highest Value Comparison') plt.xticks(rotation=45) # Lowest value comparison plt.subplot(3, 1, 3) - sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Lowest', 'Prompt_eng_opt_Lowest'])) + sns.barplot(x='Metric', y='value', hue='variable', data=pd.melt(summary_df, id_vars=['Metric'], value_vars=['Baseline_Lowest', 'Optimized_Lowest'])) plt.title('Lowest Value Comparison') plt.xticks(rotation=45) @@ -66,22 +66,20 @@ def visualize_summary_statistics(self, summary_df): def calculate_deviations(self): """Calculates deviations between baseline and prompt engineering optimized DataFrames.""" - numeric_columns = ['answer_correctness', 'faithfulness', 'answer_relevancy', 'context_precision'] deviations = { 'question': self.baseline_df['question'], 'answer': self.baseline_df['answer'] } - for column in numeric_columns: - deviations[column + '_deviation'] = self.baseline_df[column] - self.prompt_eng_df[column] + for column in self.numeric_columns: + deviations[column + '_deviation'] = self.baseline_df[column] - self.optimized_df[column] deviations_df = pd.DataFrame(deviations) return deviations_df def visualize_deviations(self, deviations_df): """Visualizes the deviations using Altair.""" - numeric_columns = ['answer_correctness', 'faithfulness', 'answer_relevancy', 'context_precision'] - deviation_melted = deviations_df.melt(id_vars=['question', 'answer'], value_vars=[col + '_deviation' for col in numeric_columns], var_name='Metric', value_name='Deviation') + deviation_melted = deviations_df.melt(id_vars=['question', 'answer'], value_vars=[col + '_deviation' for col in self.numeric_columns], var_name='Metric', value_name='Deviation') # Create the Altair plot chart = alt.Chart(deviation_melted).mark_bar().encode( @@ -99,8 +97,8 @@ def visualize_deviations(self, deviations_df): # Example usage: # baseline_df = pd.read_csv('path_to_baseline.csv') -# prompt_eng_df = pd.read_csv('path_to_prompt_eng.csv') -# benchmark = BenchmarkAnalysis(baseline_df, prompt_eng_df) +# optimized_df = pd.read_csv('path_to_prompt_eng.csv') +# benchmark = BenchmarkAnalysis(baseline_df, optimized_df) # summary_df = benchmark.calculate_summary_statistics() # print(summary_df) # benchmark.visualize_summary_statistics(summary_df) diff --git a/src/env_loader.py b/src/env_loader.py index e0cfc97..bb92cc4 100644 --- a/src/env_loader.py +++ b/src/env_loader.py @@ -1,6 +1,7 @@ import os from dotenv import load_dotenv + def load_api_keys(key_name: str = None) -> str: """ Load an API keys from environment variables. If key_name is provided, load the API key with that name. @@ -20,6 +21,8 @@ def load_api_keys(key_name: str = None) -> str: api_key = os.getenv(key_name) if api_key is None: - raise ValueError(f"API key '{key_name}' not found in environment variables.") - - return api_key \ No newline at end of file + raise ValueError( + f"API key '{key_name}' not found in environment variables." + ) + + return api_key diff --git a/src/rag_pipeline/chunking_strategies.py b/src/rag_pipeline/chunking_strategies.py index 71d9e56..5761c91 100644 --- a/src/rag_pipeline/chunking_strategies.py +++ b/src/rag_pipeline/chunking_strategies.py @@ -1,8 +1,10 @@ from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import Document -from langchain_openai.embeddings import OpenAIEmbeddings -def chunk_by_recursive_split(documents: list[Document], chunk_size: int = 1000, chunk_overlap: int = 200) -> list[Document]: + +def chunk_by_recursive_split( + documents: list[Document], chunk_size: int = 1000, chunk_overlap: int = 200 +) -> list[Document]: """ Splits a list documents into chunks of a specified size using a recursive character-based approach. Splits are based purely on character count. @@ -26,4 +28,4 @@ def chunk_by_recursive_split(documents: list[Document], chunk_size: int = 1000, except Exception as e: print(f"Error during recursive split: {e}") chunks = [] # Ensure chunks is defined even in case of error - return chunks \ No newline at end of file + return chunks diff --git a/src/rag_pipeline/rag_system.py b/src/rag_pipeline/rag_system.py index 22d74de..e88da02 100644 --- a/src/rag_pipeline/rag_system.py +++ b/src/rag_pipeline/rag_system.py @@ -1,4 +1,4 @@ -from typing import List, Any +from typing import List, Any, Optional from dotenv import load_dotenv from langchain.retrievers import EnsembleRetriever from langchain.retrievers.multi_query import MultiQueryRetriever @@ -16,7 +16,7 @@ load_dotenv() -# constants - can be easily moved to a config file +# Constants - can be easily moved to a config file PG_CONNECTION_STRING = Settings.PG_CONNECTION_STRING COLLECTION_NAME = Settings.COLLECTION_NAME SOURCE_FILE_PATH = Settings.SOURCE_FILE_PATH @@ -27,11 +27,18 @@ class RAGSystem: def __init__( self, - config: Config = None, - embeddings: Any = None, + config: Config, + embeddings: Optional[Any] = None, source_file_path: str = SOURCE_FILE_PATH, ): - # pprint(config) + """ + Initialize the RAGSystem with configuration and optional embeddings. + + Args: + config (Config): Configuration object. + embeddings (Any, optional): Embeddings object. Defaults to OpenAIEmbeddings. + source_file_path (str, optional): Path to the source file. Defaults to SOURCE_FILE_PATH. + """ self.config = config self.generator_model = config.models.generator_model self.llm_queries_generator = ChatOpenAI( @@ -40,7 +47,7 @@ def __init__( self.llm = None self.source_file_path = source_file_path self.documents = [] - self.split_docs = List[Document] + self.split_docs = [] self.collection_name = config.vectorstore.collection_name self.embeddings = embeddings if embeddings else OpenAIEmbeddings() self.vectorstore = None @@ -61,118 +68,185 @@ def __init__( self.top_n_ranked = config.retrieval.top_n_ranked def load_documents(self): - documents = load_docs_from_csv(as_document=True) - self.documents = documents - - def prepare_documents(self, len_split_docs: int = 0): - split_docs = chunk_by_recursive_split( - self.documents, chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap - ) - if len_split_docs: - split_docs = split_docs[:len_split_docs] - print(f"--documents_no: {len(split_docs)}") - return split_docs + """Load documents from the source file.""" + try: + self.documents = load_docs_from_csv(as_document=True) + except Exception as e: + print(f"Error loading documents: {e}") + raise + + def prepare_documents(self, len_split_docs: int = 0) -> List[Document]: + """ + Prepare documents by chunking them. + + Args: + len_split_docs (int, optional): Number of split documents to return for testing purposes. Defaults to 0. + + Returns: + List[Document]: List of split documents. + """ + try: + split_docs = chunk_by_recursive_split( + self.documents, + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + ) + if len_split_docs: + split_docs = split_docs[:len_split_docs] + print(f"--documents_no: {len(split_docs)}") + return split_docs + except Exception as e: + print(f"Error preparing documents: {e}") + raise def initialize_vectorstore(self): - self.vectorstore = PGVector( - embeddings=self.embeddings, - collection_name=self.collection_name, - connection=PG_CONNECTION_STRING, - use_jsonb=True, - ) - return self.vectorstore + """Initialize the vectorstore.""" + try: + self.vectorstore = PGVector( + embeddings=self.embeddings, + collection_name=self.collection_name, + connection=PG_CONNECTION_STRING, + use_jsonb=True, + ) + except Exception as e: + print(f"Error initializing vectorstore: {e}") + raise def setup_vectorstore(self): - # Initialize the vectorstore - this could be an existing collection + """Setup the vectorstore, optionally clearing it first.""" self.initialize_vectorstore() if self.clear_store: - self.vectorstore.drop_tables() - # Reinitialize the vectorstore once the tables have been dropped - self.initialize_vectorstore() - - # TODO - calculate the embedding cost if using openai embeddings - # check instance of embeddings if OpenAIEmbeddings - # if isinstance(self.embeddings, OpenAIEmbeddings): - # calculate cost here - - # Get the existing vectorstore collection - - # Add documents to the vectorstore - - def setup_bm25_retriever(self, split_docs: List[str]): - self.bm25_retriever = BM25Retriever.from_documents(split_docs) - self.bm25_retriever.k = self.k_documents + try: + self.vectorstore.drop_tables() + self.initialize_vectorstore() + except Exception as e: + print(f"Error clearing vectorstore: {e}") + raise + + def setup_bm25_retriever(self, split_docs: List[Document]): + """Setup the BM25 retriever.""" + try: + self.bm25_retriever = BM25Retriever.from_documents(split_docs) + self.bm25_retriever.k = self.k_documents + except Exception as e: + print(f"Error setting up BM25 retriever: {e}") + raise def setup_base_retriever(self): - self.base_retriever = self.vectorstore.as_retriever( - search_kwargs={"k": self.k_documents} - ) - self.final_retriever = self.base_retriever + """Setup the base retriever.""" + try: + self.base_retriever = self.vectorstore.as_retriever( + search_kwargs={"k": self.k_documents} + ) + self.final_retriever = self.base_retriever + except Exception as e: + print(f"Error setting up base retriever: {e}") + raise def setup_ensemble_retriever(self): - base_retriever = self.vectorstore.as_retriever( - search_kwargs={"k": self.k_documents} - ) - self.ensemble_retriever = EnsembleRetriever( - retrievers=[self.bm25_retriever, base_retriever], weights=[0.5, 0.5] - ) - self.final_retriever = self.ensemble_retriever + """Setup the ensemble retriever.""" + try: + base_retriever = self.vectorstore.as_retriever( + search_kwargs={"k": self.k_documents} + ) + self.ensemble_retriever = EnsembleRetriever( + retrievers=[self.bm25_retriever, base_retriever], weights=[0.5, 0.5] + ) + self.final_retriever = self.ensemble_retriever + except Exception as e: + print(f"Error setting up ensemble retriever: {e}") + raise def setup_multiquery_retriever(self, retriever): - self.final_retriever = MultiQueryRetriever.from_llm( - retriever=retriever, - llm=self.llm_queries_generator, - ) + """Setup the multi-query retriever.""" + try: + self.final_retriever = MultiQueryRetriever.from_llm( + retriever=retriever, + llm=self.llm_queries_generator, + ) + except Exception as e: + print(f"Error setting up multi-query retriever: {e}") + raise def setup_reranker(self): - print("--SETUP RERANKER--") - my_reranker = Reranker( - retriever=self.final_retriever, - top_n=self.top_n_ranked, - use_cohere_reranker=self.use_cohere_reranker, - ) - self.final_retriever = my_reranker.initialize() + """Setup the reranker.""" + try: + print("--SETUP RERANKER--") + my_reranker = Reranker( + retriever=self.final_retriever, + top_n=self.top_n_ranked, + use_cohere_reranker=self.use_cohere_reranker, + ) + self.final_retriever = my_reranker.initialize() + except Exception as e: + print(f"Error setting up reranker: {e}") + raise def setup_llm(self): - self.llm = ChatOpenAI(model_name=self.generator_model, temperature=0) - - return self.llm + """Setup the language model.""" + try: + self.llm = ChatOpenAI(model_name=self.generator_model, temperature=0) + return self.llm + except Exception as e: + print(f"Error setting up LLM: {e}") + raise def setup_rag_chain(self): - print("--SETUP RAG CHAIN--") - llm = self.setup_llm() - self.rag_chain = rag_chain_setup(self.final_retriever, llm) - print("--RAGCHAIN SETUP COMPLETE!--") + """Setup the RAG chain.""" + try: + print("--SETUP RAG CHAIN--") + llm = self.setup_llm() + self.rag_chain = rag_chain_setup(self.final_retriever, llm) + print("--RAGCHAIN SETUP COMPLETE!--") + except Exception as e: + print(f"Error setting up RAG chain: {e}") + raise def query(self, question: str) -> str: - result = self.rag_chain.invoke(question) - return result["answer"] + """ + Query the RAG system. + + Args: + question (str): The question to query. + + Returns: + str: The answer from the RAG system. + """ + try: + result = self.rag_chain.invoke(question) + return result["answer"] + except Exception as e: + print(f"Error querying RAG system: {e}") + raise def initialize(self, len_split_docs: int = 0): - self.load_documents() - self.setup_vectorstore() - self.setup_base_retriever() - - if not self.use_existing_vectorstore: - print("--SETUP NEW VECTORSTORE--") - # Set up a new vectorstore - self.split_docs = self.prepare_documents(len_split_docs) - - self.vectorstore.add_documents(self.split_docs) - - if self.use_ensemble: - print("--USING ENSEMBLE RETRIEVER--") - self.setup_bm25_retriever(self.split_docs) - self.setup_ensemble_retriever() - elif self.use_multiquery: - print("--USING MULTIQUERY RETRIEVER--") - self.setup_multiquery_retriever(self.base_retriever) - else: - print("--USING BASE RETRIEVER--") - self.setup_base_retriever() - - if self.use_reranker: - self.setup_reranker() - - self.setup_rag_chain() + """Initialize the RAG system.""" + try: + self.load_documents() + self.setup_vectorstore() + self.setup_base_retriever() + + if not self.use_existing_vectorstore: + print("--SETUP NEW VECTORSTORE--") + self.split_docs = self.prepare_documents(len_split_docs) + self.vectorstore.add_documents(self.split_docs) + + if self.use_ensemble: + print("--USING ENSEMBLE RETRIEVER--") + self.setup_bm25_retriever(self.split_docs) + self.setup_ensemble_retriever() + elif self.use_multiquery: + print("--USING MULTIQUERY RETRIEVER--") + self.setup_multiquery_retriever(self.base_retriever) + else: + print("--USING BASE RETRIEVER--") + self.setup_base_retriever() + + if self.use_reranker: + self.setup_reranker() + + self.setup_rag_chain() + except Exception as e: + print(f"Error initializing RAG system: {e}") + raise diff --git a/src/rag_pipeline/rag_utils.py b/src/rag_pipeline/rag_utils.py index 69835a1..90b78bc 100644 --- a/src/rag_pipeline/rag_utils.py +++ b/src/rag_pipeline/rag_utils.py @@ -1,9 +1,13 @@ -from typing import Dict, List +from typing import List from langchain.prompts import PromptTemplate from langchain.schema import StrOutputParser from langchain.docstore.document import Document -from langchain.schema.runnable import RunnablePassthrough, RunnableParallel, RunnableLambda +from langchain.schema.runnable import ( + RunnablePassthrough, + RunnableParallel, + RunnableLambda, +) from misc import Settings @@ -44,15 +48,13 @@ def rag_chain_setup(retriever, llm) -> RunnableParallel: filter_langsmith_dataset = RunnableLambda( lambda x: x["question"] if isinstance(x, dict) else x ) - + rag_chain = RunnableParallel( { "question": filter_langsmith_dataset, "answer": filter_langsmith_dataset | context_retriever | generator, - "contexts": filter_langsmith_dataset - | retriever - | ragas_output_parser, + "contexts": filter_langsmith_dataset | retriever | ragas_output_parser, } ) - return rag_chain \ No newline at end of file + return rag_chain diff --git a/src/rag_pipeline/reranker.py b/src/rag_pipeline/reranker.py index 2ed8812..bab8a48 100644 --- a/src/rag_pipeline/reranker.py +++ b/src/rag_pipeline/reranker.py @@ -1,39 +1,61 @@ +from typing import Optional +import logging from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.document_compressors import CrossEncoderReranker from langchain_community.cross_encoders import HuggingFaceCrossEncoder from langchain_cohere import CohereRerank -from langchain_community.llms import Cohere + +# No need to import Cohere if not directly used within the class + class Reranker: - def __init__(self, retriever, top_n: int = 5, reranker_model = None, use_cohere_reranker: bool = False): - self.reranker_model = HuggingFaceCrossEncoder(model_name="BAAI/bge-reranker-base") if reranker_model is None else reranker_model - self.top_n = top_n - self.retriever = retriever - self.use_cohere_reranker = use_cohere_reranker - self.compression_retriever = None - self.compressor = None - - def setup_opensource_model(self): - print("--USING OPEN SOURCE MODEL FOR RERANKING--") - self.compressor = CrossEncoderReranker(model=self.reranker_model, top_n=3) - return self.compression_retriever - - def setup_cohere_model(self): - print("--USING COHERE MODEL FOR RERANKING--") - self.compressor = CohereRerank(model="rerank-english-v3.0") - return self.compression_retriever - - def setup_compression_retriever(self): - self.compression_retriever = ContextualCompressionRetriever( - base_compressor=self.compressor, base_retriever=self.retriever - ) - - def initialize(self): - if self.use_cohere_reranker: - self.setup_cohere_model() - else: - self.setup_opensource_model() - - self.setup_compression_retriever() - - return self.compression_retriever \ No newline at end of file + """ + Enhances document retrieval by re-ranking results based on relevance to a query. + Offers a choice between open-source or Cohere's commercial reranking model. + """ + + def __init__( + self, + retriever, + top_n: int = 5, + reranker_model: Optional[HuggingFaceCrossEncoder] = None, + use_cohere_reranker: bool = False, + ): + """ + Initializes the Reranker. + + Args: + retriever: The base document retriever to use. + top_n: The number of top-ranked documents to consider (default: 5). + reranker_model: A custom HuggingFaceCrossEncoder model (optional). + use_cohere_reranker: Whether to use Cohere's reranking model (default: False). + """ + + self.retriever = retriever + self.top_n = top_n + self.use_cohere_reranker = use_cohere_reranker + + # Initialize with default model or provided custom model + self.reranker_model = reranker_model or HuggingFaceCrossEncoder( + model_name="BAAI/bge-reranker-base" + ) + + # Initialize after deciding on the model + self._initialize_reranker() + + def _initialize_reranker(self): + """Initializes the appropriate reranking model and compression retriever.""" + + # Use logger instead of print statements for better logging practices + if self.use_cohere_reranker: + logging.info("Using Cohere model for reranking") + compressor = CohereRerank(model="rerank-english-v3.0") + else: + logging.info("Using open source model for reranking") + compressor = CrossEncoderReranker( + model=self.reranker_model, top_n=self.top_n + ) + + self.compression_retriever = ContextualCompressionRetriever( + base_compressor=compressor, base_retriever=self.retriever + ) diff --git a/src/ragas/ragas_pipeline.py b/src/ragas/ragas_pipeline.py index 551eda8..06d2fae 100644 --- a/src/ragas/ragas_pipeline.py +++ b/src/ragas/ragas_pipeline.py @@ -19,6 +19,7 @@ EVALUATION_DATASET_DESCRIPTION = Settings.EVALUATION_DATASET_DESCRIPTION RESULTS_DIR = Settings.RESULTS_DIR + def run_ragas_evaluation( rag_chain: Any, use_langsmith: bool = False, @@ -39,7 +40,7 @@ def run_ragas_evaluation( save_results (bool, optional): If True, saves the evaluation results to a CSV file. Defaults to False. dataset_description (str, optional): The description of the dataset to upload to LangSmith. Defaults to EVALUATION_DATASET_DESCRIPTION. upload_dataset_to_langsmith (bool, optional): If True, uploads the dataset to LangSmith. Defaults to False. - + Returns: pd.DataFrame: A DataFrame containing the evaluation results. @@ -50,7 +51,7 @@ def run_ragas_evaluation( answer_relevancy, context_precision, ] - + print("--LOADING EVALUATION DATA--") # Get the test set eval_data = load_evaluation_data() # Load your evaluation data @@ -60,20 +61,24 @@ def run_ragas_evaluation( print("--USING LANGSMITH FOR EVALUATION--") # Input validation for LangSmith usage if dataset_name is None or experiment_name is None: - raise ValueError("dataset_name and experiment_name are required when using LangSmith.") - + raise ValueError( + "dataset_name and experiment_name are required when using LangSmith." + ) + if upload_dataset_to_langsmith: # Check if dataset_description is provided - input validation if dataset_description is None: - raise ValueError("dataset_description is required when uploading dataset to LangSmith.") - + raise ValueError( + "dataset_description is required when uploading dataset to LangSmith." + ) + try: print("--UPLOADING DATASET TO LANGSMITH--") upload_dataset(testset, dataset_name, dataset_description) print("--DATASET UPLOADED TO LANGSMITH--") except Exception as e: print(f"Error uploading dataset: {e}") - + print("--EVALUATING ON LANGSMITH--") result = langsmith_evaluate( dataset_name=dataset_name, @@ -90,7 +95,7 @@ def run_ragas_evaluation( print("--EVALUATION COMPLETE--") df_results = result.to_pandas() - + if save_results and not use_langsmith: # TODO - place the save results logic in a separate function try: @@ -99,20 +104,21 @@ def run_ragas_evaluation( parent_dir = RESULTS_DIR if not os.path.exists(parent_dir): os.makedirs(parent_dir) - - df_results.to_csv(f"{parent_dir}/bm_{experiment_name}_results.csv", index=False) - + + df_results.to_csv( + f"{parent_dir}/bm_{experiment_name}_results.csv", index=False + ) + print("--RESULTS SAVED--") except Exception as e: print(f"An error occurred while saving results: {e}") - - + return df_results def get_context_and_answer( evaluation_data: List[Dict[str, List[str]]], - rag_chain, + rag_chain, ) -> List[Dict[str, str]]: """Retrieves context and generates answers for each question in the evaluation data. @@ -142,11 +148,11 @@ def get_context_and_answer( ): response = rag_chain.invoke(question) contexts_list = response["contexts"] - + results["question"].append(question) results["contexts"].append(contexts_list) results["answer"].append(response["answer"]) results["ground_truth"].append(ground_truth) - + dataset = Dataset.from_dict(results) - return dataset \ No newline at end of file + return dataset diff --git a/src/ragas/ragas_utils.py b/src/ragas/ragas_utils.py index 913aa0f..b4a13b6 100644 --- a/src/ragas/ragas_utils.py +++ b/src/ragas/ragas_utils.py @@ -8,6 +8,7 @@ EVALUAION_DATASET_NAME = Settings.EVALUAION_DATASET_NAME EVALUATION_DATASET_DESCRIPTION = Settings.EVALUATION_DATASET_DESCRIPTION + def load_evaluation_data(csv_file_path: str = EVALUATION_FILE_PATH) -> dict: """Loads evaluation data from a CSV file and returns questions and ground truths. @@ -34,21 +35,26 @@ def load_evaluation_data(csv_file_path: str = EVALUATION_FILE_PATH) -> dict: except pd.errors.EmptyDataError: raise pd.errors.EmptyDataError(f"The file at path '{csv_file_path}' is empty.") except pd.errors.ParserError: - raise pd.errors.ParserError(f"The file at path '{csv_file_path}' is malformed and cannot be parsed.") + raise pd.errors.ParserError( + f"The file at path '{csv_file_path}' is malformed and cannot be parsed." + ) # Check if required columns are present if "question" not in df.columns or "ground_truth" not in df.columns: - raise ValueError("The CSV file must contain 'question' and 'ground_truth' columns.") + raise ValueError( + "The CSV file must contain 'question' and 'ground_truth' columns." + ) questions = df["question"].tolist() ground_truths = df["ground_truth"].tolist() return {"questions": questions, "ground_truths": ground_truths} + def upload_csv_dataset_to_langsmith( csv_file_path: str = EVALUATION_FILE_PATH, - dataset_name: str = EVALUAION_DATASET_NAME, - dataset_desc: str = EVALUATION_DATASET_DESCRIPTION + dataset_name: str = EVALUAION_DATASET_NAME, + dataset_desc: str = EVALUATION_DATASET_DESCRIPTION, ) -> Dataset: """Uploads an evaluation dataset from a CSV file to LangSmith. @@ -60,9 +66,9 @@ def upload_csv_dataset_to_langsmith( Returns: Dataset: The uploaded dataset object. """ - + df = pd.read_csv(csv_file_path) eval_set = Dataset.from_pandas(df) dataset = upload_dataset(eval_set, dataset_name, dataset_desc) - return dataset \ No newline at end of file + return dataset diff --git a/src/utils/display_utils.py b/src/utils/display_utils.py index dccfaeb..bac0201 100644 --- a/src/utils/display_utils.py +++ b/src/utils/display_utils.py @@ -1,8 +1,6 @@ -from typing import List -from langchain.docstore.document import Document -from datasets import load_dataset import pandas as pd + def pretty_print_docs(docs): """ Prints the content of documents in a formatted manner. @@ -16,12 +14,16 @@ def pretty_print_docs(docs): try: # Check if docs is a list if not isinstance(docs, list): - raise ValueError("The 'docs' parameter should be a list of document objects.") + raise ValueError( + "The 'docs' parameter should be a list of document objects." + ) # Check if each document has 'page_content' attribute for d in docs: - if not hasattr(d, 'page_content'): - raise AttributeError("Each document object must have a 'page_content' attribute.") + if not hasattr(d, "page_content"): + raise AttributeError( + "Each document object must have a 'page_content' attribute." + ) # Print each document's content print( @@ -32,6 +34,7 @@ def pretty_print_docs(docs): except Exception as e: print(f"An error occurred while printing documents: {e}") + def display_df(df: pd.DataFrame, n_rows: int = 5, head_or_tail: str = "head") -> None: """ Displays a DataFrame in markdown format. @@ -39,7 +42,7 @@ def display_df(df: pd.DataFrame, n_rows: int = 5, head_or_tail: str = "head") -> Args: df (pd.DataFrame): The DataFrame to display. n_rows (int): The number of rows to display. Default is 5. - head_or_tail (str): Whether to display the head or tail of the DataFrame. + head_or_tail (str): Whether to display the head or tail of the DataFrame. Must be either 'head' or 'tail'. Default is 'head'. Returns: None. @@ -47,15 +50,19 @@ def display_df(df: pd.DataFrame, n_rows: int = 5, head_or_tail: str = "head") -> Raises: ValueError: If head_or_tail is not 'head' or 'tail'. """ - + if head_or_tail not in {"head", "tail"}: raise ValueError("head_or_tail must be either 'head' or 'tail'") - + if df.empty: print("DataFrame is empty") return - + if head_or_tail == "head": - print(df.head(n_rows).to_markdown(index=False, numalign="left", stralign="left")) + print( + df.head(n_rows).to_markdown(index=False, numalign="left", stralign="left") + ) else: # head_or_tail == "tail" - print(df.tail(n_rows).to_markdown(index=False, numalign="left", stralign="left")) \ No newline at end of file + print( + df.tail(n_rows).to_markdown(index=False, numalign="left", stralign="left") + ) diff --git a/tests/utils/test_load_docs.py b/tests/utils/test_load_docs.py index 00b684d..b27a3ac 100644 --- a/tests/utils/test_load_docs.py +++ b/tests/utils/test_load_docs.py @@ -1,12 +1,14 @@ +import os + +os.chdir("../../") + import unittest import pandas as pd from langchain.docstore.document import Document +from unittest.mock import patch # patch for mocking -import os -os.chdir("../../") from src.rag_pipeline.load_docs import load_docs_from_csv -from unittest.mock import patch # patch for mocking class TestLoadDocsFromCSV(unittest.TestCase): def setUp(self): @@ -21,6 +23,7 @@ def setUp(self): ], } ) + # TODO - ADD test for loading documents as list of strings def test_load_as_documents(self):