Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Hey there!
Thanks for this repository (and the accompanying blog post), really helpful to learn more about JAX and graph neural networks!
When I ran
python train.py
, I get an error stating the following:The reason for this is that you're indexing the tensors in
train.py
using Python ranges and lists, and the authors of JAX have decided to deprecate this as can be seen here. It can be fixed by usingjnp.array(idx)
instead ofidx
(this PR does this foridx_train
,idx_val
andidx_test
).The reason I'm running this is because I would like to implement the same thing as you but using FLAX, the high-level API on top of JAX for deep learning. I have a notebook which you can run (training GCN on Cora): https://drive.google.com/file/d/1D-GwuZH19p19RjnxuDbw4GsrmM3bCyNp/view?usp=sharing
What's weird is that my initial loss is the same as yours (1.94) but after that, the loss stays 1.81 and doesn't change anymore. I'm using the same optimizer and learning rate. Would be great if you could take a look!