From 386f9da3352ef5b8d149564cb6605c6f326652d0 Mon Sep 17 00:00:00 2001 From: Sanket Sudake Date: Sat, 24 Aug 2024 12:59:19 +0530 Subject: [PATCH] Add flags for reranker/tools (#14) Signed-off-by: Sanket --- .envrc | 10 ++++++++-- app.py | 2 +- multi_tenant_rag.py | 28 ++++++++++------------------ 3 files changed, 19 insertions(+), 21 deletions(-) diff --git a/.envrc b/.envrc index e009bf5..490e3b1 100644 --- a/.envrc +++ b/.envrc @@ -1,3 +1,4 @@ +# Internal services export TGI_HOST=192.168.0.203 export TGI_PORT=80 export TEI_HOST=192.168.0.204 @@ -6,7 +7,9 @@ export RERANKER_HOST=192.168.0.205 export RERANKER_PORT=80 export VECTORDB_HOST=192.168.0.207 export VECTORDB_PORT=8000 -export STOP_TOKEN="<|endoftext|>" +export ENABLE_RERANKER="1" + +# External services export HUGGINGFACEHUB_API_TOKEN="$(cat ~/.hf_token)" #Replace with your own Hugging Face API token export TAVILY_API_KEY="$(cat ~/.tavily_token)" #Replace with your own Tavily API key export STACK_OVERFLOW_API_KEY="$(cat ~/.stack_exchange_token)" #Replace with your own Stack Exchange API key @@ -14,4 +17,7 @@ export PORTKEY_API_KEY="portkey_api_key" #Replace with your own Portkey API key export PORTKEY_PROVIDER="llm_provider_name" export PORTKEY_CUSTOM_HOST="llm_provider_host_ip_and_port" #Only if LLM is locally hosted export USE_PORTKEY="0" -export ENABLE_RERANKER="1" + +# Model specific options +export MODEL_ID="qwen/Qwen2-7B-Instruct" +export STOP_TOKEN="<|endoftext|>" diff --git a/app.py b/app.py index 0bee20b..8f17f50 100644 --- a/app.py +++ b/app.py @@ -337,7 +337,7 @@ def main(): if os.getenv("ENABLE_PORTKEY", "False") == "True": model = setup_portkey_integrated_model() else: - model = setup_huggingface_endpoint(model_id="qwen/Qwen2-7B-Instruct") + model = setup_huggingface_endpoint(model_id=os.getenv("MODEL_ID")) embedder = setup_huggingface_embeddings() agent_executor = setup_agent( diff --git a/multi_tenant_rag.py b/multi_tenant_rag.py index 7c4bdd2..9893e3a 100644 --- a/multi_tenant_rag.py +++ b/multi_tenant_rag.py @@ -1,4 +1,5 @@ import os +import logging import tempfile import yaml from yaml.loader import SafeLoader @@ -18,6 +19,8 @@ RAG, ) +logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO) + def configure_authenticator(): with open(".streamlit/config.yaml") as file: config = yaml.load(file, Loader=SafeLoader) @@ -75,7 +78,7 @@ def load_documents(self, doc): def main(): - llm = setup_huggingface_endpoint(model_id="qwen/Qwen2-7B-Instruct") + llm = setup_huggingface_endpoint(model_id=os.getenv("MODEL_ID")) embedding_svc = setup_huggingface_embeddings() @@ -112,9 +115,13 @@ def main(): f"user-collection-{user_id}", embedding_function=chroma_embeddings ) + use_reranker = st.sidebar.toggle("Use reranker", False) + use_tools = st.sidebar.toggle("Use tools", False) uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"]) question = st.chat_input("Chat with your doc") + logger = logging.getLogger(__name__) + logger.info(f"user_id: {user_id} use_reranker: {use_reranker} use_tools: {use_tools} question: {question}") rag = MultiTenantRAG(user_id, collection.name, client) # prompt = hub.pull("rlm/rag-prompt") @@ -131,7 +138,6 @@ def main(): rag.insert_embeddings( chunks=chunks, chroma_embedding_function=chroma_embeddings, - # embedder=embedding_svc, batch_size=32, ) @@ -144,25 +150,11 @@ def main(): vector_store=vectorstore, prompt=prompt, chat_history=chat_history, - use_reranker=False, + use_reranker=use_reranker, ) with st.chat_message("assistant"): answer = st.write_stream(answer) - # print( - # "####\n#### Answer received by querying docs: " + answer + "\n####" - # ) - - # answer_with_reranker = rag.query_docs( - # model=llm, - # question=question, - # vector_store=vectorstore, - # prompt=prompt, - # chat_history=chat_history, - # use_reranker=True, - # ) - - # st.chat_message("assistant").markdown(answer) - # st.chat_message("assistant").markdown(answer_with_reranker) + logger.info(f"answer: {answer}") chat_history.append({"role": "user", "content": question}) chat_history.append({"role": "assistant", "content": answer})