You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am encountering an issue where the jitted version of my function keeps making copies of a large matrix instead of performing in-place updates. I understand that JAX/XLA might not convert these operations to in-place updates due to potential unsafe inputs for which in-place updates would be "wrong". However, I have ensured that my input arguments are safe for in-place updates. Despite this, I cannot force JAX to perform in-place updates.
To resolve this, I tried using Pallas to create a kernel for in-place updates. I made a minimal working example (MWE) but encountered two problems:
The code only works if the input argument dimensions are powers of 2. I don't understand why this restriction exists. Am I missing some input arguments for pallas_call? I couldn't find documentation on this issue.
While I can update o_ref in-place, I need to update the input argument directly. Is there a way to make o_ref point to the same matrix as mat_ref, similar to donate_argnums?
Here is the MWE:
importjaximportjax.numpyasjnpimportjax.experimental.pallasasplfromjaximportmake_jaxprdefinplace_kernel(mat_ref, vec_ref, idx_ref, o_ref):
_vec=vec_ref[...]
# Is it possible to do this without copying mat_ref[...] into o_ref?o_ref[:] =mat_ref[...] # Uncommenting this returns the correct result, but doesn't this copy mat_ref[...] into o_ref?pl.store(o_ref, (idx_ref[...],), _vec)
definplace(_mat, _vec, _idx):
returnpl.pallas_call(
inplace_kernel,
out_shape=jax.ShapeDtypeStruct(_mat.shape, _mat.dtype),
)(_mat, _vec, _idx)
num=4# Only works for powers of 2. Why is this a problem?mat=jnp.arange(num*num, dtype=jnp.int32).reshape((num, num))
vec=jnp.ones(num, dtype=jnp.int32)
idx=jnp.array(2%num, dtype=jnp.int32) # modulo to ensure idx validityprint(make_jaxpr(inplace)(mat, vec, idx))
print(jax.jit(inplace)(mat, vec, idx))
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I am encountering an issue where the jitted version of my function keeps making copies of a large matrix instead of performing in-place updates. I understand that JAX/XLA might not convert these operations to in-place updates due to potential unsafe inputs for which in-place updates would be "wrong". However, I have ensured that my input arguments are safe for in-place updates. Despite this, I cannot force JAX to perform in-place updates.
To resolve this, I tried using Pallas to create a kernel for in-place updates. I made a minimal working example (MWE) but encountered two problems:
pallas_call
? I couldn't find documentation on this issue.o_ref
in-place, I need to update the input argument directly. Is there a way to makeo_ref
point to the same matrix asmat_ref
, similar todonate_argnums
?Here is the MWE:
I would appreciate any guidance on these issues.
Beta Was this translation helpful? Give feedback.
All reactions