From aca385af473013cf40d57204c94c8b9181239305 Mon Sep 17 00:00:00 2001 From: aMahanna Date: Tue, 2 Aug 2022 21:46:51 -0400 Subject: [PATCH] new: lazy attempt at #4 --- adbpyg_adapter/encoders.py | 7 ++++++- adbpyg_adapter/typings.py | 7 ++++++- tests/conftest.py | 7 ++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/adbpyg_adapter/encoders.py b/adbpyg_adapter/encoders.py index fca7574..d9977dc 100644 --- a/adbpyg_adapter/encoders.py +++ b/adbpyg_adapter/encoders.py @@ -5,7 +5,12 @@ from typing import Any, Dict, Optional -from pandas import DataFrame +try: + # https://github.com/arangoml/pyg-adapter/issues/4 + from cudf import DataFrame +except ModuleNotFoundError: + from pandas import DataFrame + from torch import Tensor, from_numpy, zeros diff --git a/adbpyg_adapter/typings.py b/adbpyg_adapter/typings.py index 992d9d1..83affec 100644 --- a/adbpyg_adapter/typings.py +++ b/adbpyg_adapter/typings.py @@ -10,7 +10,12 @@ from typing import Any, Callable, DefaultDict, Dict, List, Tuple, Union -from pandas import DataFrame +try: + # https://github.com/arangoml/pyg-adapter/issues/4 + from cudf import DataFrame +except ModuleNotFoundError: + from pandas import DataFrame + from torch import Tensor Json = Dict[str, Any] diff --git a/tests/conftest.py b/tests/conftest.py index 89bee46..c3c14ee 100755 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,10 +4,15 @@ from pathlib import Path from typing import Any, Callable +try: + # https://github.com/arangoml/pyg-adapter/issues/4 + from cudf import DataFrame +except ModuleNotFoundError: + from pandas import DataFrame + from arango import ArangoClient from arango.database import StandardDatabase from arango.http import DefaultHTTPClient -from pandas import DataFrame from torch import Tensor, tensor from torch_geometric.data import Data, HeteroData from torch_geometric.datasets import Amazon, FakeDataset, FakeHeteroDataset, KarateClub