Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisclark committed Sep 23, 2024
1 parent d79d21c commit 8820b7b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 6 deletions.
2 changes: 1 addition & 1 deletion explorer/assistant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def extract_response(r):

def table_schema(db_connection, table_name):
schema = schema_info(db_connection)
s = [table for table in schema if table[0] == table_name]
s = [table for table in schema if table[0].lower() == table_name.lower()]
if len(s):
return s[0][1]

Expand Down
12 changes: 9 additions & 3 deletions explorer/schema.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from django.core.cache import cache
from django.db import ProgrammingError

from explorer.app_settings import (
EXPLORER_SCHEMA_EXCLUDE_TABLE_PREFIXES,
Expand Down Expand Up @@ -102,10 +103,15 @@ def build_schema_info(db_connection):
for table_name in tables_to_introspect:
if not _include_table(table_name):
continue
try:
table_description = connection.introspection.get_table_description(
cursor, table_name
)
# Issue 675. A connection maybe not have permissions to access some tables in the DB.
except ProgrammingError:
continue

td = []
table_description = connection.introspection.get_table_description(
cursor, table_name
)
for row in table_description:
column_name = row[0]
try:
Expand Down
20 changes: 18 additions & 2 deletions explorer/tests/test_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
ROW_SAMPLE_SIZE,
build_prompt,
get_relevant_few_shots,
get_relevant_annotation
get_relevant_annotation,
table_schema
)

from explorer.assistant.models import TableDescription
Expand Down Expand Up @@ -211,7 +212,6 @@ def test_format_rows_from_table(self):
self.assertEqual(ret, "col1 | col2\nval1 | val2")

def test_schema_info_from_table_names(self):
from explorer.assistant.utils import table_schema
ret = table_schema(default_db_connection(), "explorer_query")
expected = [
("id", "AutoField"),
Expand All @@ -227,6 +227,22 @@ def test_schema_info_from_table_names(self):
("few_shot", "BooleanField")]
self.assertEqual(ret, expected)

def test_schema_info_from_table_names_case_invariant(self):
ret = table_schema(default_db_connection(), "EXPLORER_QUERY")
expected = [
("id", "AutoField"),
("title", "CharField"),
("sql", "TextField"),
("description", "TextField"),
("created_at", "DateTimeField"),
("last_run_date", "DateTimeField"),
("created_by_user_id", "IntegerField"),
("snapshot", "BooleanField"),
("connection", "CharField"),
("database_connection_id", "IntegerField"),
("few_shot", "BooleanField")]
self.assertEqual(ret, expected)


@unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled")
class TestAssistantUtils(TestCase):
Expand Down

0 comments on commit 8820b7b

Please sign in to comment.