Skip to content

Commit

Permalink
Add .set_token_limit() method to automatically drop old turns when sp…
Browse files Browse the repository at this point in the history
…ecified limits are reached
  • Loading branch information
cpsievert committed Dec 19, 2024
1 parent 200e26c commit fd841fb
Showing 1 changed file with 129 additions and 0 deletions.
129 changes: 129 additions & 0 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
self.provider = provider
self._turns: list[Turn] = list(turns or [])
self._tools: dict[str, Tool] = {}
self.token_limits: Optional[tuple[int, int]] = None
self._echo_options: EchoOptions = {
"rich_markdown": {},
"rich_console": {},
Expand Down Expand Up @@ -381,6 +382,121 @@ async def token_count_async(
data_model=data_model,
)

def set_token_limits(
self,
context_window: int,
max_tokens: int,
):
"""
Set a limit on the number of tokens that can be sent to the model.
By default, the size of the chat history is unbounded -- it keeps
growing as you submit more input. This can be wasteful if you don't
need to keep the entire chat history around, and can also lead to
errors if the chat history gets too large for the model to handle.
This method allows you to set a limit to the number of tokens that can
be sent to the model. If the limit is exceeded, the chat history will be
truncated to fit within the limit (i.e., the oldest turns will be
dropped).
Note that many models publish a context window as well as a maximum
output token limit. For example,
<https://platform.openai.com/docs/models/gp#gpt-4o-realtime>
<https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table>
Also, since the context window is the maximum number of input + output
tokens, the maximum number of tokens that can be sent to the model in a
single request is `context_window - max_tokens`.
Parameters
----------
context_window
The maximum number of tokens that can be sent to the model.
max_tokens
The maximum number of tokens that the model is allowed to generate
in a single response.
Note
----
This method uses `.token_count()` to estimate the token count for new input
before truncating the chat history. This is an estimate, so it may not be
perfect. Morever, any chat models based on `ChatOpenAI()` currently do not
take the tool loop into account when estimating token counts. This means, if
your input will trigger many tool calls, and/or the tool results are large,
it's recommended to set a conservative limit on the `context_window`.
Examples
--------
```python
from chatlas import ChatOpenAI
chat = ChatOpenAI(model="claude-3-5-sonnet-20241022")
chat.set_token_limit(200000, 8192)
```
"""
if max_tokens >= context_window:
raise ValueError("`max_tokens` must be less than the `context_window`.")
self.token_limits = (context_window, max_tokens)

def _maybe_drop_turns(
self,
*args: Content | str,
data_model: Optional[type[BaseModel]] = None,
):
"""
Drop turns from the chat history if they exceed the token limits.
"""

# Do nothing if token limits are not set
if self.token_limits is None:
return None

turns = self.get_turns(include_system_prompt=False)

# Do nothing if this is the first turn
if len(turns) == 0:
return None

last_turn = turns[-1]

# Sanity checks (i.e., when about to submit new input, the last turn should
# be from the assistant and should contain token counts)
if last_turn.role != "assistant":
raise ValueError(
"Expected the last turn must be from the assistant. Please report this issue."
)

if last_turn.tokens is None:
raise ValueError(
"Can't impose token limits since assistant turns contain token counts. "
"Please report this issue and consider setting `.token_limits` to `None`."
)

context_window, max_tokens = self.token_limits
max_input_size = context_window - max_tokens

# Estimate the token count for the (new) user turn
input_tokens = self.token_count(*args, data_model=data_model)

# Do nothing if current history size plus input size is within the limit
remaining_tokens = max_input_size - input_tokens
if sum(last_turn.tokens) < remaining_tokens:
return self

tokens = self.tokens()

# Drop turns until they (plus the new input) fit within the token limits
# TODO: we also need to account for the fact that dropping part of a tool loop is problematic
while sum(tokens) >= remaining_tokens:
del turns[2:]
del tokens[2:]

self.set_turns(turns)

return None

def app(
self,
*,
Expand Down Expand Up @@ -531,6 +647,8 @@ def chat(
A (consumed) response from the chat. Apply `str()` to this object to
get the text content of the response.
"""
self._maybe_drop_turns(*args)

turn = user_turn(*args)

display = self._markdown_display(echo=echo)
Expand Down Expand Up @@ -581,6 +699,9 @@ async def chat_async(
A (consumed) response from the chat. Apply `str()` to this object to
get the text content of the response.
"""
# TODO: async version?
self._maybe_drop_turns(*args)

turn = user_turn(*args)

display = self._markdown_display(echo=echo)
Expand Down Expand Up @@ -627,6 +748,8 @@ def stream(
An (unconsumed) response from the chat. Iterate over this object to
consume the response.
"""
self._maybe_drop_turns(*args)

turn = user_turn(*args)

display = self._markdown_display(echo=echo)
Expand Down Expand Up @@ -672,6 +795,9 @@ async def stream_async(
An (unconsumed) response from the chat. Iterate over this object to
consume the response.
"""
# TODO: async version?
self._maybe_drop_turns(*args)

turn = user_turn(*args)

display = self._markdown_display(echo=echo)
Expand Down Expand Up @@ -715,6 +841,7 @@ def extract_data(
dict[str, Any]
The extracted data.
"""
self._maybe_drop_turns(*args, data_model=data_model)

display = self._markdown_display(echo=echo)

Expand Down Expand Up @@ -775,6 +902,8 @@ async def extract_data_async(
dict[str, Any]
The extracted data.
"""
# TODO: async version?
self._maybe_drop_turns(*args, data_model=data_model)

display = self._markdown_display(echo=echo)

Expand Down

0 comments on commit fd841fb

Please sign in to comment.