diff --git a/pyproject.toml b/pyproject.toml index 4085f4b..0c3405b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ classifiers = [ ] dependencies = [ "tiktoken>=0.7.0", - "aiohttp>=3.9.3" + "aiohttp>=3.9.3", + "anthropic>=0.34.0" ] [project.optional-dependencies] diff --git a/tokencost/costs.py b/tokencost/costs.py index feb931a..ce29d0d 100644 --- a/tokencost/costs.py +++ b/tokencost/costs.py @@ -1,8 +1,10 @@ + """ Costs dictionary and utility tool for counting tokens """ import tiktoken +import anthropic from typing import Union, List, Dict from .constants import TOKEN_COSTS from decimal import Decimal @@ -39,6 +41,16 @@ def count_message_tokens(messages: List[Dict[str, str]], model: str) -> int: """ model = model.lower() model = strip_ft_model_name(model) + + if "claude-" in model: + """ + Note that this is only accurate for older models, e.g. `claude-2.1`. + For newer models this can only be used as a _very_ rough estimate, + instead you should rely on the `usage` property in the response for exact counts. + """ + prompt = "".join(message["content"] for message in messages) + return count_string_tokens(prompt,model) + try: encoding = tiktoken.encoding_for_model(model) except KeyError: @@ -104,6 +116,16 @@ def count_string_tokens(prompt: str, model: str) -> int: int: The number of tokens in the text string. """ model = model.lower() + if "claude-" in model: + """ + Note that this is only accurate for older models, e.g. `claude-2.1`. + For newer models this can only be used as a _very_ rough estimate, + instead you should rely on the `usage` property in the response for exact counts. + """ + client = anthropic.Client() + token_count = client.count_tokens(prompt) + return token_count + try: encoding = tiktoken.encoding_for_model(model) except KeyError: