diff --git a/adbpyg_adapter/adapter.py b/adbpyg_adapter/adapter.py index 5a79c31..6ceb2c4 100644 --- a/adbpyg_adapter/adapter.py +++ b/adbpyg_adapter/adapter.py @@ -2,9 +2,11 @@ # -*- coding: utf-8 -*- import logging from collections import defaultdict -from typing import Any, DefaultDict, Dict, List, Set, Union +from math import ceil +from typing import Any, DefaultDict, Dict, List, Optional, Set, Union import torch +from arango.cursor import Cursor from arango.database import Database from arango.graph import Graph as ADBGraph from pandas import DataFrame, Series @@ -259,65 +261,91 @@ def udf_v1_x(v1_df): for v_col, meta in metagraph["vertexCollections"].items(): logger.debug(f"Preparing '{v_col}' vertices") - df = self.__fetch_adb_docs(v_col, meta == {}, query_options) - adb_map[v_col] = { - adb_id: pyg_id for pyg_id, adb_id in enumerate(df["_key"]) - } - node_data: NodeStorage = data if is_homogeneous else data[v_col] - self.__set_pyg_data(meta, node_data, df) if preserve_adb_keys: - k = "_v_key" if is_homogeneous else "_key" - node_data[k] = list(adb_map[v_col].keys()) + preserve_key = "_v_key" if is_homogeneous else "_key" + node_data[preserve_key] = [] + + pyg_id = 0 + cursor = self.__fetch_adb_docs(v_col, meta, query_options) + while not cursor.empty(): + cursor_batch = len(cursor.batch()) # type: ignore + df = DataFrame([cursor.pop() for _ in range(cursor_batch)]) + + for adb_id in df["_key"]: + adb_map[v_col][adb_id] = pyg_id + pyg_id += 1 + + self.__set_pyg_data(meta, node_data, df) + + if preserve_adb_keys: + node_data[preserve_key].extend(list(df["_key"])) + + if cursor.has_more(): + cursor.fetch() + + df.drop(df.index, inplace=True) et_df: DataFrame v_cols: List[str] = list(metagraph["vertexCollections"].keys()) for e_col, meta in metagraph.get("edgeCollections", {}).items(): logger.debug(f"Preparing '{e_col}' edges") - df = self.__fetch_adb_docs(e_col, meta == {}, query_options) - df[["from_col", "from_key"]] = self.__split_adb_ids(df["_from"]) - df[["to_col", "to_key"]] = self.__split_adb_ids(df["_to"]) + cursor = self.__fetch_adb_docs(e_col, meta, query_options) + while not cursor.empty(): + cursor_batch = len(cursor.batch()) # type: ignore + df = DataFrame([cursor.pop() for _ in range(cursor_batch)]) - for (from_col, to_col), count in ( - df[["from_col", "to_col"]].value_counts().items() - ): - edge_type = (from_col, e_col, to_col) - if from_col not in v_cols or to_col not in v_cols: - logger.debug(f"Skipping {edge_type}") - continue # partial edge collection import to pyg + df[["from_col", "from_key"]] = self.__split_adb_ids(df["_from"]) + df[["to_col", "to_key"]] = self.__split_adb_ids(df["_to"]) - logger.debug(f"Preparing {count} '{edge_type}' edges") + for (from_col, to_col), count in ( + df[["from_col", "to_col"]].value_counts().items() + ): + edge_type = (from_col, e_col, to_col) + edge_data: EdgeStorage = data if is_homogeneous else data[edge_type] - # Get the edge data corresponding to the current edge type - et_df = df[(df["from_col"] == from_col) & (df["to_col"] == to_col)] - adb_map[edge_type] = { - adb_id: pyg_id for pyg_id, adb_id in enumerate(et_df["_key"]) - } + if from_col not in v_cols or to_col not in v_cols: + logger.debug(f"Skipping {edge_type}") + continue # partial edge collection import to pyg - from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist() - to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist() + logger.debug(f"Preparing {count} '{edge_type}' edges") - edge_data: EdgeStorage = data if is_homogeneous else data[edge_type] - edge_data.edge_index = tensor([from_nodes, to_nodes]) + et_df = df[(df["from_col"] == from_col) & (df["to_col"] == to_col)] - if torch.any(torch.isnan(edge_data.edge_index)): - if strict: - raise InvalidADBEdgesError( - f"Invalid edges found in Edge Collection {e_col}, {from_col} -> {to_col}." # noqa: E501 - ) - else: - # Remove the invalid edges - edge_data.edge_index = edge_data.edge_index[ - :, ~torch.any(edge_data.edge_index.isnan(), dim=0) - ] + from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist() + to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist() + edge_index = tensor([from_nodes, to_nodes]) - self.__set_pyg_data(meta, edge_data, et_df) + edge_data.edge_index = torch.cat( + (edge_data.get("edge_index", tensor([])), edge_index), dim=1 + ) - if preserve_adb_keys: - k = "_e_key" if is_homogeneous else "_key" - edge_data[k] = list(adb_map[edge_type].keys()) + if torch.any(torch.isnan(edge_data.edge_index)): + if strict: + raise InvalidADBEdgesError( + f"Invalid edges found in Edge Collection {e_col}, {from_col} -> {to_col}." # noqa: E501 + ) + else: + # Remove the invalid edges + edge_data.edge_index = edge_data.edge_index[ + :, ~torch.any(edge_data.edge_index.isnan(), dim=0) + ] + + self.__set_pyg_data(meta, edge_data, et_df) + + if preserve_adb_keys: + preserve_key = "_e_key" if is_homogeneous else "_key" + if preserve_key not in edge_data: + edge_data[preserve_key] = [] + + edge_data[preserve_key].extend(list(et_df["_key"])) + + if cursor.has_more(): + cursor.fetch() + + df.drop(df.index, inplace=True) logger.info(f"Created PyG '{name}' Graph") return data @@ -416,6 +444,7 @@ def pyg_to_arangodb( metagraph: PyGMetagraph = {}, explicit_metagraph: bool = True, overwrite_graph: bool = False, + batch_size: Optional[int] = None, **import_options: Any, ) -> ADBGraph: """Create an ArangoDB graph from a PyG graph. @@ -456,6 +485,10 @@ def pyg_to_arangodb( :param overwrite_graph: Overwrites the graph if it already exists. Does not drop associated collections. Defaults to False. :type overwrite_graph: bool + :param batch_size: Process the PyG Nodes & Edges in batches of size + **batch_size**. Defaults to `None`, which processes each + NodeStorage & EdgeStorage in one batch. + :type batch_size: int :param import_options: Keyword arguments to specify additional parameters for ArangoDB document insertion. Full parameter list: https://docs.python-arango.com/en/main/specs.html#arango.collection.Collection.import_bulk @@ -503,6 +536,7 @@ def y_tensor_to_2_column_dataframe(pyg_tensor): logger.debug(f"--pyg_to_arangodb('{name}')--") validate_pyg_metagraph(metagraph) + is_custom_controller = type(self.__cntrl) is not ADBPyG_Controller is_homogeneous = type(pyg_g) is Data if is_homogeneous and pyg_g.num_nodes == pyg_g.num_edges and not metagraph: @@ -526,9 +560,9 @@ def y_tensor_to_2_column_dataframe(pyg_tensor): edge_types = metagraph.get("edgeTypes", {}).keys() # type: ignore elif is_homogeneous: - n_type = name + "_N" + n_type = f"{name}_N" node_types = [n_type] - edge_types = [(n_type, name + "_E", n_type)] + edge_types = [(n_type, f"{name}_E", n_type)] else: node_types = pyg_g.node_types @@ -553,53 +587,89 @@ def y_tensor_to_2_column_dataframe(pyg_tensor): n_meta = metagraph.get("nodeTypes", {}) for n_type in node_types: - node_data = pyg_g if is_homogeneous else pyg_g[n_type] - meta = n_meta.get(n_type, {}) - empty_df = DataFrame(index=range(node_data.num_nodes)) - df = self.__set_adb_data(empty_df, meta, node_data, explicit_metagraph) - if "_id" in df: - pyg_map[n_type] = df["_id"].to_dict() - else: - if "_key" not in df: - df["_key"] = df.index.astype(str) + node_data = pyg_g if is_homogeneous else pyg_g[n_type] + node_data_batch_size = batch_size or node_data.num_nodes + + start_index = 0 + end_index = min(node_data_batch_size, node_data.num_nodes) + batches = ceil(node_data.num_nodes / node_data_batch_size) + + for _ in range(batches): + df = self.__set_adb_data( + DataFrame(index=range(start_index, end_index)), + meta, + node_data, + node_data.num_nodes, + start_index, + end_index, + explicit_metagraph, + ) + + if "_id" in df: + pyg_map[n_type].update(df["_id"].to_dict()) + else: + df["_key"] = df.get("_key", df.index.astype(str)) + pyg_map[n_type].update((n_type + "/" + df["_key"]).to_dict()) - pyg_map[n_type] = (n_type + "/" + df["_key"]).to_dict() + if is_custom_controller: + f = lambda n: self.__cntrl._prepare_pyg_node(n, n_type) + df = df.apply(f, axis=1) - if type(self.__cntrl) is not ADBPyG_Controller: - f = lambda n: self.__cntrl._prepare_pyg_node(n, n_type) - df = df.apply(f, axis=1) + self.__insert_adb_docs(n_type, df, import_options) - self.__insert_adb_docs(n_type, df, import_options) + start_index = end_index + end_index = min(end_index + node_data_batch_size, node_data.num_nodes) e_meta = metagraph.get("edgeTypes", {}) for e_type in edge_types: - edge_data = pyg_g if is_homogeneous else pyg_g[e_type] + meta = e_meta.get(e_type, {}) src_n_type, _, dst_n_type = e_type - columns = ["_from", "_to"] - meta = e_meta.get(e_type, {}) - df = DataFrame(zip(*(edge_data.edge_index.tolist())), columns=columns) - df = self.__set_adb_data(df, meta, edge_data, explicit_metagraph) + edge_data = pyg_g if is_homogeneous else pyg_g[e_type] + edge_data_batch_size = batch_size or edge_data.num_edges + + start_index = 0 + end_index = min(edge_data_batch_size, edge_data.num_edges) + batches = ceil(edge_data.num_edges / edge_data_batch_size) + + for _ in range(batches): + edge_index = edge_data.edge_index[:, start_index:end_index] + df = self.__set_adb_data( + DataFrame( + zip(*(edge_index.tolist())), + index=range(start_index, end_index), + columns=["_from", "_to"], + ), + meta, + edge_data, + edge_data.num_edges, + start_index, + end_index, + explicit_metagraph, + ) - df["_from"] = ( - df["_from"].map(pyg_map[src_n_type]) - if pyg_map[src_n_type] - else src_n_type + "/" + df["_from"].astype(str) - ) + df["_from"] = ( + df["_from"].map(pyg_map[src_n_type]) + if pyg_map[src_n_type] + else src_n_type + "/" + df["_from"].astype(str) + ) - df["_to"] = ( - df["_to"].map(pyg_map[dst_n_type]) - if pyg_map[dst_n_type] - else dst_n_type + "/" + df["_to"].astype(str) - ) + df["_to"] = ( + df["_to"].map(pyg_map[dst_n_type]) + if pyg_map[dst_n_type] + else dst_n_type + "/" + df["_to"].astype(str) + ) + + if is_custom_controller: + f = lambda e: self.__cntrl._prepare_pyg_edge(e, e_type) + df = df.apply(f, axis=1) - if type(self.__cntrl) is not ADBPyG_Controller: - f = lambda e: self.__cntrl._prepare_pyg_edge(e, e_type) - df = df.apply(f, axis=1) + self.__insert_adb_docs(e_type, df, import_options) - self.__insert_adb_docs(e_type, df, import_options) + start_index = end_index + end_index = min(end_index + edge_data_batch_size, edge_data.num_edges) logger.info(f"Created ArangoDB '{name}' Graph") return adb_graph @@ -671,31 +741,53 @@ def ntypes_to_ocollections( return list(orphan_collections) def __fetch_adb_docs( - self, col: str, meta_is_empty: bool, query_options: Any - ) -> DataFrame: + self, + col: str, + meta: Union[Set[str], Dict[str, ADBMetagraphValues]], + query_options: Any, + ) -> Cursor: """Fetches ArangoDB documents within a collection. Returns the documents in a DataFrame. :param col: The ArangoDB collection. :type col: str - :param meta_is_empty: Set to True if the metagraph specification - for **col** is empty. - :type meta_is_empty: bool + :param meta: The MetaGraph associated to **col** + :type meta: Set[str] | Dict[str, adbpyg_adapter.typings.ADBMetagraphValues] :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. :type query_options: Any :return: A DataFrame representing the ArangoDB documents. :rtype: pandas.DataFrame """ - # Only return the entire document if **meta_is_empty** is False - aql = f""" - FOR doc IN @@col - RETURN { - "{ _key: doc._key, _from: doc._from, _to: doc._to }" - if meta_is_empty - else "doc" - } - """ + + def get_aql_return_value( + meta: Union[Set[str], Dict[str, ADBMetagraphValues]] + ) -> str: + """Helper method to formulate the AQL `RETURN` value based on + the document attributes specified in **meta** + """ + attributes = [] + + if type(meta) is set: + attributes = list(meta) + + elif type(meta) is dict: + for value in meta.values(): + if type(value) is str: + attributes.append(value) + elif type(value) is dict: + attributes.extend(list(value.keys())) + elif callable(value): + # Cannot determine which attributes to extract if UDFs are used + # Therefore we just return the entire document + return "doc" + + return f""" + MERGE( + {{ _key: doc._key, _from: doc._from, _to: doc._to }}, + KEEP(doc, {list(attributes)}) + ) + """ with progress( f"(ADB → PyG): {col}", @@ -703,11 +795,10 @@ def __fetch_adb_docs( spinner_style="#40A6F5", ) as p: p.add_task("__fetch_adb_docs") - - return DataFrame( - self.__db.aql.execute( - aql, count=True, bind_vars={"@col": col}, **query_options - ) + return self.__db.aql.execute( # type: ignore + f"FOR doc IN @@col RETURN {get_aql_return_value(meta)}", + bind_vars={"@col": col}, + **{**{"stream": True}, **query_options}, ) def __insert_adb_docs( @@ -735,6 +826,7 @@ def __insert_adb_docs( docs = df.to_dict("records") result = self.__db.collection(col).import_bulk(docs, **kwargs) logger.debug(result) + df.drop(df.index, inplace=True) def __split_adb_ids(self, s: Series) -> Series: """Helper method to split the ArangoDB IDs within a Series into two columns""" @@ -764,13 +856,24 @@ def __set_pyg_data( valid_meta = meta if type(meta) is dict else {m: m for m in meta} for k, v in valid_meta.items(): - pyg_data[k] = self.__build_tensor_from_dataframe(df, k, v) + t = self.__build_tensor_from_dataframe(df, k, v) + + if k not in pyg_data: + pyg_data[k] = t + elif isinstance(pyg_data[k], Tensor): + pyg_data[k] = cat((pyg_data[k], t)) + else: + m = f"'{k}' key in PyG Data must point to a Tensor" + raise TypeError(m) def __set_adb_data( self, df: DataFrame, meta: Union[Set[str], Dict[Any, PyGMetagraphValues]], pyg_data: Union[Data, NodeStorage, EdgeStorage], + pyg_data_size: int, + start_index: int, + end_index: int, explicit_metagraph: bool, ) -> DataFrame: """A helper method to build the ArangoDB Dataframe for the given @@ -788,6 +891,14 @@ def __set_adb_data( :param pyg_data: The NodeStorage or EdgeStorage of the current PyG node or edge type. :type pyg_data: torch_geometric.data.storage.(NodeStorage | EdgeStorage) + :param pyg_data_size: The size of the NodeStorage or EdgeStorage of the + current PyG node or edge type. + :type pyg_data_size: int + :param start_index: The starting index of the current batch to process. + :type start_index: int + :param end_index: The ending index of the current batch to process. + :type end_index: int + :type pyg_data: torch_geometric.data.storage.(NodeStorage | EdgeStorage) :param explicit_metagraph: The value of **explicit_metagraph** in **pyg_to_arangodb**. :type explicit_metagraph: bool @@ -805,24 +916,34 @@ def __set_adb_data( if explicit_metagraph: pyg_keys = set(valid_meta.keys()) else: - # can't do keys() (not compatible with Homogeneous graphs) + # can't do pyg_data.keys() (not compatible with Homogeneous graphs) pyg_keys = set(k for k, _ in pyg_data.items()) - for k in pyg_keys: - if k == "edge_index": + for meta_key in pyg_keys: + if meta_key == "edge_index": continue - data = pyg_data[k] - meta_val = valid_meta.get(k, str(k)) - - if type(meta_val) is str and type(data) is list and len(data) == len(df): - if meta_val in ["_v_key", "_e_key"]: # Homogeneous situation - meta_val = "_key" + data = pyg_data[meta_key] + meta_val = valid_meta.get(meta_key, str(meta_key)) - df = df.join(DataFrame(data, columns=[meta_val])) - - if type(data) is Tensor and len(data) == len(df): - df = df.join(self.__build_dataframe_from_tensor(data, k, meta_val)) + if ( + type(meta_val) is str + and type(data) is list + and len(data) == pyg_data_size + ): + meta_val = "_key" if meta_val in ["_v_key", "_e_key"] else meta_val + df = df.join(DataFrame(data[start_index:end_index], columns=[meta_val])) + + if type(data) is Tensor and len(data) == pyg_data_size: + df = df.join( + self.__build_dataframe_from_tensor( + data[start_index:end_index], + start_index, + end_index, + meta_key, + meta_val, + ) + ) return df @@ -868,7 +989,7 @@ def __build_tensor_from_dataframe( return cat(data, dim=-1) if callable(meta_val): - # **meta_val** is a user-defined that returns a tensor + # **meta_val** is a user-defined function that returns a tensor user_defined_result = meta_val(adb_df) if type(user_defined_result) is not Tensor: # pragma: no cover @@ -882,6 +1003,8 @@ def __build_tensor_from_dataframe( def __build_dataframe_from_tensor( self, pyg_tensor: Tensor, + start_index: int, + end_index: int, meta_key: Any, meta_val: PyGMetagraphValues, ) -> DataFrame: @@ -890,6 +1013,10 @@ def __build_dataframe_from_tensor( :param pyg_tensor: The Tensor representing PyG data. :type pyg_tensor: torch.Tensor + :param start_index: The starting index of the current batch to process. + :type start_index: int + :param end_index: The ending index of the current batch to process. + :type end_index: int :param meta_key: The current PyG-ArangoDB metagraph key :type meta_key: Any :param meta_val: The value mapped to the PyG-ArangoDB metagraph key to @@ -905,7 +1032,7 @@ def __build_dataframe_from_tensor( ) if type(meta_val) is str: - df = DataFrame(columns=[meta_val]) + df = DataFrame(index=range(start_index, end_index), columns=[meta_val]) df[meta_val] = pyg_tensor.tolist() return df @@ -919,7 +1046,7 @@ def __build_dataframe_from_tensor( """ raise PyGMetagraphError(msg) - df = DataFrame(columns=meta_val) + df = DataFrame(index=range(start_index, end_index), columns=meta_val) df[meta_val] = pyg_tensor.tolist() return df diff --git a/adbpyg_adapter/typings.py b/adbpyg_adapter/typings.py index a19cbaa..dbd7cbc 100644 --- a/adbpyg_adapter/typings.py +++ b/adbpyg_adapter/typings.py @@ -24,7 +24,9 @@ PyGDataTypes = Union[str, Tuple[str, str, str]] PyGMetagraphValues = Union[str, List[str], TensorToDataFrame] -PyGMetagraph = Dict[str, Dict[PyGDataTypes, Dict[Any, PyGMetagraphValues]]] +PyGMetagraph = Dict[ + str, Dict[PyGDataTypes, Union[Set[str], Dict[Any, PyGMetagraphValues]]] +] ADBMap = DefaultDict[PyGDataTypes, Dict[str, int]] PyGMap = DefaultDict[PyGDataTypes, Dict[int, str]] diff --git a/tests/conftest.py b/tests/conftest.py index 50dffd6..816c099 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,6 @@ from arango import ArangoClient from arango.database import StandardDatabase -from arango.http import DefaultHTTPClient from pandas import DataFrame from torch import Tensor, tensor from torch_geometric.data import Data, HeteroData @@ -45,11 +44,8 @@ def pytest_configure(config: Any) -> None: print("Database: " + con["dbName"]) print("----------------------------------------") - class NoTimeoutHTTPClient(DefaultHTTPClient): - REQUEST_TIMEOUT = None # type: ignore - global db - db = ArangoClient(hosts=con["url"], http_client=NoTimeoutHTTPClient()).db( + db = ArangoClient(hosts=con["url"]).db( con["dbName"], con["username"], con["password"], verify=True ) diff --git a/tests/test_adapter.py b/tests/test_adapter.py index d1ddf27..ead5bf3 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -233,7 +233,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: @pytest.mark.parametrize( "adapter, name, pyg_g, metagraph, \ - explicit_metagraph, overwrite_graph, import_options", + explicit_metagraph, overwrite_graph, batch_size, import_options", [ ( adbpyg_adapter, @@ -242,6 +242,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: {"nodeTypes": {"Karate_1_N": {"x": "node_features"}}}, False, False, + 33, {}, ), ( @@ -251,6 +252,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: {"nodeTypes": {"Karate_2_N": {"x": "node_features"}}}, True, False, + 1000, {}, ), ( @@ -260,6 +262,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: {"nodeTypes": {"FakeHomoGraph_1_N": {"y": "label"}}}, False, False, + 1, {}, ), ( @@ -269,6 +272,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: {}, False, False, + 1000, {}, ), ( @@ -278,6 +282,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: {}, True, False, + None, {}, ), ( @@ -294,6 +299,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: }, True, False, + None, {}, ), ( @@ -307,6 +313,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: }, True, False, + None, {}, ), ( @@ -316,6 +323,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: {}, False, False, + 1, {}, ), ( @@ -325,15 +333,17 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: {"nodeTypes": {"v2": {"x": udf_v2_x_tensor_to_df}}}, True, False, + 1000, {}, ), ( adbpyg_adapter, - "FakeHeteroGraph_2", + "FakeHeteroGraph_3", get_fake_hetero_graph(avg_num_nodes=2), {"nodeTypes": {"v0": {"x", "y"}, "v2": {"x"}}}, True, False, + None, {}, ), ( @@ -343,6 +353,7 @@ def test_validate_pyg_metagraph(bad_metagraph: Dict[Any, Any]) -> None: {"nodeTypes": {"user": {"x": ["age", "gender"]}}}, False, True, + None, {}, ), ], @@ -354,11 +365,18 @@ def test_pyg_to_adb( metagraph: PyGMetagraph, explicit_metagraph: bool, overwrite_graph: bool, + batch_size: Optional[int], import_options: Any, ) -> None: db.delete_graph(name, drop_collections=True, ignore_missing=True) adapter.pyg_to_arangodb( - name, pyg_g, metagraph, explicit_metagraph, overwrite_graph, **import_options + name, + pyg_g, + metagraph, + explicit_metagraph, + overwrite_graph, + batch_size, + **import_options, ) assert_pyg_to_adb(name, pyg_g, metagraph, explicit_metagraph) db.delete_graph(name, drop_collections=True) @@ -389,7 +407,7 @@ def test_pyg_to_arangodb_with_controller() -> None: @pytest.mark.parametrize( - "adapter, name, metagraph, pyg_g_old", + "adapter, name, metagraph, pyg_g_old, batch_size", [ ( adbpyg_adapter, @@ -403,6 +421,7 @@ def test_pyg_to_arangodb_with_controller() -> None: }, }, get_karate_graph(), + 33, ), ( adbpyg_adapter, @@ -416,6 +435,7 @@ def test_pyg_to_arangodb_with_controller() -> None: }, }, get_fake_homo_graph(avg_num_nodes=3, edge_dim=1), + 1, ), ( adbpyg_adapter, @@ -431,6 +451,7 @@ def test_pyg_to_arangodb_with_controller() -> None: }, }, get_fake_hetero_graph(avg_num_nodes=2, edge_dim=2), + 1000, ), ( adbpyg_adapter, @@ -446,6 +467,7 @@ def test_pyg_to_arangodb_with_controller() -> None: }, }, get_fake_hetero_graph(avg_num_nodes=2, edge_dim=2), + None, ), ( adbpyg_adapter, @@ -461,6 +483,7 @@ def test_pyg_to_arangodb_with_controller() -> None: }, }, get_fake_hetero_graph(avg_num_nodes=2, edge_dim=2), + None, ), ( adbpyg_adapter, @@ -479,6 +502,7 @@ def test_pyg_to_arangodb_with_controller() -> None: }, }, get_fake_hetero_graph(avg_num_nodes=2, edge_dim=2), + None, ), ], ) @@ -487,12 +511,13 @@ def test_adb_to_pyg( name: str, metagraph: ADBMetagraph, pyg_g_old: Optional[Union[Data, HeteroData]], + batch_size: Optional[int], ) -> None: if pyg_g_old: db.delete_graph(name, drop_collections=True, ignore_missing=True) adapter.pyg_to_arangodb(name, pyg_g_old) - pyg_g_new = adapter.arangodb_to_pyg(name, metagraph) + pyg_g_new = adapter.arangodb_to_pyg(name, metagraph, batch_size=batch_size) assert_adb_to_pyg(pyg_g_new, metagraph) if pyg_g_old: @@ -850,9 +875,9 @@ def test_full_cycle_imdb_with_preserve_adb_keys() -> None: pyg_to_adb_metagraph: PyGMetagraph = { "nodeTypes": { "Users": {"x": ["Age", "Gender"], "_key": "_key"}, - "Movies": {"_id": "_id"}, + "Movies": {"_id"}, # Note: we can either use _id or _key here }, - "edgeTypes": {("Users", "Ratings", "Movies"): {"_key": "_key"}}, + "edgeTypes": {("Users", "Ratings", "Movies"): {"_key"}}, } adbpyg_adapter.pyg_to_arangodb(