Skip to content

How to implement nested for loop in JAX for high-dimensional PDE solvers #15607

Answered by jakevdp
ToshiyukiBandai asked this question in Q&A
Discussion options

You must be logged in to vote

One way to do this is via broadcasted indices; for example:

i = jnp.arange(nx)[None, :]  # shape=(1, nx)
j = jnp.arange(ny)[:, None]  # shape=(ny, 1)
k = j*nx + i
J = jnp.zeros(nx * ny).at[k].set(K_mat[i, j] * alpha[k])

Replies: 2 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@ToshiyukiBandai
Comment options

Answer selected by ToshiyukiBandai
Comment options

You must be logged in to vote
1 reply
@ToshiyukiBandai
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants