From 8820b7b73d9ddd8227ca648986d7186ca17d55dc Mon Sep 17 00:00:00 2001 From: Chris Clark Date: Mon, 23 Sep 2024 09:16:51 -0400 Subject: [PATCH] bug fixes --- explorer/assistant/utils.py | 2 +- explorer/schema.py | 12 +++++++++--- explorer/tests/test_assistant.py | 20 ++++++++++++++++++-- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/explorer/assistant/utils.py b/explorer/assistant/utils.py index 0a962406..45cf1626 100644 --- a/explorer/assistant/utils.py +++ b/explorer/assistant/utils.py @@ -40,7 +40,7 @@ def extract_response(r): def table_schema(db_connection, table_name): schema = schema_info(db_connection) - s = [table for table in schema if table[0] == table_name] + s = [table for table in schema if table[0].lower() == table_name.lower()] if len(s): return s[0][1] diff --git a/explorer/schema.py b/explorer/schema.py index 07a8cf2c..cc83d184 100644 --- a/explorer/schema.py +++ b/explorer/schema.py @@ -1,4 +1,5 @@ from django.core.cache import cache +from django.db import ProgrammingError from explorer.app_settings import ( EXPLORER_SCHEMA_EXCLUDE_TABLE_PREFIXES, @@ -102,10 +103,15 @@ def build_schema_info(db_connection): for table_name in tables_to_introspect: if not _include_table(table_name): continue + try: + table_description = connection.introspection.get_table_description( + cursor, table_name + ) + # Issue 675. A connection maybe not have permissions to access some tables in the DB. + except ProgrammingError: + continue + td = [] - table_description = connection.introspection.get_table_description( - cursor, table_name - ) for row in table_description: column_name = row[0] try: diff --git a/explorer/tests/test_assistant.py b/explorer/tests/test_assistant.py index d1bcde27..60c7ed2c 100644 --- a/explorer/tests/test_assistant.py +++ b/explorer/tests/test_assistant.py @@ -16,7 +16,8 @@ ROW_SAMPLE_SIZE, build_prompt, get_relevant_few_shots, - get_relevant_annotation + get_relevant_annotation, + table_schema ) from explorer.assistant.models import TableDescription @@ -211,7 +212,6 @@ def test_format_rows_from_table(self): self.assertEqual(ret, "col1 | col2\nval1 | val2") def test_schema_info_from_table_names(self): - from explorer.assistant.utils import table_schema ret = table_schema(default_db_connection(), "explorer_query") expected = [ ("id", "AutoField"), @@ -227,6 +227,22 @@ def test_schema_info_from_table_names(self): ("few_shot", "BooleanField")] self.assertEqual(ret, expected) + def test_schema_info_from_table_names_case_invariant(self): + ret = table_schema(default_db_connection(), "EXPLORER_QUERY") + expected = [ + ("id", "AutoField"), + ("title", "CharField"), + ("sql", "TextField"), + ("description", "TextField"), + ("created_at", "DateTimeField"), + ("last_run_date", "DateTimeField"), + ("created_by_user_id", "IntegerField"), + ("snapshot", "BooleanField"), + ("connection", "CharField"), + ("database_connection_id", "IntegerField"), + ("few_shot", "BooleanField")] + self.assertEqual(ret, expected) + @unittest.skipIf(not app_settings.has_assistant(), "assistant not enabled") class TestAssistantUtils(TestCase):