diff --git a/aisploit/core/model.py b/aisploit/core/model.py index f732675..468635c 100644 --- a/aisploit/core/model.py +++ b/aisploit/core/model.py @@ -13,6 +13,12 @@ class BaseLLM(Runnable[LanguageModelInput, str]): class BaseChatModel(Runnable[LanguageModelInput, BaseMessage]): @abstractmethod def supports_functions(self) -> bool: + """ + Check if the model supports additional functions beyond basic chat. + + Returns: + bool: True if the model supports additional functions, False otherwise. + """ pass diff --git a/aisploit/core/report.py b/aisploit/core/report.py index a6dddb6..1185eb8 100644 --- a/aisploit/core/report.py +++ b/aisploit/core/report.py @@ -32,7 +32,7 @@ def __iter__(self): def __len__(self): return len(self._entries) - + def __getitem__(self, index: int) -> T: """Get an entry from the report by index.""" return self._entries[index] diff --git a/aisploit/model/__init__.py b/aisploit/model/__init__.py index 1e582ed..af82a94 100644 --- a/aisploit/model/__init__.py +++ b/aisploit/model/__init__.py @@ -1,9 +1,11 @@ from .bedrock_chat import BedrockChat +from .chat_anthropic import ChatAnthropic from .chat_ollama import ChatOllama from .chat_openai import ChatOpenAI __all__ = [ "BedrockChat", + "ChatAnthropic", "ChatOllama", "ChatOpenAI", ] diff --git a/aisploit/model/bedrock_chat.py b/aisploit/model/bedrock_chat.py index 4d3b7f9..3464bfa 100644 --- a/aisploit/model/bedrock_chat.py +++ b/aisploit/model/bedrock_chat.py @@ -16,4 +16,10 @@ def __init__( ) def supports_functions(self) -> bool: + """ + Check if the model supports additional functions beyond basic chat. + + Returns: + bool: True if the model supports additional functions, False otherwise. + """ return False diff --git a/aisploit/model/chat_anthropic.py b/aisploit/model/chat_anthropic.py new file mode 100644 index 0000000..f5fd33e --- /dev/null +++ b/aisploit/model/chat_anthropic.py @@ -0,0 +1,42 @@ +from typing import Optional +from langchain_core.utils.utils import convert_to_secret_str +from langchain_anthropic import ChatAnthropic as LangchainChatAnthropic + +from ..core import BaseChatModel + + +class ChatAnthropic(LangchainChatAnthropic, BaseChatModel): + """A chat model based on Anthropic's language generation technology.""" + + def __init__( + self, + *, + api_key: Optional[str], + model_name: str = "claude-3-opus-20240229", + temperature: float = 1.0, + **kwargs, + ) -> None: + """ + Initialize the ChatAnthropic instance. + + Args: + api_key (str or None): The API key for accessing the Anthropic API. + model_name (str): The name of the language model to use. + temperature (float): The temperature parameter controlling the randomness of the generated text. + **kwargs: Additional keyword arguments to be passed to the base class constructor. + """ + super().__init__( + anthropic_api_key=convert_to_secret_str(api_key) if api_key else None, + model_name=model_name, + temperature=temperature, + **kwargs, + ) + + def supports_functions(self) -> bool: + """ + Check if the model supports additional functions beyond basic chat. + + Returns: + bool: True if the model supports additional functions, False otherwise. + """ + return False diff --git a/aisploit/model/chat_ollama.py b/aisploit/model/chat_ollama.py index a5c216c..df05bd5 100644 --- a/aisploit/model/chat_ollama.py +++ b/aisploit/model/chat_ollama.py @@ -30,4 +30,10 @@ def __init__( ) def supports_functions(self) -> bool: + """ + Check if the model supports additional functions beyond basic chat. + + Returns: + bool: True if the model supports additional functions, False otherwise. + """ return False diff --git a/aisploit/model/chat_openai.py b/aisploit/model/chat_openai.py index 8c71e36..c58e360 100644 --- a/aisploit/model/chat_openai.py +++ b/aisploit/model/chat_openai.py @@ -38,4 +38,10 @@ def __init__( ) def supports_functions(self) -> bool: + """ + Check if the model supports additional functions beyond basic chat. + + Returns: + bool: True if the model supports additional functions, False otherwise. + """ return True diff --git a/poetry.lock b/poetry.lock index 7b2a7b5..62dbf65 100644 --- a/poetry.lock +++ b/poetry.lock @@ -120,6 +120,30 @@ files = [ {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, ] +[[package]] +name = "anthropic" +version = "0.25.0" +description = "The official Python library for the anthropic API" +optional = false +python-versions = ">=3.7" +files = [ + {file = "anthropic-0.25.0-py3-none-any.whl", hash = "sha256:b5dfe4dfebace1641a02cfda939cd6dffac0152ab305ca1ef0c11023043a51a2"}, + {file = "anthropic-0.25.0.tar.gz", hash = "sha256:63372443e699da7ffb467b2d0eb5ee7740acf877368b364a1137d795ae4e4c16"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tokenizers = ">=0.13.0" +typing-extensions = ">=4.7,<5" + +[package.extras] +bedrock = ["boto3 (>=1.28.57)", "botocore (>=1.31.57)"] +vertex = ["google-auth (>=2,<3)"] + [[package]] name = "anyio" version = "4.3.0" @@ -1986,6 +2010,22 @@ websocket-client = ">=0.32.0,<0.40.0 || >0.40.0,<0.41.dev0 || >=0.43.dev0" [package.extras] adal = ["adal (>=1.0.2)"] +[[package]] +name = "langchain-anthropic" +version = "0.1.6" +description = "An integration package connecting AnthropicMessages and LangChain" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_anthropic-0.1.6-py3-none-any.whl", hash = "sha256:5626f9f2f0d3cc1665a2f5817ea1856dbfa4c745bc6f95b7043c56b6ab85e0c1"}, + {file = "langchain_anthropic-0.1.6.tar.gz", hash = "sha256:544e5c8c365964c594b80eb1db994e67d90722be9efde460229e5888524545de"}, +] + +[package.dependencies] +anthropic = ">=0.23.0,<1" +defusedxml = ">=0.7.1,<0.8.0" +langchain-core = ">=0.1.33,<0.2.0" + [[package]] name = "langchain-community" version = "0.0.31" @@ -5587,4 +5627,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "7217142ebf4dd86262bb9161091ea39148cc89ac7a71290c98dfc5515392bbb4" +content-hash = "24c6e4eaaad9b04171d93c2871146a8f801f494159622402d9ad5343d9ce39cf" diff --git a/pyproject.toml b/pyproject.toml index 7433b8f..564a635 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ torch = "^2.2.2" jinja2 = "^3.1.3" ipython = "^8.23.0" imapclient = "^3.0.1" +langchain-anthropic = "^0.1.6" [tool.poetry.group.dev.dependencies] chromadb = "^0.4.23"