Skip to content

Please help me with loops! #8706

Answered by jakevdp
rog77 asked this question in Q&A
Nov 26, 2021 · 3 comments · 1 reply
Discussion options

You must be logged in to vote

You can pass as many arguments as you wish to the carry value if they're part of a single list or tuple. So, for example, your function could look like this:

def test_func2(nmax, mat, mat2, x, sum_res, sum_res2):
  def body_fun(n, carry):
    sum_res, sum_res2 = carry
    sum_res = sum_res + jnp.dot(mat, n * x)
    sum_res2 = sum_res2 + jnp.dot(mat2, ((n/2) * x))
    return (sum_res, sum_res2)
  sum_res, sum_res2 = jax.lax.fori_loop(0, nmax, body_fun, (sum_res, sum_res2))
  return sum_res, sum_res2, jnp.sqrt(x * nmax)

Replies: 3 comments 1 reply

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
1 reply
@rog77
Comment options

Answer selected by rog77
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants