Skip to content

Commit

Permalink
revise readme
Browse files Browse the repository at this point in the history
  • Loading branch information
kayzliu committed Jan 3, 2024
1 parent a0c252c commit 0e8b646
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
13 changes: 13 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ aug_data = godm(data) # augment data
detector(aug_data) # train on data
```

The input data should be [`torch_geometric.Data`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.data.Data.html#torch_geometric.data.Data) object with the following keys:

- `x`: node features,
- `edge_index`: edge index,
- `edge_time`: edge times (optional, name can be changed by `time_attr`),
- `edge_type`: edge types (optional, name can be changed by `type_attr`),
- `y`: node labels,
- `train_mask`: training node mask,
- `val_mask`: validation node mask,
- `test_mask`: testing node mask.

So far, no additional keys is allowed. We may support more keys by padding in the future.

## Parameters

- ```hid_dim``` (type: `int`, default: `None`): hidden dimension for VAE, i.e., latent embedding dimension. `None` means the largest power of 2 that is less than or equal to the feature dimension divided by two.
Expand Down
6 changes: 3 additions & 3 deletions godm/godm.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,14 +219,14 @@ def arg_parse(self, data):
raise TypeError('data must be torch_geometric.data.Data')

if not hasattr(data, 'x'):
raise ValueError('data must has feature x')
raise ValueError('data must have feature x')

if not hasattr(data, 'y'):
raise ValueError('data must has label y')
raise ValueError('data must have label y')

if not hasattr(data, 'train_mask') or not hasattr(data, 'val_mask') \
or not hasattr(data, 'test_mask'):
raise ValueError('data must has train_mask, val_mask, test_mask')
raise ValueError('data must have train_mask, val_mask, test_mask')

if self.hid_dim is None:
self.hid_dim = 2 ** int(math.log2(data.x.size(1)) - 1)
Expand Down

0 comments on commit 0e8b646

Please sign in to comment.