Using vmap() with batched lists of GraphTuples #22243
Replies: 1 comment
-
Fixed by using jnp.stack to stack my GraphsTuples! Updating in case anyone runs into a similar issue |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi all,
To preface: I'm very new to JAX, so I apologize if this is a silly question. I've looked around for help re: GraphTuples and vmap (especially #16641 ), but I haven't had much luck. I'd appreciate any help!
I'm writing a two-layer GNN that currently predicts pretty well. However, it runs somewhat slow, so I'm looking for ways to optimize it on my GPU. I decided to implement batching during training; previously, my code just trained one time-series forecast window of data GraphTuples at a time. Now, I want to group
batch_size
time-series forecast windows into some number of batches, and usejax.vmap()
to train the windows in parallel.Here's a snippet of my code. I pass in my input and target batches, then use jax.vmap() to batch them into their time-series forecast windows:
Here's a very small version of what input_batch_graphs is structured like. In here,
batch_size
= 2, and each graph contains 6 nodes, each with 2 layers:When I try to run this code, I get the following error:
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (GraphsTuple(nodes=1, edges=1, receivers=1, senders=1, globals=1, n_node=None, n_edge=None), GraphsTuple(nodes=1, edges=1, receivers=1, senders=1, globals=1, n_node=None, n_edge=None)) for value tree PyTreeDef(([CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *]), CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *])], [CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *]), CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *])])).
Is there a way to either use
in_axes
or reformat my batch data structure such that I can use jax.vmap() to batch over the windows? I'd also appreciate any other tips for batching/using vmap :)For more context, here are my jax/jraph versions:
Thanks,
Mia
Beta Was this translation helpful? Give feedback.
All reactions