diff --git a/.github/workflows/analyze.yml b/.github/workflows/analyze.yml index 25ddf32..c4c5db7 100644 --- a/.github/workflows/analyze.yml +++ b/.github/workflows/analyze.yml @@ -37,7 +37,7 @@ jobs: steps: - name: Checkout repository - uses: actions/checkout@v2 + uses: actions/checkout@v4 # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index d6b1742..bc3f9ff 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,9 +1,8 @@ name: build on: workflow_dispatch: - push: - branches: [ master ] pull_request: + push: branches: [ master ] env: PACKAGE_DIR: adbdgl_adapter @@ -16,31 +15,44 @@ jobs: python: ["3.8", "3.9", "3.10", "3.11"] name: Python ${{ matrix.python }} steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + - name: Setup Python ${{ matrix.python }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python }} + cache: 'pip' + cache-dependency-path: setup.py + - name: Set up ArangoDB Instance via Docker - run: docker create --name adb -p 8529:8529 -e ARANGO_ROOT_PASSWORD= arangodb/arangodb:3.9.1 + run: docker create --name adb -p 8529:8529 -e ARANGO_ROOT_PASSWORD= arangodb/arangodb + - name: Start ArangoDB Instance run: docker start adb + - name: Setup pip run: python -m pip install --upgrade pip setuptools wheel + - name: Install packages run: pip install .[dev] + - name: Run black run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run flake8 run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run isort run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run mypy run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} + - name: Run pytest run: pytest --cov=${{env.PACKAGE_DIR}} --cov-report xml --cov-report term-missing -v --color=yes --no-cov-on-fail --code-highlight=yes + - name: Publish to coveralls.io - if: matrix.python == '3.8' + if: matrix.python == '3.10' env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} run: coveralls --service=github \ No newline at end of file diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 553150f..ac040fc 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -3,76 +3,34 @@ on: workflow_dispatch: release: types: [published] -env: - PACKAGE_DIR: adbdgl_adapter - TESTS_DIR: tests jobs: - build: - runs-on: ubuntu-latest - strategy: - matrix: - python: ["3.8", "3.9", "3.10", "3.11"] - name: Python ${{ matrix.python }} - steps: - - uses: actions/checkout@v2 - - name: Setup Python ${{ matrix.python }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python }} - - name: Set up ArangoDB Instance via Docker - run: docker create --name adb -p 8529:8529 -e ARANGO_ROOT_PASSWORD= arangodb/arangodb:3.9.1 - - name: Start ArangoDB Instance - run: docker start adb - - name: Setup pip - run: python -m pip install --upgrade pip setuptools wheel - - name: Install packages - run: pip install .[dev] - - name: Run black - run: black --check --verbose --diff --color ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - name: Run flake8 - run: flake8 ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - name: Run isort - run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - name: Run mypy - run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}} - - name: Run pytest - run: pytest --cov=${{env.PACKAGE_DIR}} --cov-report xml --cov-report term-missing -v --color=yes --no-cov-on-fail --code-highlight=yes - - name: Publish to coveralls.io - if: matrix.python == '3.8' - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - run: coveralls --service=github - release: - needs: build runs-on: ubuntu-latest name: Release package steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Fetch complete history for all tags and branches run: git fetch --prune --unshallow - - name: Setup python - uses: actions/setup-python@v2 + - name: Setup Python + uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.10" - name: Install release packages run: pip install setuptools wheel twine setuptools-scm[toml] - - name: Install dependencies - run: pip install .[dev] - - name: Build distribution run: python setup.py sdist bdist_wheel - - name: Publish to PyPI Test + - name: Publish to Test PyPi env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD_TEST }} run: twine upload --repository testpypi dist/* #--skip-existing - - name: Publish to PyPI + + - name: Publish to PyPi env: TWINE_USERNAME: __token__ TWINE_PASSWORD: ${{ secrets.TWINE_PASSWORD }} @@ -83,7 +41,7 @@ jobs: runs-on: ubuntu-latest name: Update Changelog steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 with: fetch-depth: 0 @@ -95,10 +53,10 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - name: Setup python - uses: actions/setup-python@v2 + - name: Setup Python + uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.10" - name: Install release packages run: pip install wheel gitchangelog pystache @@ -110,12 +68,12 @@ jobs: run: gitchangelog ${{env.VERSION}} > CHANGELOG.md - name: Make commit for auto-generated changelog - uses: EndBug/add-and-commit@v7 + uses: EndBug/add-and-commit@v9 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} with: add: "CHANGELOG.md" - branch: actions/changelog + new_branch: actions/changelog message: "!gitchangelog" - name: Create pull request for the auto generated changelog @@ -128,4 +86,4 @@ jobs: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Alert developer of open PR - run: echo "Changelog $PR_URL is ready to be merged by developer." \ No newline at end of file + run: echo "Changelog $PR_URL is ready to be merged by developer." diff --git a/adbdgl_adapter/abc.py b/adbdgl_adapter/abc.py index 9f2b4e3..12d1746 100644 --- a/adbdgl_adapter/abc.py +++ b/adbdgl_adapter/abc.py @@ -15,16 +15,18 @@ def __init__(self) -> None: raise NotImplementedError # pragma: no cover def arangodb_to_dgl( - self, name: str, metagraph: ADBMetagraph, **query_options: Any + self, name: str, metagraph: ADBMetagraph, **adb_export_kwargs: Any ) -> DGLHeteroGraph: raise NotImplementedError # pragma: no cover def arangodb_collections_to_dgl( - self, name: str, v_cols: Set[str], e_cols: Set[str], **query_options: Any + self, name: str, v_cols: Set[str], e_cols: Set[str], **adb_export_kwargs: Any ) -> DGLHeteroGraph: raise NotImplementedError # pragma: no cover - def arangodb_graph_to_dgl(self, name: str, **query_options: Any) -> DGLHeteroGraph: + def arangodb_graph_to_dgl( + self, name: str, **adb_export_kwargs: Any + ) -> DGLHeteroGraph: raise NotImplementedError # pragma: no cover def dgl_to_arangodb( @@ -34,7 +36,7 @@ def dgl_to_arangodb( metagraph: DGLMetagraph = {}, explicit_metagraph: bool = True, overwrite_graph: bool = False, - **import_options: Any, + **adb_import_kwargs: Any, ) -> ArangoDBGraph: raise NotImplementedError # pragma: no cover diff --git a/adbdgl_adapter/adapter.py b/adbdgl_adapter/adapter.py index ac9ff1a..71a3092 100644 --- a/adbdgl_adapter/adapter.py +++ b/adbdgl_adapter/adapter.py @@ -3,14 +3,17 @@ import logging from collections import defaultdict from math import ceil -from typing import Any, DefaultDict, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, DefaultDict, Dict, List, Optional, Set, Tuple, Union from arango.cursor import Cursor -from arango.database import Database +from arango.database import StandardDatabase from arango.graph import Graph as ADBGraph from dgl import DGLGraph, DGLHeteroGraph, graph, heterograph from dgl.view import EdgeSpace, HeteroEdgeDataView, HeteroNodeDataView, NodeSpace from pandas import DataFrame, Series +from rich.console import Group +from rich.live import Live +from rich.progress import Progress from torch import Tensor, cat, tensor from .abc import Abstract_ADBDGL_Adapter @@ -28,7 +31,14 @@ DGLMetagraphValues, Json, ) -from .utils import logger, progress, validate_adb_metagraph, validate_dgl_metagraph +from .utils import ( + get_bar_progress, + get_export_spinner_progress, + get_import_spinner_progress, + logger, + validate_adb_metagraph, + validate_dgl_metagraph, +) class ADBDGL_Adapter(Abstract_ADBDGL_Adapter): @@ -49,14 +59,14 @@ class ADBDGL_Adapter(Abstract_ADBDGL_Adapter): def __init__( self, - db: Database, + db: StandardDatabase, controller: ADBDGL_Controller = ADBDGL_Controller(), logging_lvl: Union[str, int] = logging.INFO, ): self.set_logging(logging_lvl) - if not isinstance(db, Database): - msg = "**db** parameter must inherit from arango.database.Database" + if not isinstance(db, StandardDatabase): + msg = "**db** parameter must inherit from arango.database.StandardDatabase" raise TypeError(msg) if not isinstance(controller, ADBDGL_Controller): @@ -64,12 +74,13 @@ def __init__( raise TypeError(msg) self.__db = db + self.__async_db = db.begin_async_execution(return_result=False) self.__cntrl = controller logger.info(f"Instantiated ADBDGL_Adapter with database '{db.name}'") @property - def db(self) -> Database: + def db(self) -> StandardDatabase: return self.__db # pragma: no cover @property @@ -79,11 +90,15 @@ def cntrl(self) -> ADBDGL_Controller: def set_logging(self, level: Union[int, str]) -> None: logger.setLevel(level) + ########################### + # Public: ArangoDB -> DGL # + ########################### + def arangodb_to_dgl( - self, name: str, metagraph: ADBMetagraph, **query_options: Any + self, name: str, metagraph: ADBMetagraph, **adb_export_kwargs: Any ) -> Union[DGLGraph, DGLHeteroGraph]: - """Create a DGL graph from ArangoDB data. DOES carry - over node/edge features/labels, via the **metagraph**. + """Create a DGL graph from an ArangoDB Metagraph. Carries + over node/edge data via the **metagraph**. :param name: The DGL graph name. :type name: str @@ -113,10 +128,10 @@ def arangodb_to_dgl( See below for examples of **metagraph**. :type metagraph: adbdgl_adapter.typings.ADBMetagraph - :param query_options: Keyword arguments to specify AQL query options when + :param adb_export_kwargs: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. Full parameter list: https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute - :type query_options: Any + :type adb_export_kwargs: Any :return: A DGL Homogeneous or Heterogeneous graph object :rtype: dgl.DGLGraph | dgl.DGLHeteroGraph :raise adbdgl_adapter.exceptions.ADBMetagraphError: If invalid metagraph. @@ -250,89 +265,58 @@ def udf_v1_x(v1_df): for v_col, meta in metagraph["vertexCollections"].items(): logger.debug(f"Preparing '{v_col}' vertices") - dgl_id = 0 - cursor = self.__fetch_adb_docs(v_col, meta, query_options) - while not cursor.empty(): - cursor_batch = len(cursor.batch()) # type: ignore - df = DataFrame([cursor.pop() for _ in range(cursor_batch)]) - - # 1. Map each ArangoDB _key to a DGL node id - for adb_id in df["_key"]: - adb_map[v_col][adb_id] = dgl_id - dgl_id += 1 - - # 2. Set the DGL Node Data - self.__set_dgl_data(v_col, meta, ndata, df) - - if cursor.has_more(): - cursor.fetch() + # 1. Fetch ArangoDB vertices + v_col_cursor, v_col_size = self.__fetch_adb_docs( + v_col, meta, **adb_export_kwargs + ) - df.drop(df.index, inplace=True) + # 2. Process ArangoDB vertices + self.__process_adb_cursor( + "#319BF5", + v_col_cursor, + v_col_size, + self.__process_adb_vertex_df, + v_col, + adb_map, + meta, + ndata=ndata, + ) #################### # Edge Collections # #################### - # et = Edge Type - et_df: DataFrame - et_blacklist: List[DGLCanonicalEType] = [] # A list of skipped edge types + # The set of skipped edge types + edge_type_blacklist: Set[DGLCanonicalEType] = set() for e_col, meta in metagraph["edgeCollections"].items(): logger.debug(f"Preparing '{e_col}' edges") - cursor = self.__fetch_adb_docs(e_col, meta, query_options) - while not cursor.empty(): - cursor_batch = len(cursor.batch()) # type: ignore - df = DataFrame([cursor.pop() for _ in range(cursor_batch)]) - - # 1. Split the ArangoDB _from & _to IDs into two columns - df[["from_col", "from_key"]] = self.__split_adb_ids(df["_from"]) - df[["to_col", "to_key"]] = self.__split_adb_ids(df["_to"]) - - # 2. Iterate over each edge type - for (from_col, to_col), count in ( - df[["from_col", "to_col"]].value_counts().items() - ): - edge_type: DGLCanonicalEType = (from_col, e_col, to_col) - - # 3. Check for partial Edge Collection import - if from_col not in v_cols or to_col not in v_cols: - logger.debug(f"Skipping {edge_type}") - et_blacklist.append(edge_type) - continue - - logger.debug(f"Preparing {count} '{edge_type}' edges") - - # 4. Get the edge data corresponding to the current edge type - et_df = df[(df["from_col"] == from_col) & (df["to_col"] == to_col)] - - # 5. Map each ArangoDB from/to _key to the corresponding DGL node id - from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist() - to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist() - - # 6. Set/Update the DGL Edge Index - if edge_type not in data_dict: - data_dict[edge_type] = (tensor(from_nodes), tensor(to_nodes)) - else: - previous_from_nodes, previous_to_nodes = data_dict[edge_type] - data_dict[edge_type] = ( - cat((previous_from_nodes, tensor(from_nodes))), - cat((previous_to_nodes, tensor(to_nodes))), - ) - - # 7. Set the DGL Edge Data - self.__set_dgl_data(edge_type, meta, edata, df) - - if cursor.has_more(): - cursor.fetch() + # 1. Fetch ArangoDB edges + e_col_cursor, e_col_size = self.__fetch_adb_docs( + e_col, meta, **adb_export_kwargs + ) - df.drop(df.index, inplace=True) + # 2. Process ArangoDB edges + self.__process_adb_cursor( + "#FCFDFC", + e_col_cursor, + e_col_size, + self.__process_adb_edge_df, + e_col, + adb_map, + meta, + edata=edata, + data_dict=data_dict, + v_cols=v_cols, + edge_type_blacklist=edge_type_blacklist, + ) if not data_dict: # pragma: no cover msg = f""" Can't create the DGL graph: no complete edge types found. The following edge types were skipped due to missing - vertex collection specifications: {et_blacklist} + vertex collection specifications: {edge_type_blacklist} """ raise ValueError(msg) @@ -348,10 +332,10 @@ def arangodb_collections_to_dgl( name: str, v_cols: Set[str], e_cols: Set[str], - **query_options: Any, + **adb_export_kwargs: Any, ) -> Union[DGLGraph, DGLHeteroGraph]: """Create a DGL graph from ArangoDB collections. Due to risk of - ambiguity, this method DOES NOT transfer ArangoDB attributes to DGL. + ambiguity, this method DOES NOT transfer ArangoDB attributes to DGL. :param name: The DGL graph name. :type name: str @@ -359,10 +343,10 @@ def arangodb_collections_to_dgl( :type v_cols: Set[str] :param e_cols: The set of ArangoDB edge collections to import to DGL. :type e_cols: Set[str] - :param query_options: Keyword arguments to specify AQL query options when + :param adb_export_kwargs: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. Full parameter list: https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute - :type query_options: Any + :type adb_export_kwargs: Any :return: A DGL Homogeneous or Heterogeneous graph object :rtype: dgl.DGLGraph | dgl.DGLHeteroGraph :raise adbdgl_adapter.exceptions.ADBMetagraphError: If invalid metagraph. @@ -372,19 +356,19 @@ def arangodb_collections_to_dgl( "edgeCollections": {col: dict() for col in e_cols}, } - return self.arangodb_to_dgl(name, metagraph, **query_options) + return self.arangodb_to_dgl(name, metagraph, **adb_export_kwargs) def arangodb_graph_to_dgl( - self, name: str, **query_options: Any + self, name: str, **adb_export_kwargs: Any ) -> Union[DGLGraph, DGLHeteroGraph]: """Create a DGL graph from an ArangoDB graph. :param name: The ArangoDB graph name. :type name: str - :param query_options: Keyword arguments to specify AQL query options when + :param adb_export_kwargs: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. Full parameter list: https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute - :type query_options: Any + :type adb_export_kwargs: Any :return: A DGL Homogeneous or Heterogeneous graph object :rtype: dgl.DGLGraph | dgl.DGLHeteroGraph :raise adbdgl_adapter.exceptions.ADBMetagraphError: If invalid metagraph. @@ -394,7 +378,13 @@ def arangodb_graph_to_dgl( edge_definitions: List[Json] = graph.edge_definitions() # type: ignore e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions} - return self.arangodb_collections_to_dgl(name, v_cols, e_cols, **query_options) + return self.arangodb_collections_to_dgl( + name, v_cols, e_cols, **adb_export_kwargs + ) + + ########################### + # Public: DGL -> ArangoDB # + ########################### def dgl_to_arangodb( self, @@ -404,7 +394,8 @@ def dgl_to_arangodb( explicit_metagraph: bool = True, overwrite_graph: bool = False, batch_size: Optional[int] = None, - **import_options: Any, + use_async: bool = False, + **adb_import_kwargs: Any, ) -> ADBGraph: """Create an ArangoDB graph from a DGL graph. @@ -437,7 +428,7 @@ def dgl_to_arangodb( See below for an example of **metagraph**. :type metagraph: adbdgl_adapter.typings.DGLMetagraph :param explicit_metagraph: Whether to take the metagraph at face value or not. - If False, node & edge types OMITTED from the metagraph will be + If False, node & edge types OMITTED from the metagraph will still be brought over into ArangoDB. Also applies to node & edge attributes. Defaults to True. :type explicit_metagraph: bool @@ -448,10 +439,13 @@ def dgl_to_arangodb( **batch_size**. Defaults to `None`, which processes each NodeStorage & EdgeStorage in one batch. :type batch_size: int - :param import_options: Keyword arguments to specify additional + :param use_async: Performs asynchronous ArangoDB ingestion if enabled. + Defaults to False. + :type use_async: bool + :param adb_import_kwargs: Keyword arguments to specify additional parameters for ArangoDB document insertion. Full parameter list: https://docs.python-arango.com/en/main/specs.html#arango.collection.Collection.import_bulk - :type import_options: Any + :type adb_import_kwargs: Any :return: The ArangoDB Graph API wrapper. :rtype: arango.graph.Graph :raise adbdgl_adapter.exceptions.DGLMetagraphError: If invalid metagraph. @@ -496,23 +490,27 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): validate_dgl_metagraph(metagraph) - is_explicit_metagraph = metagraph != {} and explicit_metagraph is_custom_controller = type(self.__cntrl) is not ADBDGL_Controller + is_explicit_metagraph = metagraph != {} and explicit_metagraph has_one_ntype = len(dgl_g.ntypes) == 1 has_one_etype = len(dgl_g.canonical_etypes) == 1 + # Get the Node & Edge types node_types, edge_types = self.__get_node_and_edge_types( name, dgl_g, metagraph, is_explicit_metagraph ) + # Create the ArangoDB Graph adb_graph = self.__create_adb_graph( name, overwrite_graph, node_types, edge_types ) - ############## - # Node Types # - ############## + spinner_progress = get_import_spinner_progress(" ") + + ############# + # DGL Nodes # + ############# n_meta = metagraph.get("nodeTypes", {}) for n_type in node_types: @@ -520,6 +518,7 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): n_key = None if has_one_ntype else n_type + ndata = dgl_g.nodes[n_key].data ndata_size = dgl_g.num_nodes(n_key) ndata_batch_size = batch_size or ndata_size @@ -527,45 +526,45 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): end_index = min(ndata_batch_size, ndata_size) batches = ceil(ndata_size / ndata_batch_size) - # For each batch of nodes - for _ in range(batches): - # 1. Map each DGL node id to an ArangoDB _key - adb_keys = [{"_key": str(i)} for i in range(start_index, end_index)] - - # 2. Set the ArangoDB Node Data - df = self.__set_adb_data( - DataFrame(adb_keys, index=range(start_index, end_index)), - meta, - dgl_g.nodes[n_key].data, - ndata_size, - start_index, - end_index, - is_explicit_metagraph, - ) + bar_progress = get_bar_progress(f"(DGL → ADB): '{n_type}'", "#97C423") + bar_progress_task = bar_progress.add_task(n_type, total=ndata_size) + + with Live(Group(bar_progress, spinner_progress)): + for _ in range(batches): + # 1. Process the Node batch + df = self.__process_dgl_node_batch( + n_type, + ndata, + ndata_size, + meta, + is_explicit_metagraph, + is_custom_controller, + start_index, + end_index, + ) - # 3. Apply the ArangoDB Node Controller (if provided) - if is_custom_controller: - f = lambda n: self.__cntrl._prepare_dgl_node(n, n_type) - df = df.apply(f, axis=1) + bar_progress.advance(bar_progress_task, advance=len(df)) - # 4. Insert the ArangoDB Node Documents - self.__insert_adb_docs(n_type, df, import_options) + # 2. Insert the ArangoDB Node Documents + self.__insert_adb_docs( + spinner_progress, df, n_type, use_async, **adb_import_kwargs + ) - # 5. Update the batch indices - start_index = end_index - end_index = min(end_index + ndata_batch_size, ndata_size) + # 3. Update the batch indices + start_index = end_index + end_index = min(end_index + ndata_batch_size, ndata_size) - ############## - # Edge Types # - ############## + ############# + # DGL Edges # + ############# e_meta = metagraph.get("edgeTypes", {}) for e_type in edge_types: meta = e_meta.get(e_type, {}) - from_col, _, to_col = e_type e_key = None if has_one_etype else e_type + edata = dgl_g.edges[e_key].data edata_size = dgl_g.num_edges(e_key) edata_batch_size = batch_size or edata_size @@ -573,93 +572,371 @@ def y_tensor_to_2_column_dataframe(dgl_tensor): end_index = min(edata_batch_size, edata_size) batches = ceil(edata_size / edata_batch_size) + bar_progress = get_bar_progress(f"(DGL → ADB): {e_type}", "#994602") + bar_progress_task = bar_progress.add_task(str(e_type), total=edata_size) + from_nodes, to_nodes = dgl_g.edges(etype=e_key) - # For each batch of edges - for _ in range(batches): - # 1. Map the DGL edges to ArangoDB _from & _to IDs - data = zip( - *( - from_nodes[start_index:end_index].tolist(), - to_nodes[start_index:end_index].tolist(), + with Live(Group(bar_progress, spinner_progress)): + for _ in range(batches): + # 1. Process the Edge batch + df = self.__process_dgl_edge_batch( + e_type, + edata, + edata_size, + meta, + from_nodes, + to_nodes, + is_explicit_metagraph, + is_custom_controller, + start_index, + end_index, ) - ) - # 2. Set the ArangoDB Edge Data - df = self.__set_adb_data( - DataFrame( - data, - index=range(start_index, end_index), - columns=["_from", "_to"], - ), - meta, - dgl_g.edges[e_key].data, - edata_size, - start_index, - end_index, - is_explicit_metagraph, + bar_progress.advance(bar_progress_task, advance=len(df)) + + # 2. Insert the ArangoDB Edge Documents + self.__insert_adb_docs( + spinner_progress, df, e_type[1], use_async, **adb_import_kwargs + ) + + # 3. Update the batch indices + start_index = end_index + end_index = min(end_index + edata_batch_size, edata_size) + + logger.info(f"Created ArangoDB '{name}' Graph") + return adb_graph + + ############################ + # Private: ArangoDB -> DGL # + ############################ + + def __fetch_adb_docs( + self, + col: str, + meta: Union[Set[str], Dict[str, ADBMetagraphValues]], + **adb_export_kwargs: Any, + ) -> Tuple[Cursor, int]: + """ArangoDB -> DGL: Fetches ArangoDB documents within a collection. + Returns the documents in a DataFrame. + + :param col: The ArangoDB collection. + :type col: str + :param meta: The MetaGraph associated to **col** + :type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues] + :param adb_export_kwargs: Keyword arguments to specify AQL query options + when fetching documents from the ArangoDB instance. + :type adb_export_kwargs: Any + :return: A DataFrame representing the ArangoDB documents. + :rtype: pandas.DataFrame + """ + + def get_aql_return_value( + meta: Union[Set[str], Dict[str, ADBMetagraphValues]] + ) -> str: + """Helper method to formulate the AQL `RETURN` value based on + the document attributes specified in **meta** + """ + attributes = [] + + if type(meta) is set: + attributes = list(meta) + + elif type(meta) is dict: + for value in meta.values(): + if type(value) is str: + attributes.append(value) + elif type(value) is dict: + attributes.extend(list(value.keys())) + elif callable(value): + # Cannot determine which attributes to extract if UDFs are used + # Therefore we just return the entire document + return "doc" + + return f""" + MERGE( + {{ _key: doc._key, _from: doc._from, _to: doc._to }}, + KEEP(doc, {list(attributes)}) ) + """ + + col_size: int = self.__db.collection(col).count() # type: ignore - df["_from"] = from_col + "/" + df["_from"].astype(str) - df["_to"] = to_col + "/" + df["_to"].astype(str) + with get_export_spinner_progress(f"ADB Export: '{col}' ({col_size})") as p: + p.add_task(col) - # 3. Apply the ArangoDB Edge Controller (if provided) - if is_custom_controller: - f = lambda e: self.__cntrl._prepare_dgl_edge(e, e_type) - df = df.apply(f, axis=1) + cursor: Cursor = self.__db.aql.execute( # type: ignore + f"FOR doc IN @@col RETURN {get_aql_return_value(meta)}", + bind_vars={"@col": col}, + **{**adb_export_kwargs, **{"stream": True}}, + ) - # 4. Insert the ArangoDB Edge Documents - self.__insert_adb_docs(e_type, df, import_options) + return cursor, col_size - # 5. Update the batch indices - start_index = end_index - end_index = min(end_index + edata_batch_size, edata_size) + def __process_adb_cursor( + self, + progress_color: str, + cursor: Cursor, + col_size: int, + process_adb_df: Callable[..., int], + col: str, + adb_map: ADBMap, + meta: Union[Set[str], Dict[str, ADBMetagraphValues]], + **kwargs: Any, + ) -> None: + """ArangoDB -> DGL: Processes the ArangoDB Cursors for vertices and edges. + + :param progress_color: The progress bar color. + :type progress_color: str + :param cursor: The ArangoDB cursor for the current **col**. + :type cursor: arango.cursor.Cursor + :param col_size: The size of **col**. + :type col_size: int + :param process_adb_df: The function to process the cursor data + (in the form of a Dataframe). + :type process_adb_df: Callable + :param col: The ArangoDB collection for the current **cursor**. + :type col: str + :param adb_map: The ArangoDB -> DGL map. + :type adb_map: adbdgl_adapter.typings.ADBMap + :param meta: The metagraph for the current **col**. + :type meta: Set[str] | Dict[str, ADBMetagraphValues] + :param kwargs: Additional keyword arguments to pass to **process_adb_df**. + :type args: Any + """ - logger.info(f"Created ArangoDB '{name}' Graph") - return adb_graph + progress = get_bar_progress(f"(ADB → DGL): '{col}'", progress_color) + progress_task_id = progress.add_task(col, total=col_size) - def __create_adb_graph( + with Live(Group(progress)): + i = 0 + while not cursor.empty(): + cursor_batch = len(cursor.batch()) # type: ignore + df = DataFrame([cursor.pop() for _ in range(cursor_batch)]) + + i = process_adb_df(i, df, col, adb_map, meta, **kwargs) + progress.advance(progress_task_id, advance=len(df)) + + df.drop(df.index, inplace=True) + + if cursor.has_more(): + cursor.fetch() + + def __process_adb_vertex_df( self, - name: str, - overwrite_graph: bool, - node_types: List[str], - edge_types: List[DGLCanonicalEType], - ) -> ADBGraph: - """Creates an ArangoDB graph. + i: int, + df: DataFrame, + v_col: str, + adb_map: ADBMap, + meta: Union[Set[str], Dict[str, ADBMetagraphValues]], + ndata: DGLData, + ) -> int: + """ArangoDB -> DGL: Process the ArangoDB Vertex DataFrame + into the DGL NData object. + + :param i: The last DGL Node id value. + :type i: int + :param df: The ArangoDB Vertex DataFrame. + :type df: pandas.DataFrame + :param v_col: The ArangoDB Vertex Collection. + :type v_col: str + :param adb_map: The ArangoDB -> DGL map. + :type adb_map: adbdgl_adapter.typings.ADBMap + :param meta: The metagraph for the current **v_col**. + :type meta: Set[str] | Dict[str, ADBMetagraphValues] + :param node_data: The node data view for storing node features + :type node_data: adbdgl_adapter.typings.DGLData + :return: The last DGL Node id value. + :rtype: int + """ + # 1. Map each ArangoDB _key to a DGL node id + for adb_id in df["_key"]: + adb_map[v_col][adb_id] = i + i += 1 - :param name: The ArangoDB graph name. - :type name: str - :param overwrite_graph: Overwrites the graph if it already exists. - Does not drop associated collections. Defaults to False. - :type overwrite_graph: bool - :param node_types: A list of strings representing the DGL node types. - :type node_types: List[str] - :param edge_types: A list of string triplets (str, str, str) for - source node type, edge type and destination node type. - :type edge_types: List[adbdgl_adapter.typings.DGLCanonicalEType] - :return: The ArangoDB Graph API wrapper. - :rtype: arango.graph.Graph + # 2. Set the DGL Node Data + self.__set_dgl_data(v_col, meta, ndata, df) + + return i + + def __process_adb_edge_df( + self, + _: int, + df: DataFrame, + e_col: str, + adb_map: ADBMap, + meta: Union[Set[str], Dict[str, ADBMetagraphValues]], + edata: DGLData, + data_dict: DGLDataDict, + v_cols: List[str], + edge_type_blacklist: Set[DGLCanonicalEType], + ) -> int: + """ArangoDB -> DGL: Process the ArangoDB Edge DataFrame + into the DGL EdgeData object. + + :param _: Not used. + :type _: int + :param df: The ArangoDB Edge DataFrame. + :type df: pandas.DataFrame + :param e_col: The ArangoDB Edge Collection. + :type e_col: str + :param adb_map: The ArangoDB -> DGL map. + :type adb_map: adbdgl_adapter.typings.ADBMap + :param meta: The metagraph for the current **e_col**. + :type meta: Set[str] | Dict[str, ADBMetagraphValues] + :param edata: The edge data view for storing edge features + :type edata: adbdgl_adapter.typings.DGLData + :param data_dict: The data for constructing a graph, + which takes the form of (U, V). + (U[i], V[i]) forms the edge with ID i in the graph. + :type data_dict: adbdgl_adapter.typings.DGLDataDict + :param v_cols: The list of ArangoDB Vertex Collections. + :type v_cols: List[str] + :param edge_type_blacklist: The set of skipped edge types + :type edge_type_blacklist: Set[DGLCanonicalEType] + :return: The last DGL Edge id value. This is a useless return value, + but is needed for type hinting. + :rtype: int """ - if overwrite_graph: - logger.debug("Overwrite graph flag is True. Deleting old graph.") - self.__db.delete_graph(name, ignore_missing=True) + # 1. Split the ArangoDB _from & _to IDs into two columns + df[["from_col", "from_key"]] = self.__split_adb_ids(df["_from"]) + df[["to_col", "to_key"]] = self.__split_adb_ids(df["_to"]) + + # 2. Iterate over each edge type + for (from_col, to_col), count in ( + df[["from_col", "to_col"]].value_counts().items() + ): + edge_type: DGLCanonicalEType = (from_col, e_col, to_col) + + # 3. Check for partial Edge Collection import + if from_col not in v_cols or to_col not in v_cols: + logger.debug(f"Skipping {edge_type}") + edge_type_blacklist.add(edge_type) + continue + + logger.debug(f"Preparing {count} {edge_type} edges") + + # 4. Get the edge data corresponding to the current edge type + et_df = df[(df["from_col"] == from_col) & (df["to_col"] == to_col)] + + # 5. Map each ArangoDB from/to _key to the corresponding DGL node id + from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist() + to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist() + + # 6. Set/Update the DGL Edge Index + if edge_type not in data_dict: + data_dict[edge_type] = (tensor(from_nodes), tensor(to_nodes)) + else: + previous_from_nodes, previous_to_nodes = data_dict[edge_type] + data_dict[edge_type] = ( + cat((previous_from_nodes, tensor(from_nodes))), + cat((previous_to_nodes, tensor(to_nodes))), + ) - if self.__db.has_graph(name): - return self.__db.graph(name) + # 7. Set the DGL Edge Data + self.__set_dgl_data(edge_type, meta, edata, df) - edge_definitions = self.__etypes_to_edefinitions(edge_types) - orphan_collections = self.__ntypes_to_ocollections(node_types, edge_types) + return 1 # Useless return value, but needed for type hinting - return self.__db.create_graph( # type: ignore[return-value] - name, - edge_definitions, - orphan_collections, + def __split_adb_ids(self, s: Series) -> Series: + """AranogDB -> DGL: Helper method to split the ArangoDB IDs + within a Series into two columns + + :param s: The Series containing the ArangoDB IDs. + :type s: pandas.Series + :return: A DataFrame with two columns: the ArangoDB Collection, + and the ArangoDB _key. + :rtype: pandas.Series + """ + return s.str.split(pat="/", n=1, expand=True) + + def __set_dgl_data( + self, + data_type: DGLDataTypes, + meta: Union[Set[str], Dict[str, ADBMetagraphValues]], + dgl_data: DGLData, + df: DataFrame, + ) -> None: + """AranogDB -> DGL: A helper method to build the DGL NodeSpace or + EdgeSpace object for the DGL graph. Is responsible for preparing the + input **meta** such that it becomes a dictionary, and building DGL-ready + tensors from the ArangoDB DataFrame **df**. + + :param data_type: The current node or edge type of the soon-to-be DGL graph. + :type data_type: str | tuple[str, str, str] + :param meta: The metagraph associated to the current ArangoDB vertex or + edge collection. e.g metagraph['vertexCollections']['Users'] + :type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues] + :param dgl_data: The (currently empty) DefaultDict object storing the node or + edge features of the soon-to-be DGL graph. + :type dgl_data: adbdgl_adapter.typings.DGLData + :param df: The DataFrame representing the ArangoDB collection data + :type df: pandas.DataFrame + """ + valid_meta: Dict[str, ADBMetagraphValues] + valid_meta = meta if type(meta) is dict else {m: m for m in meta} + + for k, v in valid_meta.items(): + t = self.__build_tensor_from_dataframe(df, k, v) + dgl_data[k][data_type] = cat((dgl_data[k][data_type], t)) + + def __build_tensor_from_dataframe( + self, + adb_df: DataFrame, + meta_key: str, + meta_val: ADBMetagraphValues, + ) -> Tensor: + """AranogDB -> DGL: Constructs a DGL-ready Tensor from a Pandas + Dataframe, based on the nature of the user-defined metagraph. + + :param adb_df: The Pandas Dataframe representing ArangoDB data. + :type adb_df: pandas.DataFrame + :param meta_key: The current ArangoDB-DGL metagraph key + :type meta_key: str + :param meta_val: The value mapped to **meta_key** to + help convert **df** into a DGL-ready Tensor. + e.g the value of `metagraph['vertexCollections']['users']['x']`. + :type meta_val: adbdgl_adapter.typings.ADBMetagraphValues + :return: A DGL-ready tensor equivalent to the dataframe + :rtype: torch.Tensor + :raise adbdgl_adapter.exceptions.ADBMetagraphError: If invalid **meta_val**. + """ + logger.debug( + f"__build_tensor_from_dataframe(df, '{meta_key}', {type(meta_val)})" ) + if type(meta_val) is str: + return tensor(adb_df[meta_val].to_list()) + + if type(meta_val) is dict: + data = [] + for attr, encoder in meta_val.items(): + if encoder is None: + data.append(tensor(adb_df[attr].to_list())) + elif callable(encoder): + data.append(encoder(adb_df[attr])) + else: # pragma: no cover + msg = f"Invalid encoder for ArangoDB attribute '{attr}': {encoder}" + raise ADBMetagraphError(msg) + + return cat(data, dim=-1) + + if callable(meta_val): + # **meta_val** is a user-defined that returns a tensor + user_defined_result = meta_val(adb_df) + + if type(user_defined_result) is not Tensor: # pragma: no cover + msg = f"Invalid return type for function {meta_val} ('{meta_key}')" + raise ADBMetagraphError(msg) + + return user_defined_result + + raise ADBMetagraphError(f"Invalid {meta_val} type") # pragma: no cover + def __create_dgl_graph( self, data_dict: DGLDataDict, adb_map: ADBMap, metagraph: ADBMetagraph ) -> Union[DGLGraph, DGLHeteroGraph]: - """Creates a DGL graph from the given DGL data. + """AranogDB -> DGL: Creates a DGL graph from the given DGL data. :param data_dict: The data for constructing a graph, which takes the form of (U, V). @@ -686,6 +963,35 @@ def __create_dgl_graph( num_nodes_dict = {v_col: len(adb_map[v_col]) for v_col in adb_map} return heterograph(data_dict, num_nodes_dict) + def __link_dgl_data( + self, + dgl_data: Union[HeteroNodeDataView, HeteroEdgeDataView], + dgl_data_temp: DGLData, + has_one_type: bool, + ) -> None: + """Links **dgl_data_temp** to **dgl_data**. This method is (unfortunately) + required, since a dgl graph's `ndata` and `edata` properties can't be + manually set (i.e `g.ndata = ndata` is not possible). + + :param dgl_data: The (empty) ndata or edata instance attribute of a dgl graph, + which is about to receive **dgl_data_temp**. + :type dgl_data: Union[dgl.view.HeteroNodeDataView, dgl.view.HeteroEdgeDataView] + :param dgl_data_temp: A temporary place to store the ndata or edata features. + :type dgl_data_temp: adbdgl_adapter.typings.DGLData + :param has_one_type: Set to True if the DGL graph only has one + node type or edge type. + :type has_one_type: bool + """ + for feature_name, feature_map in dgl_data_temp.items(): + for data_type, dgl_tensor in feature_map.items(): + dgl_data[feature_name] = ( + dgl_tensor if has_one_type else {data_type: dgl_tensor} + ) + + ############################ + # Private: DGL -> ArangoDB # + ############################ + def __get_node_and_edge_types( self, name: str, @@ -693,8 +999,8 @@ def __get_node_and_edge_types( metagraph: DGLMetagraph, is_explicit_metagraph: bool, ) -> Tuple[List[str], List[DGLCanonicalEType]]: - """Returns the node & edge types of the DGL graph, based on the - metagraph and whether the graph has default canonical etypes. + """DGL -> ArangoDB: Returns the node & edge types of the DGL graph, + based on the metagraph and whether the graph has default canonical etypes. :param name: The DGL graph name. :type name: str @@ -795,152 +1101,171 @@ def __ntypes_to_ocollections( orphan_collections = set(node_types) ^ non_orphan_collections return list(orphan_collections) - def __fetch_adb_docs( + def __create_adb_graph( self, - col: str, - meta: Union[Set[str], Dict[str, ADBMetagraphValues]], - query_options: Any, - ) -> Cursor: - """Fetches ArangoDB documents within a collection. Returns the - documents in a DataFrame. + name: str, + overwrite_graph: bool, + node_types: List[str], + edge_types: List[DGLCanonicalEType], + ) -> ADBGraph: + """Creates an ArangoDB graph. - :param col: The ArangoDB collection. - :type col: str - :param meta: The MetaGraph associated to **col** - :type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues] - :param query_options: Keyword arguments to specify AQL query options - when fetching documents from the ArangoDB instance. - :type query_options: Any - :return: A DataFrame representing the ArangoDB documents. - :rtype: pandas.DataFrame + :param name: The ArangoDB graph name. + :type name: str + :param overwrite_graph: Overwrites the graph if it already exists. + Does not drop associated collections. Defaults to False. + :type overwrite_graph: bool + :param node_types: A list of strings representing the DGL node types. + :type node_types: List[str] + :param edge_types: A list of string triplets (str, str, str) for + source node type, edge type and destination node type. + :type edge_types: List[adbdgl_adapter.typings.DGLCanonicalEType] + :return: The ArangoDB Graph API wrapper. + :rtype: arango.graph.Graph """ + if overwrite_graph: + logger.debug("Overwrite graph flag is True. Deleting old graph.") + self.__db.delete_graph(name, ignore_missing=True) - def get_aql_return_value( - meta: Union[Set[str], Dict[str, ADBMetagraphValues]] - ) -> str: - """Helper method to formulate the AQL `RETURN` value based on - the document attributes specified in **meta** - """ - attributes = [] - - if type(meta) is set: - attributes = list(meta) - - elif type(meta) is dict: - for value in meta.values(): - if type(value) is str: - attributes.append(value) - elif type(value) is dict: - attributes.extend(list(value.keys())) - elif callable(value): - # Cannot determine which attributes to extract if UDFs are used - # Therefore we just return the entire document - return "doc" - - return f""" - MERGE( - {{ _key: doc._key, _from: doc._from, _to: doc._to }}, - KEEP(doc, {list(attributes)}) - ) - """ + if self.__db.has_graph(name): + return self.__db.graph(name) - with progress( - f"(ADB → DGL): {col}", - text_style="#319BF5", - spinner_style="#FCFDFC", - ) as p: - p.add_task("__fetch_adb_docs") - return self.__db.aql.execute( # type: ignore - f"FOR doc IN @@col RETURN {get_aql_return_value(meta)}", - bind_vars={"@col": col}, - **{**{"stream": True}, **query_options}, - ) + edge_definitions = self.__etypes_to_edefinitions(edge_types) + orphan_collections = self.__ntypes_to_ocollections(node_types, edge_types) - def __insert_adb_docs( - self, doc_type: Union[str, DGLCanonicalEType], df: DataFrame, kwargs: Any - ) -> None: - """Insert ArangoDB documents into their ArangoDB collection. + return self.__db.create_graph( # type: ignore[return-value] + name, + edge_definitions, + orphan_collections, + ) - :param doc_type: The node or edge type of the soon-to-be ArangoDB documents - :type doc_type: str | tuple[str, str, str] - :param df: To-be-inserted ArangoDB documents, formatted as a DataFrame - :type df: pandas.DataFrame - :param kwargs: Keyword arguments to specify additional - parameters for ArangoDB document insertion. Full parameter list: - https://docs.python-arango.com/en/main/specs.html#arango.collection.Collection.import_bulk + def __process_dgl_node_batch( + self, + n_type: str, + ndata: NodeSpace, + ndata_size: int, + meta: Union[Set[str], Dict[Any, DGLMetagraphValues]], + is_explicit_metagraph: bool, + is_custom_controller: bool, + start_index: int, + end_index: int, + ) -> DataFrame: + """DGL -> ArangoDB: Processes the DGL Node batch + into an ArangoDB DataFrame. + + :param n_type: The DGL node type. + :type n_type: str + :param ndata: The DGL Node Space for the current **n_type**. + :type ndata: dgl.view.NodeSpace + :param ndata_size: The size of **ndata**. + :param ndata_size: int + :param meta: The metagraph for the current **n_type**. + :type meta: Set[str] | Dict[Any, adbdgl_adapter.typings.DGLMetagraphValues] + :param is_explicit_metagraph: Take the metagraph at face value or not. + :type is_explicit_metagraph: bool + :param is_custom_controller: Whether a custom controller is used. + :type is_custom_controller: bool + :param start_index: The start index of the current batch. + :type start_index: int + :param end_index: The end index of the current batch. + :type end_index: int + :return: The ArangoDB DataFrame representing the DGL Node batch. + :rtype: pandas.DataFrame """ - col = doc_type if type(doc_type) is str else doc_type[1] - - with progress( - f"(DGL → ADB): {doc_type} ({len(df)})", - text_style="#97C423", - spinner_style="#994602", - ) as p: - p.add_task("__insert_adb_docs") + # 1. Map each DGL node id to an ArangoDB _key + adb_keys = [{"_key": str(i)} for i in range(start_index, end_index)] + + # 2. Set the ArangoDB Node Data + df = self.__set_adb_data( + DataFrame(adb_keys, index=range(start_index, end_index)), + meta, + ndata, + ndata_size, + is_explicit_metagraph, + start_index, + end_index, + ) - docs = df.to_dict("records") - result = self.__db.collection(col).import_bulk(docs, **kwargs) - logger.debug(result) - df.drop(df.index, inplace=True) + # 3. Apply the ArangoDB Node Controller (if provided) + if is_custom_controller: + f = lambda n: self.__cntrl._prepare_dgl_node(n, n_type) + df = df.apply(f, axis=1) - def __split_adb_ids(self, s: Series) -> Series: - """Helper method to split the ArangoDB IDs within a Series into two columns""" - return s.str.split(pat="/", n=1, expand=True) + return df - def __set_dgl_data( + def __process_dgl_edge_batch( self, - data_type: DGLDataTypes, - meta: Union[Set[str], Dict[str, ADBMetagraphValues]], - dgl_data: DGLData, - df: DataFrame, - ) -> None: - """A helper method to build the DGL NodeSpace or EdgeSpace object - for the DGL graph. Is responsible for preparing the input **meta** such - that it becomes a dictionary, and building DGL-ready tensors from the - ArangoDB DataFrame **df**. - - :param data_type: The current node or edge type of the soon-to-be DGL graph. - :type data_type: str | tuple[str, str, str] - :param meta: The metagraph associated to the current ArangoDB vertex or - edge collection. e.g metagraph['vertexCollections']['Users'] - :type meta: Set[str] | Dict[str, adbdgl_adapter.typings.ADBMetagraphValues] - :param dgl_data: The (currently empty) DefaultDict object storing the node or - edge features of the soon-to-be DGL graph. - :type dgl_data: adbdgl_adapter.typings.DGLData - :param df: The DataFrame representing the ArangoDB collection data - :type df: pandas.DataFrame + e_type: DGLCanonicalEType, + edata: EdgeSpace, + edata_size: int, + meta: Union[Set[str], Dict[Any, DGLMetagraphValues]], + from_nodes: Tensor, + to_nodes: Tensor, + is_explicit_metagraph: bool, + is_custom_controller: bool, + start_index: int, + end_index: int, + ) -> DataFrame: + """DGL -> ArangoDB: Processes the DGL Edge batch + into an ArangoDB DataFrame. + + :param e_type: The DGL edge type. + :type e_type: adbdgl_adapter.typings.DGLCanonicalEType + :param edata: The DGL EdgeSpace for the current **e_type**. + :type edata: dgl.view.EdgeSpace + :param edata_size: The size of **edata**. + :param edata_size: int + :param meta: The metagraph for the current **e_type**. + :type meta: Set[str] | Dict[Any, adbdgl_adapter.typings.DGLMetagraphValues] + :param from_nodes: Tensor representing the Source Nodes of the **e_type**. + :type from_nodes: torch.Tensor + :param to_nodes: Tensor representing the Destination Nodes of the **e_type**. + :type to_nodes: torch.Tensor + :param is_explicit_metagraph: Take the metagraph at face value or not. + :type is_explicit_metagraph: bool + :param is_custom_controller: Whether a custom controller is used. + :type is_custom_controller: bool + :param start_index: The start index of the current batch. + :type start_index: int + :param end_index: The end index of the current batch. + :type end_index: int + :return: The ArangoDB DataFrame representing the DGL Edge batch. + :rtype: pandas.DataFrame """ - valid_meta: Dict[str, ADBMetagraphValues] - valid_meta = meta if type(meta) is dict else {m: m for m in meta} + from_col, _, to_col = e_type - for k, v in valid_meta.items(): - t = self.__build_tensor_from_dataframe(df, k, v) - dgl_data[k][data_type] = cat((dgl_data[k][data_type], t)) + # 1. Map the DGL edges to ArangoDB _from & _to IDs + data = zip( + *( + from_nodes[start_index:end_index].tolist(), + to_nodes[start_index:end_index].tolist(), + ) + ) - def __link_dgl_data( - self, - dgl_data: Union[HeteroNodeDataView, HeteroEdgeDataView], - dgl_data_temp: DGLData, - has_one_type: bool, - ) -> None: - """Links **dgl_data_temp** to **dgl_data**. This method is (unfortunately) - required, since a dgl graph's `ndata` and `edata` properties can't be - manually set (i.e `g.ndata = ndata` is not possible). + # 2. Set the ArangoDB Edge Data + df = self.__set_adb_data( + DataFrame( + data, + index=range(start_index, end_index), + columns=["_from", "_to"], + ), + meta, + edata, + edata_size, + is_explicit_metagraph, + start_index, + end_index, + ) - :param dgl_data: The (empty) ndata or edata instance attribute of a dgl graph, - which is about to receive **dgl_data_temp**. - :type dgl_data: Union[dgl.view.HeteroNodeDataView, dgl.view.HeteroEdgeDataView] - :param dgl_data_temp: A temporary place to store the ndata or edata features. - :type dgl_data_temp: adbdgl_adapter.typings.DGLData - :param has_one_type: Set to True if the DGL graph only has one - node type or edge type. - :type has_one_type: bool - """ - for feature_name, feature_map in dgl_data_temp.items(): - for data_type, dgl_tensor in feature_map.items(): - dgl_data[feature_name] = ( - dgl_tensor if has_one_type else {data_type: dgl_tensor} - ) + df["_from"] = from_col + "/" + df["_from"].astype(str) + df["_to"] = to_col + "/" + df["_to"].astype(str) + + # 3. Apply the ArangoDB Edge Controller (if provided) + if is_custom_controller: + f = lambda e: self.__cntrl._prepare_dgl_edge(e, e_type) + df = df.apply(f, axis=1) + + return df def __set_adb_data( self, @@ -948,9 +1273,9 @@ def __set_adb_data( meta: Union[Set[str], Dict[Any, DGLMetagraphValues]], dgl_data: Union[NodeSpace, EdgeSpace], dgl_data_size: int, + is_explicit_metagraph: bool, start_index: int, end_index: int, - is_explicit_metagraph: bool, ) -> DataFrame: """A helper method to build the ArangoDB Dataframe for the given collection. Is responsible for creating "sub-DataFrames" from DGL tensors, @@ -970,12 +1295,12 @@ def __set_adb_data( :param dgl_data_size: The size of the NodeStorage or EdgeStorage of the current DGL node or edge type. :type dgl_data_size: int + :param is_explicit_metagraph: Take the metagraph at face value or not. + :type is_explicit_metagraph: bool :param start_index: The starting index of the current batch to process. :type start_index: int :param end_index: The ending index of the current batch to process. :type end_index: int - :param is_explicit_metagraph: Take the metagraph at face value or not. - :type is_explicit_metagraph: bool :return: The completed DataFrame for the (soon-to-be) ArangoDB collection. :rtype: pandas.DataFrame :raise ValueError: If an unsupported DGL data value is found. @@ -1005,59 +1330,6 @@ def __set_adb_data( return df - def __build_tensor_from_dataframe( - self, - adb_df: DataFrame, - meta_key: str, - meta_val: ADBMetagraphValues, - ) -> Tensor: - """Constructs a DGL-ready Tensor from a Pandas Dataframe, based on - the nature of the user-defined metagraph. - - :param adb_df: The Pandas Dataframe representing ArangoDB data. - :type adb_df: pandas.DataFrame - :param meta_key: The current ArangoDB-DGL metagraph key - :type meta_key: str - :param meta_val: The value mapped to **meta_key** to - help convert **df** into a DGL-ready Tensor. - e.g the value of `metagraph['vertexCollections']['users']['x']`. - :type meta_val: adbdgl_adapter.typings.ADBMetagraphValues - :return: A DGL-ready tensor equivalent to the dataframe - :rtype: torch.Tensor - :raise adbdgl_adapter.exceptions.ADBMetagraphError: If invalid **meta_val**. - """ - logger.debug( - f"__build_tensor_from_dataframe(df, '{meta_key}', {type(meta_val)})" - ) - - if type(meta_val) is str: - return tensor(adb_df[meta_val].to_list()) - - if type(meta_val) is dict: - data = [] - for attr, encoder in meta_val.items(): - if encoder is None: - data.append(tensor(adb_df[attr].to_list())) - elif callable(encoder): - data.append(encoder(adb_df[attr])) - else: # pragma: no cover - msg = f"Invalid encoder for ArangoDB attribute '{attr}': {encoder}" - raise ADBMetagraphError(msg) - - return cat(data, dim=-1) - - if callable(meta_val): - # **meta_val** is a user-defined that returns a tensor - user_defined_result = meta_val(adb_df) - - if type(user_defined_result) is not Tensor: # pragma: no cover - msg = f"Invalid return type for function {meta_val} ('{meta_key}')" - raise ADBMetagraphError(msg) - - return user_defined_result - - raise ADBMetagraphError(f"Invalid {meta_val} type") # pragma: no cover - def __build_dataframe_from_tensor( self, dgl_tensor: Tensor, @@ -1130,3 +1402,39 @@ def __build_dataframe_from_tensor( return user_defined_result raise DGLMetagraphError(f"Invalid {meta_val} type") # pragma: no cover + + def __insert_adb_docs( + self, + spinner_progress: Progress, + df: DataFrame, + col: str, + use_async: bool, + **adb_import_kwargs: Any, + ) -> None: + """DGL -> ArangoDB: Insert ArangoDB documents into their ArangoDB collection. + + :param spinner_progress: The spinner progress bar. + :type spinner_progress: rich.progress.Progress + :param df: To-be-inserted ArangoDB documents, formatted as a DataFrame + :type df: pandas.DataFrame + :param col: The ArangoDB collection name. + :type col: str + :param use_async: Performs asynchronous ArangoDB ingestion if enabled. + :type use_async: bool + :param adb_import_kwargs: Keyword arguments to specify additional + parameters for ArangoDB document insertion. Full parameter list: + https://docs.python-arango.com/en/main/specs.html#arango.collection.Collection.import_bulk + :param adb_import_kwargs: Any + """ + action = f"ADB Import: '{col}' ({len(df)})" + spinner_progress_task = spinner_progress.add_task("", action=action) + + docs = df.to_dict("records") + db = self.__async_db if use_async else self.__db + result = db.collection(col).import_bulk(docs, **adb_import_kwargs) + logger.debug(result) + + df.drop(df.index, inplace=True) + + spinner_progress.stop_task(spinner_progress_task) + spinner_progress.update(spinner_progress_task, visible=False) diff --git a/adbdgl_adapter/utils.py b/adbdgl_adapter/utils.py index cd5e4b3..b88dc73 100644 --- a/adbdgl_adapter/utils.py +++ b/adbdgl_adapter/utils.py @@ -2,7 +2,14 @@ import os from typing import Any, Dict, Set, Union -from rich.progress import Progress, SpinnerColumn, TextColumn, TimeElapsedColumn +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) from .exceptions import ADBMetagraphError, DGLMetagraphError @@ -16,18 +23,34 @@ logger.addHandler(handler) -def progress( +def get_export_spinner_progress( text: str, - text_style: str = "none", - spinner_name: str = "aesthetic", - spinner_style: str = "#5BC0DE", - transient: bool = False, ) -> Progress: return Progress( - TextColumn(text, style=text_style), - SpinnerColumn(spinner_name, spinner_style), + TextColumn(text), + SpinnerColumn("aesthetic", "#5BC0DE"), + TimeElapsedColumn(), + transient=True, + ) + + +def get_import_spinner_progress(text: str) -> Progress: + return Progress( + TextColumn(text), + TextColumn("{task.fields[action]}"), + SpinnerColumn("aesthetic", "#5BC0DE"), + TimeElapsedColumn(), + transient=True, + ) + + +def get_bar_progress(text: str, color: str) -> Progress: + return Progress( + TextColumn(text), + BarColumn(complete_style=color, finished_style=color), + TaskProgressColumn(), + TextColumn("({task.completed}/{task.total})"), TimeElapsedColumn(), - transient=transient, ) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 51db111..4d913e4 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -237,7 +237,7 @@ def test_validate_dgl_metagraph(bad_metagraph: Dict[Any, Any]) -> None: @pytest.mark.parametrize( "adapter, name, dgl_g, metagraph, \ - explicit_metagraph, overwrite_graph, batch_size, import_options", + explicit_metagraph, overwrite_graph, batch_size, adb_import_kwargs", [ ( adbdgl_adapter, @@ -356,7 +356,7 @@ def test_dgl_to_adb( explicit_metagraph: bool, overwrite_graph: bool, batch_size: Optional[int], - import_options: Any, + adb_import_kwargs: Any, ) -> None: db.delete_graph(name, drop_collections=True, ignore_missing=True) adapter.dgl_to_arangodb( @@ -366,7 +366,7 @@ def test_dgl_to_adb( explicit_metagraph, overwrite_graph, batch_size, - **import_options + **adb_import_kwargs ) assert_dgl_to_adb(name, dgl_g, metagraph, explicit_metagraph) db.delete_graph(name, drop_collections=True)