Skip to content

Commit

Permalink
add mask check
Browse files Browse the repository at this point in the history
  • Loading branch information
kayzliu committed Jan 3, 2024
1 parent f8da1ea commit a0c252c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
13 changes: 9 additions & 4 deletions godm/godm.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(self,
self.t_min = None
self.t_max = None

def __call__(self, data):
def forward(self, data):
self.arg_parse(data)
data = self.preprocess(data)

Expand Down Expand Up @@ -219,10 +219,14 @@ def arg_parse(self, data):
raise TypeError('data must be torch_geometric.data.Data')

if not hasattr(data, 'x'):
raise ValueError('data must have feature x')
raise ValueError('data must has feature x')

if not hasattr(data, 'y'):
raise ValueError('data must have label y')
raise ValueError('data must has label y')

if not hasattr(data, 'train_mask') or not hasattr(data, 'val_mask') \
or not hasattr(data, 'test_mask'):
raise ValueError('data must has train_mask, val_mask, test_mask')

if self.hid_dim is None:
self.hid_dim = 2 ** int(math.log2(data.x.size(1)) - 1)
Expand Down Expand Up @@ -437,7 +441,8 @@ def recon_loss(self, x, x_, edge_label, edge_pred,
t=None, t_=None, p=None, p_=None):
loss_x = F.mse_loss(x_, x)
if self.verbose:
print("Batch Loss: feature: {:.4f}".format(loss_x.item()), end=' ')
print(" Batch Loss: feature: {:.4f}".format(loss_x.item()),
end=' ')
loss_e = F.binary_cross_entropy_with_logits(edge_pred, edge_label)
if self.verbose:
print("time: {:.4f}".format(loss_e.item()), end=' ')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def readme():
author='kayzliu',
author_email='zliu234@uic.edu',
url='https://github.com/kayzliu/godm',
download_url='https://github.com/kayzliu/godm/archive/main.zip',
download_url='https://github.com/kayzliu/godm/archive/master.zip',
keywords=['outlier detection', 'data augmentation', 'diffusion models',
'graph neural networks', 'graph generative model'],
packages=find_packages(),
Expand Down

0 comments on commit a0c252c

Please sign in to comment.