Skip to content

Commit

Permalink
add edge_index.dtype assertions
Browse files Browse the repository at this point in the history
(tests should fail)
  • Loading branch information
aMahanna committed Feb 8, 2024
1 parent 6b00036 commit 2980240
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List, Optional, Set, Union

import pytest
import torch
from pandas import DataFrame
from torch import Tensor, cat, long, tensor
from torch_geometric.data import Data, HeteroData
Expand Down Expand Up @@ -554,6 +555,7 @@ def test_adb_partial_to_pyg() -> None:
assert type(pyg_g_new) is Data
assert pyg_g["v0"].x.tolist() == pyg_g_new.x.tolist()
assert pyg_g["v0"].y.tolist() == pyg_g_new.y.tolist()
assert pyg_g[e_t].edge_index.dtype == torch.int64
assert pyg_g[e_t].edge_index.tolist() == pyg_g_new.edge_index.tolist()
assert pyg_g[e_t].edge_attr.tolist() == pyg_g_new.edge_attr.tolist()

Expand Down Expand Up @@ -714,13 +716,17 @@ def test_adb_graph_to_pyg_to_arangodb_with_missing_document_and_permissive(

graph = db.graph(name)
v_cols: Set[str] = graph.vertex_collections()
assert len(v_cols) == 1
edge_definitions: List[Json] = graph.edge_definitions()
e_cols: Set[str] = {c["edge_collection"] for c in edge_definitions}
assert len(e_cols) == 1

for v_col in v_cols:
vertex_collection = db.collection(v_col)
vertex_collection.delete("0")

number_of_missing_edges = 32 # (i.e node 0 has 32 edges)

metagraph: ADBMetagraph = {
"vertexCollections": {col: {} for col in v_cols},
"edgeCollections": {col: {} for col in e_cols},
Expand All @@ -729,7 +735,8 @@ def test_adb_graph_to_pyg_to_arangodb_with_missing_document_and_permissive(
data = adapter.arangodb_to_pyg(name, metagraph=metagraph, strict=False)

collection_count: int = db.collection(list(e_cols)[0]).count()
assert len(data.edge_index[0]) < collection_count
assert data.edge_index.dtype == torch.int64
assert data.num_edges + number_of_missing_edges == collection_count

db.delete_graph(name, drop_collections=True)

Expand Down Expand Up @@ -1076,6 +1083,7 @@ def assert_adb_to_pyg(
from_nodes = et_df["from_key"].map(adb_map[from_col]).tolist()
to_nodes = et_df["to_key"].map(adb_map[to_col]).tolist()

assert edge_data.edge_index.dtype == torch.int64
assert from_nodes == edge_data.edge_index[0].tolist()
assert to_nodes == edge_data.edge_index[1].tolist()

Expand Down

0 comments on commit 2980240

Please sign in to comment.