Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make callbacks available at the main module level again, but lazily #277

Merged
merged 1 commit into from
Aug 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions cypress/e2e/author_rename/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from langchain import LLMMathChain, OpenAI

import chainlit as cl
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler


@cl.author_rename
Expand All @@ -14,6 +13,6 @@ def rename(orig_author: str):
async def main(message: str):
llm = OpenAI(temperature=0)
llm_math = LLMMathChain.from_llm(llm=llm)
res = await llm_math.acall(message, callbacks=[AsyncLangchainCallbackHandler()])
res = await llm_math.acall(message, callbacks=[cl.AsyncLangchainCallbackHandler()])

await cl.Message(content="Hello").send()
3 changes: 1 addition & 2 deletions cypress/e2e/haystack_cb/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from haystack.nodes import PromptNode

import chainlit as cl
from chainlit.haystack.callbacks import HaystackAgentCallbackHandler


@cl.on_chat_start
Expand All @@ -13,7 +12,7 @@ async def start():
fake_prompt_node = PromptNode(model_name_or_path="gpt-3.5-turbo", api_key="fakekey")

agent = Agent(fake_prompt_node)
cb = HaystackAgentCallbackHandler(agent)
cb = cl.HaystackAgentCallbackHandler(agent)

cb.on_agent_start(name="agent")

Expand Down
3 changes: 1 addition & 2 deletions cypress/e2e/langchain_cb/main_async.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from langchain.schema import Generation, LLMResult, SystemMessage

import chainlit as cl
from chainlit.langchain.callbacks import AsyncLangchainCallbackHandler


@cl.on_chat_start
async def main():
await cl.Message(content="AsyncLangchainCb").send()

acb = AsyncLangchainCallbackHandler()
acb = cl.AsyncLangchainCallbackHandler()

await acb.on_chain_start(serialized={"id": ["TestChain1"]}, inputs={})

Expand Down
3 changes: 1 addition & 2 deletions cypress/e2e/langchain_cb/main_sync.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from langchain.schema import Generation, LLMResult, SystemMessage

import chainlit as cl
from chainlit.langchain.callbacks import LangchainCallbackHandler


@cl.on_chat_start
async def main():
await cl.Message(content="AsyncLangchainCb").send()

cb = LangchainCallbackHandler()
cb = cl.LangchainCallbackHandler()

cb.on_chain_start(serialized={"id": ["TestChain1"]}, inputs={})

Expand Down
3 changes: 1 addition & 2 deletions cypress/e2e/llama_index_cb/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from llama_index.schema import NodeWithScore, TextNode

import chainlit as cl
from chainlit.llama_index.callbacks import LlamaIndexCallbackHandler


@cl.on_chat_start
async def start():
await cl.Message(content="LlamaIndexCb").send()

cb = LlamaIndexCallbackHandler()
cb = cl.LlamaIndexCallbackHandler()

cb.start_trace()

Expand Down
15 changes: 14 additions & 1 deletion src/chainlit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from chainlit.telemetry import trace
from chainlit.types import LLMSettings
from chainlit.user_session import user_session
from chainlit.utils import wrap_user_function
from chainlit.utils import make_module_getattr, wrap_user_function
from chainlit.version import __version__

env_found = load_dotenv(dotenv_path=os.path.join(os.getcwd(), ".env"))
Expand Down Expand Up @@ -181,6 +181,15 @@ def sleep(duration: int):
return asyncio.sleep(duration)


__getattr__ = make_module_getattr(
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm new in python, I wonder why not just import the handlers like this:

from chainlit.langchain.callbacks import LangchainCallbackHandler, AsyncLangchainCallbackHandler, LlamaIndexCallbackHandler, HaystackAgentCallbackHandler

Does it has any advantages I didn't know?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be working fine but would also import the Langchain library by default, which we want to avoid for performance reasons. This PR will make the Langchain (and other libraries) imports lazy.

"LangchainCallbackHandler": "chainlit.langchain.callbacks",
"AsyncLangchainCallbackHandler": "chainlit.langchain.callbacks",
"LlamaIndexCallbackHandler": "chainlit.llama_index.callbacks",
"HaystackAgentCallbackHandler": "chainlit.haystack.callbacks",
}
)

__all__ = [
"user_session",
"LLMSettings",
Expand Down Expand Up @@ -213,4 +222,8 @@ def sleep(duration: int):
"run_sync",
"make_async",
"cache",
"LangchainCallbackHandler",
"AsyncLangchainCallbackHandler",
"LlamaIndexCallbackHandler",
"HaystackAgentCallbackHandler",
]
16 changes: 16 additions & 0 deletions src/chainlit/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import inspect
from typing import Callable

Expand Down Expand Up @@ -49,3 +50,18 @@ async def wrapper(*args):
await emitter.task_end()

return wrapper


def make_module_getattr(registry):
"""Leverage PEP 562 to make imports lazy in an __init__.py

The registry must be a dictionary with the items to import as keys and the
modules they belong to as a value.
"""

def __getattr__(name):
module_path = registry[name]
module = importlib.import_module(module_path, __package__)
return getattr(module, name)

return __getattr__