Skip to content

Commit

Permalink
Add bedrock support
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Apr 7, 2024
1 parent 464d2ee commit 9b93ddc
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 0 deletions.
2 changes: 2 additions & 0 deletions aisploit/embedding/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .bedrock import BedrockEmbeddings
from .ollama import OllamaEmbeddings
from .openai import OpenAIEmbeddings

__all__ = [
"BedrockEmbeddings",
"OllamaEmbeddings",
"OpenAIEmbeddings",
]
19 changes: 19 additions & 0 deletions aisploit/embedding/bedrock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from langchain_community.embeddings import (
BedrockEmbeddings as LangchainBedrockEmbeddings,
)


from ..core import BaseEmbeddings


class BedrockEmbeddings(LangchainBedrockEmbeddings, BaseEmbeddings):
def __init__(
self,
*,
model_id: str = "amazon.titan-embed-text-v1",
**kwargs,
) -> None:
super().__init__(
model_id=model_id,
**kwargs,
)
1 change: 1 addition & 0 deletions aisploit/embedding/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ def __init__(
) -> None:
super().__init__(
model=model,
**kwargs,
)
2 changes: 2 additions & 0 deletions aisploit/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from .bedrock_chat import BedrockChat
from .chat_ollama import ChatOllama
from .chat_openai import ChatOpenAI

__all__ = [
"BedrockChat",
"ChatOllama",
"ChatOpenAI",
]
19 changes: 19 additions & 0 deletions aisploit/model/bedrock_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from langchain_community.chat_models import BedrockChat as LangchainBedrockChat

from ..core import BaseChatModel


class BedrockChat(LangchainBedrockChat, BaseChatModel):
def __init__(
self,
*,
model_id: str,
**kwargs,
) -> None:
super().__init__(
model_id=model_id,
**kwargs,
)

def supports_functions(self) -> bool:
return False

0 comments on commit 9b93ddc

Please sign in to comment.