diff --git a/erniebot-agent/src/erniebot_agent/agents/agent.py b/erniebot-agent/src/erniebot_agent/agents/agent.py index d04a9c96..ed02f714 100644 --- a/erniebot-agent/src/erniebot_agent/agents/agent.py +++ b/erniebot-agent/src/erniebot_agent/agents/agent.py @@ -361,9 +361,37 @@ async def _run_llm_stream(self, messages: List[Message], **opts: Any) -> AsyncIt ) opts["system"] = self.system.content if self.system is not None else None opts["plugins"] = self._plugins - llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) - async for msg in llm_ret: - yield LLMResponse(message=msg) + # llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) + # async for msg in llm_ret: + # yield LLMResponse(message=msg) + # print(self.llm.extra_params.get("enable_multi_step_tool_call")) + # 流式时,无法同时处理多个工具调用 + # 所以只有关闭多步工具调用时才用流式 + if self.llm.extra_data.get("multi_step_tool_call_close", True): + llm_ret = await self.llm.chat(messages, stream=True, functions=functions, **opts) + async for msg in llm_ret: + print("_run_llm_stream", msg) + yield LLMResponse(message=msg) + else: + llm_ret = await self.llm.chat(messages, stream=False, functions=functions, **opts) + class MyAsyncIterator: + def __init__(self, data): + self.data = data + self.index = 0 + async def __anext__(self): + if self.index < len(self.data): + result = self.data[self.index] + self.index += 1 + return result + else: + raise StopAsyncIteration + def __aiter__(self): + return self + + llm_ret = MyAsyncIterator([llm_ret]) + async for msg in llm_ret: + print("_run_llm_stream", msg) + yield LLMResponse(message=msg) async def _run_tool(self, tool: BaseTool, tool_args: str) -> ToolResponse: parsed_tool_args = self._parse_tool_args(tool_args)