Skip to content

Commit

Permalink
Add sql context (#890)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahuang11 authored Dec 30, 2024
1 parent 2808c74 commit 63935df
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 29 deletions.
22 changes: 11 additions & 11 deletions lumen/ai/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@
from .models import Validity, make_agent_model, make_plan_models
from .tools import FunctionTool, Tool
from .utils import (
get_schema, log_debug, mutate_user_message, retry_llm_output,
gather_table_sources, get_schema, log_debug, mutate_user_message,
retry_llm_output,
)
from .views import LumenOutput

Expand Down Expand Up @@ -678,7 +679,8 @@ async def _lookup_schemas(
if table in provided:
continue
provided.append(table)
schema_info += f'- {table}: {cache[table]}\n\n'
schema = cache[table]
schema_info += f'- {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n'
return schema_info

async def _make_plan(
Expand All @@ -691,7 +693,8 @@ async def _make_plan(
reason_model: type[BaseModel],
plan_model: type[BaseModel],
step: ChatStep,
schemas: dict[str, dict] | None = None
schemas: dict[str, dict] | None = None,
tables_schema_str: str = ""
) -> BaseModel:
info = ''
reasoning = None
Expand All @@ -704,7 +707,6 @@ async def _make_plan(
while reasoning is None or requested:
log_debug(f"Creating plan for \033[91m{requested}\033[0m")
info += await self._lookup_schemas(tables, requested, provided, cache=schemas)
available = [t for t in tables if t not in provided]
system = await self._render_prompt(
"main",
messages,
Expand All @@ -714,7 +716,7 @@ async def _make_plan(
candidates=[agent for agent in agents.values() if not unmet_dependencies or set(agent.provides) & unmet_dependencies],
previous_plans=previous_plans,
table_info=info,
tables=available
tables_schema_str=tables_schema_str,
)
model_spec = self.prompts["main"].get("llm_spec", "default")
reasoning = await self.llm.invoke(
Expand All @@ -729,7 +731,7 @@ async def _make_plan(
step.stream(reasoning.chain_of_thought, replace=True)
previous_plans.append(reasoning.chain_of_thought)
requested = [
t for t in getattr(reasoning, 'tables', [])
t for t in getattr(reasoning, 'requested_tables', [])
if t and t not in provided
]

Expand Down Expand Up @@ -815,10 +817,8 @@ async def _resolve_plan(self, plan, agents, messages) -> tuple[list[ExecutionNod
async def _compute_execution_graph(self, messages: list[Message], agents: dict[str, Agent]) -> list[ExecutionNode]:
tool_names = [tool.name for tool in self._tools["__main__"]]
agent_names = [sagent.name[:-5] for sagent in agents.values()]
tables = {}
for src in self._memory['sources']:
for table in src.get_tables():
tables[table] = src

tables, tables_schema_str = await gather_table_sources(self._memory['sources'])

reason_model, plan_model = self._get_model(
"main",
Expand All @@ -838,7 +838,7 @@ async def _compute_execution_graph(self, messages: list[Message], agents: dict[s
plan = None
try:
plan = await self._make_plan(
messages, agents, tables, unmet_dependencies, previous_plans, reason_model, plan_model, istep, schemas
messages, agents, tables, unmet_dependencies, previous_plans, reason_model, plan_model, istep, schemas, tables_schema_str
)
except asyncio.CancelledError as e:
istep.failed_title = 'Planning was cancelled, please try again.'
Expand Down
9 changes: 5 additions & 4 deletions lumen/ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ class Sql(BaseModel):
If it's simple, just provide one step. However, if applicable, be sure to carefully
study the schema, discuss the values in the columns, and whether you need to
wrangle the data before you can use it, before finally writing a correct and valid
SQL query that fully answers the user's query.
SQL query that fully answers the user's query. If using CTEs, comment on the purpose of each.
Everything should be made as simple as possible, but no simpler.
"""
)

Expand Down Expand Up @@ -104,18 +105,18 @@ def make_plan_models(experts_or_tools: list[str], tables: list[str]):
)
extras = {}
if tables:
extras['tables'] = (
extras['requested_tables'] = (
list[Literal[tuple(tables)]],
FieldInfo(
description="A list of tables to load into memory before coming up with a plan. NOTE: Simple queries asking to list the tables/datasets do not require loading the tables. Table names MUST match verbatim including the quotations, apostrophes, periods, or lack thereof."
description="A list of the most relevant tables to explore and load into memory before coming up with a plan, based on the chain of thought. NOTE: Simple queries asking to list the tables/datasets do not require loading the tables. Table names MUST match verbatim including the quotations, apostrophes, periods, or lack thereof."
)
)
reasoning = create_model(
'Reasoning',
chain_of_thought=(
str,
FieldInfo(
description="Describe at a high-level how the actions of each expert will solve the user query."
description="Describe at a high-level how the actions of each expert will solve the user query. If the user asks a question about data that is seemingly unavailable, please request to see tables' schemas."
),
),
**extras
Expand Down
9 changes: 3 additions & 6 deletions lumen/ai/prompts/Planner/main.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ Agent Rules:
{%- if 'table' in memory %}- The result of the previous step was the `{{ memory['table'] }}` table. Consider carefully if it contains all the information you need and only invoke the SQL agent if some other calculation needs to be performed.{% endif -%}

{%- if table_info %}
Here are schemas for tables that were recently used (note that these schemas are computed on a subset of data):
Here are tables and schemas that are available to you:
{{ table_info }}
{%- endif %}
{%- if tables %}
Additionally the following datasets/tables are available and you may request to look at them before revising your plan:
{% for table in tables %}
- {{ table }}
{% endfor %}
{%- if tables_schema_str %}
{{ tables_schema_str }}
{%- endif -%}
{% if memory.get('document_sources') %}
Here are the documents you have access to:
Expand Down
2 changes: 1 addition & 1 deletion lumen/ai/prompts/SQLAgent/main.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ CAST or TO_DATE
- Try to pretty print the SQL output with newlines and indentation.
- Specify data types explicitly to avoid type mismatches.
- Handle NULL values using functions like COALESCE or IS NULL.
- Capture only the required numeric values while removing all whitespace, like `(\d+)`, or remove characters like `$`, `%`, `,`, etc, if necessary.
- Capture only the required numeric values while removing all whitespace, like `(\d+)`, or remove characters like `$`, `%`, `,`, etc, only if needed.
- Use parameterized queries to prevent SQL injection attacks.
- Use Common Table Expressions (CTEs) and subqueries to break down complex queries into manageable parts.
- Be sure to remove suspiciously large or small values that may be invalid, like -9999.
Expand Down
5 changes: 3 additions & 2 deletions lumen/ai/prompts/SQLAgent/select_table.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@

{% block instructions %}
Select the most appropriate table to use in the SQL query to answer the user's question.
Note the table names may be outdated or incorrect, so check the schemas for the most
accurate information if it's available.
{% endblock %}

{% block context %}
Here are the table schemas:

Here are the tables (and schemas if available):
{{ tables_schema_str }}
{% endblock %}
11 changes: 6 additions & 5 deletions lumen/ai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,16 +339,17 @@ async def gather_table_sources(sources: list[Source]) -> tuple[dict[str, Source]
and a markdown string of the tables and their schemas.
"""
tables_to_source = {}
tables_schema_str = "\nHere are the tables\n"
tables_schema_str = ""
for source in sources:
for table in source.get_tables():
tables_to_source[table] = source
if isinstance(source, DuckDBSource) and source.ephemeral:
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=1)
tables_schema_str += f"### {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\n"
sql = source.get_sql_expr(table)
schema = await get_schema(source, table, include_min_max=False, include_enum=True, limit=3)
tables_schema_str += f"- {table}\nSchema:\n```yaml\n{yaml.dump(schema)}```\nSQL:\n```sql\n{sql}\n```\n\n"
else:
tables_schema_str += f"### {table}\n"
return tables_to_source, tables_schema_str
tables_schema_str += f"- {table}\n\n"
return tables_to_source, tables_schema_str.strip()


def log_debug(msg: str | list, offset: int = 24, prefix: str = "", suffix: str = "", show_sep: bool = False, show_length: bool = False):
Expand Down
1 change: 1 addition & 0 deletions lumen/sources/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ def create_sql_expr_source(
if 'uri' not in kwargs and 'initializers' not in kwargs:
params['_connection'] = self._connection
params.pop('name', None)
params["ephemeral"] = True
source = type(self)(**params)
if not materialize:
return source
Expand Down

0 comments on commit 63935df

Please sign in to comment.