diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index 38ec1ca1..255b5c7e 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -20,7 +20,11 @@ from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.golden_records import GoldenRecordRepository from dataherald.repositories.nl_question import NLQuestionRepository -from dataherald.sql_database.base import SQLDatabase, SQLInjectionError +from dataherald.sql_database.base import ( + InvalidDBConnectionError, + SQLDatabase, + SQLInjectionError, +) from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator import SQLGenerator from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer @@ -136,8 +140,16 @@ def create_database_connection( use_ssh=database_connection_request.use_ssh, ssh_settings=database_connection_request.ssh_settings, ) + + SQLDatabase.get_sql_engine(db_connection, True) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) # noqa: B904 + except InvalidDBConnectionError as e: + raise HTTPException( # noqa: B904 + status_code=400, + detail=f"{e}", + ) + db_connection_repository = DatabaseConnectionRepository(self.storage) return db_connection_repository.insert(db_connection) @@ -161,8 +173,15 @@ def update_database_connection( use_ssh=database_connection_request.use_ssh, ssh_settings=database_connection_request.ssh_settings, ) + + SQLDatabase.get_sql_engine(db_connection, True) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) # noqa: B904 + except InvalidDBConnectionError as e: + raise HTTPException( # noqa: B904 + status_code=400, + detail=f"{e}", + ) db_connection_repository = DatabaseConnectionRepository(self.storage) return db_connection_repository.update(db_connection) diff --git a/dataherald/sql_database/base.py b/dataherald/sql_database/base.py index f4f465ea..f66541f1 100644 --- a/dataherald/sql_database/base.py +++ b/dataherald/sql_database/base.py @@ -21,6 +21,10 @@ class SQLInjectionError(Exception): pass +class InvalidDBConnectionError(Exception): + pass + + class DBConnections: db_connections = {} @@ -62,9 +66,15 @@ def from_uri( return cls(engine, **kwargs) @classmethod - def get_sql_engine(cls, database_info: DatabaseConnection) -> "SQLDatabase": + def get_sql_engine( + cls, database_info: DatabaseConnection, refresh_connection=False + ) -> "SQLDatabase": logger.info(f"Connecting db: {database_info.id}") - if database_info.id in DBConnections.db_connections: + if ( + database_info.id + and database_info.id in DBConnections.db_connections + and not refresh_connection + ): return DBConnections.db_connections[database_info.id] fernet_encrypt = FernetEncrypt() @@ -80,9 +90,13 @@ def get_sql_engine(cls, database_info: DatabaseConnection) -> "SQLDatabase": file_path = s3.download(file_path) db_uri = db_uri + f"?credentials_path={file_path}" - - engine = cls.from_uri(db_uri) - DBConnections.add(database_info.id, engine) + try: + engine = cls.from_uri(db_uri) + DBConnections.add(database_info.id, engine) + except Exception as e: + raise InvalidDBConnectionError( # noqa: B904 + f"Unable to connect to db: {database_info.alias}, {e}" + ) return engine @classmethod