From e7cb4458f780e2e82ce365a74a2c53927398e57e Mon Sep 17 00:00:00 2001 From: Mohammadreza Pourreza <71866535+MohammadrezaPourreza@users.noreply.github.com> Date: Mon, 18 Sep 2023 15:09:27 -0400 Subject: [PATCH] DH-4537/sorting the tables and column (#166) --- dataherald/db/__init__.py | 2 +- dataherald/db/mongo.py | 4 +++- dataherald/db_scanner/repository/base.py | 4 +++- dataherald/tests/db/test_db.py | 2 +- 4 files changed, 8 insertions(+), 4 deletions(-) diff --git a/dataherald/db/__init__.py b/dataherald/db/__init__.py index 7d4adedb..6a650ab1 100644 --- a/dataherald/db/__init__.py +++ b/dataherald/db/__init__.py @@ -29,7 +29,7 @@ def find_by_id(self, collection: str, id: str) -> dict: pass @abstractmethod - def find(self, collection: str, query: dict) -> list: + def find(self, collection: str, query: dict, sort: list = None) -> list: pass @abstractmethod diff --git a/dataherald/db/mongo.py b/dataherald/db/mongo.py index fec4146e..7879d5e0 100644 --- a/dataherald/db/mongo.py +++ b/dataherald/db/mongo.py @@ -40,7 +40,9 @@ def find_by_id(self, collection: str, id: str) -> dict: return self._data_store[collection].find_one({"_id": ObjectId(id)}) @override - def find(self, collection: str, query: dict) -> list: + def find(self, collection: str, query: dict, sort: list = None) -> list: + if sort: + return self._data_store[collection].find(query).sort(sort) return self._data_store[collection].find(query) @override diff --git a/dataherald/db_scanner/repository/base.py b/dataherald/db_scanner/repository/base.py index f3a86fd9..00e0b6eb 100644 --- a/dataherald/db_scanner/repository/base.py +++ b/dataherald/db_scanner/repository/base.py @@ -1,6 +1,7 @@ from typing import List from bson.objectid import ObjectId +from pymongo import ASCENDING from dataherald.db_scanner.models.types import TableSchemaDetail @@ -68,10 +69,11 @@ def find_all(self) -> list[TableSchemaDetail]: def find_by(self, query: dict) -> list[TableSchemaDetail]: query = {k: v for k, v in query.items() if v} - rows = self.storage.find(DB_COLLECTION, query) + rows = self.storage.find(DB_COLLECTION, query, sort=[("table_name", ASCENDING)]) result = [] for row in rows: obj = TableSchemaDetail(**row) + obj.columns = sorted(obj.columns, key=lambda x: x.name) obj.id = str(row["_id"]) result.append(obj) return result diff --git a/dataherald/tests/db/test_db.py b/dataherald/tests/db/test_db.py index 684a62d7..2c7b39c3 100644 --- a/dataherald/tests/db/test_db.py +++ b/dataherald/tests/db/test_db.py @@ -48,7 +48,7 @@ def find_by_id(self, collection: str, id: str) -> dict: return None @override - def find(self, collection: str, query: dict) -> list: + def find(self, collection: str, query: dict, sort: list = None) -> list: return [] @override