Skip to content

Commit

Permalink
fix: udf behaviour when batch_size is set
Browse files Browse the repository at this point in the history
  • Loading branch information
aMahanna committed Jul 21, 2023
1 parent a66326c commit 2a6abad
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 32 deletions.
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,25 @@ Note: If the PyG graph contains `_key`, `_v_key`, or `_e_key` properties for any
adb_g = adbpyg_adapter.pyg_to_arangodb("FakeData", data)

# 1.2: PyG to ArangoDB with a (completely optional) metagraph for customized adapter behaviour
def y_tensor_to_2_column_dataframe(pyg_tensor):
def y_tensor_to_2_column_dataframe(pyg_tensor, adb_df):
"""
A user-defined function to create two
ArangoDB attributes out of the 'y' label tensor
ArangoDB attributes out of the 'user' label tensor
NOTE: user-defined functions must return a Pandas Dataframe
:param dgl_tensor: The DGL Tensor containing the data
:type dgl_tensor: torch.Tensor
:param adb_df: The ArangoDB DataFrame to populate, whose
size is preset to the length of **dgl_tensor**.
:type adb_df: pandas.DataFrame
NOTE: user-defined functions must return the modified **adb_df**
"""
label_map = {0: "Kiwi", 1: "Blueberry", 2: "Avocado"}

df = pandas.DataFrame(columns=["label_num", "label_str"])
df["label_num"] = pyg_tensor.tolist()
df["label_str"] = df["label_num"].map(label_map)
adb_df["label_num"] = pyg_tensor.tolist()
adb_df["label_str"] = adb_df["label_num"].map(label_map)

return df
return adb_df


metagraph = {
Expand Down
32 changes: 23 additions & 9 deletions adbpyg_adapter/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,16 +500,15 @@ def pyg_to_arangodb(
**metagraph** example
.. code-block:: python
def y_tensor_to_2_column_dataframe(pyg_tensor):
def y_tensor_to_2_column_dataframe(pyg_tensor, adb_df):
# A user-defined function to create two ArangoDB attributes
# out of the 'y' label tensor
label_map = {0: "Kiwi", 1: "Blueberry", 2: "Avocado"}
df = pandas.DataFrame(columns=["label_num", "label_str"])
df["label_num"] = pyg_tensor.tolist()
df["label_str"] = df["label_num"].map(label_map)
adb_df["label_num"] = pyg_tensor.tolist()
adb_df["label_str"] = adb_df["label_num"].map(label_map)
return df
return adb_df
metagraph = {
"nodeTypes": {
Expand Down Expand Up @@ -1051,11 +1050,26 @@ def __build_dataframe_from_tensor(
return df

if callable(meta_val):
# **meta_val** is a user-defined function that returns a dataframe
user_defined_result = meta_val(pyg_tensor)
# **meta_val** is a user-defined function that populates
# and returns the empty dataframe
empty_df = DataFrame(index=range(start_index, end_index))
user_defined_result = meta_val(pyg_tensor, empty_df)

if type(user_defined_result) is not DataFrame: # pragma: no cover
msg = f"Invalid return type for function {meta_val} ('{meta_key}')"
if not isinstance(user_defined_result, DataFrame): # pragma: no cover
msg = f"""
Invalid return type for function {meta_val} ('{meta_key}').
Function must return Pandas DataFrame.
"""
raise PyGMetagraphError(msg)

if (
user_defined_result.index.start != start_index
or user_defined_result.index.stop != end_index
): # pragma: no cover
msg = f"""
User Defined Function {meta_val} ('{meta_key}') must return
DataFrame with start index {start_index} & stop index {end_index}
"""
raise PyGMetagraphError(msg)

return user_defined_result
Expand Down
2 changes: 1 addition & 1 deletion adbpyg_adapter/typings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
Json = Dict[str, Any]

DataFrameToTensor = Callable[[DataFrame], Tensor]
TensorToDataFrame = Callable[[Tensor], DataFrame]
TensorToDataFrame = Callable[[Tensor, DataFrame], DataFrame]

ADBEncoders = Dict[str, DataFrameToTensor]
ADBMetagraphValues = Union[str, DataFrameToTensor, ADBEncoders]
Expand Down
9 changes: 4 additions & 5 deletions examples/ArangoDB_PyG_Adapter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -664,14 +664,13 @@
"nx.draw(to_networkx(pyg_hetero_graph.to_homogeneous()), with_labels=True)\n",
"\n",
"# Define the metagraph\n",
"def y_tensor_to_2_column_dataframe(pyg_tensor):\n",
"def y_tensor_to_2_column_dataframe(pyg_tensor, adb_df):\n",
" label_map = {0: \"Kiwi\", 1: \"Blueberry\", 2: \"Avocado\"}\n",
"\n",
" df = pandas.DataFrame(columns=[\"label_num\", \"label_str\"])\n",
" df[\"label_num\"] = pyg_tensor.tolist()\n",
" df[\"label_str\"] = df[\"label_num\"].map(label_map)\n",
" adb_df[\"label_num\"] = pyg_tensor.tolist()\n",
" adb_df[\"label_str\"] = adb_df[\"label_num\"].map(label_map)\n",
"\n",
" return df\n",
" return adb_df\n",
"\n",
"metagraph = {\n",
" \"nodeTypes\": {\n",
Expand Down
18 changes: 8 additions & 10 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,17 @@ def get_social_graph() -> HeteroData:


# For PyG to ArangoDB testing purposes
def udf_v2_x_tensor_to_df(t: Tensor) -> DataFrame:
df = DataFrame(columns=["x"])
df["x"] = t.tolist()
# do more things with df["v2_features"] here ...
return df
def udf_v2_x_tensor_to_df(t: Tensor, adb_df: DataFrame) -> DataFrame:
adb_df["x"] = t.tolist()
# do more things with adb_df["v2_features"] here ...
return adb_df


# For PyG to ArangoDB testing purposes
def udf_users_x_tensor_to_df(t: Tensor) -> DataFrame:
df = DataFrame(columns=["age", "gender"])
df[["age", "gender"]] = t.tolist()
df["gender"] = df["gender"].map({0: "Male", 1: "Female"})
return df
def udf_users_x_tensor_to_df(t: Tensor, adb_df: DataFrame) -> DataFrame:
adb_df[["age", "gender"]] = t.tolist()
adb_df["gender"] = adb_df["gender"].map({0: "Male", 1: "Female"})
return adb_df


# For ArangoDB to PyG testing purposes
Expand Down

0 comments on commit 2a6abad

Please sign in to comment.