How to implement nested for loop in JAX for high-dimensional PDE solvers #15607
-
Hi all, I would like to ask a question similar to #9191, which was not fully answered. I am writing a PDE solver in JAX, and it works fine but is too slow because of nested loops in it. My problem can be simplified into the code snippet below. Could anyone help me vectorize the nested for loops? I want to keep the matrix
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
One way to do this is via broadcasted indices; for example: i = jnp.arange(nx)[None, :] # shape=(1, nx)
j = jnp.arange(ny)[:, None] # shape=(ny, 1)
k = j*nx + i
J = jnp.zeros(nx * ny).at[k].set(K_mat[i, j] * alpha[k]) |
Beta Was this translation helpful? Give feedback.
-
Jake's already answered this, but you may also find this nonlinear heat equation example a useful reference point. |
Beta Was this translation helpful? Give feedback.
One way to do this is via broadcasted indices; for example: