From 5e7719c0aea116ed61aaf6ef2f87df9c4b0c3434 Mon Sep 17 00:00:00 2001 From: ramnes Date: Mon, 14 Aug 2023 11:48:18 +0200 Subject: [PATCH] Make callbacks available at the main module level again, but lazily This keeps the imports fast, but with a few advantages over the current implementation: 1. we don't introduce a breaking change to the end users; 2. it's easier to import `chainlit` once and getting all things from there; 3. doing so also makes it easier to read the code and understand what comes from chainlit: `cl.LangchainCallbackHandler` makes it obvious that it's Chainlit's callback handler for Langchain (as opposed to `LangchainCallbackHandler`, without `cl.` in front). --- cypress/e2e/author_rename/main.py | 3 +-- cypress/e2e/haystack_cb/main.py | 5 +++-- cypress/e2e/langchain_cb/main_async.py | 3 +-- cypress/e2e/langchain_cb/main_sync.py | 3 +-- cypress/e2e/llama_index_cb/main.py | 3 +-- src/chainlit/__init__.py | 15 ++++++++++++++- src/chainlit/utils.py | 16 ++++++++++++++++ 7 files changed, 37 insertions(+), 11 deletions(-) diff --git a/cypress/e2e/author_rename/main.py b/cypress/e2e/author_rename/main.py index d6f4e7086d..e8ae325ec0 100644 --- a/cypress/e2e/author_rename/main.py +++ b/cypress/e2e/author_rename/main.py @@ -1,7 +1,6 @@ from langchain import LLMMathChain, OpenAI import chainlit as cl -from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler @cl.author_rename @@ -14,6 +13,6 @@ def rename(orig_author: str): async def main(message: str): llm = OpenAI(temperature=0) llm_math = LLMMathChain.from_llm(llm=llm) - res = await llm_math.acall(message, callbacks=[AsyncLangchainCallbackHandler()]) + res = await llm_math.acall(message, callbacks=[cl.AsyncLangchainCallbackHandler()]) await cl.Message(content="Hello").send() diff --git a/cypress/e2e/haystack_cb/main.py b/cypress/e2e/haystack_cb/main.py index 14bb2624e4..ed81a4915b 100644 --- a/cypress/e2e/haystack_cb/main.py +++ b/cypress/e2e/haystack_cb/main.py @@ -1,19 +1,20 @@ +print("hi") from haystack.agents import Agent, Tool from haystack.agents.agent_step import AgentStep from haystack.nodes import PromptNode import chainlit as cl -from chainlit.haystack.callbacks import HaystackAgentCallbackHandler @cl.on_chat_start async def start(): + print("start") await cl.Message(content="HaystackCb").send() fake_prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="fakekey") agent = Agent(fake_prompt_node) - cb = HaystackAgentCallbackHandler(agent) + cb = cl.HaystackAgentCallbackHandler(agent) cb.on_agent_start(name="agent") diff --git a/cypress/e2e/langchain_cb/main_async.py b/cypress/e2e/langchain_cb/main_async.py index 10bdcd090f..164b24b1a0 100644 --- a/cypress/e2e/langchain_cb/main_async.py +++ b/cypress/e2e/langchain_cb/main_async.py @@ -1,14 +1,13 @@ from langchain.schema import Generation, LLMResult, SystemMessage import chainlit as cl -from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler @cl.on_chat_start async def main(): await cl.Message(content="AsyncLangchainCb").send() - acb = AsyncLangchainCallbackHandler() + acb = cl.AsyncLangchainCallbackHandler() await acb.on_chain_start(serialized={"id": ["TestChain1"]}, inputs={}) diff --git a/cypress/e2e/langchain_cb/main_sync.py b/cypress/e2e/langchain_cb/main_sync.py index b3a9cf6fc2..78603d490c 100644 --- a/cypress/e2e/langchain_cb/main_sync.py +++ b/cypress/e2e/langchain_cb/main_sync.py @@ -1,14 +1,13 @@ from langchain.schema import Generation, LLMResult, SystemMessage import chainlit as cl -from chainlit.langchain.callbacks import LangchainCallbackHandler @cl.on_chat_start async def main(): await cl.Message(content="AsyncLangchainCb").send() - cb = LangchainCallbackHandler() + cb = cl.LangchainCallbackHandler() cb.on_chain_start(serialized={"id": ["TestChain1"]}, inputs={}) diff --git a/cypress/e2e/llama_index_cb/main.py b/cypress/e2e/llama_index_cb/main.py index 0479190393..23af1ccca8 100644 --- a/cypress/e2e/llama_index_cb/main.py +++ b/cypress/e2e/llama_index_cb/main.py @@ -2,14 +2,13 @@ from llama_index.schema import NodeWithScore, TextNode import chainlit as cl -from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler @cl.on_chat_start async def start(): await cl.Message(content="LlamaIndexCb").send() - cb = LlamaIndexCallbackHandler() + cb = cl.LlamaIndexCallbackHandler() cb.start_trace() diff --git a/src/chainlit/__init__.py b/src/chainlit/__init__.py index b44d3b0a63..d76202bd9d 100644 --- a/src/chainlit/__init__.py +++ b/src/chainlit/__init__.py @@ -32,7 +32,7 @@ from chainlit.telemetry import trace from chainlit.types import LLMSettings from chainlit.user_session import user_session -from chainlit.utils import wrap_user_function +from chainlit.utils import make_module_getattr, wrap_user_function from chainlit.version import __version__ env_found = load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env")) @@ -181,6 +181,15 @@ def sleep(duration: int): return asyncio.sleep(duration) +__getattr__ = make_module_getattr( + { + "LangchainCallbackHandler": "chainlit.langchain.callbacks", + "AsyncLangchainCallbackHandler": "chainlit.langchain.callbacks", + "LlamaIndexCallbackHandler": "chainlit.llama_index.callbacks", + "HaystackAgentCallbackHandler": "chainlit.haystack.callbacks", + } +) + __all__ = [ "user_session", "LLMSettings", @@ -213,4 +222,8 @@ def sleep(duration: int): "run_sync", "make_async", "cache", + "LangchainCallbackHandler", + "AsyncLangchainCallbackHandler", + "LlamaIndexCallbackHandler", + "HaystackAgentCallbackHandler", ] diff --git a/src/chainlit/utils.py b/src/chainlit/utils.py index dc2794ba61..31d6477a33 100644 --- a/src/chainlit/utils.py +++ b/src/chainlit/utils.py @@ -1,3 +1,4 @@ +import importlib import inspect from typing import Callable @@ -49,3 +50,18 @@ async def wrapper(*args): await emitter.task_end() return wrapper + + +def make_module_getattr(registry): + """Leverage PEP 562 to make imports lazy in an __init__.py + + The registry must be a dictionary with the items to import as keys and the + modules they belong to as a value. + """ + + def __getattr__(name): + module_path = registry[name] + module = importlib.import_module(module_path, __package__) + return getattr(module, name) + + return __getattr__