From ce8eb072a49243abe8a5bd88f92df7ff3833f8a1 Mon Sep 17 00:00:00 2001 From: Willy Douhard Date: Sat, 12 Aug 2023 16:21:15 +0800 Subject: [PATCH] fix lc streaming (#272) --- src/chainlit/langchain/callbacks.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/src/chainlit/langchain/callbacks.py b/src/chainlit/langchain/callbacks.py index 557637061c..867def95a9 100644 --- a/src/chainlit/langchain/callbacks.py +++ b/src/chainlit/langchain/callbacks.py @@ -119,11 +119,29 @@ def append_to_last_tokens(self, token: str) -> None: self.last_tokens.pop(0) self.last_tokens_stripped.pop(0) + def _compare_last_tokens(self, last_tokens: List[str]): + if last_tokens == self.answer_prefix_tokens_stripped: + # If tokens match perfectly we are done + return True + else: + # Some LLMs will consider all the tokens of the final answer as one token + # so we check if any last token contains all answer tokens + return any( + [ + all( + answer_token in last_token + for answer_token in self.answer_prefix_tokens_stripped + ) + for last_token in last_tokens + ] + ) + def check_if_answer_reached(self) -> bool: if self.strip_tokens: - return self.last_tokens_stripped == self.answer_prefix_tokens_stripped + return self._compare_last_tokens(self.last_tokens_stripped) + else: - return self.last_tokens == self.answer_prefix_tokens + return self._compare_last_tokens(self.last_tokens) def start_stream(self): author = self.get_author()