Skip to content

Commit

Permalink
fix: Fixed the parse_env function.
Browse files Browse the repository at this point in the history
  • Loading branch information
anirbanbasu committed Aug 7, 2024
1 parent 37b5a0e commit f977a6a
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 70 deletions.
60 changes: 60 additions & 0 deletions src/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Copyright 2024 Anirban Basu

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Various utility functions used in the project."""

import os
from typing import Any
import constants


def parse_env(
self,
var_name: str,
default_value: str | None = None,
type_cast=str,
convert_to_list=False,
list_split_char=constants.SPACE_STRING,
) -> Any | list[Any]:
"""
Parse the environment variable and return the value.
Args:
var_name (str): The name of the environment variable.
default_value (str | None): The default value to use if the environment variable is not set. Defaults to None.
type_cast (str): The type to cast the value to.
convert_to_list (bool): Whether to convert the value to a list.
list_split_char (str): The character to split the list on.
Returns:
(Any | list[Any]) The parsed value, either as a single value or a list. The type of the returned single
value or individual elements in the list depends on the supplied type_cast parameter.
"""
if os.getenv(var_name) is None and default_value is None:
raise ValueError(
f"Environment variable {var_name} does not exist and a default value has not been provided."
)
parsed_value = None
if type_cast is bool:
parsed_value = (
os.getenv(var_name, default_value).lower() in constants.TRUE_VALUES_LIST
)
else:
parsed_value = os.getenv(var_name, default_value)

value: Any | list[Any] = (
type_cast(parsed_value)
if not convert_to_list
else [type_cast(v) for v in parsed_value.split(list_split_char)]
)
return value
98 changes: 28 additions & 70 deletions src/webapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
# limitations under the License.
"""The main web application module for the Gradio app."""

import os
from dotenv import load_dotenv
from typing import Any

from coder_agent import CoderAgent, TestCase
from utils import parse_env

try:
from icecream import ic
Expand Down Expand Up @@ -57,51 +56,51 @@ class GradioApp:
def __init__(self):
"""Default constructor for the Gradio app."""
ic(load_dotenv())
self._gradio_host: str = self.parse_env(
self._gradio_host: str = parse_env(
constants.ENV_VAR_NAME__GRADIO_SERVER_HOST,
default_value=constants.ENV_VAR_VALUE__GRADIO_SERVER_HOST,
)
self._gradio_port: int = self.parse_env(
self._gradio_port: int = parse_env(
constants.ENV_VAR_NAME__GRADIO_SERVER_PORT,
default_value=constants.ENV_VAR_VALUE__GRADIO_SERVER_PORT,
type_cast=int,
)
ic(self._gradio_host, self._gradio_port)
self._llm_provider = self.parse_env(
self._llm_provider = parse_env(
constants.ENV_VAR_NAME__LLM_PROVIDER,
default_value=constants.ENV_VAR_VALUE__LLM_PROVIDER,
)
if self._llm_provider == "Ollama":
self._llm = ChatOllama(
base_url=self.parse_env(
base_url=parse_env(
constants.ENV_VAR_NAME__LLM_OLLAMA_URL,
default_value=constants.ENV_VAR_VALUE__LLM_OLLAMA_URL,
),
model=self.parse_env(
model=parse_env(
constants.ENV_VAR_NAME__LLM_OLLAMA_MODEL,
default_value=constants.ENV_VAR_VALUE__LLM_OLLAMA_MODEL,
),
temperature=self.parse_env(
temperature=parse_env(
constants.ENV_VAR_NAME__LLM_TEMPERATURE,
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
),
top_p=self.parse_env(
top_p=parse_env(
constants.ENV_VAR_NAME__LLM_TOP_P,
default_value=constants.ENV_VAR_VALUE__LLM_TOP_P,
type_cast=float,
),
top_k=self.parse_env(
top_k=parse_env(
constants.ENV_VAR_NAME__LLM_TOP_K,
default_value=constants.ENV_VAR_VALUE__LLM_TOP_K,
type_cast=int,
),
repeat_penalty=self.parse_env(
repeat_penalty=parse_env(
constants.ENV_VAR_NAME__LLM_REPEAT_PENALTY,
default_value=constants.ENV_VAR_VALUE__LLM_REPEAT_PENALTY,
type_cast=float,
),
seed=self.parse_env(
seed=parse_env(
constants.ENV_VAR_NAME__LLM_SEED,
default_value=constants.ENV_VAR_VALUE__LLM_SEED,
type_cast=int,
Expand All @@ -110,33 +109,33 @@ def __init__(self):
)
elif self._llm_provider == "Groq":
self._llm = ChatGroq(
api_key=self.parse_env(constants.ENV_VAR_NAME__LLM_GROQ_API_KEY),
model=self.parse_env(
api_key=parse_env(constants.ENV_VAR_NAME__LLM_GROQ_API_KEY),
model=parse_env(
constants.ENV_VAR_NAME__LLM_GROQ_MODEL,
default_value=constants.ENV_VAR_VALUE__LLM_GROQ_MODEL,
),
temperature=self.parse_env(
temperature=parse_env(
constants.ENV_VAR_NAME__LLM_TEMPERATURE,
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
),
# model_kwargs={
# "top_p": self.parse_env(
# "top_p": parse_env(
# constants.ENV_VAR_NAME__LLM_TOP_P,
# default_value=constants.ENV_VAR_VALUE__LLM_TOP_P,
# type_cast=float,
# ),
# "top_k": self.parse_env(
# "top_k": parse_env(
# constants.ENV_VAR_NAME__LLM_TOP_K,
# default_value=constants.ENV_VAR_VALUE__LLM_TOP_K,
# type_cast=int,
# ),
# "repeat_penalty": self.parse_env(
# "repeat_penalty": parse_env(
# constants.ENV_VAR_NAME__LLM_REPEAT_PENALTY,
# default_value=constants.ENV_VAR_VALUE__LLM_REPEAT_PENALTY,
# type_cast=float,
# ),
# "seed": self.parse_env(
# "seed": parse_env(
# constants.ENV_VAR_NAME__LLM_SEED,
# default_value=constants.ENV_VAR_VALUE__LLM_SEED,
# type_cast=int,
Expand All @@ -149,40 +148,38 @@ def __init__(self):
)
elif self._llm_provider == "Anthropic":
self._llm = ChatAnthropic(
api_key=self.parse_env(constants.ENV_VAR_NAME__LLM_ANTHROPIC_API_KEY),
model=self.parse_env(
api_key=parse_env(constants.ENV_VAR_NAME__LLM_ANTHROPIC_API_KEY),
model=parse_env(
constants.ENV_VAR_NAME__LLM_ANTHROPIC_MODEL,
default_value=constants.ENV_VAR_VALUE__LLM_ANTHROPIC_MODEL,
),
temperature=self.parse_env(
temperature=parse_env(
constants.ENV_VAR_NAME__LLM_TEMPERATURE,
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
),
)
elif self._llm_provider == "Cohere":
self._llm = ChatCohere(
cohere_api_key=self.parse_env(
constants.ENV_VAR_NAME__LLM_COHERE_API_KEY
),
model=self.parse_env(
cohere_api_key=parse_env(constants.ENV_VAR_NAME__LLM_COHERE_API_KEY),
model=parse_env(
constants.ENV_VAR_NAME__LLM_COHERE_MODEL,
default_value=constants.ENV_VAR_VALUE__LLM_COHERE_MODEL,
),
temperature=self.parse_env(
temperature=parse_env(
constants.ENV_VAR_NAME__LLM_TEMPERATURE,
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
),
)
elif self._llm_provider == "Open AI":
self._llm = ChatOpenAI(
api_key=self.parse_env(constants.ENV_VAR_NAME__LLM_OPENAI_API_KEY),
model=self.parse_env(
api_key=parse_env(constants.ENV_VAR_NAME__LLM_OPENAI_API_KEY),
model=parse_env(
constants.ENV_VAR_NAME__LLM_OPENAI_MODEL,
default_value=constants.ENV_VAR_VALUE__LLM_OPENAI_MODEL,
),
temperature=self.parse_env(
temperature=parse_env(
constants.ENV_VAR_NAME__LLM_TEMPERATURE,
default_value=constants.ENV_VAR_VALUE__LLM_TEMPERATURE,
type_cast=float,
Expand All @@ -192,45 +189,6 @@ def __init__(self):
raise ValueError(f"Unsupported LLM provider: {self._llm_provider}")
ic(self._llm_provider, self._llm)

def parse_env(
self,
var_name: str,
default_value: str = None,
type_cast=str,
convert_to_list=False,
list_split_char=constants.SPACE_STRING,
) -> Any | list[Any]:
"""
Parse the environment variable and return the value.
Args:
var_name (str): The name of the environment variable.
default_value (str): The default value to use if the environment variable is not set.
type_cast (str): The type to cast the value to.
convert_to_list (bool): Whether to convert the value to a list.
list_split_char (str): The character to split the list on.
Returns:
(Any | list[Any]) The parsed value, either as a single value or a list. The type of the returned single
value or individual elements in the list depends on the supplied type_cast parameter.
"""
if var_name not in os.environ and default_value is None:
raise ValueError(f"Environment variable {var_name} does not exist.")
parsed_value = None
if type_cast is bool:
parsed_value = (
os.getenv(var_name, default_value).lower() in constants.TRUE_VALUES_LIST
)
else:
parsed_value = os.getenv(var_name, default_value)

value: Any | list[Any] = (
type_cast(parsed_value)
if not convert_to_list
else [type_cast(v) for v in parsed_value.split(list_split_char)]
)
return value

def find_solution(
self, user_question: str, runtime_limit: int, test_cases: list[TestCase] = None
):
Expand All @@ -253,7 +211,7 @@ def find_solution(
},
messages=[
SystemMessagePromptTemplate.from_template(
template=self.parse_env(
template=parse_env(
constants.ENV_VAR_NAME__LLM_SYSTEM_PROMPT,
constants.ENV_VAR_VALUE__LLM_SYSTEM_PROMPT,
)
Expand Down

0 comments on commit f977a6a

Please sign in to comment.