Skip to content

Commit

Permalink
DH-4620/format sql queries in engine to be user-friendly (#156)
Browse files Browse the repository at this point in the history
* DH-4620/format sql queries in enigne to be user-friendly

* DH-4620/adding the string utils
  • Loading branch information
MohammadrezaPourreza authored Sep 13, 2023
1 parent ed77fa0 commit c167a43
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 4 deletions.
15 changes: 15 additions & 0 deletions dataherald/sql_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datetime import date, datetime
from typing import Any, List, Tuple

import sqlparse
from langchain.schema import AgentAction

from dataherald.config import Component, System
Expand All @@ -12,6 +13,7 @@
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator.create_sql_query_status import create_sql_query_status
from dataherald.types import NLQuery, NLQueryResponse, SQLQueryResult
from dataherald.utils.strings import contains_line_breaks


class SQLGenerator(Component, ABC):
Expand Down Expand Up @@ -42,6 +44,19 @@ def format_intermediate_representations(
)
return formatted_intermediate_representation

def format_sql_query(self, sql_query: str) -> str:
comments = [
match.group() for match in re.finditer(r"--.*$", sql_query, re.MULTILINE)
]
sql_query_without_comments = re.sub(r"--.*$", "", sql_query, flags=re.MULTILINE)

if contains_line_breaks(sql_query_without_comments.strip()):
return sql_query

parsed = sqlparse.format(sql_query_without_comments, reindent=True)

return parsed + "\n" + "\n".join(comments)

@abstractmethod
def generate_response(
self,
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/dataherald_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,7 @@ def generate_response(
for step in result["intermediate_steps"]:
action = step[0]
if type(action) == AgentAction and action.tool == "sql_db_query":
sql_query_list.append(action.tool_input)
sql_query_list.append(self.format_sql_query(action.tool_input))
intermediate_steps = self.format_intermediate_representations(
result["intermediate_steps"]
)
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/langchain_sqlagent.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def generate_response(
for step in result["intermediate_steps"]:
action = step[0]
if type(action) == AgentAction and action.tool == "sql_db_query":
sql_query_list.append(action.tool_input)
sql_query_list.append(self.format_sql_query(action.tool_input))
intermediate_steps = self.format_intermediate_representations(
result["intermediate_steps"]
)
Expand Down
2 changes: 1 addition & 1 deletion dataherald/sql_generator/langchain_sqlchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ def generate_response(
exec_time=exec_time,
total_cost=cb.total_cost,
total_tokens=cb.total_tokens,
sql_query=result["intermediate_steps"][1],
sql_query=self.format_sql_query(result["intermediate_steps"][1]),
)
return self.create_sql_query_status(self.database, response.sql_query, response)
2 changes: 1 addition & 1 deletion dataherald/sql_generator/llamaindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,6 @@ def generate_response(
total_tokens=token_counter.total_llm_token_count,
total_cost=total_cost,
intermediate_steps=[str(result.metadata)],
sql_query=result.metadata["sql_query"],
sql_query=self.format_sql_query(result.metadata["sql_query"]),
)
return self.create_sql_query_status(self.database, response.sql_query, response)
9 changes: 9 additions & 0 deletions dataherald/utils/strings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import re


def remove_whitespace(input_string: str) -> str:
return re.sub(r"\s+", " ", input_string).strip()


def contains_line_breaks(input_string: str) -> bool:
return "\n" in input_string

0 comments on commit c167a43

Please sign in to comment.