A jit puzzle on how to tell that the size of arrays are actually known in the context of a set diff and union problem #6901
-
Consider the following simplified example import jax.numpy as jnp
import jax
n = 16
k = 3
# A given set of k unique indices
A = jnp.array([7, 2, 5], dtype=jnp.int32)
# Another set of 2*k unique indices
B = jnp.array([4, 3, 2, 1, 9, 10], dtype=jnp.int32)
# Find k new indices in B and combine them with A
def extend_A_by_k_from_B(A, B, k):
C = jnp.setdiff1d(B, A)
D = jnp.hstack((A, C[:k]))
return D
D = extend_A_by_k_from_B(A, B, k)
print(D)
f = jax.jit(extend_A_by_k_from_B, static_argnums=(2,))
D = f(A, B, k)
print(D) In the function
While the non-JIT version works fine, I cannot figure out a way to convert this logic into a form that is compatible with JIT compilation. The error I face is:
I tried various ideas using arrays of flags with Is there a way to rewrite this logic that can be happily accepted by the JIT compiler? P.S.: This is a simplified set union step in an algorithm called Compressive Sampling Matching Pursuit. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Here's a version that might work for you, given the assumptions about the inputs def extend_A_by_k_from_B(A, B, k):
ind = jnp.argsort(jnp.in1d(B, A))[:k]
return jnp.hstack((A, B[ind])) It chooses the first |
Beta Was this translation helpful? Give feedback.
Here's a version that might work for you, given the assumptions about the inputs
It chooses the first
k
inputs ofB
that do not appear inA
, and concatenates them toA