diff --git a/README.md b/README.md index 1f8bda3b..9185e9bb 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ Dataherald is a natural language-to-SQL engine built for enterprise-level questi This project is undergoing swift development, and as such, the API may be subject to change at any time. +If you would like to learn more, you can join the Discord or read the docs. + ## Overview ### Background @@ -162,7 +164,7 @@ curl -X 'POST' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_alias": "my_db_alias_identifier", + "alias": "my_db_alias", "use_ssh": false, "connection_uri": "sqlite:///mydb.db", "path_to_credentials_file": "my-folder/my-secret.json" # Required for bigquery @@ -176,7 +178,7 @@ curl -X 'POST' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_alias": "my_db_alias_identifier", + "alias": "my_db_alias", "use_ssh": true, "ssh_settings": { "db_name": "db_name", @@ -254,7 +256,7 @@ curl -X 'POST' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_alias": "db_name", + "db_connection_id": "db_connection_id", "table_name": "table_name" }' ``` @@ -267,7 +269,7 @@ Once a database was scanned you can use this endpoint to retrieve the tables nam ``` curl -X 'GET' \ - '/api/v1/scanned-databases?db_alias=databricks' \ + '/api/v1/scanned-databases?db_connection_id=64dfa0e103f5134086f7090c' \ -H 'accept: application/json' ``` @@ -289,11 +291,11 @@ curl -X 'POST' \ ``` #### Adding string descriptions -In addition to database table_info and golden_sql, you can add strings describing tables and/or columns to the context store manually from the `PATCH /api/v1/scanned-db/{db_name}/{table_name}` endpoint +In addition to database table_info and golden_sql, you can add strings describing tables and/or columns to the context store manually from the `PATCH /api/v1/scanned-db/{db_connection_id}/{table_name}` endpoint ``` curl -X 'PATCH' \ - '/api/v1/scanned-db/db_name/table_name' \ + '/api/v1/scanned-db/db_connection_id/table_name' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ @@ -322,7 +324,7 @@ curl -X 'POST' \ -H 'Content-Type: application/json' \ -d '{ "question": "Your question in natural language", - "db_alias": "db_name" + "db_connection_id": "db_connection_id" }' ``` diff --git a/dataherald/api/__init__.py b/dataherald/api/__init__.py index ea14d551..dcfa9b19 100644 --- a/dataherald/api/__init__.py +++ b/dataherald/api/__init__.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod -from typing import Any, List +from typing import List from dataherald.api.types import Query from dataherald.config import Component -from dataherald.eval import Evaluation +from dataherald.db_scanner.models.types import TableSchemaDetail from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings from dataherald.types import ( DatabaseConnectionRequest, @@ -12,7 +12,6 @@ GoldenRecordRequest, NLQueryResponse, QuestionRequest, - ScannedDBResponse, ScannerRequest, TableDescriptionRequest, UpdateQueryRequest, @@ -34,42 +33,57 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: pass @abstractmethod - def connect_database( + def create_database_connection( self, database_connection_request: DatabaseConnectionRequest ) -> DatabaseConnection: pass @abstractmethod - def add_description( + def list_database_connections(self) -> list[DatabaseConnection]: + pass + + @abstractmethod + def update_database_connection( + self, + db_connection_id: str, + database_connection_request: DatabaseConnectionRequest, + ) -> DatabaseConnection: + pass + + @abstractmethod + def update_table_description( self, - db_name: str, - table_name: str, + table_description_id: str, table_description_request: TableDescriptionRequest, - ) -> bool: + ) -> TableSchemaDetail: pass @abstractmethod - def add_golden_records( - self, golden_records: List[GoldenRecordRequest] - ) -> List[GoldenRecord]: + def list_table_descriptions( + self, db_connection_id: str | None = None, table_name: str | None = None + ) -> list[TableSchemaDetail]: pass @abstractmethod - def execute_query(self, query: Query) -> tuple[str, dict]: + def add_golden_records( + self, golden_records: List[GoldenRecordRequest] + ) -> List[GoldenRecord]: pass @abstractmethod - def update_query(self, query_id: str, query: UpdateQueryRequest) -> NLQueryResponse: + def execute_sql_query(self, query: Query) -> tuple[str, dict]: pass @abstractmethod - def execute_temp_query( - self, query_id: str, query: ExecuteTempQueryRequest + def update_nl_query_response( + self, query_id: str, query: UpdateQueryRequest ) -> NLQueryResponse: pass @abstractmethod - def get_scanned_databases(self, db_alias: str) -> ScannedDBResponse: + def get_nl_query_response( + self, query_request: ExecuteTempQueryRequest + ) -> NLQueryResponse: pass @abstractmethod diff --git a/dataherald/api/fastapi.py b/dataherald/api/fastapi.py index e080d61c..38ec1ca1 100644 --- a/dataherald/api/fastapi.py +++ b/dataherald/api/fastapi.py @@ -13,9 +13,11 @@ from dataherald.context_store import ContextStore from dataherald.db import DB from dataherald.db_scanner import Scanner +from dataherald.db_scanner.models.types import TableSchemaDetail from dataherald.db_scanner.repository.base import DBScannerRepository from dataherald.eval import Evaluator from dataherald.repositories.base import NLQueryResponseRepository +from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.golden_records import GoldenRecordRepository from dataherald.repositories.nl_question import NLQuestionRepository from dataherald.sql_database.base import SQLDatabase, SQLInjectionError @@ -30,8 +32,6 @@ NLQuery, NLQueryResponse, QuestionRequest, - ScannedDBResponse, - ScannedDBTable, ScannerRequest, TableDescriptionRequest, UpdateQueryRequest, @@ -53,27 +53,29 @@ def heartbeat(self) -> int: @override def scan_db(self, scanner_request: ScannerRequest) -> bool: - """Takes a db_alias and scan all the tables columns""" - db_connection = self.storage.find_one( - "database_connection", {"alias": scanner_request.db_alias} + """Takes a db_connection_id and scan all the tables columns""" + db_connection_repository = DatabaseConnectionRepository(self.storage) + + db_connection = db_connection_repository.find_by_id( + scanner_request.db_connection_id ) if not db_connection: raise HTTPException(status_code=404, detail="Database connection not found") - database_connection = DatabaseConnection(**db_connection) + try: - database = SQLDatabase.get_sql_engine(database_connection) + database = SQLDatabase.get_sql_engine(db_connection) except Exception as e: raise HTTPException( # noqa: B904 status_code=400, - detail=f"Unable to connect to db: {scanner_request.db_alias}, {e}", + detail=f"Unable to connect to db: {scanner_request.db_connection_id}, {e}", ) scanner = self.system.instance(Scanner) try: scanner.scan( database, - scanner_request.db_alias, - scanner_request.table_name, + scanner_request.db_connection_id, + scanner_request.table_names, DBScannerRepository(self.storage), ) except ValueError as e: @@ -89,19 +91,19 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: context_store = self.system.instance(ContextStore) user_question = NLQuery( - question=question_request.question, db_alias=question_request.db_alias + question=question_request.question, + db_connection_id=question_request.db_connection_id, ) nl_question_repository = NLQuestionRepository(self.storage) user_question = nl_question_repository.insert(user_question) - db_connection = self.storage.find_one( - "database_connection", {"alias": question_request.db_alias} + db_connection_repository = DatabaseConnectionRepository(self.storage) + database_connection = db_connection_repository.find_by_id( + question_request.db_connection_id ) - if not db_connection: + if not database_connection: raise HTTPException(status_code=404, detail="Database connection not found") - database_connection = DatabaseConnection(**db_connection) - context = context_store.retrieve_context_for_question(user_question) start_generated_answer = time.time() try: @@ -123,38 +125,55 @@ def answer_question(self, question_request: QuestionRequest) -> NLQueryResponse: return json.loads(json_util.dumps(nl_query_response)) @override - def connect_database( + def create_database_connection( self, database_connection_request: DatabaseConnectionRequest ) -> DatabaseConnection: try: db_connection = DatabaseConnection( + alias=database_connection_request.alias, uri=database_connection_request.connection_uri, path_to_credentials_file=database_connection_request.path_to_credentials_file, - alias=database_connection_request.db_alias, use_ssh=database_connection_request.use_ssh, ssh_settings=database_connection_request.ssh_settings, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) # noqa: B904 - db_connection.id = str( - self.storage.update_or_create( - "database_connection", - {"alias": database_connection_request.db_alias}, - db_connection.dict(), - ) - ) + db_connection_repository = DatabaseConnectionRepository(self.storage) + return db_connection_repository.insert(db_connection) + + @override + def list_database_connections(self) -> list[DatabaseConnection]: + db_connection_repository = DatabaseConnectionRepository(self.storage) + return db_connection_repository.find_all() - return db_connection + @override + def update_database_connection( + self, + db_connection_id: str, + database_connection_request: DatabaseConnectionRequest, + ) -> DatabaseConnection: + try: + db_connection = DatabaseConnection( + id=db_connection_id, + alias=database_connection_request.alias, + uri=database_connection_request.connection_uri, + path_to_credentials_file=database_connection_request.path_to_credentials_file, + use_ssh=database_connection_request.use_ssh, + ssh_settings=database_connection_request.ssh_settings, + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) # noqa: B904 + db_connection_repository = DatabaseConnectionRepository(self.storage) + return db_connection_repository.update(db_connection) @override - def add_description( + def update_table_description( self, - db_name: str, - table_name: str, + table_description_id: str, table_description_request: TableDescriptionRequest, - ) -> bool: + ) -> TableSchemaDetail: scanner_repository = DBScannerRepository(self.storage) - table = scanner_repository.get_table_info(db_name, table_name) + table = scanner_repository.find_by_id(table_description_id) if not table: raise HTTPException( @@ -169,8 +188,16 @@ def add_description( if column_request.name == column.name: column.description = column_request.description - scanner_repository.update(table) - return True + return scanner_repository.update(table) + + @override + def list_table_descriptions( + self, db_connection_id: str | None = None, table_name: str | None = None + ) -> list[TableSchemaDetail]: + scanner_repository = DBScannerRepository(self.storage) + return scanner_repository.find_by( + {"db_connection_id": db_connection_id, "table_name": table_name} + ) @override def add_golden_records( @@ -181,23 +208,23 @@ def add_golden_records( return context_store.add_golden_records(golden_records) @override - def execute_query(self, query: Query) -> tuple[str, dict]: + def execute_sql_query(self, query: Query) -> tuple[str, dict]: """Executes a SQL query against the database and returns the results""" - db_connection = self.storage.find_one( - "database_connection", {"alias": query.db_alias} + db_connection_repository = DatabaseConnectionRepository(self.storage) + database_connection = db_connection_repository.find_by_id( + query.db_connection_id ) - if not db_connection: + if not database_connection: raise HTTPException(status_code=404, detail="Database connection not found") - database_connection = DatabaseConnection(**db_connection) database = SQLDatabase.get_sql_engine(database_connection) try: - result = database.run_sql(query.sql_statement) + result = database.run_sql(query.sql_query) except SQLInjectionError as e: raise HTTPException(status_code=404, detail=str(e)) from e return result @override - def update_query( + def update_nl_query_response( self, query_id: str, query: UpdateQueryRequest # noqa: ARG002 ) -> NLQueryResponse: nl_query_response_repository = NLQueryResponseRepository(self.storage) @@ -209,14 +236,14 @@ def update_query( if nl_query_response.sql_query.strip() != query.sql_query.strip(): nl_query_response.sql_query = query.sql_query evaluator = self.system.instance(Evaluator) - db_connection = self.storage.find_one( - "database_connection", {"alias": nl_question.db_alias} + db_connection_repository = DatabaseConnectionRepository(self.storage) + database_connection = db_connection_repository.find_by_id( + nl_question.db_connection_id ) - if not db_connection: + if not database_connection: raise HTTPException( status_code=404, detail="Database connection not found" ) - database_connection = DatabaseConnection(**db_connection) try: confidence_score = evaluator.get_confidence_score( nl_question, nl_query_response, database_connection @@ -230,12 +257,14 @@ def update_query( return json.loads(json_util.dumps(nl_query_response)) @override - def execute_temp_query( - self, query_id: str, query: ExecuteTempQueryRequest # noqa: ARG002 + def get_nl_query_response( + self, query_request: ExecuteTempQueryRequest # noqa: ARG002 ) -> NLQueryResponse: nl_query_response_repository = NLQueryResponseRepository(self.storage) - nl_query_response = nl_query_response_repository.find_by_id(query_id) - nl_query_response.sql_query = query.sql_query + nl_query_response = nl_query_response_repository.find_by_id( + query_request.query_id + ) + nl_query_response.sql_query = query_request.sql_query try: generates_nl_answer = GeneratesNlAnswer(self.system, self.storage) nl_query_response = generates_nl_answer.execute(nl_query_response) @@ -243,24 +272,6 @@ def execute_temp_query( raise HTTPException(status_code=404, detail=str(e)) from e return json.loads(json_util.dumps(nl_query_response)) - @override - def get_scanned_databases(self, db_alias: str) -> ScannedDBResponse: - scanner_repository = DBScannerRepository(self.storage) - tables = scanner_repository.get_all_tables_by_db(db_alias) - process_tables = [] - for table in tables: - process_tables.append( - ScannedDBTable( - id=str(table.id), - name=table.table_name, - columns=[column.name for column in table.columns], - ) - ) - scanned_db_response = ScannedDBResponse( - db_alias=db_alias, tables=process_tables - ) - return json.loads(json_util.dumps(scanned_db_response)) - @override def delete_golden_record(self, golden_record_id: str) -> dict: context_store = self.system.instance(ContextStore) diff --git a/dataherald/api/types.py b/dataherald/api/types.py index 74e417e0..81d5173d 100644 --- a/dataherald/api/types.py +++ b/dataherald/api/types.py @@ -1,6 +1,5 @@ -from pydantic import BaseModel +from dataherald.types import DBConnectionValidation -class Query(BaseModel): - sql_statement: str - db_alias: str +class Query(DBConnectionValidation): + sql_query: str diff --git a/dataherald/context_store/default.py b/dataherald/context_store/default.py index f73ef5bc..b73de25f 100644 --- a/dataherald/context_store/default.py +++ b/dataherald/context_store/default.py @@ -23,7 +23,7 @@ def retrieve_context_for_question( logger.info(f"Getting context for {nl_question.question}") closest_questions = self.vector_store.query( query_texts=[nl_question.question], - db_alias=nl_question.db_alias, + db_connection_id=nl_question.db_connection_id, collection=self.golden_record_collection, num_results=number_of_samples, ) @@ -58,7 +58,7 @@ def add_golden_records( golden_record = GoldenRecord( question=question, sql_query=record.sql_query, - db_alias=record.db_alias, + db_connection_id=record.db_connection_id, ) retruned_golden_records.append(golden_record) golden_record = golden_records_repository.insert(golden_record) @@ -66,7 +66,10 @@ def add_golden_records( documents=question, collection=self.golden_record_collection, metadata=[ - {"tables_used": tables[0], "db_alias": record.db_alias} + { + "tables_used": tables[0], + "db_connection_id": record.db_connection_id, + } ], # this should be updated for multiple tables ids=[str(golden_record.id)], ) diff --git a/dataherald/db/__init__.py b/dataherald/db/__init__.py index 0961ec09..7d4adedb 100644 --- a/dataherald/db/__init__.py +++ b/dataherald/db/__init__.py @@ -12,6 +12,10 @@ def __init__(self, system: System): def insert_one(self, collection: str, obj: dict) -> int: pass + @abstractmethod + def rename(self, old_collection_name: str, new_collection_name) -> None: + pass + @abstractmethod def update_or_create(self, collection: str, query: dict, obj: dict) -> int: pass diff --git a/dataherald/db/mongo.py b/dataherald/db/mongo.py index b04cd15e..fec4146e 100644 --- a/dataherald/db/mongo.py +++ b/dataherald/db/mongo.py @@ -23,6 +23,10 @@ def find_one(self, collection: str, query: dict) -> dict: def insert_one(self, collection: str, obj: dict) -> int: return self._data_store[collection].insert_one(obj).inserted_id + @override + def rename(self, old_collection_name: str, new_collection_name) -> None: + self._data_store[old_collection_name].rename(new_collection_name) + @override def update_or_create(self, collection: str, query: dict, obj: dict) -> int: row = self.find_one(collection, query) diff --git a/dataherald/db_scanner/__init__.py b/dataherald/db_scanner/__init__.py index fac318ad..6308a023 100644 --- a/dataherald/db_scanner/__init__.py +++ b/dataherald/db_scanner/__init__.py @@ -12,8 +12,8 @@ class Scanner(Component, ABC): def scan( self, db_engine: SQLDatabase, - db_alias: str, - table_name: str | None, + db_connection_id: str, + table_names: list[str] | None, repository: DBScannerRepository, ) -> None: """ "Scan a db""" diff --git a/dataherald/db_scanner/models/types.py b/dataherald/db_scanner/models/types.py index 7d6c3f98..62efd2fc 100644 --- a/dataherald/db_scanner/models/types.py +++ b/dataherald/db_scanner/models/types.py @@ -20,7 +20,7 @@ class ColumnDetail(BaseModel): class TableSchemaDetail(BaseModel): id: Any - db_alias: str + db_connection_id: str table_name: str description: str | None table_schema: str | None diff --git a/dataherald/db_scanner/repository/base.py b/dataherald/db_scanner/repository/base.py index cebae0c2..f3a86fd9 100644 --- a/dataherald/db_scanner/repository/base.py +++ b/dataherald/db_scanner/repository/base.py @@ -4,26 +4,35 @@ from dataherald.db_scanner.models.types import TableSchemaDetail -DB_COLLECTION = "table_schema_detail" +DB_COLLECTION = "table_descriptions" class DBScannerRepository: def __init__(self, storage): self.storage = storage + def find_by_id(self, id: str) -> TableSchemaDetail | None: + row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) + if not row: + return None + obj = TableSchemaDetail(**row) + obj.id = str(row["_id"]) + return obj + def get_table_info( - self, db_alias: str, table_name: str + self, db_connection_id: str, table_name: str ) -> TableSchemaDetail | None: row = self.storage.find_one( - DB_COLLECTION, {"db_alias": db_alias, "table_name": table_name} + DB_COLLECTION, + {"db_connection_id": db_connection_id, "table_name": table_name}, ) if row: row["id"] = row["_id"] return TableSchemaDetail(**row) return None - def get_all_tables_by_db(self, db_alias: str) -> List[TableSchemaDetail]: - rows = self.storage.find(DB_COLLECTION, {"db_alias": db_alias}) + def get_all_tables_by_db(self, db_connection_id: str) -> List[TableSchemaDetail]: + rows = self.storage.find(DB_COLLECTION, {"db_connection_id": db_connection_id}) tables = [] for row in rows: row["id"] = row["_id"] @@ -33,7 +42,10 @@ def get_all_tables_by_db(self, db_alias: str) -> List[TableSchemaDetail]: def save_table_info(self, table_info: TableSchemaDetail) -> None: self.storage.update_or_create( DB_COLLECTION, - {"db_alias": table_info.db_alias, "table_name": table_info.table_name}, + { + "db_connection_id": table_info.db_connection_id, + "table_name": table_info.table_name, + }, table_info.dict(), ) @@ -44,3 +56,22 @@ def update(self, table_info: TableSchemaDetail) -> TableSchemaDetail: table_info.dict(exclude={"id"}), ) return table_info + + def find_all(self) -> list[TableSchemaDetail]: + rows = self.storage.find_all(DB_COLLECTION) + result = [] + for row in rows: + obj = TableSchemaDetail(**row) + obj.id = str(row["_id"]) + result.append(obj) + return result + + def find_by(self, query: dict) -> list[TableSchemaDetail]: + query = {k: v for k, v in query.items() if v} + rows = self.storage.find(DB_COLLECTION, query) + result = [] + for row in rows: + obj = TableSchemaDetail(**row) + obj.id = str(row["_id"]) + result.append(obj) + return result diff --git a/dataherald/db_scanner/sqlalchemy.py b/dataherald/db_scanner/sqlalchemy.py index 54812bf6..1fa6c058 100644 --- a/dataherald/db_scanner/sqlalchemy.py +++ b/dataherald/db_scanner/sqlalchemy.py @@ -117,7 +117,7 @@ def scan_single_table( meta: MetaData, table: str, db_engine: SQLDatabase, - db_alias: str, + db_connection_id: str, repository: DBScannerRepository, ) -> TableSchemaDetail: print(f"Scanning table: {table}") @@ -135,7 +135,7 @@ def scan_single_table( ) object = TableSchemaDetail( - db_alias=db_alias, + db_connection_id=db_connection_id, table_name=table, columns=table_columns, table_schema=self.get_table_schema( @@ -153,19 +153,18 @@ def scan_single_table( def scan( self, db_engine: SQLDatabase, - db_alias: str, - table_name: str | None, + db_connection_id: str, + table_names: list[str] | None, repository: DBScannerRepository, ) -> None: inspector = inspect(db_engine.engine) meta = MetaData(bind=db_engine.engine) MetaData.reflect(meta, views=True) tables = inspector.get_table_names() + inspector.get_view_names() - if table_name: + if table_names: + table_names = [table.lower() for table in table_names] tables = [ - table - for table in tables - if table and table.lower() == table_name.lower() + table for table in tables if table and table.lower() in table_names ] if len(tables) == 0: raise ValueError("No table found") @@ -175,7 +174,7 @@ def scan( meta=meta, table=table, db_engine=db_engine, - db_alias=db_alias, + db_connection_id=db_connection_id, repository=repository, ) result.append(obj) diff --git a/dataherald/repositories/base.py b/dataherald/repositories/base.py index 78bc2408..d4c7caff 100644 --- a/dataherald/repositories/base.py +++ b/dataherald/repositories/base.py @@ -2,7 +2,7 @@ from dataherald.types import NLQueryResponse -DB_COLLECTION = "nl_query_response" +DB_COLLECTION = "nl_query_responses" class NLQueryResponseRepository: diff --git a/dataherald/repositories/database_connections.py b/dataherald/repositories/database_connections.py new file mode 100644 index 00000000..af2019bb --- /dev/null +++ b/dataherald/repositories/database_connections.py @@ -0,0 +1,49 @@ +from bson.objectid import ObjectId + +from dataherald.sql_database.models.types import DatabaseConnection + +DB_COLLECTION = "database_connections" + + +class DatabaseConnectionRepository: + def __init__(self, storage): + self.storage = storage + + def insert(self, database_connection: DatabaseConnection) -> DatabaseConnection: + database_connection.id = str( + self.storage.insert_one( + DB_COLLECTION, database_connection.dict(exclude={"id"}) + ) + ) + return database_connection + + def find_one(self, query: dict) -> DatabaseConnection | None: + row = self.storage.find_one(DB_COLLECTION, query) + if not row: + return None + return DatabaseConnection(**row) + + def update(self, database_connection: DatabaseConnection) -> DatabaseConnection: + self.storage.update_or_create( + DB_COLLECTION, + {"_id": ObjectId(database_connection.id)}, + database_connection.dict(exclude={"id"}), + ) + return database_connection + + def find_by_id(self, id: str) -> DatabaseConnection | None: + row = self.storage.find_one(DB_COLLECTION, {"_id": ObjectId(id)}) + if not row: + return None + obj = DatabaseConnection(**row) + obj.id = str(row["_id"]) + return obj + + def find_all(self) -> list[DatabaseConnection]: + rows = self.storage.find_all(DB_COLLECTION) + result = [] + for row in rows: + obj = DatabaseConnection(**row) + obj.id = str(row["_id"]) + result.append(obj) + return result diff --git a/dataherald/repositories/nl_question.py b/dataherald/repositories/nl_question.py index 1bdf2ab6..65edcee6 100644 --- a/dataherald/repositories/nl_question.py +++ b/dataherald/repositories/nl_question.py @@ -2,7 +2,7 @@ from dataherald.types import NLQuery -DB_COLLECTION = "nl_question" +DB_COLLECTION = "nl_questions" class NLQuestionRepository: diff --git a/dataherald/scripts/__init__.py b/dataherald/scripts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/dataherald/scripts/migrate_v001_to_v002.py b/dataherald/scripts/migrate_v001_to_v002.py new file mode 100644 index 00000000..eee399e2 --- /dev/null +++ b/dataherald/scripts/migrate_v001_to_v002.py @@ -0,0 +1,69 @@ +import os + +from sql_metadata import Parser + +import dataherald.config +from dataherald.config import System +from dataherald.db import DB +from dataherald.vector_store import VectorStore + + +def add_db_connection_id(collection_name: str, storage) -> None: + collection_rows = storage.find_all(collection_name) + for collection_row in collection_rows: + if "db_alias" not in collection_row: + continue + database_connection = storage.find_one( + "database_connection", {"alias": collection_row["db_alias"]} + ) + if not database_connection: + continue + collection_row["db_connection_id"] = str(database_connection["_id"]) + # update object + storage.update_or_create( + collection_name, {"_id": collection_row["_id"]}, collection_row + ) + + +if __name__ == "__main__": + settings = dataherald.config.Settings() + system = System(settings) + system.start() + storage = system.instance(DB) + # Update relations + add_db_connection_id("table_schema_detail", storage) + add_db_connection_id("golden_records", storage) + add_db_connection_id("nl_question", storage) + # Refresh vector stores + golden_record_collection = os.environ.get( + "GOLDEN_RECORD_COLLECTION", "dataherald-staging" + ) + vector_store = system.instance(VectorStore) + try: + vector_store.delete_collection(golden_record_collection) + except Exception: # noqa: S110 + pass + # Upload golden records + golden_records = storage.find_all("golden_records") + for golden_record in golden_records: + tables = Parser(golden_record["sql_query"]).tables + question = golden_record["question"] + vector_store.add_record( + documents=question, + collection=golden_record_collection, + metadata=[ + { + "tables_used": tables[0], + "db_connection_id": golden_record["db_connection_id"], + } + ], # this should be updated for multiple tables + ids=[str(golden_record["_id"])], + ) + # Re-name collections + try: + storage.rename("nl_query_response", "nl_query_responses") + storage.rename("nl_question", "nl_questions") + storage.rename("database_connection", "database_connections") + storage.rename("table_schema_detail", "table_descriptions") + except Exception: # noqa: S110 + pass diff --git a/dataherald/server/fastapi/__init__.py b/dataherald/server/fastapi/__init__.py index 1585eff8..6f35fdd8 100644 --- a/dataherald/server/fastapi/__init__.py +++ b/dataherald/server/fastapi/__init__.py @@ -1,16 +1,15 @@ -from typing import Any, List, Union +from typing import Any, List import fastapi from fastapi import FastAPI as _FastAPI from fastapi import status -from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.routing import APIRoute import dataherald from dataherald.api.types import Query from dataherald.config import Settings -from dataherald.eval import Evaluation +from dataherald.db_scanner.models.types import TableSchemaDetail from dataherald.sql_database.models.types import DatabaseConnection, SSHSettings from dataherald.types import ( DatabaseConnectionRequest, @@ -19,7 +18,6 @@ GoldenRecordRequest, NLQueryResponse, QuestionRequest, - ScannedDBResponse, ScannerRequest, TableDescriptionRequest, UpdateQueryRequest, @@ -46,53 +44,98 @@ def __init__(self, settings: Settings): self.router = fastapi.APIRouter() self.router.add_api_route( - "/api/v1/question", self.answer_question, methods=["POST"] + "/api/v1/database-connections", + self.create_database_connection, + methods=["POST"], + tags=["Database connections"], ) - self.router.add_api_route("/api/v1/scanner", self.scan_db, methods=["POST"]) - - self.router.add_api_route("/api/v1/heartbeat", self.heartbeat, methods=["GET"]) - self.router.add_api_route( - "/api/v1/database", self.connect_database, methods=["POST"] + "/api/v1/database-connections", + self.list_database_connections, + methods=["GET"], + tags=["Database connections"], ) self.router.add_api_route( - "/api/v1/scanned-db/{db_name}/{table_name}", - self.add_description, - methods=["PATCH"], + "/api/v1/database-connections/{db_connection_id}", + self.update_database_connection, + methods=["PUT"], + tags=["Database connections"], ) - self.router.add_api_route("/api/v1/query", self.execute_query, methods=["POST"]) + self.router.add_api_route( + "/api/v1/table-descriptions/scan", + self.scan_db, + methods=["POST"], + tags=["Table descriptions"], + ) self.router.add_api_route( - "/api/v1/query/{query_id}", self.update_query, methods=["PATCH"] + "/api/v1/table-descriptions/{table_description_id}", + self.update_table_description, + methods=["PATCH"], + tags=["Table descriptions"], ) self.router.add_api_route( - "/api/v1/query/{query_id}/execution", - self.execute_temp_query, - methods=["POST"], + "/api/v1/table-descriptions", + self.list_table_descriptions, + methods=["GET"], + tags=["Table descriptions"], ) self.router.add_api_route( "/api/v1/golden-records/{golden_record_id}", self.delete_golden_record, methods=["DELETE"], + tags=["Golden records"], ) self.router.add_api_route( "/api/v1/golden-records", self.add_golden_records, methods=["POST"], + tags=["Golden records"], + ) + + self.router.add_api_route( + "/api/v1/golden-records", + self.get_golden_records, + methods=["GET"], + tags=["Golden records"], + ) + + self.router.add_api_route( + "/api/v1/question", + self.answer_question, + methods=["POST"], + tags=["Question"], + ) + + self.router.add_api_route( + "/api/v1/nl-query-responses", + self.get_nl_query_response, + methods=["POST"], + tags=["NL query responses"], ) self.router.add_api_route( - "/api/v1/golden-records", self.get_golden_records, methods=["GET"] + "/api/v1/nl-query-responses/{query_id}", + self.update_nl_query_response, + methods=["PATCH"], + tags=["NL query responses"], ) self.router.add_api_route( - "/api/v1/scanned-databases", self.get_scanned_databases, methods=["GET"] + "/api/v1/sql-query-executions", + self.execute_sql_query, + methods=["POST"], + tags=["SQL queries"], + ) + + self.router.add_api_route( + "/api/v1/heartbeat", self.heartbeat, methods=["GET"], tags=["System"] ) self._app.include_router(self.router) @@ -113,34 +156,57 @@ def root(self) -> dict[str, int]: def heartbeat(self) -> dict[str, int]: return self.root() - def connect_database( + def create_database_connection( self, database_connection_request: DatabaseConnectionRequest ) -> DatabaseConnection: - """Connects a database to the Dataherald service""" - return self._api.connect_database(database_connection_request) + """Creates a database connection""" + return self._api.create_database_connection(database_connection_request) + + def list_database_connections(self) -> list[DatabaseConnection]: + """List all database connections""" + return self._api.list_database_connections() + + def update_database_connection( + self, + db_connection_id: str, + database_connection_request: DatabaseConnectionRequest, + ) -> DatabaseConnection: + """Creates a database connection""" + return self._api.update_database_connection( + db_connection_id, database_connection_request + ) - def add_description( + def update_table_description( self, - db_name: str, - table_name: str, + table_description_id: str, table_description_request: TableDescriptionRequest, - ) -> bool: + ) -> TableSchemaDetail: """Add descriptions for tables and columns""" - return self._api.add_description(db_name, table_name, table_description_request) + return self._api.update_table_description( + table_description_id, table_description_request + ) - def execute_query(self, query: Query) -> tuple[str, dict]: - """Executes a query on the given db_alias""" - return self._api.execute_query(query) + def list_table_descriptions( + self, db_connection_id: str | None = None, table_name: str | None = None + ) -> list[TableSchemaDetail]: + """List table descriptions""" + return self._api.list_table_descriptions(db_connection_id, table_name) - def update_query(self, query_id: str, query: UpdateQueryRequest) -> NLQueryResponse: - """Executes a query on the given db_alias""" - return self._api.update_query(query_id, query) + def execute_sql_query(self, query: Query) -> tuple[str, dict]: + """Executes a query on the given db_connection_id""" + return self._api.execute_sql_query(query) - def execute_temp_query( - self, query_id: str, query: ExecuteTempQueryRequest + def update_nl_query_response( + self, query_id: str, query: UpdateQueryRequest ) -> NLQueryResponse: - """Executes a query on the given db_alias""" - return self._api.execute_temp_query(query_id, query) + """Executes a query on the given db_connection_id""" + return self._api.update_nl_query_response(query_id, query) + + def get_nl_query_response( + self, query_request: ExecuteTempQueryRequest + ) -> NLQueryResponse: + """Executes a query on the given db_connection_id""" + return self._api.get_nl_query_response(query_request) def delete_golden_record(self, golden_record_id: str) -> dict: """Deletes a golden record""" @@ -161,7 +227,3 @@ def add_golden_records( def get_golden_records(self, page: int = 1, limit: int = 10) -> List[GoldenRecord]: """Gets golden records""" return self._api.get_golden_records(page, limit) - - def get_scanned_databases(self, db_alias: str) -> ScannedDBResponse: - """Gets golden records""" - return self._api.get_scanned_databases(db_alias) diff --git a/dataherald/sql_database/base.py b/dataherald/sql_database/base.py index e9755630..ebeadc31 100644 --- a/dataherald/sql_database/base.py +++ b/dataherald/sql_database/base.py @@ -63,14 +63,14 @@ def from_uri( @classmethod def get_sql_engine(cls, database_info: DatabaseConnection) -> "SQLDatabase": - logger.info(f"Connecting db: {database_info.alias}") - if database_info.alias in DBConnections.db_connections: - return DBConnections.db_connections[database_info.alias] + logger.info(f"Connecting db: {database_info.id}") + if database_info.id in DBConnections.db_connections: + return DBConnections.db_connections[database_info.id] fernet_encrypt = FernetEncrypt() if database_info.use_ssh: engine = cls.from_uri_ssh(database_info) - DBConnections.add(database_info.alias, engine) + DBConnections.add(database_info.id, engine) return engine db_uri = unquote(fernet_encrypt.decrypt(database_info.uri)) if db_uri.lower().startswith("bigquery"): @@ -82,7 +82,7 @@ def get_sql_engine(cls, database_info: DatabaseConnection) -> "SQLDatabase": db_uri = db_uri + f"?credentials_path={file_path}" engine = cls.from_uri(db_uri) - DBConnections.add(database_info.alias, engine) + DBConnections.add(database_info.id, engine) return engine @classmethod diff --git a/dataherald/sql_generator/dataherald_sqlagent.py b/dataherald/sql_generator/dataherald_sqlagent.py index 2e775903..889641a1 100644 --- a/dataherald/sql_generator/dataherald_sqlagent.py +++ b/dataherald/sql_generator/dataherald_sqlagent.py @@ -564,7 +564,9 @@ def generate_response( context_store = self.system.instance(ContextStore) storage = self.system.instance(DB) repository = DBScannerRepository(storage) - db_scan = repository.get_all_tables_by_db(db_alias=database_connection.alias) + db_scan = repository.get_all_tables_by_db( + db_connection_id=database_connection.id + ) if not db_scan: raise ValueError("No scanned tables found for database") few_shot_examples = context_store.retrieve_context_for_question( diff --git a/dataherald/sql_generator/generates_nl_answer.py b/dataherald/sql_generator/generates_nl_answer.py index 404819eb..89fbf1e0 100644 --- a/dataherald/sql_generator/generates_nl_answer.py +++ b/dataherald/sql_generator/generates_nl_answer.py @@ -6,9 +6,9 @@ ) from dataherald.model.chat_model import ChatModel +from dataherald.repositories.database_connections import DatabaseConnectionRepository from dataherald.repositories.nl_question import NLQuestionRepository from dataherald.sql_database.base import SQLDatabase -from dataherald.sql_database.models.types import DatabaseConnection from dataherald.sql_generator.create_sql_query_status import create_sql_query_status from dataherald.types import NLQueryResponse @@ -36,10 +36,10 @@ def execute(self, nl_query_response: NLQueryResponse) -> NLQueryResponse: nl_query_response.nl_question_id ) - db_connection = self.storage.find_one( - "database_connection", {"alias": nl_question.db_alias} + db_connection_repository = DatabaseConnectionRepository(self.storage) + database_connection = db_connection_repository.find_by_id( + nl_question.db_connection_id ) - database_connection = DatabaseConnection(**db_connection) database = SQLDatabase.get_sql_engine(database_connection) nl_query_response = create_sql_query_status( database, nl_query_response.sql_query, nl_query_response diff --git a/dataherald/tests/db/test_db.py b/dataherald/tests/db/test_db.py index bc2365c4..684a62d7 100644 --- a/dataherald/tests/db/test_db.py +++ b/dataherald/tests/db/test_db.py @@ -10,9 +10,10 @@ class TestDB(DB): def __init__(self, system: System): super().__init__(system) self.memory = {} - self.memory["database_connection"] = [ + self.memory["database_connections"] = [ { - "alias": "foo", + "_id": "64dfa0e103f5134086f7090c", + "alias": "alias", "use_ssh": False, "uri": "gAAAAABkwD9Y9EpBxF1hRxhovjvedX1TeDNu-WaGqDebk_CJnpGjRlpXzDOl_puehMSbz9KDQ6OqPepl8XQpD0EchiV7he4j5tEXYE33eak87iORA7s8ko0=", # noqa: E501 "ssh_settings": None, @@ -70,3 +71,7 @@ def delete_by_id(self, collection: str, id: str) -> int: del collection[i] return 1 return 0 + + @override + def rename(self, old_collection_name: str, new_collection_name) -> None: + pass diff --git a/dataherald/tests/test_api.py b/dataherald/tests/test_api.py index 603a9f68..ebfa7db3 100644 --- a/dataherald/tests/test_api.py +++ b/dataherald/tests/test_api.py @@ -14,19 +14,24 @@ def test_heartbeat(): def test_scan_all_tables(): - response = client.post("/api/v1/scanner", json={"db_alias": "foo"}) + response = client.post( + "/api/v1/table-descriptions/scan", + json={"db_connection_id": "64dfa0e103f5134086f7090c"}, + ) assert response.status_code == HTTP_200_CODE def test_scan_one_table(): response = client.post( - "/api/v1/scanner", json={"db_alias": "foo", "table_name": "foo"} + "/api/v1/table-descriptions/scan", + json={"db_connection_id": "64dfa0e103f5134086f7090c", "table_names": ["foo"]}, ) assert response.status_code == HTTP_404_CODE def test_answer_question(): response = client.post( - "/api/v1/question", json={"question": "Who am I?", "db_alias": "foo"} + "/api/v1/question", + json={"question": "Who am I?", "db_connection_id": "64dfa0e103f5134086f7090c"}, ) assert response.status_code == HTTP_200_CODE diff --git a/dataherald/tests/vector_store/test_vector_store.py b/dataherald/tests/vector_store/test_vector_store.py index 4d0e8446..f2e18bf3 100644 --- a/dataherald/tests/vector_store/test_vector_store.py +++ b/dataherald/tests/vector_store/test_vector_store.py @@ -14,7 +14,7 @@ def __init__(self, system: System): def query( self, query_texts: List[str], - db_alias: str, + db_connection_id: str, collection: str, num_results: int, # noqa: ARG002 ) -> list: diff --git a/dataherald/types.py b/dataherald/types.py index 7ebe35a5..f235b39b 100644 --- a/dataherald/types.py +++ b/dataherald/types.py @@ -1,17 +1,31 @@ -# from datetime import datetime add this later from enum import Enum from typing import Any -from pydantic import BaseModel +from bson.errors import InvalidId +from bson.objectid import ObjectId +from pydantic import BaseModel, validator from dataherald.sql_database.models.types import SSHSettings +class DBConnectionValidation(BaseModel): + db_connection_id: str + + @validator("db_connection_id") + def object_id_validation(cls, v: str): + try: + ObjectId(v) + except InvalidId: + raise ValueError("Must be a valid ObjectId") # noqa: B904 + return v + + class UpdateQueryRequest(BaseModel): sql_query: str class ExecuteTempQueryRequest(BaseModel): + query_id: str sql_query: str @@ -23,20 +37,19 @@ class SQLQueryResult(BaseModel): class NLQuery(BaseModel): id: Any question: str - db_alias: str + db_connection_id: str -class GoldenRecordRequest(BaseModel): +class GoldenRecordRequest(DBConnectionValidation): question: str sql_query: str - db_alias: str class GoldenRecord(BaseModel): id: Any question: str sql_query: str - db_alias: str + db_connection_id: str class SQLGenerationStatus(Enum): @@ -61,17 +74,6 @@ class NLQueryResponse(BaseModel): # date_entered: datetime = datetime.now() add this later -class ScannedDBTable(BaseModel): - id: str - name: str - columns: list[str] - - -class ScannedDBResponse(BaseModel): - db_alias: str - tables: list[ScannedDBTable] - - class SupportedDatabase(Enum): POSTGRES = "POSTGRES" DATABRICKS = "DATABRICKS" @@ -80,18 +82,16 @@ class SupportedDatabase(Enum): BIGQUERY = "BIGQUERY" -class QuestionRequest(BaseModel): +class QuestionRequest(DBConnectionValidation): question: str - db_alias: str -class ScannerRequest(BaseModel): - db_alias: str - table_name: str | None +class ScannerRequest(DBConnectionValidation): + table_names: list[str] | None class DatabaseConnectionRequest(BaseModel): - db_alias: str + alias: str use_ssh: bool = False connection_uri: str | None path_to_credentials_file: str | None diff --git a/dataherald/vector_store/__init__.py b/dataherald/vector_store/__init__.py index c85e3e55..4f9b3c12 100644 --- a/dataherald/vector_store/__init__.py +++ b/dataherald/vector_store/__init__.py @@ -13,7 +13,11 @@ def __init__(self, system: System): @abstractmethod def query( - self, query_texts: List[str], db_alias: str, collection: str, num_results: int + self, + query_texts: List[str], + db_connection_id: str, + collection: str, + num_results: int, ) -> list: pass diff --git a/dataherald/vector_store/chroma.py b/dataherald/vector_store/chroma.py index 40148849..621807ed 100644 --- a/dataherald/vector_store/chroma.py +++ b/dataherald/vector_store/chroma.py @@ -22,7 +22,11 @@ def __init__( @override def query( - self, query_texts: List[str], db_alias: str, collection: str, num_results: int + self, + query_texts: List[str], + db_connection_id: str, + collection: str, + num_results: int, ) -> list: try: target_collection = self.chroma_client.get_collection(collection) @@ -32,7 +36,7 @@ def query( query_results = target_collection.query( query_texts=query_texts, n_results=num_results, - where={"db_alias": db_alias}, + where={"db_connection_id": db_connection_id}, ) return self.convert_to_pinecone_object_model(query_results) diff --git a/dataherald/vector_store/pinecone.py b/dataherald/vector_store/pinecone.py index 3afae4aa..e3ee324f 100644 --- a/dataherald/vector_store/pinecone.py +++ b/dataherald/vector_store/pinecone.py @@ -24,7 +24,11 @@ def __init__(self, system: System): @override def query( - self, query_texts: List[str], db_alias: str, collection: str, num_results: int + self, + query_texts: List[str], + db_connection_id: str, + collection: str, + num_results: int, ) -> list: index = pinecone.Index(collection) xq = openai.Embedding.create(input=query_texts[0], engine=EMBEDDING_MODEL)[ @@ -33,7 +37,7 @@ def query( query_response = index.query( queries=[xq], filter={ - "db_alias": {"$eq": db_alias}, + "db_connection_id": {"$eq": db_connection_id}, }, top_k=num_results, include_metadata=True, diff --git a/docs/api.add_descriptions.rst b/docs/api.add_descriptions.rst index 25de2631..5e7eadff 100644 --- a/docs/api.add_descriptions.rst +++ b/docs/api.add_descriptions.rst @@ -7,7 +7,7 @@ To return an accurate response set descriptions per table and column. Request this ``PATCH`` endpoint:: - /api/v1/scanned-db/{db_name}/{table_name} + /api/v1/table-descriptions/{table_description_id} **Parameters** @@ -16,8 +16,7 @@ Request this ``PATCH`` endpoint:: :header: "Name", "Type", "Description" :widths: 20, 20, 60 - "db_name", "String", "Database name, ``Required``" - "table_name", "String", "Table name, ``Required``" + "table_description_id", "String", "Table description id, ``Required``" **Request body** @@ -39,7 +38,30 @@ HTTP 200 code response .. code-block:: rst - true + { + "id": "string", + "db_connection_id": "string", + "table_name": "string", + "description": "string", + "table_schema": "string", + "columns": [ + { + "name": "string", + "is_primary_key": false, + "data_type": "str", + "description": "string", + "low_cardinality": false, + "categories": [ + "string" + ], + "foreign_key": { + "field_name": "string", + "reference_table": "string" + } + } + ], + "examples": [] + } **Example 1** @@ -48,7 +70,7 @@ Only set a table description .. code-block:: rst curl -X 'PATCH' \ - '/api/v1/scanned-db/foo_db/foo_table' \ + '/api/v1/table-descriptions/64fa09446cec0b4ff60d3ae3' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ @@ -62,7 +84,7 @@ Only set columns descriptions .. code-block:: rst curl -X 'PATCH' \ - '/api/v1/scanned-db/foo_db/foo_table' \ + '/api/v1/table-descriptions/64fa09446cec0b4ff60d3ae3' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ diff --git a/docs/api.database.rst b/docs/api.create_database_connection.rst similarity index 95% rename from docs/api.database.rst rename to docs/api.create_database_connection.rst index 9acdde71..ad5c2799 100644 --- a/docs/api.database.rst +++ b/docs/api.create_database_connection.rst @@ -11,14 +11,14 @@ You can find additional details on how to connect to each of the supported data **Request this POST endpoint**:: - /api/v1/database + /api/v1/database-connections **Request body** .. code-block:: rst { - "db_alias": "string", + "alias": "string", "use_ssh": true, "connection_uri": "string", "path_to_credentials_file": "string", @@ -69,11 +69,11 @@ Without a SSH connection .. code-block:: rst curl -X 'POST' \ - '/api/v1/database' \ + '/api/v1/database-connections' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_alias": "my_db_alias_identifier", + "alias": "my_db_alias_identifier", "use_ssh": false, "connection_uri": "sqlite:///mydb.db" }' @@ -85,11 +85,11 @@ With a SSH connection .. code-block:: rst curl -X 'POST' \ - 'http://localhost/api/v1/database' \ + '/api/v1/database-connections' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_alias": "my_db_alias_identifier", + "alias": "my_db_alias", "use_ssh": true, "ssh_settings": { "db_name": "db_name", diff --git a/docs/api.get_scanned_databases.rst b/docs/api.get_scanned_databases.rst deleted file mode 100644 index a07e6fdb..00000000 --- a/docs/api.get_scanned_databases.rst +++ /dev/null @@ -1,67 +0,0 @@ -Get a scanned database -============================= - -Once a database was scanned you can use this endpoint to retrieve the -tables names and columns - -Request this ``GET`` endpoint:: - - /api/v1/scanned-databases - -**Parameters** - -.. csv-table:: - :header: "Name", "Type", "Description" - :widths: 20, 20, 60 - - "db_alias", "string", "DB alias, ``Required``" - -**Responses** - -HTTP 200 code response - -.. code-block:: rst - - { - "db_alias": "string", - "tables": [ - { - "id": "string", - "name": "string", - "columns": [ - "string" - ] - } - ] - } - -**Request example** - -.. code-block:: rst - - curl -X 'GET' \ - '/api/v1/scanned-databases?db_alias=databricks' \ - -H 'accept: application/json' - -**Response example** - -.. code-block:: rst - - { - "db_alias": "databricks", - "tables": [ - { - "id": "64dfa18c03f5134086f7090d", - "name": "median_rent", - "columns": [ - "period_start", - "period_end", - "period_type", - "geo_type", - "property_type", - "location_name", - "metric_value" - ] - } - ] - } diff --git a/docs/api.golden_record.rst b/docs/api.golden_record.rst index e3e3fbb3..e234b614 100644 --- a/docs/api.golden_record.rst +++ b/docs/api.golden_record.rst @@ -22,7 +22,7 @@ Request this ``POST`` endpoint:: .. code-block:: rst [ - {"question": "question", "sql_query": "sql_query", "db_alias":"db_alias"}, + {"question": "question", "sql_query": "sql_query", "db_connection_id":"db_connection_id"}, ] **Responses** @@ -32,7 +32,7 @@ HTTP 200 code response .. code-block:: rst [ - {"id": "id", "question": "question", "sql_query":"sql", db_alias: "database alias"}, + {"id": "id", "question": "question", "sql_query":"sql", db_connection_id: "db_connection_id"}, ] **Example** @@ -48,7 +48,7 @@ HTTP 200 code response { "question": "what was the median home sale price in Califronia in Q1 2021?", "sql_query": "SELECT location_name, period_end, metric_value FROM redfin_median_sale_price rmsp WHERE geo_type = '\''state'\'' AND location_name='\''California'\'' AND property_type = '\''All Residential'\'' AND period_start BETWEEN '\''2021-01-01'\'' AND '\''2021-03-31'\'' ORDER BY period_end;", - "db_alias": "v2_real_estate", + "db_connection_id": "64dfa0e103f5134086f7090c", }]' Delete golden records @@ -109,7 +109,7 @@ HTTP 200 code response .. code-block:: rst [ - {"id": "id", "question": "question", "sql_query":"sql", db_alias: "database alias"}, + {"id": "id", "question": "question", "sql_query":"sql", db_connection_id: "db_connection_id"}, ] **Example** diff --git a/docs/api.list_database_connections.rst b/docs/api.list_database_connections.rst new file mode 100644 index 00000000..a7fde699 --- /dev/null +++ b/docs/api.list_database_connections.rst @@ -0,0 +1,55 @@ +List Database connections +============================= + +This endpoint list all the existing db connections + +**Request this GET endpoint**:: + + /api/v1/database-connections + + +**Responses** + +HTTP 200 code response + +.. code-block:: rst + + [ + { + "id": "64dfa0e103f5134086f7090c", + "alias": "databricks", + "use_ssh": false, + "uri": "foooAABk91Q4wjoR2h07GR7_72BdQnxi8Rm6i_EjyS-mzz_o2c3RAWaEqnlUvkK5eGD5kUfE5xheyivl1Wfbk_EM7CgV4SvdLmOOt7FJV-3kG4zAbar=", + "path_to_credentials_file": null, + "ssh_settings": null + }, + { + "id": "64e52c5f7d6dc4bc510d6d28", + "alias": "postgres", + "use_ssh": true, + "uri": null, + "path_to_credentials_file": null, + "ssh_settings": { + "db_name": "string", + "host": "string", + "username": "string", + "password": "foo-LWx6c1h6V0KkPRm9O148Pm9scvoO-wnasdasd1dQjf0ZQuFYI07uCjPiMcZ6uC19mUkiiYiHcKyok1NaLaGDAabkwg==", + "remote_host": "string", + "remote_db_name": "string", + "remote_db_password": "bar-LWxpBkRDuasdOwU__Sk4bdzruiYYqiyl8gBEEzCyFYCaCOcbQqOa_OwsS-asdasdsad==", + "private_key_path": "string", + "private_key_password": "fooo-LWxPdFcjQw9lU7CeK_2ELR3jGBq0G_uQ7E2rfPLk2RcFR4aDO9e2HmeAQtVpdvtrsQ_0zjsy9q7asdsadXExYJ0g==", + "db_driver": "string" + } + } + ] + +**Example 1** + +.. code-block:: rst + + curl -X 'GET' \ + '/api/v1/database-connections' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' + ' diff --git a/docs/api.list_table_description.rst b/docs/api.list_table_description.rst new file mode 100644 index 00000000..39551007 --- /dev/null +++ b/docs/api.list_table_description.rst @@ -0,0 +1,60 @@ +.. api.scan_database: + +List table descriptions +======================= + +Once you have scanned a db connection you can list the table descriptions by requesting this endpoint. + +Request this ``GET`` endpoint:: + + /api/v1/table-descriptions + +**Parameters** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "db_connection_id", "string", "Filter by connection id, ``Optional``" + "table_name", "string", "Filter by table name, ``Optional``" + +**Responses** + +HTTP 200 code response + +.. code-block:: rst + + [ + { + "id": "string", + "db_connection_id": "string", + "table_name": "string", + "description": "string", + "table_schema": "string", + "columns": [ + { + "name": "string", + "is_primary_key": false, + "data_type": "str", + "description": "string", + "low_cardinality": false, + "categories": [ + "string" + ], + "foreign_key": { + "field_name": "string", + "reference_table": "string" + } + } + ], + "examples": [] + } + ] + +**Request example** + +.. code-block:: rst + + curl -X 'GET' \ + '/api/v1/table-descriptions?db_connection_id=64fa09446cec0b4ff60d3ae3&table_name=foo' \ + -H 'accept: application/json' diff --git a/docs/api.process_nl_query_response.rst b/docs/api.process_nl_query_response.rst new file mode 100644 index 00000000..6f14c5ab --- /dev/null +++ b/docs/api.process_nl_query_response.rst @@ -0,0 +1,107 @@ +Process a NL query response +============================= + +Once you made a question you can try sending a new sql query to improve the response, this is not stored + +Request this ``POST`` endpoint:: + + /api/v1/nl-query-responses + +**Request body** + +.. code-block:: rst + + { + "query_id": "string", # required + "sql_query": "string" # required + } + +**Responses** + +HTTP 200 code response + +.. code-block:: rst + + { + "id": "string", + "nl_question_id": "string", + "nl_response": "string", + "intermediate_steps": [ + "string" + ], + "sql_query": "string", + "sql_query_result": { + "columns": [ + "string" + ], + "rows": [ + {} + ] + }, + "sql_generation_status": "NONE", + "error_message": "string", + "exec_time": 0, + "total_tokens": 0, + "total_cost": 0, + "confidence_score": 0 + } + +**Request example** + + +.. code-block:: rst + + curl -X 'POST' \ + '/api/v1/nl-query-responses' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "sql_query": "SELECT "dh_zip_code", MAX("metric_value") as max_rent + FROM db_table + WHERE "dh_county_name" = 'Los Angeles' AND "period_start" = '2022-05-01' AND "period_end" = '2022-05-31' + GROUP BY "zip_code" + ORDER BY max_rent DESC + LIMIT 1;", + "query_id": "64c424fa3f4036441e882352" + }' + +**Response example** + +.. code-block:: rst + + { + "id": { + "$oid": "64c424fa3f4036441e882352" + }, + "nl_question_id": { + "$oid": "64dbd8cf944f867b3c450467" + }, + "nl_response": "The most expensive zip to rent in Los Angeles city is 90210", + "intermediate_steps": [ + "", + ], + "sql_query": "SELECT "zip_code", MAX("metric_value") as max_rent + FROM db_table + WHERE "dh_county_name" = 'Los Angeles' AND "period_start" = '2022-05-01' AND "period_end" = '2022-05-31' + GROUP BY "zip_code" + ORDER BY max_rent DESC + LIMIT 1;", + "sql_query_result": { + "columns": [ + "zip_code", + "max_rent" + ], + "rows": [ + { + "zip_code": "90210", + "max_rent": 58279.6479072398192 + } + ] + }, + "sql_generation_status": "VALID", + "error_message": null, + "exec_time": 37.183526277542114, + "total_tokens": 17816, + "total_cost": 1.1087399999999998 + "confidence_score": 0.95 + } diff --git a/docs/api.question.rst b/docs/api.question.rst index fc019aa0..10537307 100644 --- a/docs/api.question.rst +++ b/docs/api.question.rst @@ -13,7 +13,7 @@ Request this ``POST`` endpoint:: .. code-block:: rst [ - {"nl_question": "question", "sql": "sql_query", "db":"db_alias"}, + {"nl_question": "question", "sql": "sql_query", "db_connection_id":"db_connection_id"}, ] **Responses** @@ -57,7 +57,7 @@ HTTP 200 code response -H 'Content-Type: application/json' \ -d '{ "question": "What is the median rent price for each property type in Los angeles city?", - "db_alias": "db_alias" + "db_connection_id": "db_connection_id" }' **Response example** diff --git a/docs/api.rst b/docs/api.rst index fea954c0..cd91503d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -1,13 +1,114 @@ API ======================= +The Dataherald Engine exposes RESTful APIs that can be used to: + +* 🔌 Connect to and manage connections to databases +* 🔑 Add context to the engine through scanning the databases, adding descriptions to tables and columns and adding golden records +* 🙋‍♀️ Ask natural language questions from the relational data + +Our APIs have resource-oriented URL built around standard HTTP response codes and verbs. The core resources are described below. + + +Database Connections +------------------------------ + +The ``database-connections`` object allows you to define connections to your relational data stores. + +Related endpoints are: + +* :doc:`Create database connection ` -- ``POST api/v1/database-connections`` +* :doc:`List database connections ` -- ``GET api/v1/database-connections`` +* :doc:`Update a database connection ` -- ``PUT api/v1/database-connections/{alias}`` + + +.. code-block:: json + + { + "alias": "string", + "use_ssh": false, + "connection_uri": "string", + "path_to_credentials_file": "string", + "ssh_settings": { + "db_name": "string", + "host": "string", + "username": "string", + "password": "string", + "remote_host": "string", + "remote_db_name": "string", + "remote_db_password": "string", + "private_key_path": "string", + "private_key_password": "string", + "db_driver": "string" + } + } + + +Query Response +------------------ +The ``query-response`` object is created from the answering natural language questions from the relational data. + +The related endpoints are: + +* :doc:`process_nl_query_response ` -- ``POST api/v1/nl-query-responses`` +* :doc:`update_nl_query_response ` -- ``PATCH api/v1/nl-query-responses/{query_id}`` + + +.. code-block:: json + + { + "confidence_score": "string", + "error_message": "string", + "exec_time": "float", + "intermediate_steps":["string"], + "nl_question_id": "string", + "nl_response": "string", + "sql_generation_status": "string", + "sql_query": "string", + "sql_query_result": {}, + "total_cost": "float", + "total_tokens": "int" + } + + +Table Descriptions +--------------------- +The ``table-descriptions`` object is used to add context about the tables and columns in the relational database. +These are then used to help the LLM build valid SQL to answer natural language questions. + +Related endpoints are: + +* :doc:`Scan table description ` -- ``POST api/v1/table-descriptions/scan`` +* :doc:`Add table description ` -- ``PATCH api/v1/table-descriptions/{table_description_id}`` +* :doc:`List table description ` -- ``GET api/v1/table-descriptions`` + +.. code-block:: json + + { + "columns": [{}], + "db_connection_id": "string", + "description": "string", + "examples": [{}], + "table_name": "string", + "table_schema": "string" + } + + + .. toctree:: :hidden: - api.database - api.scan_database - api.golden_record + api.create_database_connection + api.list_database_connections + api.update_database_connection + + api.scan_table_description api.add_descriptions + api.list_table_description + + api.golden_record + api.question - api.update_query - api.get_scanned_databases + + api.update_nl_query_response.rst + api.process_nl_query_response diff --git a/docs/api.scan_database.rst b/docs/api.scan_table_description.rst similarity index 64% rename from docs/api.scan_database.rst rename to docs/api.scan_table_description.rst index bfb6c3ff..ed7d057e 100644 --- a/docs/api.scan_database.rst +++ b/docs/api.scan_table_description.rst @@ -6,19 +6,19 @@ Scan a Database Once you have set your db credentials request this endpoint to scan your database. It maps all tables and columns so It will help the SQL Agent to generate an accurate answer. -It can scan all db tables or if you specify a `table_name` then It will only scan that table. +It can scan all db tables or if you specify a `table_names` then It will only scan those tables. Request this ``POST`` endpoint:: - /api/v1/scanner + /api/v1/table-descriptions/scan **Request body** .. code-block:: rst { - "db_alias": "string", - "table_name": "string" # Optional + "db_connection_id": "string", + "table_names": ["string"] # Optional } **Responses** @@ -31,7 +31,7 @@ HTTP 200 code response **Request example** -To scan all the tables in a db don't specify a `table_name` +To scan all the tables in a db don't specify a `table_names` .. code-block:: rst @@ -40,5 +40,5 @@ To scan all the tables in a db don't specify a `table_name` -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_alias": "db_alias" + "db_connection_id": "db_connection_id" }' diff --git a/docs/api.update_database_connection.rst b/docs/api.update_database_connection.rst new file mode 100644 index 00000000..28f25259 --- /dev/null +++ b/docs/api.update_database_connection.rst @@ -0,0 +1,108 @@ +Update a Database connection +============================= + +This endpoint is used to update a Database connection + +**Request this PUT endpoint**:: + + /api/v1/database-connections/{db_connection_id} + +**Parameters** + +.. csv-table:: + :header: "Name", "Type", "Description" + :widths: 20, 20, 60 + + "db_connection_id", "String", "Set the database connection id, ``Required``" + +**Request body** + +.. code-block:: rst + + { + "alias": "string", + "use_ssh": true, + "connection_uri": "string", + "path_to_credentials_file": "string", + "ssh_settings": { + "db_name": "string", + "host": "string", + "username": "string", + "password": "string", + "remote_host": "string", + "remote_db_name": "string", + "remote_db_password": "string", + "private_key_path": "string", + "private_key_password": "string", + "db_driver": "string" + } + } + +**Responses** + +HTTP 200 code response + +.. code-block:: rst + + { + "id": "64f251ce9614e0e94b0520bc", + "alias": "string_999", + "use_ssh": false, + "uri": "gAAAAABk8lHQNAUn5XARb94Q8H1OfHpVzOtzP3b2LCpwxUsNCe7LGkwkN8FX-IF3t65oI5mTzgDMR0BY2lzvx55gO0rxlQxRDA==", + "path_to_credentials_file": "string", + "ssh_settings": { + "db_name": "string", + "host": "string", + "username": "string", + "password": "gAAAAABk8lHQAaaSuoUKxddkMHw7jerwFmUeiE3hL6si06geRt8CV-r43fbckZjI6LbIULWPZ4HlQUF9_YpfaYfM6FarQbhDUQ==", + "remote_host": "string", + "remote_db_name": "string", + "remote_db_password": "gAAAAABk8lHQpZyZ6ow8EuYPWe5haP-roQbBWkZn3trLgdO632IDoKcXAW-8yjzDDQ4uH03iWFzEgJq8HRxkJTC6Ht7Qrlz2PQ==", + "private_key_path": "string", + "private_key_password": "gAAAAABk8lHQWilFpIbCADvunHGYFMqgoPKIml_WRXf5Yuowqng28DVsq6-sChl695y5D_mWrr1I3hcJCZqkmhDqpma6iz3PKA==", + "db_driver": "string" + } + } + +**Example 1** + +Without a SSH connection + +.. code-block:: rst + + curl -X 'PUT' \ + '/api/v1/database-connections/64f251ce9614e0e94b0520bc' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "alias": "my_db_alias_identifier", + "use_ssh": false, + "connection_uri": "sqlite:///mydb.db" + }' + +**Example 2** + +With a SSH connection + +.. code-block:: rst + + curl -X 'PUT' \ + '/api/v1/database-connections/64f251ce9614e0e94b0520bc' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "alias": "my_db_alias", + "use_ssh": true, + "ssh_settings": { + "db_name": "db_name", + "host": "string", + "username": "string", + "password": "string", + "remote_host": "string", + "remote_db_name": "string", + "remote_db_password": "string", + "private_key_path": "string", + "private_key_password": "string", + "db_driver": "string" + } + }' diff --git a/docs/api.update_query.rst b/docs/api.update_nl_query_response.rst similarity index 68% rename from docs/api.update_query.rst rename to docs/api.update_nl_query_response.rst index b7188298..9447afa0 100644 --- a/docs/api.update_query.rst +++ b/docs/api.update_nl_query_response.rst @@ -1,11 +1,11 @@ -Update a query -======================= +Update a NL query response +============================ -You can give feedback to improve the queries, and set a query response as a golden query +Once you ask a question, you can give feedback to improve the queries Request this ``PATCH`` endpoint:: - /api/v1/query/{query_id} + /api/v1/nl-query-responses/{query_id} **Parameters** @@ -20,8 +20,7 @@ Request this ``PATCH`` endpoint:: .. code-block:: rst { - "sql_query": "string", # optional - "golden_record": true # boolean and optional + "sql_query": "string", # required } **Responses** @@ -30,7 +29,29 @@ HTTP 200 code response .. code-block:: rst - true + { + "id": "string", + "nl_question_id": "string", + "nl_response": "string", + "intermediate_steps": [ + "string" + ], + "sql_query": "string", + "sql_query_result": { + "columns": [ + "string" + ], + "rows": [ + {} + ] + }, + "sql_generation_status": "NONE", + "error_message": "string", + "exec_time": 0, + "total_tokens": 0, + "total_cost": 0, + "confidence_score": 0 + } **Request example** @@ -38,7 +59,7 @@ HTTP 200 code response .. code-block:: rst curl -X 'POST' \ - '/api/v1/query/64c424fa3f4036441e882352' \ + '/api/v1/nl-query-responses/64c424fa3f4036441e882352' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ @@ -47,8 +68,7 @@ HTTP 200 code response WHERE "dh_county_name" = 'Los Angeles' AND "period_start" = '2022-05-01' AND "period_end" = '2022-05-31' GROUP BY "zip_code" ORDER BY max_rent DESC - LIMIT 1;", - "golden_record": true + LIMIT 1;" }' **Response example** @@ -88,7 +108,6 @@ HTTP 200 code response "error_message": null, "exec_time": 37.183526277542114, "total_tokens": 17816, - "total_cost": 1.1087399999999998, - "golden_record": true, + "total_cost": 1.1087399999999998 "confidence_score": 0.95 } diff --git a/docs/api_server.rst b/docs/api_server.rst index 6f8a9e6a..168b2b62 100644 --- a/docs/api_server.rst +++ b/docs/api_server.rst @@ -41,7 +41,7 @@ All implementations of the API module must inherit and implement the abstract :c :return: The NLQueryResponse containing the response to the user's question. :rtype: NLQueryResponse -.. method:: connect_database(self, database_connection_request: DatabaseConnectionRequest) -> bool +.. method:: create_database_connection(self, database_connection_request: DatabaseConnectionRequest) -> bool :noindex: Establishes a connection to a database using the provided connection request. @@ -51,13 +51,13 @@ All implementations of the API module must inherit and implement the abstract :c :return: True if the connection was established successfully; otherwise, False. :rtype: bool -.. method:: add_description(self, db_name: str, table_name: str, table_description_request: TableDescriptionRequest) -> bool +.. method:: add_description(self, db_connection_id: str, table_name: str, table_description_request: TableDescriptionRequest) -> bool :noindex: Adds a description to a specific table within a database based on the provided table description request. - :param db_name: The name of the database. - :type db_name: str + :param db_connection_id: The db connection id + :type db_connection_id: str :param table_name: The name of the table. :type table_name: str :param table_description_request: The table description request. @@ -109,13 +109,13 @@ All implementations of the API module must inherit and implement the abstract :c :return: The NLQueryResponse containing the result of the temporary query execution. :rtype: NLQueryResponse -.. method:: get_scanned_databases(self, db_alias: str) -> ScannedDBResponse +.. method:: get_scanned_databases(self, db_connection_id: str) -> ScannedDBResponse :noindex: - Retrieves information about scanned databases based on a database alias. + Retrieves information about scanned databases based on a database connection id. - :param db_alias: The alias of the database. - :type db_alias: str + :param db_connection_id: The database connection id. + :type db_connection_id: str :return: The ScannedDBResponse containing information about scanned databases. :rtype: ScannedDBResponse diff --git a/docs/conf.py b/docs/conf.py index dc5cde16..e99f88df 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,10 +11,11 @@ sys.path.insert(0, os.path.abspath("..")) -project = "Dataherald" +project = "Dataherald AI" copyright = "2023, Dataherald" author = "Dataherald" -release = "0.0.1" +release = "main" +html_title = project # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/docs/contributing.projects.rst b/docs/contributing.projects.rst new file mode 100644 index 00000000..79724f15 --- /dev/null +++ b/docs/contributing.projects.rst @@ -0,0 +1,11 @@ +Jumping in +==================== + +We are beyond thrilled that you are considering joining this project. There are a number of +community projects that are in development, spanning areas such as: + +* Connecting to public data sources +* Building integrations with front-end frameworks +* Testing and benchmarking new NL-to-SQL approaches proposed in academic literature + +The best place to jump in is to hop on the #projects channel on our :ref:`Discord server `_ \ No newline at end of file diff --git a/docs/envars.rst b/docs/envars.rst new file mode 100644 index 00000000..78e80012 --- /dev/null +++ b/docs/envars.rst @@ -0,0 +1,63 @@ +Environment Variables +======================= +The Dataherald engine has a number of environment variables that need to be set in order for it to work. The following is the sample +provided in the .env.example file with the default values. + + +.. code-block:: bash + + OPENAI_API_KEY = + ORG_ID = + LLM_MODEL = 'gpt-4-32k' + + + GOLDEN_RECORD_COLLECTION = 'my-golden-records' + PINECONE_API_KEY = + PINECONE_ENVIRONMENT = + + + API_SERVER = "dataherald.api.fastapi.FastAPI" + SQL_GENERATOR = "dataherald.sql_generator.dataherald_sqlagent.DataheraldSQLAgent" + EVALUATOR = "dataherald.eval.simple_evaluator.SimpleEvaluator" + DB = "dataherald.db.mongo.MongoDB" + VECTOR_STORE = 'dataherald.vector_store.chroma.Chroma' + CONTEXT_STORE = 'dataherald.context_store.default.DefaultContextStore' + DB_SCANNER = 'dataherald.db_scanner.sqlalchemy.SqlAlchemyScanner' + + + MONGODB_URI = "mongodb://admin:admin@mongodb:27017" + MONGODB_DB_NAME = 'dataherald' + MONGODB_DB_USERNAME = 'admin' + MONGODB_DB_PASSWORD = 'admin' + + ENCRYPT_KEY = + + S3_AWS_ACCESS_KEY_ID = + S3_AWS_SECRET_ACCESS_KEY = + ` + + +.. csv-table:: + :header: "Variable Name", "Description", "Default Value", "Required" + :widths: 15, 55, 25, 5 + + "OPENAI_API_KEY", "The OpenAI key used by the Dataherald Engine", "None", "Yes" + "ORG_ID", "The OpenAI Organization ID used by the Dataherald Engine", "None", "Yes" + "LLM_MODEL", "The Language Model used by the Dataherald Engine. Supported values include gpt-4-32k, gpt-4, gpt-3.5-turbo, gpt-3.5-turbo-16k", "``gpt-4-32k``", "No" + "GOLDEN_RECORD_COLLECTION", "The name of the collection in Mongo where golden records will be stored", "``my-golden-records``", "No" + "PINECONE_API_KEY", "The Pinecone API key used", "None", "Yes if using the Pinecone vector store" + "PINECONE_ENVIRONMENT", "The Pinecone environment", "None", "Yes if using the Pinecone vector store" + "API_SERVER", "The implementation of the API Module used by the Dataherald Engine.", "``dataherald.api.fastapi.FastAPI``", "Yes" + "SQL_GENERATOR", "The implementation of the SQLGenerator Module to be used.", "``dataherald.sql_generator. dataherald_sqlagent. DataheraldSQLAgent``", "Yes" + "EVALUATOR", "The implementation of the Evaluator Module to be used.", "``dataherald.eval. simple_evaluator.SimpleEvaluator``", "Yes" + "DB", "The implementation of the DB Module to be used.", "``dataherald.db.mongo.MongoDB``", "Yes" + "VECTOR_STORE", "The implementation of the Vector Store Module to be used. Chroma and Pinecone modules are currently included.", "``dataherald.vector_store. chroma.Chroma``", "Yes" + "CONTEXT_STORE", "The implementation of the Context Store Module to be used.", "``dataherald.context_store. default.DefaultContextStore``", "Yes" + "DB_SCANNER", "The implementation of the DB Scanner Module to be used.", "``dataherald.db_scanner. sqlalchemy.SqlAlchemyScanner``", "Yes" + "MONGODB_URI", "The URI of the MongoDB that will be used for application storage.", "``mongodb:// admin:admin@mongodb:27017``", "Yes" + "MONGODB_DB_NAME", "The name of the MongoDB database that will be used.", "``dataherald``", "Yes" + "MONGODB_DB_USERNAME", "The username of the MongoDB database", "``admin``", "Yes" + "MONGODB_DB_PASSWORD", "The password of the MongoDB database", "``admin``", "Yes" + "ENCRYPT_KEY", "The key that will be used to encrypt data at rest before storing", "None", "Yes" + "S3_AWS_ACCESS_KEY_ID", "The key used to access credential files if saved to S3", "None", "No" + "S3_AWS_SECRET_ACCESS_KEY", "The key used to access credential files if saved to S3", "None", "No " diff --git a/docs/getting_started.rst b/docs/getting_started.rst deleted file mode 100644 index 4c0fe45c..00000000 --- a/docs/getting_started.rst +++ /dev/null @@ -1,19 +0,0 @@ -.. _getting_started: - -Getting started -======================== - -Dataherald AI comes "batteries included." While the engine is modular and core modules can be easily replaced, we have included best-in-class implementations of core modules so you can get set up in minutes. - - - -.. toctree:: - :hidden: - - introduction - quickstart - - - - - diff --git a/docs/index.rst b/docs/index.rst index 1b5ceb23..cba4eb79 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -7,14 +7,38 @@ Dataherald AI ======================================== Welcome to the official documentation page of the Dataherald AI engine. This documentation is intended for developers who want to: -* Use the Dataherald AI engine to set up Natural Language interfaces from structured data in their own projects. -* Contribute to the Dataherald AI engine. +* 🖥️ Use the Dataherald AI engine to set up Natural Language interfaces from structured data in their own projects. +* 🏍️ Contribute to the Dataherald AI engine. These documents will cover how to get started, how to set up an API from your database that can answer questions in plain English and how to extend the core engine's functionality. .. toctree:: + :maxdepth: 1 + :caption: Getting Started + :hidden: + + introduction + quickstart + +.. toctree:: + :caption: References :hidden: - getting_started api - modules \ No newline at end of file + envars + modules + +.. toctree:: + :caption: Tutorials + :hidden: + + tutorial.sample_database + tutorial.finetune_sql_generator + tutorial.chatgpt_plugin + + +.. toctree:: + :caption: Contributing + :hidden: + + contributing.projects diff --git a/docs/introduction.rst b/docs/introduction.rst index 46cfccfc..e18bf8b2 100644 --- a/docs/introduction.rst +++ b/docs/introduction.rst @@ -12,8 +12,8 @@ You can use Dataherald to: Dataherald is built to: -* Be modular, allowing different implementations of core modules to be plugged-in -* Come batteries included: Have best-in-class implementations for modules like text to SQL, evaluation -* Be easy to set-up and use with major data warehouses -* Allow for Active Learning, allowing you to improve the performance with usage -* Be fast +* 🔌 Be modular, allowing different implementations of core modules to be plugged-in +* 🔋 Come batteries included: Have best-in-class implementations for modules like text to SQL, evaluation +* 📀 Be easy to set-up and use with major data warehouses +* 👨‍🏫 Allow for Active Learning, allowing you to improve the performance with usage +* 🏎️ Be fast \ No newline at end of file diff --git a/docs/modules.rst b/docs/modules.rst index 45dba705..94f54e72 100644 --- a/docs/modules.rst +++ b/docs/modules.rst @@ -48,9 +48,3 @@ The following diagram illustrates the overall system architecture. -Get Involved ------------- - -We encourage you to imeplemnt your own modules and engage with the Dataherald community on `Discord `_ or on `GitHub `_. Your input drives our ongoing development and improvement. - -:ref:`Back to Top ` diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 49d1390e..67a42d9e 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -94,7 +94,7 @@ You can define a DB connection through a call to the following API endpoint `/ap -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_alias": "my_db_alias_identifier", + "db_connection_id": "db_connection_id", "use_ssh": false, "connection_uri": "sqlite:///mydb.db" }' @@ -110,7 +110,7 @@ If you need to connect to your database through an SSH tunnel, you will need to -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ - "db_alias": "my_db_alias_identifier", + "db_connection_id": "db_connection_id", "use_ssh": true, "ssh_settings": { "db_name": "db_name", @@ -150,7 +150,7 @@ Once you have connected the engine to your data warehouse (and preferably added -H 'Content-Type: application/json' \ -d '{ "question": "what was the most expensive zip code to rent in Los Angeles county in May 2022?"", - "db_alias": "db_name" + "db_connection_id": "db_connection_id" }' diff --git a/docs/tutorial.chatgpt_plugin.rst b/docs/tutorial.chatgpt_plugin.rst new file mode 100644 index 00000000..ebbea8f0 --- /dev/null +++ b/docs/tutorial.chatgpt_plugin.rst @@ -0,0 +1,4 @@ +Create a ChatGPT plug-in from your structured data +===================================================== + +Coming soon ... \ No newline at end of file diff --git a/docs/tutorial.finetune_sql_generator.rst b/docs/tutorial.finetune_sql_generator.rst new file mode 100644 index 00000000..1d67e80d --- /dev/null +++ b/docs/tutorial.finetune_sql_generator.rst @@ -0,0 +1,4 @@ +Using a Custom Text to SQL Engine +================================== + +Coming soon ... diff --git a/docs/tutorial.sample_database.rst b/docs/tutorial.sample_database.rst new file mode 100644 index 00000000..65a8fd8a --- /dev/null +++ b/docs/tutorial.sample_database.rst @@ -0,0 +1,4 @@ +Setting up a sample Database for accurate NL-to-SQL +==================================================== + +Coming soon ... \ No newline at end of file diff --git a/docs/vector_store.rst b/docs/vector_store.rst index 4f9c6f29..3cc64421 100644 --- a/docs/vector_store.rst +++ b/docs/vector_store.rst @@ -22,15 +22,15 @@ This abstract class defines the common methods that both ChromaDB and Pinecone v :param system: The system object. :type system: System -.. method:: query(self, query_texts: List[str], db_alias: str, collection: str, num_results: int) -> list +.. method:: query(self, query_texts: List[str], db_connection_id: str, collection: str, num_results: int) -> list :noindex: Executes a query to retrieve similar vectors from the vector store. :param query_texts: A list of query texts. :type query_texts: List[str] - :param db_alias: The alias for the database. - :type db_alias: str + :param db_connection_id: The db connection id. + :type db_connection_id: str :param collection: The name of the collection. :type collection: str :param num_results: The number of results to retrieve.