diff --git a/dataherald/model/chat_model.py b/dataherald/model/chat_model.py index ccd64939..5fe8d4c7 100644 --- a/dataherald/model/chat_model.py +++ b/dataherald/model/chat_model.py @@ -1,7 +1,7 @@ import os from typing import Any -from langchain.chat_models import ChatAnthropic, ChatGooglePalm, ChatOpenAI +from langchain.chat_models import ChatLiteLLM from overrides import override from dataherald.model import LLMModel @@ -11,18 +11,18 @@ class ChatModel(LLMModel): def __init__(self, system): super().__init__(system) self.model_name = os.environ.get("LLM_MODEL", "gpt-4-32k") - self.openai_api_key = os.environ.get("OPENAI_API_KEY") - self.anthropic_api_key = os.environ.get("ANTHROPIC_API_KEY") - self.google_api_key = os.environ.get("GOOGLE_API_KEY") + self.api_keys = { + "openai_api_key": os.environ.get("OPENAI_API_KEY"), + "anthropic_api_key": os.environ.get("ANTHROPIC_API_KEY"), + "google_api_key": os.environ.get("GOOGLE_API_KEY"), + "cohere_api_key": os.environ.get("COHERE_API_KEY"), + } @override def get_model(self, **kwargs: Any) -> Any: - if self.openai_api_key: - self.model = ChatOpenAI(model_name=self.model_name, **kwargs) - elif self.anthropic_api_key: - self.model = ChatAnthropic(model=self.model, **kwargs) - elif self.google_api_key: - self.model = ChatGooglePalm(model_name=self.model_name, **kwargs) - else: - raise ValueError("No valid API key environment variable found") - return self.model + for _, api_key in self.api_keys.items(): + if api_key: + self.model = ChatLiteLLM(model_name=self.model_name, **kwargs) + return self.model + + raise ValueError("No valid API key environment variable found") diff --git a/requirements.txt b/requirements.txt index 9e9d60f5..8097c79d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ dnspython==2.3.0 fastapi==0.98.0 httpx==0.24.1 -langchain==0.0.230 +langchain==0.0.286 load-dotenv==0.1.0 mypy-extensions==1.0.0 openai==0.27.8 @@ -33,3 +33,4 @@ sphinx==6.2.1 sphinx-book-theme==1.0.1 boto3==1.28.38 botocore==1.31.38 +litellm >= 0.1.574