Replies: 1 comment
-
Thanks for the question! The difference here is that if you use |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Dear developers,
I have a function as follows
where
t1.shape=(8,38)
andt2.shape=(8,8,38,38)
.The JIT compiled function is quite simple and uses a buffer with the size of ~1MB, which is as expected.
However, if I replace
numpy.tril_indices
withjax.numpy.tril_indices
, the complied function becomes very complicated and the buffer size is doubled (see module_0309.jit_amplitudes_to_vector.cpu_after_optimizations-buffer-assignment.txt). What is the cause of this, and should I just use numpy for indexing purposes instead of jax.numpy?Thank you in advance.
Beta Was this translation helpful? Give feedback.
All reactions