Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Async callback for langchain #353

Merged
merged 4 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
623 changes: 595 additions & 28 deletions pdm.lock

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
]
requires-python = ">=3.7.1"
requires-python = ">=3.8.1"
dependencies = [
"miservice_fork",
"openai",
Expand All @@ -22,11 +22,11 @@ dependencies = [
"EdgeGPT==0.1.26",
"langchain==0.0.301",
"datetime==5.2",
"bs4==0.0.1",
"beautifulsoup4>=4.12.0",
"chardet==5.1.0",
"typing==3.7.4.3",
"google-search-results==2.4.2",
"numexpr==2.8.6"
"google-search-results>=2.4.2",
"numexpr==2.8.6",
]
license = {text = "MIT"}
dynamic = ["version"]
Expand Down
40 changes: 28 additions & 12 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# This file is @generated by PDM.
# Please do not edit it manually.

aiohttp==3.8.5
aiohttp==3.8.4
aiosignal==1.3.1
annotated-types==0.6.0
anyio==3.6.2
async-timeout==4.0.2
attrs==22.2.0
Expand All @@ -11,10 +12,13 @@ beautifulsoup4==4.12.2
BingImageCreator==0.1.3
browser-cookie3==0.19.1
cachetools==4.2.4
certifi==2023.7.22
certifi==2022.12.7
chardet==5.1.0
charset-normalizer==3.1.0
colorama==0.4.6
dataclasses==0.6
dataclasses-json==0.6.1
datetime==5.2
deep-translator==1.11.4
edge-tts==6.1.3
EdgeGPT==0.1.26
Expand All @@ -23,6 +27,7 @@ google-api-core==1.34.0
google-auth==1.35.0
google-cloud-core==1.7.3
google-cloud-translate==2.0.1
google-search-results==2.4.2
googleapis-common-protos==1.59.1
grpcio==1.56.2
grpcio-status==1.48.2
Expand All @@ -33,38 +38,49 @@ httpcore==0.16.3
httpx==0.24.1
hyperframe==6.0.1
idna==3.4
jsonpatch==1.33
jsonpointer==2.4
langchain==0.0.301
langsmith==0.0.52
lz4==4.3.2
markdown-it-py==2.2.0
marshmallow==3.20.1
mdurl==0.1.2
miservice-fork==2.1.1
multidict==6.0.4
mypy-extensions==1.0.0
numexpr==2.8.6
numpy==1.24.4
openai==0.27.2
packaging==23.2
prompt-toolkit==3.0.38
protobuf==3.20.3
pyasn1==0.5.0
pyasn1-modules==0.3.0
pycryptodomex==3.18.0
pygments==2.15.0
pydantic==2.4.2
pydantic-core==2.10.1
pygments==2.14.0
PyJWT==2.8.0
pytz==2023.3.post1
PyYAML==6.0.1
regex==2022.10.31
requests==2.31.0
requests==2.28.2
rich==13.3.2
rsa==4.9
setuptools==68.0.0
six==1.16.0
sniffio==1.3.0
soupsieve==2.4.1
SQLAlchemy==2.0.22
tenacity==8.2.3
tqdm==4.65.0
typing==3.7.4.3
typing-extensions==4.8.0
typing-inspect==0.9.0
urllib3==1.26.15
wcwidth==0.2.6
websockets==11.0
yarl==1.8.2
zhipuai==1.0.7
langchain==0.0.301
datetime==5.2
bs4==0.0.1
chardet==5.1.0
typing==3.7.4.3
google-search-results==2.4.2
numexpr==2.8.6

zope-interface==6.1
28 changes: 13 additions & 15 deletions xiaogpt/bot/langchain_bot.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from __future__ import annotations

import openai
import asyncio
import os

from rich import print

from xiaogpt.bot.base_bot import BaseBot
from xiaogpt.utils import split_sentences

from xiaogpt.langchain.callbacks import AsyncIteratorCallbackHandler
from xiaogpt.langchain.chain import agent_search
from xiaogpt.langchain.stream_call_back import streaming_call_queue

import os
from xiaogpt.utils import split_sentences


class LangChainBot(BaseBot):
Expand Down Expand Up @@ -42,18 +41,17 @@ def from_config(cls, config):
async def ask(self, query, **options):
# Todo,Currently only supports stream
raise Exception(
"The bot does not support it. Please use 'ask_streamadd --stream'"
"The bot does not support it. Please use 'ask_stream, add: --stream'"
)

async def ask_stream(self, query, **options):
agent_search(query)
callback = AsyncIteratorCallbackHandler()
task = asyncio.create_task(agent_search(query, callback))
try:
while True:
if not streaming_call_queue.empty():
token = streaming_call_queue.get()
print(token, end="")
yield token
else:
break
async for message in split_sentences(callback.aiter()):
yield message
except Exception as e:
print("An error occurred:", str(e))
finally:
print()
await task
87 changes: 87 additions & 0 deletions xiaogpt/langchain/callbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from __future__ import annotations

import asyncio
from typing import Any, AsyncIterator
from uuid import UUID

from langchain.callbacks.base import AsyncCallbackHandler


class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
"""Callback handler that returns an async iterator."""

@property
def always_verbose(self) -> bool:
return True

def __init__(self) -> None:
self.queue = asyncio.Queue()
self.done = asyncio.Event()

async def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
metadata: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
self.done.clear()

async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
if token is not None and token != "":
print(token, end="")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems we need drop this print

Copy link
Collaborator Author

@frostming frostming Oct 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Haha, that is to display the text to console in stream mode. Intended.
All other bots have such prints, too

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy that!

self.queue.put_nowait(token)

async def on_chain_end(
self,
outputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
self.done.set()

async def on_chain_error(
self,
error: BaseException,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> None:
self.done.set()

async def aiter(self) -> AsyncIterator[str]:
while not self.queue.empty() or not self.done.is_set():
# Wait for the next token in the queue,
# but stop waiting if the done event is set
done, other = await asyncio.wait(
[
# NOTE: If you add other tasks here, update the code below,
# which assumes each set has exactly one task each
asyncio.ensure_future(self.queue.get()),
asyncio.ensure_future(self.done.wait()),
],
return_when=asyncio.FIRST_COMPLETED,
)

# Cancel the other task
if other:
other.pop().cancel()

# Extract the value of the first completed task
token_or_done = done.pop().result()

# If the extracted value is the boolean True, the done event was set
if token_or_done is True:
break

# Otherwise, the extracted value is a token, which we yield
yield token_or_done
18 changes: 5 additions & 13 deletions xiaogpt/langchain/chain.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,15 @@
from langchain.agents import initialize_agent, Tool
from langchain.agents import AgentType
from langchain.tools import BaseTool
from langchain.llms import OpenAI
from langchain.agents import AgentType, Tool, initialize_agent
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import LLMMathChain
from langchain.utilities import SerpAPIWrapper
from langchain.chat_models import ChatOpenAI
from langchain.memory import ChatMessageHistory
from xiaogpt.langchain.stream_call_back import StreamCallbackHandler
from langchain.agents.agent_toolkits import ZapierToolkit
from langchain.utilities.zapier import ZapierNLAWrapper
from langchain.memory import ConversationBufferMemory
from langchain.utilities import SerpAPIWrapper


def agent_search(query):
async def agent_search(query: str, callback: BaseCallbackHandler) -> None:
llm = ChatOpenAI(
streaming=True,
temperature=0,
model="gpt-3.5-turbo-0613",
callbacks=[StreamCallbackHandler()],
)

# Initialization: search chain, mathematical calculation chain
Expand All @@ -35,4 +27,4 @@ def agent_search(query):
)

# query eg:'杭州亚运会中国队获得了多少枚金牌?' // '计算3的2次方'
agent.run(query)
await agent.arun(query, callbacks=[callback])
16 changes: 0 additions & 16 deletions xiaogpt/langchain/stream_call_back.py

This file was deleted.

2 changes: 1 addition & 1 deletion xiaogpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def calculate_tts_elapse(text: str) -> float:
return len(_no_elapse_chars.sub("", text)) / speed


_ending_punctuations = ("。", "?", "!", ";", ".", "?", "!", ";")
_ending_punctuations = ("。", "?", "!", ";", "\n", "?", "!", ";")


async def split_sentences(text_stream: AsyncIterator[str]) -> AsyncIterator[str]:
Expand Down
Loading