diff --git a/src/chainlit/langchain/callbacks.py b/src/chainlit/langchain/callbacks.py index 557637061c..867def95a9 100644 --- a/src/chainlit/langchain/callbacks.py +++ b/src/chainlit/langchain/callbacks.py @@ -119,11 +119,29 @@ def append_to_last_tokens(self, token: str) -> None: self.last_tokens.pop(0) self.last_tokens_stripped.pop(0) + def _compare_last_tokens(self, last_tokens: List[str]): + if last_tokens == self.answer_prefix_tokens_stripped: + # If tokens match perfectly we are done + return True + else: + # Some LLMs will consider all the tokens of the final answer as one token + # so we check if any last token contains all answer tokens + return any( + [ + all( + answer_token in last_token + for answer_token in self.answer_prefix_tokens_stripped + ) + for last_token in last_tokens + ] + ) + def check_if_answer_reached(self) -> bool: if self.strip_tokens: - return self.last_tokens_stripped == self.answer_prefix_tokens_stripped + return self._compare_last_tokens(self.last_tokens_stripped) + else: - return self.last_tokens == self.answer_prefix_tokens + return self._compare_last_tokens(self.last_tokens) def start_stream(self): author = self.get_author()