Skip to content

Commit

Permalink
Tools integration (#15)
Browse files Browse the repository at this point in the history
* Add internal tools to query data

Signed-off-by: Sanket <sanketsudake@gmail.com>

* Temp correction

Signed-off-by: Sanket <sanketsudake@gmail.com>

* partially working tools code

Signed-off-by: Sanket <sanketsudake@gmail.com>

* Cleanup code an reorganize app.py (#17)

* Cleanup code an reorganize app.py

Signed-off-by: Sanket <sanketsudake@gmail.com>

* cleanup retriever code

Signed-off-by: Sanket <sanketsudake@gmail.com>

---------

Signed-off-by: Sanket <sanketsudake@gmail.com>

* fix tool configurations (#18)

* fix tool configurations

* update tools base url in config files

---------

Co-authored-by: Sameer Kulkarni <sameer@acalvio.io>

---------

Signed-off-by: Sanket <sanketsudake@gmail.com>
Co-authored-by: Sameer Kulkarni <samkulkarni20@gmail.com>
Co-authored-by: Sameer Kulkarni <sameer@acalvio.io>
  • Loading branch information
3 people authored Aug 25, 2024
1 parent 3ce47df commit 900907d
Show file tree
Hide file tree
Showing 16 changed files with 735 additions and 118 deletions.
5 changes: 3 additions & 2 deletions .envrc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export RERANKER_PORT=80
export VECTORDB_HOST=192.168.0.207
export VECTORDB_PORT=8000
export ENABLE_RERANKER="1"
export TOOLS_BASE_URL="http://192.168.0.209"

# External services
export HUGGINGFACEHUB_API_TOKEN="$(cat ~/.hf_token)" #Replace with your own Hugging Face API token
Expand All @@ -19,9 +20,9 @@ export PORTKEY_CUSTOM_HOST="llm_provider_host_ip_and_port" #Only if LLM is local
export USE_PORTKEY="0"

# Model specific options
export MODEL_ID="qwen/Qwen2-7B-Instruct"
export STOP_TOKEN="<|endoftext|>"
export MAX_TOKENS=1024

# Streamlit configurations
export AUTH_CONFIG_FILE_PATH=".streamlit/config.yaml"
export STREAMLIT_CLIENT_SHOW_ERROR_DETAILS=False
export STREAMLIT_SERVER_HEADLESS=True
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
__pycache__/**
insf_venv/**
*.pyc
136 changes: 43 additions & 93 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@
import uuid
import datasets
from langchain_huggingface import HuggingFaceEndpointEmbeddings
from langchain_community.chat_models import ChatHuggingFace
from langchain_community.llms import HuggingFaceEndpoint
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.agents import create_react_agent, AgentExecutor
from langchain.tools.retriever import create_retriever_tool
from langchain_community.tools.tavily_search import TavilySearchResults
from langchain_community.utilities import StackExchangeAPIWrapper
from langchain_community.tools.stackexchange.tool import StackExchangeTool
from langchain_core.messages import SystemMessage
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_core.documents import Document
Expand All @@ -19,19 +15,24 @@
import chromadb
from chromadb.config import Settings
from chromadb.utils.embedding_functions import HuggingFaceEmbeddingServer

from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
from urllib3.exceptions import ProtocolError
from langchain.retrievers import ContextualCompressionRetriever
from tei_rerank import TEIRerank
from transformers import AutoTokenizer

from tools import get_tools
from tei_rerank import TEIRerank

import streamlit as st
import streamlit_authenticator as stauth

import yaml
from yaml.loader import SafeLoader

from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type
from urllib3.exceptions import ProtocolError
from langchain.globals import set_verbose, set_debug

set_verbose(True)
set_debug(True)

st.set_page_config(layout="wide", page_title="InSightful")

Expand Down Expand Up @@ -80,24 +81,15 @@ def hf_embedding_server():

# Set up HuggingFaceEndpoint model
@st.cache_resource
def setup_huggingface_endpoint(model_id):
llm = HuggingFaceEndpoint(
endpoint_url="http://{host}:{port}".format(
def setup_chat_endpoint():
model = ChatOpenAI(
base_url="http://{host}:{port}/v1".format(
host=os.getenv("TGI_HOST", "localhost"), port=os.getenv("TGI_PORT", "8080")
),
temperature=0.3,
task="conversational",
stop_sequences=[
"<|im_end|>",
"<|eot_id|>",
"{your_token}".format(
your_token=os.getenv("STOP_TOKEN", "<|end_of_text|>")
),
],
max_tokens=os.getenv("MAX_TOKENS", 1024),
temperature=0.7,
api_key="dummy",
)

model = ChatHuggingFace(llm=llm, model_id=model_id)

return model


Expand Down Expand Up @@ -159,8 +151,6 @@ def load_prompt_and_system_ins(

class RAG:
def __init__(self, collection_name, db_client):
# self.llm = llm
# self.embedding_svc = embedding_svc
self.collection_name = collection_name
self.db_client = db_client

Expand Down Expand Up @@ -208,17 +198,9 @@ def insert_embeddings(self, chunks, chroma_embedding_function, batch_size=32):
documents = [chunk.page_content for chunk in batch]

collection.add(ids=chunk_ids, metadatas=metadatas, documents=documents)
# db = Chroma(
# embedding_function=embedder,
# collection_name=self.collection_name,
# client=self.db_client,
# )
print("Embeddings inserted\n")
# return db

def query_docs(
self, model, question, vector_store, prompt, chat_history, use_reranker=False
):
def get_retriever(self, vector_store, use_reranker=False):
retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 10}
)
Expand All @@ -234,7 +216,12 @@ def query_docs(
retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
return retriever

def query_docs(
self, model, question, vector_store, prompt, chat_history, use_reranker=False
):
retriever = self.get_retriever(vector_store, use_reranker)
pass_question = lambda input: input["question"]
rag_chain = (
RunnablePassthrough.assign(context=pass_question | retriever | format_docs)
Expand All @@ -245,6 +232,7 @@ def query_docs(

return rag_chain.stream({"question": question, "chat_history": chat_history})


def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)

Expand All @@ -262,70 +250,22 @@ def create_retriever(
collection_name=collection_name,
client=client,
)
if reranker:
compressor = TEIRerank(
url="http://{host}:{port}".format(
host=os.getenv("RERANKER_HOST", "localhost"),
port=os.getenv("RERANKER_PORT", "8082"),
),
top_n=10,
batch_size=16,
)

retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 100}
)
compression_retriever = ContextualCompressionRetriever(
base_compressor=compressor, base_retriever=retriever
)
info_retriever = create_retriever_tool(compression_retriever, name, description)
else:
retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 10}
)
info_retriever = create_retriever_tool(retriever, name, description)

return info_retriever


def setup_tools(_model, _client, _chroma_embedding_function, _embedder):
tools = []
if (
os.getenv("STACK_OVERFLOW_API_KEY")
and os.getenv("STACK_OVERFLOW_API_KEY").strip()
):
stackexchange_wrapper = StackExchangeAPIWrapper(max_results=3)
stackexchange_tool = StackExchangeTool(api_wrapper=stackexchange_wrapper)
tools.append(stackexchange_tool)

if os.getenv("TAVILY_API_KEY") and os.getenv("TAVILY_API_KEY").strip():
web_search_tool = TavilySearchResults(max_results=10, handle_tool_error=True)
tools.append(web_search_tool)
retriever = rag.get_retriever(vector_store, use_reranker=reranker)

use_reranker = os.getenv("USE_RERANKER", "False") == "True"
retriever = create_retriever(
"slack_conversations_retriever",
"Useful for when you need to answer from Slack conversations.",
_client,
_chroma_embedding_function,
_embedder,
reranker=use_reranker,
retriever = vector_store.as_retriever(
search_type="similarity", search_kwargs={"k": 10}
)
tools.append(retriever)

return tools

return create_retriever_tool(retriever, name, description)

@st.cache_resource
def setup_agent(_model, _prompt, _client, _chroma_embedding_function, _embedder):
tools = setup_tools(_model, _client, _chroma_embedding_function, _embedder)
def setup_agent(_model, _prompt, _tools):
agent = create_react_agent(
llm=_model,
prompt=_prompt,
tools=tools,
tools=_tools,
)
agent_executor = AgentExecutor(
agent=agent, verbose=True, tools=tools, handle_parsing_errors=True
agent=agent, verbose=True, tools=_tools, handle_parsing_errors=True
)
return agent_executor

Expand All @@ -337,12 +277,22 @@ def main():
if os.getenv("ENABLE_PORTKEY", "False") == "True":
model = setup_portkey_integrated_model()
else:
model = setup_huggingface_endpoint(model_id=os.getenv("MODEL_ID"))
model = setup_chat_endpoint()
embedder = setup_huggingface_embeddings()
use_reranker = os.getenv("USE_RERANKER", "False") == "True"

agent_executor = setup_agent(
model, prompt, client, chroma_embedding_function, embedder
retriever_tool = create_retriever(
"slack_conversations_retriever",
"Useful for when you need to answer from Slack conversations.",
client,
chroma_embedding_function,
embedder,
reranker=use_reranker,
)
_tools = get_tools()
_tools.append(retriever_tool)

agent_executor = setup_agent(model, prompt, _tools)

st.title("InSightful: Your AI Assistant for community questions")
st.text("Made with ❤️ by InfraCloud Technologies")
Expand Down
4 changes: 1 addition & 3 deletions k8s-manifests/env.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,10 @@ data:
RERANKER_PORT: "80"
VECTORDB_HOST: "ai-stack-vectordb"
VECTORDB_PORT: "8000"
STOP_TOKEN: "<|endoftext|>"
TOOLS_BASE_URL: "http://192.168.0.209"
PORTKEY_PROVIDER: "llm_provider_name"
PORTKEY_CUSTOM_HOST: "llm_provider_host_ip_and_port"
USE_PORTKEY: "0"
USE_RERANKER: "1"
AUTH_CONFIG_FILE_PATH: "/opt/auth-config/config.yaml"
MODEL_ID: "meta-llama/Meta-Llama-3.1-8B-Instruct"
STOP_TOKEN: "<|endoftext|>"
STREAMLIT_CLIENT_SHOW_ERROR_DETAILS: False
64 changes: 44 additions & 20 deletions multi_tenant_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,21 @@
from langchain_community.document_loaders import PyPDFLoader
from langchain_community.vectorstores.chroma import Chroma
from unstructured.cleaners.core import clean_extra_whitespace, group_broken_paragraphs
from tools import get_tools

from app import (
setup_chroma_client,
hf_embedding_server,
load_prompt_and_system_ins,
setup_huggingface_embeddings,
setup_huggingface_endpoint,
setup_chat_endpoint,
RAG,
setup_agent,
)

logging.basicConfig(format="%(asctime)s - %(message)s", level=logging.INFO)


def configure_authenticator():
auth_config = os.getenv("AUTH_CONFIG_FILE_PATH", default=".streamlit/config.yaml")
print(f"auth_config: {auth_config}")
Expand Down Expand Up @@ -81,7 +84,12 @@ def load_documents(self, doc):


def main():
llm = setup_huggingface_endpoint(model_id=os.getenv("MODEL_ID"))
use_reranker = st.sidebar.toggle("Use reranker", False)
use_tools = st.sidebar.toggle("Use tools", False)
uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"])
question = st.chat_input("Chat with your docs or apis")

llm = setup_chat_endpoint()

embedding_svc = setup_huggingface_embeddings()

Expand All @@ -97,8 +105,12 @@ def main():
Be concise and always provide accurate, specific, and relevant information.
"""

template_file_path = "templates/multi_tenant_rag_prompt_template.tmpl"
if use_tools:
template_file_path = "templates/multi_tenant_rag_prompt_template_tools.tmpl"

prompt, system_instructions = load_prompt_and_system_ins(
template_file_path="templates/multi_tenant_rag_prompt_template.tmpl",
template_file_path=template_file_path,
template=template,
)

Expand All @@ -118,15 +130,16 @@ def main():
f"user-collection-{user_id}", embedding_function=chroma_embeddings
)

use_reranker = st.sidebar.toggle("Use reranker", False)
use_tools = st.sidebar.toggle("Use tools", False)
uploaded_file = st.sidebar.file_uploader("Upload a document", type=["pdf"])
question = st.chat_input("Chat with your doc")

logger = logging.getLogger(__name__)
logger.info(f"user_id: {user_id} use_reranker: {use_reranker} use_tools: {use_tools} question: {question}")
logger.info(
f"user_id: {user_id} use_reranker: {use_reranker} use_tools: {use_tools} question: {question}"
)
rag = MultiTenantRAG(user_id, collection.name, client)

if use_tools:
tools = get_tools()
agent_executor = setup_agent(llm, prompt, tools)

# prompt = hub.pull("rlm/rag-prompt")

vectorstore = Chroma(
Expand All @@ -147,17 +160,28 @@ def main():
if question:
st.chat_message("user").markdown(question)
with st.spinner():
answer = rag.query_docs(
model=llm,
question=question,
vector_store=vectorstore,
prompt=prompt,
chat_history=chat_history,
use_reranker=use_reranker,
)
with st.chat_message("assistant"):
answer = st.write_stream(answer)
logger.info(f"answer: {answer}")
if use_tools:
answer = agent_executor.invoke(
{
"question": question,
"chat_history": chat_history,
}
)["output"]
with st.chat_message("assistant"):
answer = st.write(answer)
logger.info(f"answer: {answer}")
else:
answer = rag.query_docs(
model=llm,
question=question,
vector_store=vectorstore,
prompt=prompt,
chat_history=chat_history,
use_reranker=use_reranker,
)
with st.chat_message("assistant"):
answer = st.write_stream(answer)
logger.info(f"answer: {answer}")

chat_history.append({"role": "user", "content": question})
chat_history.append({"role": "assistant", "content": answer})
Expand Down
Loading

0 comments on commit 900907d

Please sign in to comment.