Skip to content

Commit

Permalink
DH-4705 Validate db connections
Browse files Browse the repository at this point in the history
  • Loading branch information
jcjc712 committed Sep 22, 2023
1 parent 10fcf3e commit ec9fa91
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
21 changes: 20 additions & 1 deletion dataherald/api/fastapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
from dataherald.repositories.database_connections import DatabaseConnectionRepository
from dataherald.repositories.golden_records import GoldenRecordRepository
from dataherald.repositories.nl_question import NLQuestionRepository
from dataherald.sql_database.base import SQLDatabase, SQLInjectionError
from dataherald.sql_database.base import (
InvalidDBConnectionError,
SQLDatabase,
SQLInjectionError,
)
from dataherald.sql_database.models.types import DatabaseConnection
from dataherald.sql_generator import SQLGenerator
from dataherald.sql_generator.generates_nl_answer import GeneratesNlAnswer
Expand Down Expand Up @@ -136,8 +140,16 @@ def create_database_connection(
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
)

SQLDatabase.get_sql_engine(db_connection, True)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) # noqa: B904
except InvalidDBConnectionError as e:
raise HTTPException( # noqa: B904
status_code=400,
detail=f"{e}",
)

db_connection_repository = DatabaseConnectionRepository(self.storage)
return db_connection_repository.insert(db_connection)

Expand All @@ -161,8 +173,15 @@ def update_database_connection(
use_ssh=database_connection_request.use_ssh,
ssh_settings=database_connection_request.ssh_settings,
)

SQLDatabase.get_sql_engine(db_connection, True)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e)) # noqa: B904
except InvalidDBConnectionError as e:
raise HTTPException( # noqa: B904
status_code=400,
detail=f"{e}",
)
db_connection_repository = DatabaseConnectionRepository(self.storage)
return db_connection_repository.update(db_connection)

Expand Down
24 changes: 19 additions & 5 deletions dataherald/sql_database/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class SQLInjectionError(Exception):
pass


class InvalidDBConnectionError(Exception):
pass


class DBConnections:
db_connections = {}

Expand Down Expand Up @@ -62,9 +66,15 @@ def from_uri(
return cls(engine, **kwargs)

@classmethod
def get_sql_engine(cls, database_info: DatabaseConnection) -> "SQLDatabase":
def get_sql_engine(
cls, database_info: DatabaseConnection, refresh_connection=False
) -> "SQLDatabase":
logger.info(f"Connecting db: {database_info.id}")
if database_info.id in DBConnections.db_connections:
if (
database_info.id
and database_info.id in DBConnections.db_connections
and not refresh_connection
):
return DBConnections.db_connections[database_info.id]

fernet_encrypt = FernetEncrypt()
Expand All @@ -80,9 +90,13 @@ def get_sql_engine(cls, database_info: DatabaseConnection) -> "SQLDatabase":
file_path = s3.download(file_path)

db_uri = db_uri + f"?credentials_path={file_path}"

engine = cls.from_uri(db_uri)
DBConnections.add(database_info.id, engine)
try:
engine = cls.from_uri(db_uri)
DBConnections.add(database_info.id, engine)
except Exception as e:
raise InvalidDBConnectionError( # noqa: B904
f"Unable to connect to db: {database_info.alias}, {e}"
)
return engine

@classmethod
Expand Down

0 comments on commit ec9fa91

Please sign in to comment.