Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

q parameter does not receive gradient even when trainable_q = True in MagNet_link_prediction #60

Open
ClaudMor opened this issue Jun 18, 2024 · 5 comments

Comments

@ClaudMor
Copy link

Describe the bug
The q parameter does not receive gradient even when trainable_q = True in MagNet_link_prediction .

To Reproduce
The following MWE (taken from here) performs a .backward() on MagNet_link_prediction. The .grad field of q is not updated, while the .grad field of weight does.

import numpy as np
from sklearn.metrics import accuracy_score
import torch

from torch_geometric_signed_directed.utils import  link_class_split, in_out_degree
from torch_geometric_signed_directed.nn.directed import MagNet_link_prediction 
from torch_geometric_signed_directed.data import  load_directed_real_data
from torch_geometric import seed_everything

# Set the seed
seed = 12345
seed_everything(seed) 
# Set the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load data as specified by the paper
data = load_directed_real_data(dataset='cora_ml', root="../data/pygsd/").to(device) 
link_data = link_class_split(data, prob_val=0.05, prob_test=0.15, task = 'existence', device=device, splits = 10, seed = seed)

# Load the model with parameters as specified in the paper + trainable_q = True
model = MagNet_link_prediction(q=0.05, K=1, num_features=2, dropout = 0.5, hidden=16, label_dim=2, trainable_q = True).to(device)
criterion = torch.nn.NLLLoss()


split = list(link_data.keys())[1]
edge_index = link_data[split]['graph'] 
edge_weight = link_data[split]['weights']
query_edges = link_data[split]['train']['edges']
y = link_data[split]['train']['label']
X_real = in_out_degree(edge_index,
size=len(data.x)).to(device)
X_img = X_real.clone()
out = model(X_real, X_img, edge_index=edge_index,
                    query_edges=query_edges,
                    edge_weight=edge_weight)
loss = criterion(out, y)
loss.backward()

model.Chebs[0].q.grad # not updated
model.Chebs[0].weight.grad # properly updated

torchviz confirms that q is not part of the backward graph:

from torchviz import make_dot
viz = make_dot(loss, params=dict(model.named_parameters()),)
viz.view()

I believe the non-differentiable operation happens somewhere in __norm__ or get_magnetic_Laplacian, but I haven't been able to identify it exactly.

Expected behavior
model.Chebs[0].q.grad should be properly updated.

Desktop:

  • OS: Windows 10

Thanks in advance for your help.

@SherylHYX
Copy link
Owner

Thank you for your interest in our package. Do you have any idea how to fix that and perhaps could you open a pull request on this?

@ClaudMor
Copy link
Author

Hi @SherylHYX,

I think the issue is a (perhaps intended) mismatch in nomenclature between the paper, the documentation and the code. Essentially, the paper and the documentation define K as the order of the Chebyshev polynomial, which when set to 1 returns the convolution as defined immediately under equation (6) of the paper. In the code however, if K is set to 1 the $W_{neigh}$ weights are never defined, nor used. If the code does not enter the latter if statement, the Laplacian does not become part of the computation graph and thus q does not receive any gradient. This explains why the default value of K is set to 2, but it caught me off-guard since I was reading the paper where K=2 would imply an additional term in the convolution.

If this interpretation is correct (please check!), it looks to me like an issue about conventions. Therefore I'll leave it to you to decide whether to follow the paper or keep it like this (maybe in the latter case the documentation could be corrected to specify that K is the order of Chebyshev - 1).

Please let me know if I made any mistake :) .

@SherylHYX
Copy link
Owner

SherylHYX commented Jun 21, 2024

Thank you @ClaudMor for pointing this out. I have now fixed the documentation. Note that $K$ is the order of the Chebyshev polynomial plus 1 instead of minus 1, which is indeed the size of the Chebyshev filter. I have also updated the released version of the package.

@ClaudMor
Copy link
Author

ClaudMor commented Aug 9, 2024

Hi @SherylHYX,

There may be also another case where q is not trained even if trainable_q = True, that is when also cached = True. Infact, in that case, get_magnetic_Laplacian and __norm__ are never invoked and therefore q does not receive gradients.

If you think it's correct, perhaps it should be mentioned in the documentation?

Thanks

@SherylHYX
Copy link
Owner

SherylHYX commented Aug 9, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants