Helper functions for Graph permutation #50
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
One of the fundamental concepts of graph based data is, that changing the order of the nodes (or the edges) does not change the graph properties.
Basically all GNNs work can and maybe should be tested if the results are invariant (or at least equivariant) under such permutations.
And I thought it might be helpful if we had a way to apply such a permutation quickly to the graphs and see if that is true.
I needed some helper functions for my own work that applies such permutation, and since I appreciate this open-source project, I was wondering if you would like to include them in the official jraph repo.
I implemented 2 helper functions that can apply permutations in the nodes, and permutations in the edges.
One fundamental requirement I set for myself was, that this should still work for batched and padded graphs, and the permutations should be easily applied only inside the individual graphs inside a graph tuple.
I added a user friendly way to achieve this randomly, but also have an option for more controlled permutation.
As for example it is needed in order to invert a permutation.
I am open for feedback on this design decision.
JIT
Unfortunately, I can't quite figure out a way to write the node-wise permutation sequence generation in a way that is jitable.
I don't this it fundamentally necessary, but it would nice if we could make that possible. I am open for ideas.
The problem, I am running into is, that
jax.random.permutation
cannot bevmap
for variable sizes of the permutation.Tests
I added some test for the helper functions.
At the moment it is more an end-to-end test then a unit test, so if you have suggestions I am willing to improve there.
Also, I am using the helper function to generate random graphs to test with.
I needed a bit more control about the size of individual graphs.
With too small graphs, the probability of having a permutation that doesn't actually do anything was too big, so the test would falsely detect that the permutation didn't work.
Generally, I since my test is build on random graphs with random permutations, a little brittleness remains.
Let me know if I should switch to fixed graphs instead to eliminate the brittleness.
I hope this a contribution you would like to include.