-
I'd really appreciate it if someone could explain to me how I would do something like the above with Jax, I'm a complete newb and need to see is I can use Jax to accelerate some numpy code. I need to be able to pass an array and some vectors into a function, and loop over them an arbitrary number of times (so no unrolling loops); and, with reference to the loop step, and update a vector as i go along. Ideally, I would want to jit this function, and run it on a TPU (I am using colab). I can see Jax has loop functionality, but I can't see how to pass arrays/vectors into a loop and have it update the result as I go along (issue with immutable arrays?) The above code example isn't exactly what I am trying to do, but understanding the best way to do this type of thing would really help me. Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 1 reply
-
The way to do this in a JIT-compatible way is using Here is the equivalent of your function from above: def test_func2(nmax, mat, x, sum_res):
def body_fun(n, sum_res):
return sum_res + jnp.dot(mat, n * x)
sum_res = lax.fori_loop(0, nmax, body_fun, sum_res)
return sum_res, jnp.sqrt(x * nmax) Keep in mind that |
Beta Was this translation helpful? Give feedback.
-
Thank you for your answer @jakevdp it is much appreciated - Although, I should have been more explicit in my original question... Say I'd tried to do the following - evaluate multiple separate mat/vector dot products inside of the body, how would I go about that? In the actual program loop I'd need a few dot products, and a number of other calculations that are dependant on the value of n and some vec*vec operations, and they need to be evaluated in the same loop iteration. Running the code below (and trying to extrapolate from your example) gives the error "fori_loop() takes 4 positional arguments but 5 were given", how might I resolve this? I apologise for not being clearer in the first place.
|
Beta Was this translation helpful? Give feedback.
-
You can pass as many arguments as you wish to the carry value if they're part of a single list or tuple. So, for example, your function could look like this: def test_func2(nmax, mat, mat2, x, sum_res, sum_res2):
def body_fun(n, carry):
sum_res, sum_res2 = carry
sum_res = sum_res + jnp.dot(mat, n * x)
sum_res2 = sum_res2 + jnp.dot(mat2, ((n/2) * x))
return (sum_res, sum_res2)
sum_res, sum_res2 = jax.lax.fori_loop(0, nmax, body_fun, (sum_res, sum_res2))
return sum_res, sum_res2, jnp.sqrt(x * nmax) |
Beta Was this translation helpful? Give feedback.
You can pass as many arguments as you wish to the carry value if they're part of a single list or tuple. So, for example, your function could look like this: