Replies: 1 comment 2 replies
-
You cannot use I don't immediately see any symmetries in this problem that would lend themselves to an efficient vectorized operation, so probably the best option here would be to implement it in terms of A = jnp.asarray(A)
result = jax.lax.fori_loop(1, A.shape[1], lambda i, val: my_compare(val, A[:, i]), A[:, 0]) |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I want to perform reduce operation on an array in a slightly different way than
jnp.argmax
.I have an array
A
with shape(2,5)
. I wanted to reduceA
in the below way: for a pair of index(i, j)
, calculatez = A[0][i]*A[1][j]-A[0][j]*A[1][i]
, ifz>0
, then picki
otherwise pickj
. Continue this computation until getting the final index.What's the best way for me to perform this in JAX?
Below is my code on trying to implement this reduce operation, which threw error
IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0.
.Beta Was this translation helpful? Give feedback.
All reactions