Skip to content

Commit

Permalink
Make callbacks available at the main module level again, but lazily
Browse files Browse the repository at this point in the history
This keeps the imports fast, but with a few advantages over the
current implementation:

1. we don't introduce a breaking change to the end users;
2. it's easier to import `chainlit` once and getting all things from
there;
3. doing so also makes it easier to read the code and understand what
comes from chainlit: `cl.LangchainCallbackHandler` makes it obvious
that it's Chainlit's callback handler for Langchain (as opposed to
`LangchainCallbackHandler`, without `cl.` in front).
  • Loading branch information
ramnes committed Aug 14, 2023
1 parent ce8eb07 commit 5e7719c
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 11 deletions.
3 changes: 1 addition & 2 deletions cypress/e2e/author_rename/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from langchain import LLMMathChain, OpenAI

import chainlit as cl
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler


@cl.author_rename
Expand All @@ -14,6 +13,6 @@ def rename(orig_author: str):
async def main(message: str):
llm = OpenAI(temperature=0)
llm_math = LLMMathChain.from_llm(llm=llm)
res = await llm_math.acall(message, callbacks=[AsyncLangchainCallbackHandler()])
res = await llm_math.acall(message, callbacks=[cl.AsyncLangchainCallbackHandler()])

await cl.Message(content="Hello").send()
5 changes: 3 additions & 2 deletions cypress/e2e/haystack_cb/main.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
print("hi")
from haystack.agents import Agent, Tool
from haystack.agents.agent_step import AgentStep
from haystack.nodes import PromptNode

import chainlit as cl
from chainlit.haystack.callbacks import HaystackAgentCallbackHandler


@cl.on_chat_start
async def start():
print("start")
await cl.Message(content="HaystackCb").send()

fake_prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="fakekey")

agent = Agent(fake_prompt_node)
cb = HaystackAgentCallbackHandler(agent)
cb = cl.HaystackAgentCallbackHandler(agent)

cb.on_agent_start(name="agent")

Expand Down
3 changes: 1 addition & 2 deletions cypress/e2e/langchain_cb/main_async.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from langchain.schema import Generation, LLMResult, SystemMessage

import chainlit as cl
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler


@cl.on_chat_start
async def main():
await cl.Message(content="AsyncLangchainCb").send()

acb = AsyncLangchainCallbackHandler()
acb = cl.AsyncLangchainCallbackHandler()

await acb.on_chain_start(serialized={"id": ["TestChain1"]}, inputs={})

Expand Down
3 changes: 1 addition & 2 deletions cypress/e2e/langchain_cb/main_sync.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from langchain.schema import Generation, LLMResult, SystemMessage

import chainlit as cl
from chainlit.langchain.callbacks import LangchainCallbackHandler


@cl.on_chat_start
async def main():
await cl.Message(content="AsyncLangchainCb").send()

cb = LangchainCallbackHandler()
cb = cl.LangchainCallbackHandler()

cb.on_chain_start(serialized={"id": ["TestChain1"]}, inputs={})

Expand Down
3 changes: 1 addition & 2 deletions cypress/e2e/llama_index_cb/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from llama_index.schema import NodeWithScore, TextNode

import chainlit as cl
from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler


@cl.on_chat_start
async def start():
await cl.Message(content="LlamaIndexCb").send()

cb = LlamaIndexCallbackHandler()
cb = cl.LlamaIndexCallbackHandler()

cb.start_trace()

Expand Down
15 changes: 14 additions & 1 deletion src/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from chainlit.telemetry import trace
from chainlit.types import LLMSettings
from chainlit.user_session import user_session
from chainlit.utils import wrap_user_function
from chainlit.utils import make_module_getattr, wrap_user_function
from chainlit.version import __version__

env_found = load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))
Expand Down Expand Up @@ -181,6 +181,15 @@ def sleep(duration: int):
return asyncio.sleep(duration)


__getattr__ = make_module_getattr(
{
"LangchainCallbackHandler": "chainlit.langchain.callbacks",
"AsyncLangchainCallbackHandler": "chainlit.langchain.callbacks",
"LlamaIndexCallbackHandler": "chainlit.llama_index.callbacks",
"HaystackAgentCallbackHandler": "chainlit.haystack.callbacks",
}
)

__all__ = [
"user_session",
"LLMSettings",
Expand Down Expand Up @@ -213,4 +222,8 @@ def sleep(duration: int):
"run_sync",
"make_async",
"cache",
"LangchainCallbackHandler",
"AsyncLangchainCallbackHandler",
"LlamaIndexCallbackHandler",
"HaystackAgentCallbackHandler",
]
16 changes: 16 additions & 0 deletions src/chainlit/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import inspect
from typing import Callable

Expand Down Expand Up @@ -49,3 +50,18 @@ async def wrapper(*args):
await emitter.task_end()

return wrapper


def make_module_getattr(registry):
"""Leverage PEP 562 to make imports lazy in an __init__.py
The registry must be a dictionary with the items to import as keys and the
modules they belong to as a value.
"""

def __getattr__(name):
module_path = registry[name]
module = importlib.import_module(module_path, __package__)
return getattr(module, name)

return __getattr__

0 comments on commit 5e7719c

Please sign in to comment.