Skip to content

Commit

Permalink
Merge pull request #17 from arangoml/feature/exceptions
Browse files Browse the repository at this point in the history
Add Strict Parameter & Handle Invalid Edges
  • Loading branch information
geenen124 committed Jul 18, 2023
2 parents b65d70d + b452084 commit c096ccc
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 10 deletions.
29 changes: 23 additions & 6 deletions adbpyg_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
from typing import Any, DefaultDict, Dict, List, Set, Union

import torch
from arango.database import Database
from arango.graph import Graph as ADBGraph
from pandas import DataFrame, Series
Expand All @@ -14,7 +15,7 @@

from .abc import Abstract_ADBPyG_Adapter
from .controller import ADBPyG_Controller
from .exceptions import ADBMetagraphError, PyGMetagraphError
from .exceptions import ADBMetagraphError, InvalidADBEdgesError, PyGMetagraphError
from .typings import (
ADBMap,
ADBMetagraph,
Expand Down Expand Up @@ -79,6 +80,7 @@ def arangodb_to_pyg(
name: str,
metagraph: ADBMetagraph,
preserve_adb_keys: bool = False,
strict: bool = True,
**query_options: Any,
) -> Union[Data, HeteroData]:
"""Create a PyG graph from ArangoDB data. DOES carry
Expand Down Expand Up @@ -127,6 +129,9 @@ def arangodb_to_pyg(
ArangoDB graph is Heterogeneous, the ArangoDB keys will be preserved
under `_key` in your PyG graph.
:type preserve_adb_keys: bool
:param strict: Set fault tolerance when loading a graph from ArangoDB. If set
to false, this will ignore invalid edges (e.g. dangling/half edges).
:type strict: bool
:param query_options: Keyword arguments to specify AQL query options when
fetching documents from the ArangoDB instance. Full parameter list:
https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute
Expand Down Expand Up @@ -296,6 +301,18 @@ def udf_v1_x(v1_df):

edge_data: EdgeStorage = data if is_homogeneous else data[edge_type]
edge_data.edge_index = tensor([from_nodes, to_nodes])

if torch.any(torch.isnan(edge_data.edge_index)):
if strict:
raise InvalidADBEdgesError(
f"Invalid edges found in Edge Collection {e_col}, {from_col} -> {to_col}." # noqa: E501
)
else:
# Remove the invalid edges
edge_data.edge_index = edge_data.edge_index[
:, ~torch.any(edge_data.edge_index.isnan(), dim=0)
]

self.__set_pyg_data(meta, edge_data, et_df)

if preserve_adb_keys:
Expand Down Expand Up @@ -654,28 +671,28 @@ def ntypes_to_ocollections(
return list(orphan_collections)

def __fetch_adb_docs(
self, col: str, empty_meta: bool, query_options: Any
self, col: str, meta_is_empty: bool, query_options: Any
) -> DataFrame:
"""Fetches ArangoDB documents within a collection. Returns the
documents in a DataFrame.
:param col: The ArangoDB collection.
:type col: str
:param empty_meta: Set to True if the metagraph specification
:param meta_is_empty: Set to True if the metagraph specification
for **col** is empty.
:type empty_meta: bool
:type meta_is_empty: bool
:param query_options: Keyword arguments to specify AQL query options
when fetching documents from the ArangoDB instance.
:type query_options: Any
:return: A DataFrame representing the ArangoDB documents.
:rtype: pandas.DataFrame
"""
# Only return the entire document if **empty_meta** is False
# Only return the entire document if **meta_is_empty** is False
aql = f"""
FOR doc IN @@col
RETURN {
"{ _key: doc._key, _from: doc._from, _to: doc._to }"
if empty_meta
if meta_is_empty
else "doc"
}
"""
Expand Down
13 changes: 13 additions & 0 deletions adbpyg_adapter/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ class ADBMetagraphError(ADBPyGValidationError):

class PyGMetagraphError(ADBPyGValidationError):
"""Invalid PyG Metagraph value"""


##################
# ADB -> PyG #
##################


class ADBPyGImportError(ADBPyGError):
"""Errors on import from ArangoDB to PyG"""


class InvalidADBEdgesError(ADBPyGImportError):
"""Invalid edges on import from ArangoDB to PyG"""
4 changes: 2 additions & 2 deletions adbpyg_adapter/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"PyGMap",
]

from typing import Any, Callable, DefaultDict, Dict, List, Tuple, Union
from typing import Any, Callable, DefaultDict, Dict, List, Set, Tuple, Union

from pandas import DataFrame
from torch import Tensor
Expand All @@ -20,7 +20,7 @@

ADBEncoders = Dict[str, DataFrameToTensor]
ADBMetagraphValues = Union[str, DataFrameToTensor, ADBEncoders]
ADBMetagraph = Dict[str, Dict[str, Dict[str, ADBMetagraphValues]]]
ADBMetagraph = Dict[str, Dict[str, Union[Set[str], Dict[str, ADBMetagraphValues]]]]

PyGDataTypes = Union[str, Tuple[str, str, str]]
PyGMetagraphValues = Union[str, List[str], TensorToDataFrame]
Expand Down
70 changes: 68 additions & 2 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@

from adbpyg_adapter import ADBPyG_Adapter
from adbpyg_adapter.encoders import CategoricalEncoder, IdentityEncoder
from adbpyg_adapter.exceptions import ADBMetagraphError, PyGMetagraphError
from adbpyg_adapter.exceptions import (
ADBMetagraphError,
InvalidADBEdgesError,
PyGMetagraphError,
)
from adbpyg_adapter.typings import (
ADBMap,
ADBMetagraph,
Expand Down Expand Up @@ -643,6 +647,68 @@ def test_adb_graph_to_pyg(
db.delete_graph(name, drop_collections=True)


@pytest.mark.parametrize("adapter", [adbpyg_adapter])
def test_adb_graph_to_pyg_to_arangodb_with_missing_document_and_strict(
adapter: ADBPyG_Adapter,
) -> None:
name = "Karate_3"
data = get_karate_graph()
db.delete_graph(name, drop_collections=True, ignore_missing=True)

ADBPyG_Adapter(db).pyg_to_arangodb(name, data)

graph = db.graph(name)
v_cols: Set[str] = graph.vertex_collections() # type: ignore
edge_definitions: List[Json] = graph.edge_definitions() # type: ignore
e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions}

for v_col in v_cols:
vertex_collection = db.collection(v_col)
vertex_collection.delete("0")

metagraph: ADBMetagraph = {
"vertexCollections": {col: {} for col in v_cols},
"edgeCollections": {col: {} for col in e_cols},
}

with pytest.raises(InvalidADBEdgesError):
adapter.arangodb_to_pyg(name, metagraph=metagraph, strict=True)

db.delete_graph(name, drop_collections=True)


@pytest.mark.parametrize("adapter", [adbpyg_adapter])
def test_adb_graph_to_pyg_to_arangodb_with_missing_document_and_permissive(
adapter: ADBPyG_Adapter,
) -> None:
name = "Karate_3"
data = get_karate_graph()
db.delete_graph(name, drop_collections=True, ignore_missing=True)

ADBPyG_Adapter(db).pyg_to_arangodb(name, data)

graph = db.graph(name)
v_cols: Set[str] = graph.vertex_collections() # type: ignore
edge_definitions: List[Json] = graph.edge_definitions() # type: ignore
e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions}

for v_col in v_cols:
vertex_collection = db.collection(v_col)
vertex_collection.delete("0")

metagraph: ADBMetagraph = {
"vertexCollections": {col: {} for col in v_cols},
"edgeCollections": {col: {} for col in e_cols},
}

data = adapter.arangodb_to_pyg(name, metagraph=metagraph, strict=False)

collection_count: int = db.collection(list(e_cols)[0]).count() # type: ignore
assert len(data.edge_index[0]) < collection_count

db.delete_graph(name, drop_collections=True)


def test_full_cycle_imdb_without_preserve_adb_keys() -> None:
name = "imdb"
db.delete_graph(name, drop_collections=True, ignore_missing=True)
Expand Down Expand Up @@ -1000,7 +1066,7 @@ def assert_adb_to_pyg(


def assert_adb_to_pyg_meta(
meta: Union[str, Dict[str, ADBMetagraphValues]],
meta: Union[str, Set[str], Dict[str, ADBMetagraphValues]],
df: DataFrame,
pyg_data: Union[NodeStorage, EdgeStorage],
) -> None:
Expand Down

0 comments on commit c096ccc

Please sign in to comment.