Skip to content

Commit

Permalink
Merge pull request #151 from athina-ai/feature/async-llm-calls
Browse files Browse the repository at this point in the history
Adding support for async run prompt and api call steps
  • Loading branch information
vivek-athina authored Dec 23, 2024
2 parents c562460 + 7e36c1c commit 5e70c7c
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 69 deletions.
175 changes: 107 additions & 68 deletions athina/steps/api.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,83 @@
# Step to make an external api call
import json
import time
from typing import Union, Dict, Any, Iterable, Optional
import requests
from athina.steps import Step
from typing import Union, Dict, Any, Optional
import aiohttp
from jinja2 import Environment
from athina.helpers.jinja_helper import PreserveUndefined
import urllib.parse
from athina.steps.base import Step
import asyncio


def prepare_input_data(data):
def prepare_input_data(data: Dict[str, Any]) -> Dict[str, Any]:
"""Prepare input data by converting complex types to JSON strings."""
return {
key: json.dumps(value) if isinstance(value, (list, dict)) else value
for key, value in data.items()
}


def create_jinja_env() -> Environment:
"""Create a Jinja2 environment with custom settings."""
return Environment(
variable_start_string="{{",
variable_end_string="}}",
undefined=PreserveUndefined,
)


def prepare_template_data(
env: Environment,
template_dict: Optional[Dict[str, str]],
input_data: Dict[str, Any],
) -> Optional[Dict[str, str]]:
"""Prepare template data by rendering Jinja2 templates."""
if template_dict is None:
return None

prepared_dict = template_dict.copy()
for key, value in prepared_dict.items():
prepared_dict[key] = env.from_string(value).render(**input_data)
return prepared_dict


def prepare_body(
env: Environment, body_template: Optional[str], input_data: Dict[str, Any]
) -> Optional[str]:
"""Prepare request body by rendering Jinja2 template."""
if body_template is None:
return None

return env.from_string(body_template).render(**input_data)


def process_response(
status_code: int,
response_text: str,
) -> Dict[str, Any]:
"""Process the API response and return formatted result."""
if status_code >= 400:
# If the status code is an error, return the error message
return {
"status": "error",
"data": f"Failed to make the API call.\nStatus code: {status_code}\nError:\n{response_text}",
}

try:
json_response = json.loads(response_text)
# If the response is JSON, return the JSON data
return {
"status": "success",
"data": json_response,
}
except json.JSONDecodeError:
# If the response is not JSON, return the text
return {
"status": "success",
"data": response_text,
}


class ApiCall(Step):
"""
Step that makes an external API call.
Expand All @@ -35,85 +97,58 @@ class ApiCall(Step):
body: Optional[str] = None
env: Environment = None
name: Optional[str] = None
timeout: int = 30 # Default timeout in seconds
retries: int = 2 # Default number of retries

class Config:
arbitrary_types_allowed = True

def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
"""Make an API call and return the response."""
async def execute_async(self, input_data: Any) -> Union[Dict[str, Any], None]:
"""Make an async API call and return the response."""

if input_data is None:
input_data = {}

if not isinstance(input_data, dict):
raise TypeError("Input data must be a dictionary.")

# Create a custom Jinja2 environment with double curly brace delimiters and PreserveUndefined
self.env = Environment(
variable_start_string="{{",
variable_end_string="}}",
undefined=PreserveUndefined,
# Prepare the environment and input data
self.env = create_jinja_env()
prepared_input_data = prepare_input_data(input_data)

# Prepare request components
prepared_body = prepare_body(self.env, self.body, prepared_input_data)
prepared_headers = prepare_template_data(
self.env, self.headers, prepared_input_data
)
prepared_body = None
# Add a filter to the Jinja2 environment to convert the input data to JSON
if self.body is not None:
body_template = self.env.from_string(self.body)
prepared_input_data = prepare_input_data(input_data)
prepared_body = body_template.render(**prepared_input_data)

prepared_headers = self.headers.copy() if self.headers is not None else None
prepared_params = self.params.copy() if self.params is not None else None

if prepared_headers is not None:
for key, value in prepared_headers.items():
prepared_headers[key] = self.env.from_string(value).render(
**prepared_input_data
)

if prepared_params is not None:
for key, value in prepared_params.items():
prepared_params[key] = self.env.from_string(value).render(
**prepared_input_data
)

retries = 2 # number of retries
timeout = 30 # seconds
for attempt in range(retries):
prepared_params = prepare_template_data(
self.env, self.params, prepared_input_data
)

timeout = aiohttp.ClientTimeout(total=self.timeout)

for attempt in range(self.retries):
try:
response = requests.request(
method=self.method,
url=self.url,
headers=prepared_headers,
params=prepared_params,
json=(
async with aiohttp.ClientSession(timeout=timeout) as session:
json_body = (
json.loads(prepared_body, strict=False)
if prepared_body
else None
),
timeout=timeout,
)
if response.status_code >= 400:
# If the status code is an error, return the error message
return {
"status": "error",
"data": f"Failed to make the API call.\nStatus code: {response.status_code}\nError:\n{response.text}",
}
try:
json_response = response.json()
# If the response is JSON, return the JSON data
return {
"status": "success",
"data": json_response,
}
except json.JSONDecodeError:
# If the response is not JSON, return the text
return {
"status": "success",
"data": response.text,
}
except requests.Timeout:
if attempt < retries - 1:
time.sleep(2)
)

async with session.request(
method=self.method,
url=self.url,
headers=prepared_headers,
params=prepared_params,
json=json_body,
) as response:
response_text = await response.text()
return process_response(response.status, response_text)

except asyncio.TimeoutError:
if attempt < self.retries - 1:
await asyncio.sleep(2)
continue
# If the request times out after multiple attempts, return an error message
return {
Expand All @@ -126,3 +161,7 @@ def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
"status": "error",
"data": f"Failed to make the API call.\nError: {e.__class__.__name__}\nDetails:\n{str(e)}",
}

def execute(self, input_data: Any) -> Union[Dict[str, Any], None]:
"""Synchronous execute api call that runs the async method in an event loop."""
return asyncio.run(self.execute_async(input_data))
4 changes: 4 additions & 0 deletions athina/steps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ def execute(self, input_data: Any) -> Any:
"""Execute the core logic of the step. This should be implemented by subclasses."""
raise NotImplementedError("Subclasses must implement this method")

async def execute_async(self, input_data: Any) -> Any:
"""Execute the core logic of the step asynchronously. This should be implemented by subclasses."""
pass


class Debug(Step):
"""
Expand Down
75 changes: 75 additions & 0 deletions athina/steps/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,78 @@ def execute(self, input_data: dict, **kwargs) -> str:
except Exception as e:
traceback.print_exc()
return {"status": "error", "data": str(e)}

