From 9b93ddcba5c49aefa356958bb4f31aa00fd7217a Mon Sep 17 00:00:00 2001 From: hupe1980 Date: Mon, 8 Apr 2024 00:39:46 +0200 Subject: [PATCH] Add bedrock support --- aisploit/embedding/__init__.py | 2 ++ aisploit/embedding/bedrock.py | 19 +++++++++++++++++++ aisploit/embedding/ollama.py | 1 + aisploit/model/__init__.py | 2 ++ aisploit/model/bedrock_chat.py | 19 +++++++++++++++++++ 5 files changed, 43 insertions(+) create mode 100644 aisploit/embedding/bedrock.py create mode 100644 aisploit/model/bedrock_chat.py diff --git a/aisploit/embedding/__init__.py b/aisploit/embedding/__init__.py index 2da7e1b..1bcb5d4 100644 --- a/aisploit/embedding/__init__.py +++ b/aisploit/embedding/__init__.py @@ -1,7 +1,9 @@ +from .bedrock import BedrockEmbeddings from .ollama import OllamaEmbeddings from .openai import OpenAIEmbeddings __all__ = [ + "BedrockEmbeddings", "OllamaEmbeddings", "OpenAIEmbeddings", ] diff --git a/aisploit/embedding/bedrock.py b/aisploit/embedding/bedrock.py new file mode 100644 index 0000000..a7ec86a --- /dev/null +++ b/aisploit/embedding/bedrock.py @@ -0,0 +1,19 @@ +from langchain_community.embeddings import ( + BedrockEmbeddings as LangchainBedrockEmbeddings, +) + + +from ..core import BaseEmbeddings + + +class BedrockEmbeddings(LangchainBedrockEmbeddings, BaseEmbeddings): + def __init__( + self, + *, + model_id: str = "amazon.titan-embed-text-v1", + **kwargs, + ) -> None: + super().__init__( + model_id=model_id, + **kwargs, + ) diff --git a/aisploit/embedding/ollama.py b/aisploit/embedding/ollama.py index 6c90676..bf8768a 100644 --- a/aisploit/embedding/ollama.py +++ b/aisploit/embedding/ollama.py @@ -13,4 +13,5 @@ def __init__( ) -> None: super().__init__( model=model, + **kwargs, ) diff --git a/aisploit/model/__init__.py b/aisploit/model/__init__.py index a7dd628..1e582ed 100644 --- a/aisploit/model/__init__.py +++ b/aisploit/model/__init__.py @@ -1,7 +1,9 @@ +from .bedrock_chat import BedrockChat from .chat_ollama import ChatOllama from .chat_openai import ChatOpenAI __all__ = [ + "BedrockChat", "ChatOllama", "ChatOpenAI", ] diff --git a/aisploit/model/bedrock_chat.py b/aisploit/model/bedrock_chat.py new file mode 100644 index 0000000..4d3b7f9 --- /dev/null +++ b/aisploit/model/bedrock_chat.py @@ -0,0 +1,19 @@ +from langchain_community.chat_models import BedrockChat as LangchainBedrockChat + +from ..core import BaseChatModel + + +class BedrockChat(LangchainBedrockChat, BaseChatModel): + def __init__( + self, + *, + model_id: str, + **kwargs, + ) -> None: + super().__init__( + model_id=model_id, + **kwargs, + ) + + def supports_functions(self) -> bool: + return False