diff --git a/app.py b/app.py index 7ff5a04..507f69e 100644 --- a/app.py +++ b/app.py @@ -196,13 +196,16 @@ def query_docs(self, model, question, vector_store, prompt, chat_history): retriever = vector_store.as_retriever( search_type="similarity", search_kwargs={"k": 4} ) + pass_question = lambda input: input["question"] rag_chain = ( - {"context": retriever | format_docs, "chat_history": chat_history, "question": RunnablePassthrough()} + RunnablePassthrough.assign( + context= pass_question | retriever | format_docs + ) | prompt | model | StrOutputParser() ) - answer = rag_chain.invoke(question) + answer = rag_chain.invoke({"question": question, "chat_history": chat_history}) return answer def format_docs(docs): diff --git a/multi_tenant_rag.py b/multi_tenant_rag.py index 0d385e3..29ba011 100644 --- a/multi_tenant_rag.py +++ b/multi_tenant_rag.py @@ -109,7 +109,6 @@ def main(): if question := st.chat_input("Chat with your doc"): st.chat_message("user").markdown(question) - chat_history.append({"role": "user", "content": question}) with st.spinner(): answer = rag.query_docs(model=llm, question=question, @@ -118,6 +117,9 @@ def main(): chat_history=chat_history) print("####\n#### Answer received by querying docs: " + answer + "\n####") st.chat_message("assistant").markdown(answer) + chat_history.append({"role": "user", "content": question}) + chat_history.append({"role": "assistant", "content": answer}) + st.session_state["chat_history"] = chat_history if __name__ == "__main__": authenticator = authenticate("login")