async def execute_async(self, input_data: dict, **kwargs) -> dict:
"""Execute a prompt with the LLM service asynchronously."""
if input_data is None:
input_data = {}

if not isinstance(input_data, dict) and self.input_key:
raise ValueError("PromptExecution Error: Input data must be a dictionary")

try:
messages = self.template.resolve(**input_data)
# Convert messages to API format
# TODO: Why is api_formatted_messages not used?
api_formatted_messages = [msg.to_api_format() for msg in messages]

llm_service_response = await self.llm_service.chat_completion_async(
messages,
model=self.model,
**self.model_options.model_dump(),
**(self.tool_config.model_dump() if self.tool_config else {}),
**({"response_format": self.response_format}),
**(
kwargs.get("search_domain_filter", {})
if isinstance(kwargs.get("search_domain_filter"), dict)
else {}
),
)
llmresponse = llm_service_response["value"]
output_type = kwargs.get("output_type", None)
error = None
if output_type:
if output_type == "string":
if not isinstance(llmresponse, str):
error = "LLM response is not a string"
response = llmresponse

elif output_type == "number":
extracted_response = ExtractNumberFromString().execute(llmresponse)
if not isinstance(extracted_response, (int, float)):
error = "LLM response is not a number"
response = extracted_response

elif output_type == "array":
extracted_response = ExtractJsonFromString().execute(llmresponse)
if not isinstance(extracted_response, list):
error = "LLM response is not an array"
response = extracted_response

elif output_type == "object":
extracted_response = ExtractJsonFromString().execute(llmresponse)
if not isinstance(extracted_response, dict):
error = "LLM response is not an object"
response = extracted_response

elif not isinstance(llmresponse, str):
error = "LLM service response is not a string"

else:
response = llmresponse

if error:
return {"status": "error", "data": error}
else:
return {
"status": "success",
"data": response,
"metadata": (
json.loads(llm_service_response.get("metadata", "{}"))
if llm_service_response.get("metadata")
else {}
),
}
except Exception as e:
traceback.print_exc()
return {"status": "error", "data": str(e)}
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "athina"
version = "1.6.30"
version = "1.6.31"
description = "Python SDK to configure and run evaluations for your LLM-based application"
authors = ["Shiv Sakhuja <shiv@athina.ai>", "Akshat Gupta <akshat@athina.ai>", "Vivek Aditya <vivek@athina.ai>", "Akhil Bisht <akhil@athina.ai>"]
readme = "README.md"
Expand Down

0 comments on commit 5e70c7c

Please sign in to comment.