From 0e8b64699c073c7d6e494634b58392d9ebc398dc Mon Sep 17 00:00:00 2001 From: Kay Liu Date: Tue, 2 Jan 2024 18:38:47 -0600 Subject: [PATCH] revise readme --- README.md | 13 +++++++++++++ godm/godm.py | 6 +++--- 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index b22f093..2ac7a1e 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,19 @@ aug_data = godm(data) # augment data detector(aug_data) # train on data ``` +The input data should be [`torch_geometric.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) object with the following keys: + +- `x`: node features, +- `edge_index`: edge index, +- `edge_time`: edge times (optional, name can be changed by `time_attr`), +- `edge_type`: edge types (optional, name can be changed by `type_attr`), +- `y`: node labels, +- `train_mask`: training node mask, +- `val_mask`: validation node mask, +- `test_mask`: testing node mask. + +So far, no additional keys is allowed. We may support more keys by padding in the future. + ## Parameters - ```hid_dim``` (type: `int`, default: `None`): hidden dimension for VAE, i.e., latent embedding dimension. `None` means the largest power of 2 that is less than or equal to the feature dimension divided by two. diff --git a/godm/godm.py b/godm/godm.py index b64675a..77938fd 100644 --- a/godm/godm.py +++ b/godm/godm.py @@ -219,14 +219,14 @@ def arg_parse(self, data): raise TypeError('data must be torch_geometric.data.Data') if not hasattr(data, 'x'): - raise ValueError('data must has feature x') + raise ValueError('data must have feature x') if not hasattr(data, 'y'): - raise ValueError('data must has label y') + raise ValueError('data must have label y') if not hasattr(data, 'train_mask') or not hasattr(data, 'val_mask') \ or not hasattr(data, 'test_mask'): - raise ValueError('data must has train_mask, val_mask, test_mask') + raise ValueError('data must have train_mask, val_mask, test_mask') if self.hid_dim is None: self.hid_dim = 2 ** int(math.log2(data.x.size(1)) - 1)