-
I ran into a strange timing issue today, and was curious if anyone had some insight. Basically, the time to run a function in JIT changes drastically when a value returned by that function is needed by the Python interpreter. Here is a good simple example: import numpy as onp
import jax.numpy as np
from jax import jit
from time import process_time as timer
A = onp.random.rand(1000,1000)
@jit
def f(A):
y = np.linalg.pinv(A)
return y[0,0]
f(A)
start1 = timer()
y = f(A)
stop1 = timer()
start = timer()
y = f(A)
print(y)
stop = timer()
print("\nTime without print: "+str(stop1-start1))
print("Time with print: "+str(stop-start)) The results on my machine are:
My thoughts are that the JIT compiler returns to the python interpreter before the function has finished, and then issues some sort of wait call when that value is needed, e.g., in this case the print call. What I'd like to do is get the actual time of just the JIT function, i.e., force the Python interpreter to wait until the JIT function is finished before running the stop = timer() without the use of another function such as the print call. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
JAX uses asynchronous dispatch in the backend, so the Python program will continue executing while the JAX computation is running. For this sort of benchmark, the best approach is to call the y = f(A).block_until_ready() When you call |
Beta Was this translation helpful? Give feedback.
JAX uses asynchronous dispatch in the backend, so the Python program will continue executing while the JAX computation is running.
For this sort of benchmark, the best approach is to call the
block_until_ready()
function, which will cause the backend to block until all results are computed:When you call
print
, it implicitly blocks until the computation is complete in order to display the result.