Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add .set_token_limit() method #28

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def __init__(
self.provider = provider
self._turns: list[Turn] = list(turns or [])
self._tools: dict[str, Tool] = {}
self.token_limits: Optional[tuple[int, int]] = None
self._echo_options: EchoOptions = {
"rich_markdown": {},
"rich_console": {},
Expand Down Expand Up @@ -381,6 +382,121 @@ async def token_count_async(
data_model=data_model,
)

def set_token_limits(
self,
context_window: int,
max_tokens: int,
):
"""
Set a limit on the number of tokens that can be sent to the model.

By default, the size of the chat history is unbounded -- it keeps
growing as you submit more input. This can be wasteful if you don't
need to keep the entire chat history around, and can also lead to
errors if the chat history gets too large for the model to handle.

This method allows you to set a limit to the number of tokens that can
be sent to the model. If the limit is exceeded, the chat history will be
truncated to fit within the limit (i.e., the oldest turns will be
dropped).

Note that many models publish a context window as well as a maximum
output token limit. For example,

<https://platform.openai.com/docs/models/gp#gpt-4o-realtime>
<https://docs.anthropic.com/en/docs/about-claude/models#model-comparison-table>

Also, since the context window is the maximum number of input + output
tokens, the maximum number of tokens that can be sent to the model in a
single request is `context_window - max_tokens`.

Parameters
----------
context_window
The maximum number of tokens that can be sent to the model.
max_tokens
The maximum number of tokens that the model is allowed to generate
in a single response.

Note
----
This method uses `.token_count()` to estimate the token count for new input
before truncating the chat history. This is an estimate, so it may not be
perfect. Morever, any chat models based on `ChatOpenAI()` currently do not
take the tool loop into account when estimating token counts. This means, if
your input will trigger many tool calls, and/or the tool results are large,
it's recommended to set a conservative limit on the `context_window`.

Examples
--------
```python
from chatlas import ChatOpenAI

chat = ChatOpenAI(model="claude-3-5-sonnet-20241022")
chat.set_token_limit(200000, 8192)
```
"""
if max_tokens >= context_window:
raise ValueError("`max_tokens` must be less than the `context_window`.")
self.token_limits = (context_window, max_tokens)

def _maybe_drop_turns(
self,
*args: Content | str,
data_model: Optional[type[BaseModel]] = None,
):
"""
Drop turns from the chat history if they exceed the token limits.
"""

# Do nothing if token limits are not set
if self.token_limits is None:
return None

turns = self.get_turns(include_system_prompt=False)

# Do nothing if this is the first turn
if len(turns) == 0:
return None

last_turn = turns[-1]

# Sanity checks (i.e., when about to submit new input, the last turn should
# be from the assistant and should contain token counts)
if last_turn.role != "assistant":
raise ValueError(
"Expected the last turn must be from the assistant. Please report this issue."
)

if last_turn.tokens is None:
raise ValueError(
"Can't impose token limits since assistant turns contain token counts. "
"Please report this issue and consider setting `.token_limits` to `None`."
)

context_window, max_tokens = self.token_limits
max_input_size = context_window - max_tokens

# Estimate the token count for the (new) user turn
input_tokens = self.token_count(*args, data_model=data_model)

# Do nothing if current history size plus input size is within the limit
remaining_tokens = max_input_size - input_tokens
if sum(last_turn.tokens) < remaining_tokens:
return self

tokens = self.tokens(values="discrete")

# Drop turns until they (plus the new input) fit within the token limits
# TODO: we also need to account for the fact that dropping part of a tool loop is problematic
while sum(tokens) >= remaining_tokens:
del turns[2:]
del tokens[2:]

self.set_turns(turns)

return None

def app(
self,
*,
Expand Down Expand Up @@ -531,6 +647,8 @@ def chat(
A (consumed) response from the chat. Apply `str()` to this object to
get the text content of the response.
"""
self._maybe_drop_turns(*args)

turn = user_turn(*args)

display = self._markdown_display(echo=echo)
Expand Down Expand Up @@ -581,6 +699,9 @@ async def chat_async(
A (consumed) response from the chat. Apply `str()` to this object to
get the text content of the response.
"""
# TODO: async version?
self._maybe_drop_turns(*args)

turn = user_turn(*args)

display = self._markdown_display(echo=echo)
Expand Down Expand Up @@ -627,6 +748,8 @@ def stream(
An (unconsumed) response from the chat. Iterate over this object to
consume the response.
"""
self._maybe_drop_turns(*args)

turn = user_turn(*args)

display = self._markdown_display(echo=echo)
Expand Down Expand Up @@ -672,6 +795,9 @@ async def stream_async(
An (unconsumed) response from the chat. Iterate over this object to
consume the response.
"""
# TODO: async version?
self._maybe_drop_turns(*args)

turn = user_turn(*args)

display = self._markdown_display(echo=echo)
Expand Down Expand Up @@ -715,6 +841,7 @@ def extract_data(
dict[str, Any]
The extracted data.
"""
self._maybe_drop_turns(*args, data_model=data_model)

display = self._markdown_display(echo=echo)

Expand Down Expand Up @@ -775,6 +902,8 @@ async def extract_data_async(
dict[str, Any]
The extracted data.
"""
# TODO: async version?
self._maybe_drop_turns(*args, data_model=data_model)

display = self._markdown_display(echo=echo)

Expand Down
Loading