Skip to content

Commit

Permalink
chore: Added support for Groq, Cohere and Open AI. However, tool call…
Browse files Browse the repository at this point in the history
…ing and evaluation problems need to be fixed.
  • Loading branch information
anirbanbasu committed Jul 27, 2024
1 parent 83da142 commit a9c27af
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 11 deletions.
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ langchain-experimental
langchain
langchainhub
langchain_groq
langchain-openai
langchain-cohere
langchain-ollama
rank_bm25

Expand Down
91 changes: 80 additions & 11 deletions src/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
except ImportError: # Graceful fallback if IceCream isn't installed.
ic = lambda *a: None if not a else (a[0] if len(a) == 1 else a) # noqa

from langchain_core.language_models.chat_models import BaseChatModel
from langchain_ollama import ChatOllama
from langchain_groq.chat_models import ChatGroq
from langchain_cohere.chat_models import ChatCohere
from langchain_openai.chat_models import ChatOpenAI
from langchain_core.messages import HumanMessage, AIMessage
from langgraph.graph.message import AnyMessage

Expand All @@ -47,6 +51,7 @@ class GradioApp:
"""The main Gradio app class."""

_gr_state_user_input_text = gr.State(constants.EMPTY_STRING)
_llm: BaseChatModel = None

def __init__(self):
"""Default constructor for the Gradio app."""
Expand Down Expand Up @@ -102,6 +107,73 @@ def __init__(self):
),
format="json",
)
elif self._llm_provider == "Groq":
self._llm = ChatGroq(
api_key=self.parse_env(constants.ENV_VAR_NAME__LLM_GROQ_API_KEY),
model=self.parse_env(
constants.ENV_VAR_NAME__LLM_GROQ_MODEL,
default_value=constants.ENV_VAR_VALUE__LLM_GROQ_MODEL,
),
temperature=self.parse_env(
constants.ENV_VAR_NAME__LLM_TEMPERATURE,
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
),
# model_kwargs={
# "top_p": self.parse_env(
# constants.ENV_VAR_NAME__LLM_TOP_P,
# default_value=constants.ENV_VAR_VALUE__LLM_TOP_P,
# type_cast=float,
# ),
# "top_k": self.parse_env(
# constants.ENV_VAR_NAME__LLM_TOP_K,
# default_value=constants.ENV_VAR_VALUE__LLM_TOP_K,
# type_cast=int,
# ),
# "repeat_penalty": self.parse_env(
# constants.ENV_VAR_NAME__LLM_REPEAT_PENALTY,
# default_value=constants.ENV_VAR_VALUE__LLM_REPEAT_PENALTY,
# type_cast=float,
# ),
# "seed": self.parse_env(
# constants.ENV_VAR_NAME__LLM_SEED,
# default_value=constants.ENV_VAR_VALUE__LLM_SEED,
# type_cast=int,
# ),
# },
# Streaming is not compatible with JSON
# streaming=False,
# JSON response is not compatible with tool calling
# response_format={"type": "json_object"},
)
elif self._llm_provider == "Cohere":
self._llm = ChatCohere(
cohere_api_key=self.parse_env(
constants.ENV_VAR_NAME__LLM_COHERE_API_KEY
),
model=self.parse_env(
constants.ENV_VAR_NAME__LLM_COHERE_MODEL,
default_value=constants.ENV_VAR_VALUE__LLM_COHERE_MODEL,
),
temperature=self.parse_env(
constants.ENV_VAR_NAME__LLM_TEMPERATURE,
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
),
)
elif self._llm_provider == "Open AI":
self._llm = ChatOpenAI(
api_key=self.parse_env(constants.ENV_VAR_NAME__LLM_OPENAI_API_KEY),
model=self.parse_env(
constants.ENV_VAR_NAME__LLM_OPENAI_MODEL,
default_value=constants.ENV_VAR_VALUE__LLM_OPENAI_MODEL,
),
temperature=self.parse_env(
constants.ENV_VAR_NAME__LLM_TEMPERATURE,
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
),
)
else:
raise ValueError(f"Unsupported LLM provider: {self._llm_provider}")
ic(self._llm_provider, self._llm)
Expand Down Expand Up @@ -200,13 +272,13 @@ def find_solution(
ai_message: AIMessage = result["solve"][
constants.AGENT_STATE__KEY_MESSAGES
][-1]
if not ai_message.tool_calls:
raise ValueError("Coding agent did not produce a valid code block")
yield [
ai_message.tool_calls[0]["args"]["reasoning"],
ai_message.tool_calls[0]["args"]["pseudocode"],
ai_message.tool_calls[0]["args"]["code"],
]
if ai_message.tool_calls:
# raise ValueError("Coding agent did not produce a valid code block")
yield [
ai_message.tool_calls[0]["args"]["reasoning"],
ai_message.tool_calls[0]["args"]["pseudocode"],
ai_message.tool_calls[0]["args"]["code"],
]

def add_test_case(
self, test_cases: list[TestCase], test_case_in: str, test_case_out: str
Expand Down Expand Up @@ -317,10 +389,7 @@ def construct_interface(self):
value=10,
)
with gr.Column(elem_id="ui_main_right"):
gr.Markdown(
f"""# Solution
Using `{self._llm.model}@{self._llm_provider}`"""
)
gr.Markdown("# Solution")
output_reasoning = gr.Markdown(
label="Reasoning",
show_label=True,
Expand Down

0 comments on commit a9c27af

Please sign in to comment.