Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update callback imports #18

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion chroma-qa-chat/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
HumanMessagePromptTemplate,
)
import chainlit as cl
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler


text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100)
Expand Down Expand Up @@ -86,7 +87,7 @@ async def init():
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain") # type: RetrievalQAWithSourcesChain
cb = cl.AsyncLangchainCallbackHandler(
cb = AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
Expand Down
3 changes: 2 additions & 1 deletion haystack/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from haystack.pipelines import DocumentSearchPipeline

import chainlit as cl
from chainlit.haystack.callbacks import HaystackAgentCallbackHandler

load_dotenv()

Expand Down Expand Up @@ -87,7 +88,7 @@ def get_agent(retriever):

retriever = get_retriever()
agent = get_agent(retriever)
cl.HaystackAgentCallbackHandler(agent)
HaystackAgentCallbackHandler(agent)


@cl.author_rename
Expand Down
9 changes: 5 additions & 4 deletions image-gen/app.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import chainlit as cl
from chainlit.action import Action

from tools import generate_image_tool, edit_image_tool
from langchain.agents import initialize_agent, AgentType, AgentExecutor
from langchain.chat_models import ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.agents.structured_chat.prompt import SUFFIX

import chainlit as cl
from chainlit.action import Action
from chainlit.langchain.callbacks import LangchainCallbackHandler


@cl.action_callback("Create variation")
async def create_variant(action: Action):
Expand Down Expand Up @@ -50,7 +51,7 @@ async def main(message):

# No async implementation in the Stability AI client, fallback to sync
res = await cl.make_async(agent.run)(
input=message, callbacks=[cl.LangchainCallbackHandler()]
input=message, callbacks=[LangchainCallbackHandler()]
)

elements = []
Expand Down
4 changes: 3 additions & 1 deletion langchain-aiplugins/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from langchain.agents import load_tools, initialize_agent, AgentExecutor
from langchain.agents import AgentType
from langchain.tools import AIPluginTool

import chainlit as cl
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler


@cl.on_chat_start
Expand All @@ -24,6 +26,6 @@ def start():
@cl.on_message
async def main(message):
agent = cl.user_session.get("agent") # type: AgentExecutor
res = await agent.arun(message, callbacks=[cl.AsyncLangchainCallbackHandler()])
res = await agent.arun(message, callbacks=[AsyncLangchainCallbackHandler()])

await cl.Message(content=res).send()
3 changes: 2 additions & 1 deletion langchain-ask-human/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import chainlit as cl
from chainlit.sync import run_sync
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler


class HumanInputChainlit(BaseTool):
Expand Down Expand Up @@ -63,5 +64,5 @@ def start():
@cl.on_message
async def main(message):
agent = cl.user_session.get("agent") # type: AgentExecutor
res = await agent.arun(message, callbacks=[cl.AsyncLangchainCallbackHandler()])
res = await agent.arun(message, callbacks=[AsyncLangchainCallbackHandler()])
await cl.Message(content=res).send()
5 changes: 3 additions & 2 deletions langflow/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from langchain.agents import AgentExecutor

from chainlit.langflow import load_flow
import chainlit as cl
from chainlit.langflow import load_flow
from chainlit.langchain.callbacks import LangchainCallbackHandler


with open("./schema.json", "r") as f:
Expand All @@ -26,7 +27,7 @@ async def main(message):

# Run the flow
res = await cl.make_async(flow.run)(
message, callbacks=[cl.LangchainCallbackHandler()]
message, callbacks=[LangchainCallbackHandler()]
)

# Send the response
Expand Down
7 changes: 5 additions & 2 deletions llama-index-googledocs-qa/app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os

from llama_index import download_loader
from llama_index import ServiceContext, VectorStoreIndex,LangchainEmbedding, PromptHelper, LLMPredictor
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
Expand All @@ -11,7 +12,9 @@
)
from llama_index.callbacks.base import CallbackManager
import openai

import chainlit as cl
from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler


openai.api_key = os.environ.get("OPENAI_API_KEY")
Expand All @@ -21,7 +24,7 @@
def load_context():
try:
# Rebuild the storage context
storage_context = StorageContext.from_defaults(persist_dir="./storage", callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]))
storage_context = StorageContext.from_defaults(persist_dir="./storage", callback_manager=CallbackManager([LlamaIndexCallbackHandler()]))
# Load the index
index = load_index_from_storage(storage_context, storage_context=storage_context)
except:
Expand Down Expand Up @@ -53,7 +56,7 @@ def load_context():
# embed_model=embed_model, ## (optional)
# node_parser=node_parser, ## (optional)
prompt_helper=prompt_helper,
callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
callback_manager=CallbackManager([LlamaIndexCallbackHandler()]),

)
index = VectorStoreIndex.from_documents(documents,service_context=service_context)
Expand Down
5 changes: 3 additions & 2 deletions llama-index/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
load_index_from_storage,
)
from langchain.chat_models import ChatOpenAI
import chainlit as cl

import chainlit as cl
from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler

openai.api_key = os.environ.get("OPENAI_API_KEY")

Expand Down Expand Up @@ -43,7 +44,7 @@ async def factory():
service_context = ServiceContext.from_defaults(
llm_predictor=llm_predictor,
chunk_size=512,
callback_manager=CallbackManager([cl.LlamaIndexCallbackHandler()]),
callback_manager=CallbackManager([LlamaIndexCallbackHandler()]),
)

query_engine = index.as_query_engine(
Expand Down
3 changes: 2 additions & 1 deletion pdf-qa/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import chainlit as cl
from chainlit.types import AskFileResponse
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler

pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
Expand Down Expand Up @@ -103,7 +104,7 @@ async def start():
@cl.on_message
async def main(message):
chain = cl.user_session.get("chain") # type: RetrievalQAWithSourcesChain
cb = cl.AsyncLangchainCallbackHandler(
cb = AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
Expand Down
4 changes: 3 additions & 1 deletion pinecone/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from langchain.chains import RetrievalQAWithSourcesChain
from langchain.chat_models import ChatOpenAI
import pinecone

import chainlit as cl
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler

pinecone.init(
api_key=os.environ.get("PINECONE_API_KEY"),
Expand Down Expand Up @@ -43,7 +45,7 @@ async def start():
async def main(message):
chain = cl.user_session.get("chain") # type: RetrievalQAWithSourcesChain

cb = cl.AsyncLangchainCallbackHandler(
cb = AsyncLangchainCallbackHandler(
stream_final_answer=True, answer_prefix_tokens=["FINAL", "ANSWER"]
)
cb.answer_reached = True
Expand Down