Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Oct 19, 2023
1 parent e6b6920 commit fdf4d8a
Showing 1 changed file with 18 additions and 18 deletions.
36 changes: 18 additions & 18 deletions adbpyg_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import torch
from arango.cursor import Cursor
from arango.database import Database
from arango.database import StandardDatabase
from arango.graph import Graph as ADBGraph
from pandas import DataFrame, Series
from rich.console import Group
Expand Down Expand Up @@ -57,14 +57,14 @@ class ADBPyG_Adapter(Abstract_ADBPyG_Adapter):

def __init__(
self,
db: Database,
db: StandardDatabase,
controller: ADBPyG_Controller = ADBPyG_Controller(),
logging_lvl: Union[str, int] = logging.INFO,
):
self.set_logging(logging_lvl)

if not isinstance(db, Database):
msg = "**db** parameter must inherit from arango.database.Database"
if not isinstance(db, StandardDatabase):
msg = "**db** parameter must inherit from arango.database.StandardDatabase"
raise TypeError(msg)

if not isinstance(controller, ADBPyG_Controller):
Expand All @@ -78,7 +78,7 @@ def __init__(
logger.info(f"Instantiated ADBPyG_Adapter with database '{db.name}'")

@property
def db(self) -> Database:
def db(self) -> StandardDatabase:
return self.__db # pragma: no cover

@property
Expand All @@ -98,7 +98,7 @@ def arangodb_to_pyg(
metagraph: ADBMetagraph,
preserve_adb_keys: bool = False,
strict: bool = True,
**query_options: Any,
**adb_export_kwargs: Any,
) -> Union[Data, HeteroData]:
"""Create a PyG graph from ArangoDB data. DOES carry
over node/edge features/labels, via the **metagraph**.
Expand Down Expand Up @@ -149,10 +149,10 @@ def arangodb_to_pyg(
:param strict: Set fault tolerance when loading a graph from ArangoDB. If set
to false, this will ignore invalid edges (e.g. dangling/half edges).
:type strict: bool
:param query_options: Keyword arguments to specify AQL query options when
:param adb_export_kwargs: Keyword arguments to specify AQL query options when
fetching documents from the ArangoDB instance. Full parameter list:
https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute
:type query_options: Any
:type adb_export_kwargs: Any
:return: A PyG Data or HeteroData object
:rtype: torch_geometric.data.Data | torch_geometric.data.HeteroData
:raise adbpyg_adapter.exceptions.ADBMetagraphError: If invalid metagraph.
Expand Down Expand Up @@ -294,7 +294,7 @@ def udf_v1_x(v1_df):

# 1. Fetch ArangoDB vertices
v_col_cursor, v_col_size = self.__fetch_adb_docs(
v_col, meta, **query_options
v_col, meta, **adb_export_kwargs
)

# 2. Process ArangoDB vertices
Expand Down Expand Up @@ -323,7 +323,7 @@ def udf_v1_x(v1_df):

# 1. Fetch ArangoDB edges
e_col_cursor, e_col_size = self.__fetch_adb_docs(
e_col, meta, **query_options
e_col, meta, **adb_export_kwargs
)

# 2. Process ArangoDB edges
Expand Down Expand Up @@ -351,7 +351,7 @@ def arangodb_collections_to_pyg(
v_cols: Set[str],
e_cols: Set[str],
preserve_adb_keys: bool = False,
**query_options: Any,
**adb_export_kwargs: Any,
) -> Union[Data, HeteroData]:
"""Create a PyG graph from ArangoDB collections. Due to risk of
ambiguity, this method DOES NOT transfer ArangoDB attributes to PyG.
Expand All @@ -377,10 +377,10 @@ def arangodb_collections_to_pyg(
ArangoDB graph is Heterogeneous, the ArangoDB keys will be preserved
under `_key` in your PyG graph.
:type preserve_adb_keys: bool
:param query_options: Keyword arguments to specify AQL query options when
:param adb_export_kwargs: Keyword arguments to specify AQL query options when
fetching documents from the ArangoDB instance. Full parameter list:
https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute
:type query_options: Any
:type adb_export_kwargs: Any
:return: A PyG Data or HeteroData object
:rtype: torch_geometric.data.Data | torch_geometric.data.HeteroData
:raise adbpyg_adapter.exceptions.ADBMetagraphError: If invalid metagraph.
Expand All @@ -390,10 +390,10 @@ def arangodb_collections_to_pyg(
"edgeCollections": {col: dict() for col in e_cols},
}

return self.arangodb_to_pyg(name, metagraph, preserve_adb_keys, **query_options)
return self.arangodb_to_pyg(name, metagraph, preserve_adb_keys, **adb_export_kwargs)

def arangodb_graph_to_pyg(
self, name: str, preserve_adb_keys: bool = False, **query_options: Any
self, name: str, preserve_adb_keys: bool = False, **adb_export_kwargs: Any
) -> Union[Data, HeteroData]:
"""Create a PyG graph from an ArangoDB graph. Due to risk of
ambiguity, this method DOES NOT transfer ArangoDB attributes to PyG.
Expand All @@ -415,10 +415,10 @@ def arangodb_graph_to_pyg(
ArangoDB graph is Heterogeneous, the ArangoDB keys will be preserved
under `_key` in your PyG graph.
:type preserve_adb_keys: bool
:param query_options: Keyword arguments to specify AQL query options when
:param adb_export_kwargs: Keyword arguments to specify AQL query options when
fetching documents from the ArangoDB instance. Full parameter list:
https://docs.python-arango.com/en/main/specs.html#arango.aql.AQL.execute
:type query_options: Any
:type adb_export_kwargs: Any
:return: A PyG Data or HeteroData object
:rtype: torch_geometric.data.Data | torch_geometric.data.HeteroData
:raise adbpyg_adapter.exceptions.ADBMetagraphError: If invalid metagraph.
Expand All @@ -429,7 +429,7 @@ def arangodb_graph_to_pyg(
e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions}

return self.arangodb_collections_to_pyg(
name, v_cols, e_cols, preserve_adb_keys, **query_options
name, v_cols, e_cols, preserve_adb_keys, **adb_export_kwargs
)

###########################
Expand Down

0 comments on commit fdf4d8a

Please sign in to comment.