diff --git a/adbpyg_adapter/adapter.py b/adbpyg_adapter/adapter.py index 6069d62..5a79c31 100644 --- a/adbpyg_adapter/adapter.py +++ b/adbpyg_adapter/adapter.py @@ -4,6 +4,7 @@ from collections import defaultdict from typing import Any, DefaultDict, Dict, List, Set, Union +import torch from arango.database import Database from arango.graph import Graph as ADBGraph from pandas import DataFrame, Series @@ -14,7 +15,7 @@ from .abc import Abstract_ADBPyG_Adapter from .controller import ADBPyG_Controller -from .exceptions import ADBMetagraphError, PyGMetagraphError +from .exceptions import ADBMetagraphError, InvalidADBEdgesError, PyGMetagraphError from .typings import ( ADBMap, ADBMetagraph, @@ -79,6 +80,7 @@ def arangodb_to_pyg( name: str, metagraph: ADBMetagraph, preserve_adb_keys: bool = False, + strict: bool = True, **query_options: Any, ) -> Union[Data, HeteroData]: """Create a PyG graph from ArangoDB data. DOES carry @@ -127,6 +129,9 @@ def arangodb_to_pyg( ArangoDB graph is Heterogeneous, the ArangoDB keys will be preserved under `_key` in your PyG graph. :type preserve_adb_keys: bool + :param strict: Set fault tolerance when loading a graph from ArangoDB. If set + to false, this will ignore invalid edges (e.g. dangling/half edges). + :type strict: bool :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. Full parameter list: https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute @@ -296,6 +301,18 @@ def udf_v1_x(v1_df): edge_data: EdgeStorage = data if is_homogeneous else data[edge_type] edge_data.edge_index = tensor([from_nodes, to_nodes]) + + if torch.any(torch.isnan(edge_data.edge_index)): + if strict: + raise InvalidADBEdgesError( + f"Invalid edges found in Edge Collection {e_col}, {from_col} -> {to_col}." # noqa: E501 + ) + else: + # Remove the invalid edges + edge_data.edge_index = edge_data.edge_index[ + :, ~torch.any(edge_data.edge_index.isnan(), dim=0) + ] + self.__set_pyg_data(meta, edge_data, et_df) if preserve_adb_keys: @@ -654,28 +671,28 @@ def ntypes_to_ocollections( return list(orphan_collections) def __fetch_adb_docs( - self, col: str, empty_meta: bool, query_options: Any + self, col: str, meta_is_empty: bool, query_options: Any ) -> DataFrame: """Fetches ArangoDB documents within a collection. Returns the documents in a DataFrame. :param col: The ArangoDB collection. :type col: str - :param empty_meta: Set to True if the metagraph specification + :param meta_is_empty: Set to True if the metagraph specification for **col** is empty. - :type empty_meta: bool + :type meta_is_empty: bool :param query_options: Keyword arguments to specify AQL query options when fetching documents from the ArangoDB instance. :type query_options: Any :return: A DataFrame representing the ArangoDB documents. :rtype: pandas.DataFrame """ - # Only return the entire document if **empty_meta** is False + # Only return the entire document if **meta_is_empty** is False aql = f""" FOR doc IN @@col RETURN { "{ _key: doc._key, _from: doc._from, _to: doc._to }" - if empty_meta + if meta_is_empty else "doc" } """ diff --git a/adbpyg_adapter/exceptions.py b/adbpyg_adapter/exceptions.py index ed5a90a..cb5c81c 100644 --- a/adbpyg_adapter/exceptions.py +++ b/adbpyg_adapter/exceptions.py @@ -17,3 +17,16 @@ class ADBMetagraphError(ADBPyGValidationError): class PyGMetagraphError(ADBPyGValidationError): """Invalid PyG Metagraph value""" + + +################## +# ADB -> PyG # +################## + + +class ADBPyGImportError(ADBPyGError): + """Errors on import from ArangoDB to PyG""" + + +class InvalidADBEdgesError(ADBPyGImportError): + """Invalid edges on import from ArangoDB to PyG""" diff --git a/adbpyg_adapter/typings.py b/adbpyg_adapter/typings.py index 5bb9305..a19cbaa 100644 --- a/adbpyg_adapter/typings.py +++ b/adbpyg_adapter/typings.py @@ -8,7 +8,7 @@ "PyGMap", ] -from typing import Any, Callable, DefaultDict, Dict, List, Tuple, Union +from typing import Any, Callable, DefaultDict, Dict, List, Set, Tuple, Union from pandas import DataFrame from torch import Tensor @@ -20,7 +20,7 @@ ADBEncoders = Dict[str, DataFrameToTensor] ADBMetagraphValues = Union[str, DataFrameToTensor, ADBEncoders] -ADBMetagraph = Dict[str, Dict[str, Dict[str, ADBMetagraphValues]]] +ADBMetagraph = Dict[str, Dict[str, Union[Set[str], Dict[str, ADBMetagraphValues]]]] PyGDataTypes = Union[str, Tuple[str, str, str]] PyGMetagraphValues = Union[str, List[str], TensorToDataFrame] diff --git a/tests/test_adapter.py b/tests/test_adapter.py index 509de5b..d1ddf27 100644 --- a/tests/test_adapter.py +++ b/tests/test_adapter.py @@ -10,7 +10,11 @@ from adbpyg_adapter import ADBPyG_Adapter from adbpyg_adapter.encoders import CategoricalEncoder, IdentityEncoder -from adbpyg_adapter.exceptions import ADBMetagraphError, PyGMetagraphError +from adbpyg_adapter.exceptions import ( + ADBMetagraphError, + InvalidADBEdgesError, + PyGMetagraphError, +) from adbpyg_adapter.typings import ( ADBMap, ADBMetagraph, @@ -643,6 +647,68 @@ def test_adb_graph_to_pyg( db.delete_graph(name, drop_collections=True) +@pytest.mark.parametrize("adapter", [adbpyg_adapter]) +def test_adb_graph_to_pyg_to_arangodb_with_missing_document_and_strict( + adapter: ADBPyG_Adapter, +) -> None: + name = "Karate_3" + data = get_karate_graph() + db.delete_graph(name, drop_collections=True, ignore_missing=True) + + ADBPyG_Adapter(db).pyg_to_arangodb(name, data) + + graph = db.graph(name) + v_cols: Set[str] = graph.vertex_collections() # type: ignore + edge_definitions: List[Json] = graph.edge_definitions() # type: ignore + e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions} + + for v_col in v_cols: + vertex_collection = db.collection(v_col) + vertex_collection.delete("0") + + metagraph: ADBMetagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + with pytest.raises(InvalidADBEdgesError): + adapter.arangodb_to_pyg(name, metagraph=metagraph, strict=True) + + db.delete_graph(name, drop_collections=True) + + +@pytest.mark.parametrize("adapter", [adbpyg_adapter]) +def test_adb_graph_to_pyg_to_arangodb_with_missing_document_and_permissive( + adapter: ADBPyG_Adapter, +) -> None: + name = "Karate_3" + data = get_karate_graph() + db.delete_graph(name, drop_collections=True, ignore_missing=True) + + ADBPyG_Adapter(db).pyg_to_arangodb(name, data) + + graph = db.graph(name) + v_cols: Set[str] = graph.vertex_collections() # type: ignore + edge_definitions: List[Json] = graph.edge_definitions() # type: ignore + e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions} + + for v_col in v_cols: + vertex_collection = db.collection(v_col) + vertex_collection.delete("0") + + metagraph: ADBMetagraph = { + "vertexCollections": {col: {} for col in v_cols}, + "edgeCollections": {col: {} for col in e_cols}, + } + + data = adapter.arangodb_to_pyg(name, metagraph=metagraph, strict=False) + + collection_count: int = db.collection(list(e_cols)[0]).count() # type: ignore + assert len(data.edge_index[0]) < collection_count + + db.delete_graph(name, drop_collections=True) + + def test_full_cycle_imdb_without_preserve_adb_keys() -> None: name = "imdb" db.delete_graph(name, drop_collections=True, ignore_missing=True) @@ -1000,7 +1066,7 @@ def assert_adb_to_pyg( def assert_adb_to_pyg_meta( - meta: Union[str, Dict[str, ADBMetagraphValues]], + meta: Union[str, Set[str], Dict[str, ADBMetagraphValues]], df: DataFrame, pyg_data: Union[NodeStorage, EdgeStorage], ) -> None: