Skip to content

Commit

Permalink
Streamline show table (#907)
Browse files Browse the repository at this point in the history
Co-authored-by: Philipp Rudiger <prudiger@anaconda.com>
  • Loading branch information
ahuang11 and philippjfr authored Jan 2, 2025
1 parent 0b69b09 commit 5d347d5
Showing 1 changed file with 23 additions and 11 deletions.
34 changes: 23 additions & 11 deletions lumen/ai/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,16 +459,26 @@ class SQLAgent(LumenBaseAgent):

_output_type = SQLOutput

async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource]:
async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, BaseSQLSource, bool]:
"""Select the most relevant table based on the user query."""
join_required = None
sources = self._memory["sources"]
tables_to_source, tables_schema_str = await gather_table_sources(sources)
tables = tuple(tables_to_source)
if messages and messages[-1]["content"].startswith("Show the table: '"):

user_message = ""
for message in messages[::-1]:
if message["role"] == "user":
user_message = message["content"]
break

if messages and "Show the table: " in user_message:
# Handle the case where explicitly requested a table
table = re.search(r"Show the table: '([^']+)'", messages[-1]["content"]).group(1)
table = re.search(r"Show the table: '([^']+)'", user_message).group(1)
join_required = False
elif len(tables) == 1:
table = tables[0]
join_required = False
else:
with self.interface.add_step(title="Choosing the most relevant table...", steps_layout=self._steps_layout) as step:
if len(tables) > 1:
Expand Down Expand Up @@ -499,7 +509,7 @@ async def _select_relevant_table(self, messages: list[Message]) -> tuple[str, Ba
sources = [src for src in sources if table in src]
source = sources[0] if sources else self._memory["source"]

return table, source
return table, source, join_required

@retry_llm_output()
async def _create_valid_sql(
Expand Down Expand Up @@ -690,19 +700,21 @@ async def respond(
8. If a join is required, remove source/table prefixes from the last message.
9. Construct the SQL query with `_create_valid_sql`.
"""
table, source = await self._select_relevant_table(messages)
table, source, join_required = await self._select_relevant_table(messages)
if not hasattr(source, "get_sql_expr"):
return None

# include min max for more context for data cleaning
schema = await get_schema(source, table, include_min_max=True)
join_required = await self._check_requires_joins(messages, schema, table)

tables_to_source = {table: source}
if join_required is None:
return None
if join_required:
tables_to_source = await self.find_join_tables(messages)
else:
tables_to_source = {table: source}
join_required = await self._check_requires_joins(messages, schema, table)
if join_required is None:
# Bail if query was cancelled or errored out
return None
if join_required:
tables_to_source = await self.find_join_tables(messages)

tables_sql_schemas = {}
for source_table, source in tables_to_source.items():
Expand Down

0 comments on commit 5d347d5

Please sign in to comment.