Skip to content

Commit

Permalink
housekeeping
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Jul 4, 2023
1 parent d767dcd commit 6ee5011
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 29 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- name: Run isort
run: isort --check --profile=black ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
- name: Run mypy
run: mypy ${{env.PACKAGE_DIR}}
run: mypy ${{env.PACKAGE_DIR}} ${{env.TESTS_DIR}}
- name: Run pytest
run: pytest --cov=${{env.PACKAGE_DIR}} --cov-report xml --cov-report term-missing -v --color=yes --no-cov-on-fail --code-highlight=yes
- name: Publish to coveralls.io
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ celerybeat-schedule

# Environments
.env
.venv
.venv*
env/
venv/
ENV/
Expand Down
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,10 @@ Prerequisite: `arangorestore`
1. `git clone https://github.com/arangoml/pyg-adapter.git`
2. `cd pyg-adapter`
3. (create virtual environment of choice)
4. `pip install -e .[dev]`
5. (create an ArangoDB instance with method of choice)
6. `pytest --url <> --dbName <> --username <> --password <>`
4. `pip install torch`
5. `pip install -e .[dev]`
6. (create an ArangoDB instance with method of choice)
7. `pytest --url <> --dbName <> --username <> --password <>`

**Note**: A `pytest` parameter can be omitted if the endpoint is using its default value:
```python
Expand Down
19 changes: 12 additions & 7 deletions adbpyg_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from arango.database import Database
from arango.graph import Graph as ADBGraph
from pandas import DataFrame
from pandas import DataFrame, Series
from torch import Tensor, cat, tensor
from torch_geometric.data import Data, HeteroData
from torch_geometric.data.storage import EdgeStorage, NodeStorage
Expand Down Expand Up @@ -68,7 +68,7 @@ def db(self) -> Database:
return self.__db # pragma: no cover

@property
def cntrl(self) -> Database:
def cntrl(self) -> ADBPyG_Controller:
return self.__cntrl # pragma: no cover

def set_logging(self, level: Union[int, str]) -> None:
Expand Down Expand Up @@ -272,8 +272,8 @@ def udf_v1_x(v1_df):
logger.debug(f"Preparing '{e_col}' edges")

df = self.__fetch_adb_docs(e_col, meta == {}, query_options)
df[["from_col", "from_key"]] = df["_from"].str.split("/", 1, True)
df[["to_col", "to_key"]] = df["_to"].str.split("/", 1, True)
df[["from_col", "from_key"]] = self.__split_adb_ids(df["_from"])
df[["to_col", "to_key"]] = self.__split_adb_ids(df["_to"])

for (from_col, to_col), count in (
df[["from_col", "to_col"]].value_counts().items()
Expand Down Expand Up @@ -384,8 +384,9 @@ def arangodb_graph_to_pyg(
:raise adbpyg_adapter.exceptions.ADBMetagraphError: If invalid metagraph.
"""
graph = self.__db.graph(name)
v_cols = graph.vertex_collections()
e_cols = {col["edge_collection"] for col in graph.edge_definitions()}
v_cols: Set[str] = graph.vertex_collections() # type: ignore
edge_definitions: List[Json] = graph.edge_definitions() # type: ignore
e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions}

return self.arangodb_collections_to_pyg(
name, v_cols, e_cols, preserve_adb_keys, **query_options
Expand Down Expand Up @@ -526,7 +527,7 @@ def y_tensor_to_2_column_dataframe(pyg_tensor):
edge_definitions = self.etypes_to_edefinitions(edge_types)
orphan_collections = self.ntypes_to_ocollections(node_types, edge_types)
adb_graph = self.__db.create_graph(
name, edge_definitions, orphan_collections
name, edge_definitions, orphan_collections # type: ignore
)

# Define PyG data properties
Expand Down Expand Up @@ -718,6 +719,10 @@ def __insert_adb_docs(
result = self.__db.collection(col).import_bulk(docs, **kwargs)
logger.debug(result)

def __split_adb_ids(self, s: Series) -> Series:
"""Helper method to split the ArangoDB IDs within a Series into two columns"""
return s.str.split(pat="/", n=1, expand=True)

def __set_pyg_data(
self,
meta: Union[Set[str], Dict[str, ADBMetagraphValues]],
Expand Down
1 change: 0 additions & 1 deletion adbpyg_adapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def validate_pyg_metagraph(metagraph: Dict[Any, Dict[Any, Any]]) -> None:

for parent_key in ["nodeTypes", "edgeTypes"]:
for k, meta in metagraph.get(parent_key, {}).items():

if type(meta) == set:
for m in meta:
if type(m) != str:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ def pytest_configure(config: Any) -> None:
print("Database: " + con["dbName"])
print("----------------------------------------")

class NoTimeoutHTTPClient(DefaultHTTPClient): # type: ignore
REQUEST_TIMEOUT = None
class NoTimeoutHTTPClient(DefaultHTTPClient):
REQUEST_TIMEOUT = None # type: ignore

global db
db = ArangoClient(hosts=con["url"], http_client=NoTimeoutHTTPClient()).db(
Expand Down
31 changes: 17 additions & 14 deletions tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ADBMap,
ADBMetagraph,
ADBMetagraphValues,
Json,
PyGMap,
PyGMetagraph,
PyGMetagraphValues,
Expand Down Expand Up @@ -45,7 +46,7 @@ class Bad_ADBPyG_Controller:
pass

with pytest.raises(TypeError):
ADBPyG_Adapter(bad_db)
ADBPyG_Adapter(bad_db) # type: ignore

with pytest.raises(TypeError):
ADBPyG_Adapter(db, Bad_ADBPyG_Controller()) # type: ignore
Expand Down Expand Up @@ -372,11 +373,11 @@ def test_pyg_to_arangodb_with_controller() -> None:

ADBPyG_Adapter(db, Custom_ADBPyG_Controller()).pyg_to_arangodb(name, data)

for doc in db.collection(name + "_N"):
for doc in db.collection(f"{name}_N"): # type: ignore
assert "foo" in doc
assert doc["foo"] == "bar"

for edge in db.collection(name + "_E"):
for edge in db.collection(f"{name}_E"): # type: ignore
assert "bar" in edge
assert edge["bar"] == "foo"

Expand Down Expand Up @@ -618,9 +619,10 @@ def test_adb_graph_to_pyg(

pyg_g_new = adapter.arangodb_graph_to_pyg(name)

arango_graph = db.graph(name)
v_cols = arango_graph.vertex_collections()
e_cols = {col["edge_collection"] for col in arango_graph.edge_definitions()}
graph = db.graph(name)
v_cols: Set[str] = graph.vertex_collections() # type: ignore
edge_definitions: List[Json] = graph.edge_definitions() # type: ignore
e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions}

# Manually set the number of nodes (since nodes are feature-less)
for v_col in v_cols:
Expand Down Expand Up @@ -705,10 +707,11 @@ def test_full_cycle_homogeneous_with_preserve_adb_keys() -> None:

pyg_g = adbpyg_adapter.arangodb_graph_to_pyg(name, preserve_adb_keys=True)

# Establish ground truth
arango_graph = db.graph(name)
v_cols = arango_graph.vertex_collections()
e_cols = {col["edge_collection"] for col in arango_graph.edge_definitions()}
graph = db.graph(name)
v_cols: Set[str] = graph.vertex_collections() # type: ignore
edge_definitions: List[Json] = graph.edge_definitions() # type: ignore
e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions}

metagraph: ADBMetagraph = {
"vertexCollections": {col: {} for col in v_cols},
"edgeCollections": {col: {} for col in e_cols},
Expand Down Expand Up @@ -849,8 +852,8 @@ def assert_pyg_to_adb(
collection = db.collection(e_col)

df = DataFrame(collection.all())
df[["from_col", "from_key"]] = df["_from"].str.split("/", 1, True)
df[["to_col", "to_key"]] = df["_to"].str.split("/", 1, True)
df[["from_col", "from_key"]] = df["_from"].str.split(pat="/", n=1, expand=True)
df[["to_col", "to_key"]] = df["_to"].str.split(pat="/", n=1, expand=True)

et_df = df[(df["from_col"] == from_col) & (df["to_col"] == to_col)]
assert len(et_df) == edge_data.num_edges
Expand Down Expand Up @@ -959,8 +962,8 @@ def assert_adb_to_pyg(
assert collection.count() <= pyg_g.num_edges

df = DataFrame(collection.all())
df[["from_col", "from_key"]] = df["_from"].str.split("/", 1, True)
df[["to_col", "to_key"]] = df["_to"].str.split("/", 1, True)
df[["from_col", "from_key"]] = df["_from"].str.split(pat="/", n=1, expand=True)
df[["to_col", "to_key"]] = df["_to"].str.split(pat="/", n=1, expand=True)

for (from_col, to_col), count in (
df[["from_col", "to_col"]].value_counts().items()
Expand Down

0 comments on commit 6ee5011

Please sign in to comment.