Skip to content

Commit

Permalink
DH-4595 added support for adding db_connection from admin-console (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
DishenWang2023 committed May 7, 2024
1 parent efb1cf8 commit 140aab7
Show file tree
Hide file tree
Showing 16 changed files with 117 additions and 64 deletions.
22 changes: 21 additions & 1 deletion apps/ai/server/modules/db_connection/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi.security import HTTPBearer
from pydantic import Json

from modules.db_connection.models.responses import DBConnectionResponse
from modules.db_connection.models.responses import DBConnectionResponse, DriverResponse
from modules.db_connection.service import DBConnectionService
from utils.auth import Authorize, VerifyToken

Expand All @@ -16,6 +16,12 @@
db_connection_service = DBConnectionService()


@router.get("/drivers", status_code=status.HTTP_200_OK)
async def get_drivers(token: str = Depends(token_auth_scheme)) -> list[DriverResponse]:
VerifyToken(token.credentials).verify()
return db_connection_service.get_drivers()


@router.get("/list", status_code=status.HTTP_200_OK)
async def get_db_connections(
token: str = Depends(token_auth_scheme),
Expand Down Expand Up @@ -43,3 +49,17 @@ async def add_db_connection(
return await db_connection_service.add_db_connection(
db_connection_request_json, org_id, file
)


@router.put("/{id}", status_code=status.HTTP_200_OK)
async def update_db_connection(
id: str,
db_connection_request_json: Json = Form(),
file: UploadFile = None,
token: str = Depends(token_auth_scheme),
) -> DBConnectionResponse:
org_id = authorize.user_and_get_org_id(VerifyToken(token.credentials).verify())
authorize.db_connection_in_organization(id, org_id)
return await db_connection_service.update_db_connection(
id, db_connection_request_json, file
)
9 changes: 7 additions & 2 deletions apps/ai/server/modules/db_connection/models/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,22 @@ class SSHSettings(BaseModel):
class BaseDBConnection(BaseModel):
alias: str | None
use_ssh: bool = False
connection_uri: str | None

path_to_credentials_file: str | None
ssh_settings: SSHSettings | None


class DBConnection(BaseDBConnection):
id: Any = Field(alias="_id")
uri: str | None


class DBConnectionRef(BaseModel):
id: Any = Field(alias="_id")
db_connection_id: Any
organization_id: Any
alias: str


class Driver(BaseModel):
name: str | None
driver: str
1 change: 1 addition & 0 deletions apps/ai/server/modules/db_connection/models/requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@


class DBConnectionRequest(BaseDBConnection):
connection_uri: str | None
pass
7 changes: 6 additions & 1 deletion apps/ai/server/modules/db_connection/models/responses.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from modules.db_connection.models.entities import BaseDBConnection
from modules.db_connection.models.entities import BaseDBConnection, Driver


class DBConnectionResponse(BaseDBConnection):
id: str
uri: str


class DriverResponse(Driver):
pass
47 changes: 33 additions & 14 deletions apps/ai/server/modules/db_connection/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi import UploadFile

from config import settings
from modules.db_connection.models.entities import DBConnection, DBConnectionRef
from modules.db_connection.models.entities import DBConnectionRef, Driver
from modules.db_connection.models.requests import DBConnectionRequest
from modules.db_connection.models.responses import DBConnectionResponse
from modules.db_connection.repository import DBConnectionRepository
Expand All @@ -25,14 +25,18 @@ def get_db_connections(self, org_id: str) -> list[DBConnectionResponse]:
]
db_connections = self.repo.get_db_connections(db_connection_ids)
return [
self._get_mapped_db_connection_response(db_connection)
DBConnectionResponse(
id=str(db_connection.id), **db_connection.dict(exclude={"id"})
)
for db_connection in db_connections
]

def get_db_connection(self, db_connection_id: str) -> DBConnectionResponse:
db_connection = self.repo.get_db_connection(db_connection_id)
return (
self._get_mapped_db_connection_response(db_connection)
DBConnectionResponse(
id=str(db_connection.id), **db_connection.dict(exclude={"id"})
)
if db_connection
else None
)
Expand All @@ -58,25 +62,40 @@ async def add_db_connection(
raise_for_status(response.status_code, response.text)

response_json = response.json()
db_connection = DBConnection(**response_json)
db_connection.id = ObjectId(response_json["id"])
self.repo.add_db_connection_ref(
DBConnectionRef(
alias=db_connection_request.alias,
db_connection_id=db_connection.id,
db_connection_id=ObjectId(response_json["id"]),
organization_id=ObjectId(org_id),
).dict(exclude={"id"})
)

self.org_service.update_db_connection_id(org_id, response_json["id"])

return self._get_mapped_db_connection_response(db_connection)
return DBConnectionResponse(**response.json())

def _get_mapped_db_connection_response(
self, db_connection: DBConnection
async def update_db_connection(
self, id, db_connection_request_json: dict, file: UploadFile = None
) -> DBConnectionResponse:
db_connection_response = DBConnectionResponse(
id=str(db_connection.id), **db_connection.dict(exclude={"id"})
)
db_connection_response.id = str(db_connection_response.id)
return db_connection_response
db_connection_request = DBConnectionRequest(**db_connection_request_json)

if file:
s3 = S3()
db_connection_request.path_to_credentials_file = s3.upload(file)

async with httpx.AsyncClient() as client:
response = await client.put(
settings.k2_core_url + f"/database-connections/{id}",
json=db_connection_request.dict(),
)
raise_for_status(response.status_code, response.text)
return DBConnectionResponse(**response.json())

def get_drivers(self) -> list[Driver]:
return [
{"name": "PostgreSQL", "driver": "postgresql+psycopg2"},
{"name": "Databricks", "driver": "databricks"},
{"name": "Snowflake", "driver": "snowflake"},
{"name": "BigQuery", "driver": "bigquery"},
{"name": "AWS Athena", "driver": "awsathena+rest"},
]
4 changes: 2 additions & 2 deletions apps/ai/server/modules/golden_sql/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ async def add_golden_sql(
json=[golden_sql_request.dict()],
timeout=settings.default_k2_core_timeout,
)
raise_for_status(response.status_code, response.json())
raise_for_status(response.status_code, response.text)
response_json = response.json()[0]
golden_sql = GoldenSQL(**response_json)
golden_sql.id = ObjectId(response_json["id"])
Expand Down Expand Up @@ -108,7 +108,7 @@ async def delete_golden_sql(
settings.k2_core_url + f"/golden-records/{golden_id}",
timeout=settings.default_k2_core_timeout,
)
raise_for_status(response.status_code, response.json())
raise_for_status(response.status_code, response.text)
if response.json()["status"]:
if query_response_id:
matched_count = self.repo.delete_verified_golden_sql_ref(
Expand Down
2 changes: 1 addition & 1 deletion apps/ai/server/modules/organization/models/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@

class OrganizationResponse(BaseOrganization):
id: str
db_connection_id: str
db_connection_id: str | None
4 changes: 3 additions & 1 deletion apps/ai/server/modules/organization/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,5 +116,7 @@ def _get_mapped_organization_response(
) -> OrganizationResponse:
org_dict = organization.dict()
org_dict["id"] = str(org_dict["id"])
org_dict["db_connection_id"] = str(org_dict["db_connection_id"])
org_dict["db_connection_id"] = (
str(org_dict["db_connection_id"]) if org_dict["db_connection_id"] else None
)
return OrganizationResponse(**org_dict)
18 changes: 9 additions & 9 deletions apps/ai/server/modules/query/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from modules.golden_sql.models.entities import GoldenSQLSource
from modules.golden_sql.models.requests import GoldenSQLRequest
from modules.golden_sql.service import GoldenSQLService
from modules.organization.models.entities import Organization
from modules.organization.models.responses import OrganizationResponse
from modules.organization.service import OrganizationService
from modules.query.models.entities import (
QueryRef,
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(self):
self.user_service = UserService()

async def answer_question(
self, question_request: QuestionRequest, organization: Organization
self, question_request: QuestionRequest, organization: OrganizationResponse
) -> QuerySlackResponse:
question_string = remove_slack_mentions(question_request.question)

Expand All @@ -54,20 +54,20 @@ async def answer_question(
settings.k2_core_url + "/question",
json={
"question": question_string,
"db_connection_id": str(organization.db_connection_id),
"db_connection_id": organization.db_connection_id,
},
timeout=settings.default_k2_core_timeout,
)

raise_for_status(response.status_code, response.json())
raise_for_status(response.status_code, response.text)

# adds document that links user info to query response
query_response = CoreQueryResponse(**response.json())
query_id: str = response.json()["id"]["$oid"]

# if query ref doesn't exist, create one
if not self.repo.get_query_response_ref(query_id):
display_id = self.repo.get_next_display_id(str(organization.id))
display_id = self.repo.get_next_display_id(organization.id)

current_utc_time = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
username = SlackWebClient(
Expand Down Expand Up @@ -172,7 +172,7 @@ async def patch_query(
self,
query_id: str,
query_request: QueryUpdateRequest,
organization: Organization,
organization: OrganizationResponse,
user: User,
) -> QueryResponse:
is_golden_record = (
Expand All @@ -185,7 +185,7 @@ async def patch_query(
json={"sql_query": query_request.sql_query},
timeout=settings.default_k2_core_timeout,
)
raise_for_status(response.status_code, response.json())
raise_for_status(response.status_code, response.text)

new_query_response = CoreQueryResponse(**response.json())
question = self.repo.get_question(new_query_response.nl_question_id["$oid"])
Expand All @@ -200,7 +200,7 @@ async def patch_query(
)
await self.golden_sql_service.add_golden_sql(
golden_sql,
str(organization.id),
organization.id,
source=GoldenSQLSource.verified_query,
query_response_id=query_id,
)
Expand Down Expand Up @@ -244,7 +244,7 @@ async def run_query(self, query_id: str, query_request: QueryExecutionRequest):
},
timeout=settings.default_k2_core_timeout,
)
raise_for_status(response.status_code, response.json())
raise_for_status(response.status_code, response.text)
response_ref = self.repo.get_query_response_ref(query_id)
new_query_response = CoreQueryResponse(**response.json())
question = self.repo.get_question(new_query_response.nl_question_id["$oid"])
Expand Down
10 changes: 5 additions & 5 deletions apps/ai/server/modules/table_description/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def get_table_descriptions(
user = authorize.user(VerifyToken(token.credentials).verify())
organization = authorize.get_organization_by_user(user)
return await table_description_service.get_table_descriptions(
table_name, str(organization.db_connection_id)
table_name, organization.db_connection_id
)


Expand All @@ -41,16 +41,16 @@ async def get_database_table_descriptions(
user = authorize.user(VerifyToken(token.credentials).verify())
organization = authorize.get_organization_by_user(user)
return await table_description_service.get_database_table_descriptions(
str(organization.db_connection_id)
organization.db_connection_id
)


@router.post("/scan", status_code=status.HTTP_201_CREATED)
async def scan_table_descriptions(
@router.post("/sync-schemas", status_code=status.HTTP_201_CREATED)
async def sync_table_descriptions_schemas(
scan_request: ScanRequest, token: str = Depends(token_auth_scheme)
):
authorize.user(VerifyToken(token.credentials).verify())
return await table_description_service.scan_table_descriptions(scan_request)
return await table_description_service.sync_table_descriptions_schemas(scan_request)


@router.patch("/{id}", status_code=status.HTTP_200_OK)
Expand Down
7 changes: 7 additions & 0 deletions apps/ai/server/modules/table_description/models/entities.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from enum import Enum
from typing import Any

from pydantic import BaseModel, Field


class SchemaStatus(Enum):
NOT_SYNCHRONIZED = "NOT_SYNCHRONIZED"
SYNCHRONIZING = "SYNCHRONIZING"
SYNCHRONIZED = "SYNCHRONIZED"


class ColumnDescription(BaseModel):
name: str | None
description: str | None
Expand Down
6 changes: 5 additions & 1 deletion apps/ai/server/modules/table_description/models/responses.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from pydantic import BaseModel

from modules.table_description.models.entities import BaseTableDescription
from modules.table_description.models.entities import BaseTableDescription, SchemaStatus


class TableDescriptionResponse(BaseTableDescription):
id: str
table_name: str
schemas_status: SchemaStatus | None
last_schemas_sync: str | None


class BasicTableDescriptionResponse(BaseModel):
id: str
name: str | None
columns: list[str] | None
schemas_status: SchemaStatus | None
last_schemas_sync: str | None


class DatabaseDescriptionResponse(BaseModel):
Expand Down
Loading

0 comments on commit 140aab7

Please sign in to comment.